diff --git a/go.mod b/go.mod index 4f915d7..ddcf7e2 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/schollz/progressbar/v3 v3.13.0 github.com/twitchtv/twirp v8.1.3+incompatible github.com/urfave/cli/v2 v2.23.7 + github.com/vmihailenco/msgpack/v5 v5.3.5 go.arsenm.dev/logger v0.0.0-20230126004036-a8cbbe3b6fe6 go.arsenm.dev/translate v0.0.0-20230113025904-5ad1ec0ed296 golang.org/x/exp v0.0.0-20220916125017-b168a2c6b86b @@ -86,6 +87,7 @@ require ( github.com/sirupsen/logrus v1.9.0 // indirect github.com/therootcompany/xz v1.0.1 // indirect github.com/ulikunitz/xz v0.5.10 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/xanzy/ssh-agent v0.3.1 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect diff --git a/go.sum b/go.sum index 4cb18d2..f3bdebf 100644 --- a/go.sum +++ b/go.sum @@ -251,6 +251,10 @@ github.com/ulikunitz/xz v0.5.10 h1:t92gobL9l3HE202wg3rlk19F6X+JOxl9BBrCCMYEYd8= github.com/ulikunitz/xz v0.5.10/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/urfave/cli/v2 v2.23.7 h1:YHDQ46s3VghFHFf1DdF+Sh7H4RqhcM+t0TmZRJx4oJY= github.com/urfave/cli/v2 v2.23.7/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= +github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= +github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/xanzy/ssh-agent v0.3.0/go.mod h1:3s9xbODqPuuhK9JV1R321M/FlMZSBvE5aY6eAcqrDh0= github.com/xanzy/ssh-agent v0.3.1 h1:AmzO1SSWxw73zxFZPRwaMN1MohDw8UyHnmuxyceTEGo= github.com/xanzy/ssh-agent v0.3.1/go.mod h1:QIE4lCeL7nkC25x+yA3LBIYfwCc1TFziCtG7cBAac6w= diff --git a/internal/dl/dl.go b/internal/dl/dl.go index 645eb47..44a7c35 100644 --- a/internal/dl/dl.go +++ b/internal/dl/dl.go @@ -2,15 +2,22 @@ package dl import ( "context" + "errors" "io" "os" "path/filepath" + "github.com/vmihailenco/msgpack/v5" "go.arsenm.dev/logger/log" "go.arsenm.dev/lure/internal/dlcache" ) +const manifestFileName = ".lure_cache_manifest" + +var ErrChecksumMismatch = errors.New("dl: checksums did not match") + var Downloaders = []Downloader{ + GitDownloader{}, FileDownloader{}, } @@ -32,48 +39,85 @@ func (t Type) String() string { } type Options struct { - ID string - Name string - URL string - Destination string - Progress io.Writer + SHA256 []byte + Name string + URL string + Destination string + CacheDisabled bool + PostprocDisabled bool + Progress io.Writer +} + +type Manifest struct { + Type Type + Name string } type Downloader interface { Name() string - Type() Type MatchURL(string) bool - Download(Options) error + Download(Options) (Type, string, error) } type UpdatingDownloader interface { Downloader - Update(Options) error + Update(Options) (bool, error) } -func Download(ctx context.Context, opts Options) error { +func Download(ctx context.Context, opts Options) (err error) { d := getDownloader(opts.URL) - cacheDir, ok := dlcache.Get(opts.ID) + + if opts.CacheDisabled { + _, _, err = d.Download(opts) + return err + } + + var t Type + cacheDir, ok := dlcache.Get(opts.URL) if ok { - ok, err := handleCache(cacheDir, opts.Destination, d.Type()) + var updated bool + if d, ok := d.(UpdatingDownloader); ok { + log.Info("Source can be updated, updating if required").Str("source", opts.Name).Str("downloader", d.Name()).Send() + + updated, err = d.Update(Options{ + Name: opts.Name, + URL: opts.URL, + Destination: cacheDir, + Progress: opts.Progress, + }) + if err != nil { + return err + } + } + + m, err := getManifest(cacheDir) + if err != nil { + return err + } + t = m.Type + + dest := filepath.Join(opts.Destination, m.Name) + ok, err := handleCache(cacheDir, dest, t) if err != nil { return err } - if ok { - log.Info("Source found in cache, linked to destination").Str("source", opts.Name).Stringer("type", d.Type()).Send() + if ok && !updated { + log.Info("Source found in cache, linked to destination").Str("source", opts.Name).Stringer("type", t).Send() + return nil + } else if ok { return nil } } log.Info("Downloading source").Str("source", opts.Name).Str("downloader", d.Name()).Send() - cacheDir, err := dlcache.New(opts.ID) + cacheDir, err = dlcache.New(opts.URL) if err != nil { return err } - err = d.Download(Options{ + t, name, err := d.Download(Options{ Name: opts.Name, URL: opts.URL, Destination: cacheDir, @@ -83,10 +127,36 @@ func Download(ctx context.Context, opts Options) error { return err } - _, err = handleCache(cacheDir, opts.Destination, d.Type()) + err = writeManifest(cacheDir, Manifest{t, name}) + if err != nil { + return err + } + + dest := filepath.Join(opts.Destination, name) + _, err = handleCache(cacheDir, dest, t) return err } +func writeManifest(cacheDir string, m Manifest) error { + fl, err := os.Create(filepath.Join(cacheDir, manifestFileName)) + if err != nil { + return err + } + defer fl.Close() + return msgpack.NewEncoder(fl).Encode(m) +} + +func getManifest(cacheDir string) (m Manifest, err error) { + fl, err := os.Open(filepath.Join(cacheDir, manifestFileName)) + if err != nil { + return Manifest{}, err + } + defer fl.Close() + + err = msgpack.NewDecoder(fl).Decode(&m) + return +} + func handleCache(cacheDir, dest string, t Type) (bool, error) { switch t { case TypeFile: @@ -95,24 +165,28 @@ func handleCache(cacheDir, dest string, t Type) (bool, error) { return false, err } - names, err := cd.Readdirnames(1) - if err != nil && err != io.EOF { - return false, err - } - - // If the cache dir contains no files, - // assume there is no cache entry - if len(names) == 0 { + names, err := cd.Readdirnames(2) + if err == io.EOF { break + } else if err != nil { + return false, err } - err = os.Link(filepath.Join(cacheDir, names[0]), filepath.Join(dest, filepath.Base(names[0]))) - if err != nil { - return false, err + cd.Close() + + for _, name := range names { + if name == manifestFileName { + continue + } + + err = os.Link(filepath.Join(cacheDir, names[0]), filepath.Join(dest, filepath.Base(names[0]))) + if err != nil { + return false, err + } } return true, nil case TypeDir: - err := os.Link(cacheDir, dest) + err := linkDir(cacheDir, dest) if err != nil { return false, err } @@ -121,6 +195,30 @@ func handleCache(cacheDir, dest string, t Type) (bool, error) { return false, nil } +func linkDir(src, dest string) error { + return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.Name() == manifestFileName { + return nil + } + + rel, err := filepath.Rel(src, path) + if err != nil { + return err + } + + newPath := filepath.Join(dest, rel) + if info.IsDir() { + return os.Mkdir(newPath, info.Mode()) + } + + return os.Link(path, newPath) + }) +} + func getDownloader(u string) Downloader { for _, d := range Downloaders { if d.MatchURL(u) { diff --git a/internal/dl/file.go b/internal/dl/file.go index 2196157..02eb99e 100644 --- a/internal/dl/file.go +++ b/internal/dl/file.go @@ -1,14 +1,21 @@ package dl import ( + "bytes" + "context" + "crypto/sha256" "io" "net/http" "os" "path" "path/filepath" "regexp" + "strings" + "time" + "github.com/mholt/archiver/v4" "github.com/schollz/progressbar/v3" + "go.arsenm.dev/lure/internal/shutils" ) type FileDownloader struct{} @@ -25,27 +32,155 @@ func (FileDownloader) MatchURL(string) bool { return true } -func (FileDownloader) Download(opts Options) error { +func (FileDownloader) Download(opts Options) (Type, string, error) { res, err := http.Get(opts.URL) if err != nil { - return err + return 0, "", err } - defer res.Body.Close() name := getFilename(res) - fl, err := os.Create(filepath.Join(opts.Destination, name)) + path := filepath.Join(opts.Destination, name) + fl, err := os.Create(path) if err != nil { - return err + return 0, "", err + } + defer fl.Close() + + var bar io.WriteCloser + if opts.Progress != nil { + bar = progressbar.NewOptions64( + res.ContentLength, + progressbar.OptionSetDescription(name), + progressbar.OptionSetWriter(opts.Progress), + progressbar.OptionShowBytes(true), + progressbar.OptionSetWidth(10), + progressbar.OptionThrottle(65*time.Millisecond), + progressbar.OptionShowCount(), + progressbar.OptionOnCompletion(func() { + _, _ = io.WriteString(opts.Progress, "\n") + }), + progressbar.OptionSpinnerType(14), + progressbar.OptionFullWidth(), + progressbar.OptionSetRenderBlankState(true), + ) + defer bar.Close() + } else { + bar = shutils.NopRWC{} } - bar := progressbar.DefaultBytes( - res.ContentLength, - "downloading "+name, - ) - defer bar.Close() + h := sha256.New() - _, err = io.Copy(io.MultiWriter(fl, bar), res.Body) - return err + var w io.Writer + if opts.SHA256 != nil { + w = io.MultiWriter(fl, h, bar) + } else { + w = io.MultiWriter(fl, bar) + } + + _, err = io.Copy(w, res.Body) + if err != nil { + return 0, "", err + } + res.Body.Close() + + if opts.SHA256 != nil { + sum := h.Sum(nil) + if !bytes.Equal(sum, opts.SHA256) { + return 0, "", ErrChecksumMismatch + } + } + + if opts.PostprocDisabled { + return TypeFile, name, nil + } + + _, err = fl.Seek(0, io.SeekStart) + if err != nil { + return 0, "", err + } + + format, r, err := archiver.Identify(name, fl) + if err == archiver.ErrNoMatch { + return TypeFile, name, nil + } else if err != nil { + return 0, "", err + } + + err = extractFile(r, format, name, opts) + if err != nil { + return 0, "", err + } + + err = os.Remove(path) + return TypeDir, strings.TrimSuffix(name, format.Name()), err +} + +func extractFile(r io.Reader, format archiver.Format, name string, opts Options) (err error) { + fname := format.Name() + + switch format := format.(type) { + case archiver.Extractor: + err = format.Extract(context.Background(), r, nil, func(ctx context.Context, f archiver.File) error { + fr, err := f.Open() + if err != nil { + return err + } + defer fr.Close() + fi, err := f.Stat() + if err != nil { + return err + } + fm := fi.Mode() + + path := filepath.Join(opts.Destination, f.NameInArchive) + + err = os.MkdirAll(filepath.Dir(path), 0o755) + if err != nil { + return err + } + + if f.IsDir() { + err = os.Mkdir(path, 0o755) + if err != nil { + return err + } + } else { + outFl, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fm.Perm()) + if err != nil { + return err + } + defer outFl.Close() + + _, err = io.Copy(outFl, fr) + return err + } + return nil + }) + if err != nil { + return err + } + case archiver.Decompressor: + rc, err := format.OpenReader(r) + if err != nil { + return err + } + defer rc.Close() + + path := filepath.Join(opts.Destination, name) + path = strings.TrimSuffix(path, fname) + + outFl, err := os.Create(path) + if err != nil { + return err + } + + _, err = io.Copy(outFl, rc) + if err != nil { + return err + } + } + + return nil } var cdHeaderRgx = regexp.MustCompile(`filename="(.+)"`) diff --git a/internal/shutils/nop.go b/internal/shutils/nop.go index a72747a..d4822eb 100644 --- a/internal/shutils/nop.go +++ b/internal/shutils/nop.go @@ -47,7 +47,7 @@ func (NopRWC) Read([]byte) (int, error) { } func (NopRWC) Write([]byte) (int, error) { - return 0, io.EOF + return 0, nil } func (NopRWC) Close() error {