diff --git a/go/go.mod b/go/go.mod index be770b0d3..eea5ef032 100644 --- a/go/go.mod +++ b/go/go.mod @@ -24,6 +24,7 @@ require ( go.opentelemetry.io/otel/sdk/metric v1.26.0 go.opentelemetry.io/otel/trace v1.26.0 golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 + golang.org/x/tools v0.23.0 google.golang.org/api v0.188.0 google.golang.org/protobuf v1.34.2 gopkg.in/yaml.v2 v2.4.0 diff --git a/go/go.sum b/go/go.sum index 826d38989..1f845f631 100644 --- a/go/go.sum +++ b/go/go.sum @@ -223,6 +223,10 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.23.0 h1:SGsXPZ+2l4JsgaCKkx+FQ9YZ5XEtA1GZYuoDjenLjvg= +golang.org/x/tools v0.23.0/go.mod h1:pnu6ufv6vQkll6szChhK3C3L/ruaIv5eBeztNG8wtsI= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.188.0 h1:51y8fJ/b1AaaBRJr4yWm96fPcuxSo0JcegXE3DaHQHw= google.golang.org/api v0.188.0/go.mod h1:VR0d+2SIiWOYG3r/jdm7adPW9hI2aRv9ETOSCQ9Beag= diff --git a/go/internal/cmd/copy/copy.go b/go/internal/cmd/copy/copy.go new file mode 100644 index 000000000..f6192a320 --- /dev/null +++ b/go/internal/cmd/copy/copy.go @@ -0,0 +1,392 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// copy is a tool for copying parts of files. +// It reads a set of source files, collecting named sequences of lines to copy +// to a destination file. +// It then reads the destination files, replacing the named sections there, called sinks, +// with sections of the same name from the source files. +// +// Files involved in the copy contain comments of the form +// +// //copy:COMMAND ARGS... +// +// There can't be spaces between the "//", the word "copy:", and the command name. +// +// Commands that appear in source files are: +// +// start FILENAME NAME +// Start copying to the sink NAME in FILENAME. +// stop +// Stop copying. +// +// Commands that appear in destination files are: +// +// sink NAME +// Where to start copying lines marked with NAME. +// endsink NAME +// Where to end a replacement. Inserted by the tool. +package main + +import ( + "bytes" + "errors" + "flag" + "fmt" + "log" + "os" + "path/filepath" + "slices" + "strings" + + "golang.org/x/exp/maps" +) + +var ( + destDir = flag.String("dest", "", "destination directory") +) + +func main() { + flag.Parse() + log.SetFlags(0) + log.SetPrefix("copy: ") + var chunks []*chunk + for _, sourceFilename := range flag.Args() { + cs, err := parseSourceFile(sourceFilename) + if err != nil { + log.Fatalf("%s: %v", sourceFilename, err) + } + chunks = append(chunks, cs...) + } + if err := writeChunks(chunks); err != nil { + log.Fatal(err) + } +} + +type command struct { + name string + file string + sink string +} + +// A chunk is a sequence of bytes to be copied to a sink in a file. +type chunk struct { + srcFile string + destFile string + sink string + data []byte +} + +// parseSourceFile parses the named file into chunks. +func parseSourceFile(filename string) ([]*chunk, error) { + data, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + chunks, err := parseSource(data) + if err != nil { + return nil, err + } + dir := *destDir + if dir == "" { + dir = filepath.Dir(filename) + } + if err := setChunkFilenames(chunks, filename, dir); err != nil { + return nil, err + } + return chunks, nil +} + +// parseSource parses the contents of a source file into chunks. +func parseSource(src []byte) ([]*chunk, error) { + lines := bytes.SplitAfter(src, []byte("\n")) + var chunks []*chunk + var curChunk *chunk + for ln, line := range lines { + cmd, err := parseCommand(line) + if err != nil { + return nil, fmt.Errorf("%d: %w", ln, err) + } + if cmd == nil { + if curChunk != nil { + curChunk.data = append(curChunk.data, line...) + } + continue + } + switch cmd.name { + case "start": + if curChunk != nil { + return nil, fmt.Errorf("%d: start without preceding stop", ln) + } + curChunk = &chunk{ + destFile: cmd.file, + sink: cmd.sink, + } + case "stop": + if curChunk == nil { + return nil, fmt.Errorf("%d: stop without preceding start", ln) + } + chunks = append(chunks, curChunk) + curChunk = nil + default: + return nil, fmt.Errorf("%d: unexpected copy command %q in source file", ln, cmd.name) + } + } + if curChunk != nil { + return nil, errors.New("missing stop at end of file") + } + return chunks, nil +} + +// setChunkFilenames fills out the filename parts of a chunk with information about +// the source and destination. +func setChunkFilenames(chunks []*chunk, srcFile, destDir string) error { + var err error + for _, c := range chunks { + c.destFile = filepath.Join(destDir, c.destFile) + c.srcFile, err = relativePathTo(srcFile, c.destFile) + if err != nil { + return err + } + } + return nil +} + +var prefix = []byte("//copy:") + +// parseCommand parses a copy command. +// If the line does not contain a command, it returns (ni, nil). +func parseCommand(line []byte) (*command, error) { + s := bytes.TrimSpace(line) + after, found := bytes.CutPrefix(s, prefix) + if !found { + return nil, nil + } + fields := strings.Fields(string(after)) + if len(fields) == 0 { + return nil, errors.New("empty command") + } + + checkArgs := func(want int) error { + if got := len(fields) - 1; got != want { + return fmt.Errorf("command %q should have %d args, not %d", after, want, got) + } + return nil + } + + d := &command{name: fields[0]} + switch fields[0] { + default: + return nil, fmt.Errorf("unknown command %q", fields[0]) + case "start": + if err := checkArgs(2); err != nil { + return nil, err + } + d.file = fields[1] + d.sink = fields[2] + case "stop": + if err := checkArgs(0); err != nil { + return nil, err + } + case "sink": + // sink may have "from src1, src2, ..." after its name. + if len(fields)-1 < 1 { + return nil, fmt.Errorf("command %q should have at least one arg", after) + } + d.sink = fields[1] + case "endsink": + if err := checkArgs(1); err != nil { + return nil, err + } + d.sink = fields[1] + } + return d, nil +} + +// writeChunks writes the chunks to the destination files. +func writeChunks(chunks []*chunk) error { + // Collect chunks by destination file. + byFile := map[string][]*chunk{} + for _, c := range chunks { + byFile[c.destFile] = append(byFile[c.destFile], c) + } + + for file, cs := range byFile { + if err := writeChunksToFile(file, cs); err != nil { + return fmt.Errorf("%s: %w", file, err) + } + } + return nil +} + +func writeChunksToFile(file string, chunks []*chunk) (err error) { + // Parse the destination file into pieces. + pieces, err := parseDestFile(file) + if err != nil { + return err + } + if err := insertChunksIntoPieces(pieces, chunks); err != nil { + return err + } + data := concatPieces(pieces) + return os.WriteFile(file, data, 0644) +} + +// A piece is a contiguous section of a destination file. +// It is either literal data (sink == "") or a named sink. +type piece struct { + srcFiles map[string]bool + sink string + data []byte +} + +// parseDestFile parses a destination file into pieces. +func parseDestFile(file string) ([]*piece, error) { + data, err := os.ReadFile(file) + if err != nil { + return nil, err + } + return parseDest(data) +} + +// parseDest parses the contents of a destination file into pieces. +func parseDest(data []byte) ([]*piece, error) { + var pieces []*piece + lines := bytes.SplitAfter(data, []byte("\n")) + cur := &piece{} + for ln, line := range lines { + cmd, err := parseCommand(line) + if err != nil { + return nil, fmt.Errorf("%d: %w", ln, err) + } + if cmd == nil { + cur.data = append(cur.data, line...) + continue + } + switch cmd.name { + case "sink": + // If the current piece is a sink, then we know now that it is a new sink + // and its contents is empty; everything we've seen since the previous sink command + // is part of a literal piece. + if cur.sink != "" { + pieces = append(pieces, &piece{sink: cur.sink}) + cur.sink = "" + } + pieces = append(pieces, cur) + cur = &piece{sink: cmd.sink} + + case "endsink": + if cur.sink == "" { + return nil, fmt.Errorf("%d: endsink command without preceding sink", ln) + } + if cur.sink != cmd.sink { + return nil, fmt.Errorf("%d: sink name %q does not match endsink name %q", ln, cur.sink, cmd.sink) + } + pieces = append(pieces, cur) + cur = &piece{} + + default: + return nil, fmt.Errorf("%d: unexpected copy command %q in source file", ln, cmd.name) + } + } + // Same situation as in case "sink" above. + if cur.sink != "" { + pieces = append(pieces, &piece{sink: cur.sink}) + cur.sink = "" + } + return append(pieces, cur), nil +} + +// insertChunksIntoPieces inserts the chunks into sink pieces of the same name. +// It returns an error if there is a chunk with no matching sink. +func insertChunksIntoPieces(pieces []*piece, chunks []*chunk) error { + // Group chunks by sink name. + bySink := map[string][]*chunk{} + for _, c := range chunks { + bySink[c.sink] = append(bySink[c.sink], c) + } + // For each piece with a corresponding chunk, replace the piece's contents. + used := map[string]bool{} + for _, p := range pieces { + if p.sink == "" { + continue + } + cs := bySink[p.sink] + if len(cs) == 0 { + continue + } + p.data = nil + for _, c := range cs { + p.data = append(p.data, c.data...) + if p.srcFiles == nil { + p.srcFiles = map[string]bool{} + } + p.srcFiles[c.srcFile] = true + } + used[p.sink] = true + } + + // Fail if a sink with chunks wasn't used. + for sink := range bySink { + if !used[sink] { + return fmt.Errorf("sink %q unused", sink) + } + } + return nil +} + +// concatPieces concatenates the contents of the pieces together, inserting +// markers for sinks. +func concatPieces(pieces []*piece) []byte { + var buf bytes.Buffer + for _, p := range pieces { + if p.sink != "" { + srcFiles := maps.Keys(p.srcFiles) + slices.Sort(srcFiles) + fmt.Fprintf(&buf, "//copy:sink %s from %s\n", p.sink, strings.Join(srcFiles, ", ")) + fmt.Fprintf(&buf, "// DO NOT MODIFY below vvvv\n") + } + buf.Write(p.data) + if p.sink != "" { + fmt.Fprintf(&buf, "// DO NOT MODIFY above ^^^^\n") + fmt.Fprintf(&buf, "//copy:endsink %s\n", p.sink) + } + } + return buf.Bytes() +} + +// relativePathTo returns a path to src that is relative to dest. +// For example if src is d1/src.go and dest is d2/dest.go, then +// relativePathTo returns ../d1/src.go. +func relativePathTo(src, dest string) (string, error) { + asrc, err := filepath.Abs(src) + if err != nil { + return "", err + } + adest, err := filepath.Abs(dest) + if err != nil { + return "", err + } + sep := string([]byte{filepath.Separator}) + ddir := filepath.Dir(adest) + nups := 0 + for ddir != "." && ddir != sep { + if strings.HasPrefix(asrc, ddir+sep) { + break + } + ddir = filepath.Dir(ddir) + nups++ + } + return strings.Repeat(".."+sep, nups) + strings.TrimPrefix(asrc, ddir+sep), nil +} diff --git a/go/internal/cmd/copy/copy_test.go b/go/internal/cmd/copy/copy_test.go new file mode 100644 index 000000000..24cabaa0f --- /dev/null +++ b/go/internal/cmd/copy/copy_test.go @@ -0,0 +1,142 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/tools/txtar" +) + +func TestParseCommand(t *testing.T) { + for _, test := range []struct { + line string + want *command + wantErr bool + }{ + {"", nil, false}, + {"x", nil, false}, + {"//coopy:x", nil, false}, + {"// copy:x", nil, false}, + {"//copy:start", nil, true}, + {"//copy:start file sink", &command{"start", "file", "sink"}, false}, + {"//copy:stop", &command{name: "stop"}, false}, + {"//copy:sink foo", &command{name: "sink", sink: "foo"}, false}, + {"//copy:endsink bar", &command{name: "endsink", sink: "bar"}, false}, + } { + got, err := parseCommand([]byte(test.line)) + if err != nil { + if !test.wantErr { + t.Fatalf("%q: got error %q", test.line, err) + } + continue + } + if !cmp.Equal(got, test.want, cmp.AllowUnexported(command{})) { + t.Errorf("%q:\ngot %+v\nwant %+v", test.line, got, test.want) + } + + } +} + +func TestFull(t *testing.T) { + files, err := filepath.Glob(filepath.Join("testdata", "*.txt")) + if err != nil { + t.Fatal(err) + } + + doit := func(t *testing.T, dest, src []byte) []byte { + chunks, err := parseSource(src) + if err != nil { + t.Fatal(err) + } + if err := setChunkFilenames(chunks, "source", ""); err != nil { + t.Fatal(err) + } + pieces, err := parseDest(dest) + if err != nil { + t.Fatal(err) + } + if err := insertChunksIntoPieces(pieces, chunks); err != nil { + t.Fatal(err) + } + return concatPieces(pieces) + } + + for _, file := range files { + t.Run(strings.TrimPrefix(filepath.Base(file), ".txt"), func(t *testing.T) { + ar, err := txtar.ParseFile(file) + if err != nil { + t.Fatal(err) + } + var source, dest, want txtar.File + for _, f := range ar.Files { + switch f.Name { + case "source": + source = f + case "dest": + dest = f + case "want": + want = f + default: + t.Fatal("unknown txtar filename") + } + } + got := doit(t, dest.Data, source.Data) + if diff := cmp.Diff(want.Data, got); diff != "" { + t.Errorf("mismatch (-want, +got)\n%s", diff) + } + + // Running it on the output should produce the same result. + got = doit(t, got, source.Data) + if diff := cmp.Diff(want.Data, got); diff != "" { + t.Errorf("second time: mismatch (-want, +got)\n%s", diff) + } + }) + } +} + +func TestUnused(t *testing.T) { + chunks := []*chunk{{sink: "S"}} + pieces := []*piece{{sink: "T"}} + err := insertChunksIntoPieces(pieces, chunks) + if err == nil { + t.Fatal("want error") + } +} + +func TestRelativePathTo(t *testing.T) { + for _, test := range []struct { + p1, p2 string + want string + }{ + {"a", "b", "a"}, + {"d/a", "d/b", "a"}, + {"d1/a", "b", "d1/a"}, + {"a", "d1/b", "../a"}, + {"d1/a", "d2/b", "../d1/a"}, + {"g.go", "../vertexai/v.go", "../copy/g.go"}, + } { + got, err := relativePathTo(test.p1, test.p2) + if err != nil { + t.Fatal(err) + } + if got != test.want { + t.Errorf("relativePathTo(%q, %q) = %q, want %q", test.p1, test.p2, got, test.want) + } + } +} diff --git a/go/internal/cmd/copy/testdata/multiple.txt b/go/internal/cmd/copy/testdata/multiple.txt new file mode 100644 index 000000000..a27825e96 --- /dev/null +++ b/go/internal/cmd/copy/testdata/multiple.txt @@ -0,0 +1,45 @@ +In this test, there are more than one sink with the same name in the +destination, and more than one source with that name in the source file. +-- source -- +//copy:start dest one +a +b +c +//copy:stop +d +//copy:start dest two +e +//copy:stop +//copy:start dest one +f +g +//copy:stop +h +-- dest -- +//copy:sink one +//copy:sink two +//copy:sink one +-- want -- +//copy:sink one from source +// DO NOT MODIFY below vvvv +a +b +c +f +g +// DO NOT MODIFY above ^^^^ +//copy:endsink one +//copy:sink two from source +// DO NOT MODIFY below vvvv +e +// DO NOT MODIFY above ^^^^ +//copy:endsink two +//copy:sink one from source +// DO NOT MODIFY below vvvv +a +b +c +f +g +// DO NOT MODIFY above ^^^^ +//copy:endsink one diff --git a/go/internal/cmd/copy/testdata/simple.txt b/go/internal/cmd/copy/testdata/simple.txt new file mode 100644 index 000000000..42adb3a10 --- /dev/null +++ b/go/internal/cmd/copy/testdata/simple.txt @@ -0,0 +1,31 @@ +-- source -- +first +second +//copy:start dest foo +third +fourth +//copy:stop +fifth +//copy:start dest bar +sixth +//copy:stop +seventh +-- dest -- +line1 +//copy:sink bar +line2 +//copy:sink foo +-- want -- +line1 +//copy:sink bar from source +// DO NOT MODIFY below vvvv +sixth +// DO NOT MODIFY above ^^^^ +//copy:endsink bar +line2 +//copy:sink foo from source +// DO NOT MODIFY below vvvv +third +fourth +// DO NOT MODIFY above ^^^^ +//copy:endsink foo diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 444a71c13..3dbcc59b7 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Parts of this file are copied into vertexai, because the code is identical +// except for the import path of the Gemini SDK. +//go:generate go run ../../internal/cmd/copy -dest ../vertexai googleai.go + package googleai import ( @@ -345,6 +349,8 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb return r, nil } +//copy:start vertexai.go translateCandidate + // translateCandidate translates from a genai.GenerateContentResponse to an ai.GenerateResponse. func translateCandidate(cand *genai.Candidate) *ai.Candidate { c := &ai.Candidate{} @@ -386,6 +392,10 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate { return c } +//copy:stop + +//copy:start vertexai.go translateResponse + // Translate from a genai.GenerateContentResponse to a ai.GenerateResponse. func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse { r := &ai.GenerateResponse{} @@ -401,6 +411,10 @@ func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse return r } +//copy:stop + +//copy:start vertexai.go convertParts + // convertParts converts a slice of *ai.Part to a slice of genai.Part. func convertParts(parts []*ai.Part) ([]genai.Part, error) { res := make([]genai.Part, 0, len(parts)) @@ -426,7 +440,7 @@ func convertPart(p *ai.Part) (genai.Part, error) { } return genai.Blob{MIMEType: contentType, Data: data}, nil case p.IsData(): - panic("googleai does not support Data parts") + panic(fmt.Sprintf("%s does not support Data parts", provider)) case p.IsToolResponse(): toolResp := p.ToolResponse fr := genai.FunctionResponse{ @@ -445,3 +459,5 @@ func convertPart(p *ai.Part) (genai.Part, error) { panic("unknown part type in a request") } } + +//copy:stop diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 0ca6ad6d4..93a2988d5 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -358,6 +358,9 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb return r, nil } +//copy:sink translateCandidate from ../googleai/googleai.go +// DO NOT MODIFY below vvvv + // translateCandidate translates from a genai.GenerateContentResponse to an ai.GenerateResponse. func translateCandidate(cand *genai.Candidate) *ai.Candidate { c := &ai.Candidate{} @@ -399,6 +402,12 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate { return c } +// DO NOT MODIFY above ^^^^ +//copy:endsink translateCandidate + +//copy:sink translateResponse from ../googleai/googleai.go +// DO NOT MODIFY below vvvv + // Translate from a genai.GenerateContentResponse to a ai.GenerateResponse. func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse { r := &ai.GenerateResponse{} @@ -414,6 +423,12 @@ func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse return r } +// DO NOT MODIFY above ^^^^ +//copy:endsink translateResponse + +//copy:sink convertParts from ../googleai/googleai.go +// DO NOT MODIFY below vvvv + // convertParts converts a slice of *ai.Part to a slice of genai.Part. func convertParts(parts []*ai.Part) ([]genai.Part, error) { res := make([]genai.Part, 0, len(parts)) @@ -439,7 +454,7 @@ func convertPart(p *ai.Part) (genai.Part, error) { } return genai.Blob{MIMEType: contentType, Data: data}, nil case p.IsData(): - panic("vertexai does not support Data parts") + panic(fmt.Sprintf("%s does not support Data parts", provider)) case p.IsToolResponse(): toolResp := p.ToolResponse fr := genai.FunctionResponse{ @@ -458,3 +473,6 @@ func convertPart(p *ai.Part) (genai.Part, error) { panic("unknown part type in a request") } } + +// DO NOT MODIFY above ^^^^ +//copy:endsink convertParts