diff --git a/client.go b/client.go index 3aa5dd1d5..ed84783f8 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "io/ioutil" "os" "path/filepath" @@ -317,11 +318,43 @@ func (c *Client) get(ctx context.Context, req *Request, g Getter) (*GetResult, * return nil, &getError{true, err} } - err = copyDir(ctx, req.realDst, subDir, false, req.DisableSymlinks, req.umask()) - if err != nil { - return nil, &getError{false, err} + if stat, err := os.Stat(subDir); err != nil { + return nil, &getError{false, fmt.Errorf("failed to stat '%s': %w", subDir, err)} + } else if stat.IsDir() { + err = copyDir(ctx, req.realDst, subDir, false, req.DisableSymlinks, req.umask()) + if err != nil { + return nil, &getError{false, err} + } + return &GetResult{req.realDst}, nil + } else { + src, err := os.Open(subDir) + if err != nil { + return nil, &getError{false, fmt.Errorf("failed to open local source file at '%s': %w", subDir, err)} + } + //goland:noinspection GoUnhandledErrorResult + defer src.Close() + + target := filepath.Join(req.realDst, filepath.Base(subDir)) + dst, err := os.Create(target) + if err != nil { + return nil, &getError{false, fmt.Errorf("failed to open local target file at '%s': %w", target, err)} + } + //goland:noinspection GoUnhandledErrorResult + defer dst.Close() + + buf := make([]byte, 1024*20) // 20k buffer should usually suffice for 99% of files + for { + n, err := src.Read(buf) + if err != nil && err != io.EOF { + return nil, &getError{false, fmt.Errorf("failed to read local source file at '%s': %w", subDir, err)} + } else if n == 0 { + break + } else if _, err := dst.Write(buf[:n]); err != nil { + return nil, &getError{false, fmt.Errorf("failed to write to local source file at '%s': %w", target, err)} + } + } + return &GetResult{target}, nil } - return &GetResult{req.realDst}, nil } return &GetResult{req.Dst}, nil diff --git a/get_github_test.go b/get_github_test.go new file mode 100644 index 000000000..bb4f80087 --- /dev/null +++ b/get_github_test.go @@ -0,0 +1,76 @@ +package getter + +import ( + "context" + testing_helper "github.com/hashicorp/go-getter/v2/helper/testing" + "os" + "path/filepath" + "testing" +) + +const basicMainTFExpectedContents = `# Hello + +module "foo" { + source = "./foo" +} +` + +func TestGitGetter_githubDirWithModeAny(t *testing.T) { + if !testHasGit { + t.Skip("git not found, skipping") + } + + ctx := context.Background() + dst := testing_helper.TempDir(t) + defer os.RemoveAll(dst) + + req := &Request{ + Src: "git::https://github.com/arikkfir/go-getter.git//testdata/basic?ref=v2", + Dst: dst, + GetMode: ModeAny, + Copy: true, + } + client := Client{} + result, err := client.Get(ctx, req) + if err != nil { + t.Fatalf("Failed fetching GitHub directory: %s", err) + } else if stat, err := os.Stat(result.Dst); err != nil { + t.Fatalf("Failed stat dst at '%s': %s", result.Dst, err) + } else if !stat.IsDir() { + t.Fatalf("Expected '%s' to be a directory", result.Dst) + } else if entries, err := os.ReadDir(result.Dst); err != nil { + t.Fatalf("Failed listing directory '%s': %s", result.Dst, err) + } else if len(entries) != 3 { + t.Fatalf("Expected dir '%s' to contain 3 items: %s", result.Dst, err) + } else { + testing_helper.AssertContents(t, filepath.Join(result.Dst, "main.tf"), basicMainTFExpectedContents) + } +} + +func TestGitGetter_githubFileWithModeAny(t *testing.T) { + if !testHasGit { + t.Skip("git not found, skipping") + } + + ctx := context.Background() + dst := testing_helper.TempDir(t) + defer os.RemoveAll(dst) + + req := &Request{ + Src: "git::https://github.com/arikkfir/go-getter.git//testdata/basic/main.tf?ref=v2", + Dst: dst, + GetMode: ModeAny, + Copy: true, + } + client := Client{} + result, err := client.Get(ctx, req) + if err != nil { + t.Fatalf("Failed fetching GitHub file: %s", err) + } else if stat, err := os.Stat(result.Dst); err != nil { + t.Fatalf("Failed stat dst at '%s': %s", result.Dst, err) + } else if stat.IsDir() { + t.Fatalf("Expected '%s' to be a file", result.Dst) + } else { + testing_helper.AssertContents(t, result.Dst, basicMainTFExpectedContents) + } +}