Separate download.Get() function and add ~name parameter to git downloads

This commit is contained in:
Elara 2022-10-03 15:38:38 -07:00
parent 2b6815e287
commit c0e535c630
1 changed files with 168 additions and 137 deletions

View File

@ -23,6 +23,7 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"errors" "errors"
"hash"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@ -67,6 +68,22 @@ func Get(ctx context.Context, opts GetOptions) error {
} }
query := src.Query() query := src.Query()
if strings.HasPrefix(src.Scheme, "git+") {
err = getGit(ctx, src, query, opts)
if err != nil {
return err
}
} else {
err = getFile(ctx, src, query, opts)
if err != nil {
return err
}
}
return nil
}
func getGit(ctx context.Context, src *url.URL, query url.Values, opts GetOptions) (err error) {
tag := query.Get("~tag") tag := query.Get("~tag")
query.Del("~tag") query.Del("~tag")
@ -79,6 +96,9 @@ func Get(ctx context.Context, opts GetOptions) error {
depthStr := query.Get("~depth") depthStr := query.Get("~depth")
query.Del("~depth") query.Del("~depth")
name := query.Get("~name")
query.Del("~name")
var refName plumbing.ReferenceName var refName plumbing.ReferenceName
if tag != "" { if tag != "" {
refName = plumbing.NewTagReferenceName(tag) refName = plumbing.NewTagReferenceName(tag)
@ -86,169 +106,180 @@ func Get(ctx context.Context, opts GetOptions) error {
refName = plumbing.NewBranchReferenceName(branch) refName = plumbing.NewBranchReferenceName(branch)
} }
if strings.HasPrefix(src.Scheme, "git+") { src.Scheme = strings.TrimPrefix(src.Scheme, "git+")
src.Scheme = strings.TrimPrefix(src.Scheme, "git+") src.RawQuery = query.Encode()
src.RawQuery = query.Encode()
name := path.Base(src.Path) if name == "" {
name = path.Base(src.Path)
name = strings.TrimSuffix(name, ".git") name = strings.TrimSuffix(name, ".git")
}
dstDir := opts.Destination dstDir := opts.Destination
if opts.EncloseGit { if opts.EncloseGit {
dstDir = filepath.Join(opts.Destination, name) dstDir = filepath.Join(opts.Destination, name)
} }
depth := 0 depth := 0
if depthStr != "" { if depthStr != "" {
depth, err = strconv.Atoi(depthStr) depth, err = strconv.Atoi(depthStr)
if err != nil {
return err
}
}
cloneOpts := &git.CloneOptions{
URL: src.String(),
Progress: os.Stderr,
Depth: depth,
}
repo, err := git.PlainCloneContext(ctx, dstDir, false, cloneOpts)
if err != nil { if err != nil {
return err return err
} }
}
w, err := repo.Worktree() cloneOpts := &git.CloneOptions{
if err != nil { URL: src.String(),
return err Progress: os.Stderr,
} Depth: depth,
}
checkoutOpts := &git.CheckoutOptions{} repo, err := git.PlainCloneContext(ctx, dstDir, false, cloneOpts)
if refName != "" { if err != nil {
checkoutOpts.Branch = refName return err
} else if commit != "" { }
checkoutOpts.Hash = plumbing.NewHash(commit)
} else {
return nil
}
return w.Checkout(checkoutOpts) w, err := repo.Worktree()
if err != nil {
return err
}
checkoutOpts := &git.CheckoutOptions{}
if refName != "" {
checkoutOpts.Branch = refName
} else if commit != "" {
checkoutOpts.Hash = plumbing.NewHash(commit)
} else { } else {
name := query.Get("~name") return nil
query.Del("~name") }
archive := query.Get("~archive") return w.Checkout(checkoutOpts)
query.Del("~archive") }
src.RawQuery = query.Encode() func getFile(ctx context.Context, src *url.URL, query url.Values, opts GetOptions) error {
name := query.Get("~name")
query.Del("~name")
if name == "" { archive := query.Get("~archive")
name = path.Base(src.Path) query.Del("~archive")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, src.String(), nil) src.RawQuery = query.Encode()
if name == "" {
name = path.Base(src.Path)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, src.String(), nil)
if err != nil {
return err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
hash := sha256.New()
format, input, err := archiver.Identify(name, res.Body)
if err == archiver.ErrNoMatch || archive == "false" {
fl, err := os.Create(filepath.Join(opts.Destination, name))
if err != nil { if err != nil {
return err return err
} }
res, err := http.DefaultClient.Do(req) w := io.MultiWriter(hash, fl)
_, err = io.Copy(w, input)
if err != nil { if err != nil {
return err return err
} }
hash := sha256.New() res.Body.Close()
fl.Close()
format, input, err := archiver.Identify(name, res.Body) if opts.SHA256Sum != nil {
if err == archiver.ErrNoMatch || archive == "false" { sum := hash.Sum(nil)
fl, err := os.Create(filepath.Join(opts.Destination, name)) if !bytes.Equal(opts.SHA256Sum, sum) {
if err != nil { return ErrChecksumMismatch
return err
} }
}
w := io.MultiWriter(hash, fl) } else if err != nil {
return err
_, err = io.Copy(w, input) } else {
if err != nil { err = extractFile(ctx, input, hash, format, name, opts)
return err if err != nil {
}
res.Body.Close()
fl.Close()
if opts.SHA256Sum != nil {
sum := hash.Sum(nil)
if !bytes.Equal(opts.SHA256Sum, sum) {
return ErrChecksumMismatch
}
}
} else if err != nil {
return err return err
} else { }
r := io.TeeReader(input, hash) }
fname := format.Name()
return nil
switch format := format.(type) { }
case archiver.Extractor:
err = format.Extract(ctx, r, nil, func(ctx context.Context, f archiver.File) error { func extractFile(ctx context.Context, input io.Reader, hash hash.Hash, format archiver.Format, name string, opts GetOptions) (err error) {
fr, err := f.Open() r := io.TeeReader(input, hash)
if err != nil { fname := format.Name()
return err
} switch format := format.(type) {
defer fr.Close() case archiver.Extractor:
err = format.Extract(ctx, r, nil, func(ctx context.Context, f archiver.File) error {
path := filepath.Join(opts.Destination, f.NameInArchive) fr, err := f.Open()
if err != nil {
err = os.MkdirAll(filepath.Dir(path), 0o755) return err
if err != nil { }
return err defer fr.Close()
}
path := filepath.Join(opts.Destination, f.NameInArchive)
if f.IsDir() {
err = os.Mkdir(path, 0o755) err = os.MkdirAll(filepath.Dir(path), 0o755)
if err != nil { if err != nil {
return err return err
} }
} else {
outFl, err := os.Create(path) if f.IsDir() {
if err != nil { err = os.Mkdir(path, 0o755)
return err if err != nil {
} return err
defer outFl.Close() }
} else {
_, err = io.Copy(outFl, fr) outFl, err := os.Create(path)
return err if err != nil {
} return err
return nil }
}) defer outFl.Close()
if err != nil {
return err _, err = io.Copy(outFl, fr)
} return err
case archiver.Decompressor: }
rc, err := format.OpenReader(r) return nil
if err != nil { })
return err if err != nil {
} return err
defer rc.Close() }
case archiver.Decompressor:
path := filepath.Join(opts.Destination, name) rc, err := format.OpenReader(r)
path = strings.TrimSuffix(path, fname) if err != nil {
return err
outFl, err := os.Create(path) }
if err != nil { defer rc.Close()
return err
} path := filepath.Join(opts.Destination, name)
path = strings.TrimSuffix(path, fname)
_, err = io.Copy(outFl, rc)
if err != nil { outFl, err := os.Create(path)
return err if err != nil {
} return err
} }
if opts.SHA256Sum != nil { _, err = io.Copy(outFl, rc)
sum := hash.Sum(nil) if err != nil {
if !bytes.Equal(opts.SHA256Sum, sum) { return err
return ErrChecksumMismatch }
} }
}
if opts.SHA256Sum != nil {
sum := hash.Sum(nil)
if !bytes.Equal(opts.SHA256Sum, sum) {
return ErrChecksumMismatch
} }
} }