From d6e0c84587a0ed8d6230f8968249e330aed6b8fa Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Mon, 23 Sep 2024 14:57:32 -0700 Subject: [PATCH] server,setectest: allow creating a server with a pre-made database Particularly for tests, it is helpful if the database can be given to the server at construction time, rather than read in off disk. Add a new DB field to the server.Config. If this field is populated, it takes precedence over the DBPath and related fields. Otherwise, the DBPath, Key, and AuditLog fields are all required (as before) and have the same semantics. --- server/server.go | 30 ++++++++++++++++++++++++------ server/server_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ setectest/server.go | 3 +-- 3 files changed, 66 insertions(+), 8 deletions(-) diff --git a/server/server.go b/server/server.go index f9698a1..cc2ae05 100644 --- a/server/server.go +++ b/server/server.go @@ -33,23 +33,37 @@ import ( // Config is the configuration for a Server. type Config struct { - // DBPath is the path to the secrets database. + // DB, if set, is used as the secrets database for the server. + // If non-nil, the DBPath and Key fields are ignored. + // If nil, then DBPath, Key, and AuditLog must all be set. + DB *db.DB + + // DBPath, if non-empty, is the path to the secrets database. + // It must be set if DB is nil. DBPath string + // Key is the AEAD used to encrypt/decrypt the database. + // It must be set if DB is nil. Key tink.AEAD + // AuditLog is the writer to use for audit logs. + // It must be set if DB is nil. AuditLog *audit.Writer + // WhoIs is a function that reports an identity for a client IP // address. Outside of tests, it will be the WhoIs of a Tailscale // LocalClient. WhoIs func(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) + // Mux is the http.ServeMux on which the server registers its HTTP - // handlers. + // handlers. It must be non-nil. Mux *http.ServeMux + // BackupBucket is an AWS S3 bucket name to which database // backups should be saved. If empty, the database is not backed // up. BackupBucket string + // BackupBucketRegion is the AWS region that the S3 bucket is in. // // You would think that one could derive this automatically given @@ -91,9 +105,13 @@ var staticFiles embed.FS // New creates a secret server and makes it ready to serve. func New(ctx context.Context, cfg Config) (*Server, error) { - db, err := db.Open(cfg.DBPath, cfg.Key, cfg.AuditLog) - if err != nil { - return nil, fmt.Errorf("opening DB: %w", err) + kdb := cfg.DB + if kdb == nil { + var err error + kdb, err = db.Open(cfg.DBPath, cfg.Key, cfg.AuditLog) + if err != nil { + return nil, fmt.Errorf("opening DB: %w", err) + } } tmpl := template.New("").Funcs(template.FuncMap{ @@ -106,7 +124,7 @@ func New(ctx context.Context, cfg Config) (*Server, error) { } ret := &Server{ - db: db, + db: kdb, whois: cfg.WhoIs, tmpl: tmpl, diff --git a/server/server_test.go b/server/server_test.go index 1419aac..0695d90 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -7,18 +7,59 @@ import ( "context" "encoding/json" "errors" + "io" + "net/http" "net/http/httptest" + "path/filepath" "testing" "github.com/tailscale/setec/acl" + "github.com/tailscale/setec/audit" "github.com/tailscale/setec/client/setec" + "github.com/tailscale/setec/db" "github.com/tailscale/setec/server" "github.com/tailscale/setec/setectest" "github.com/tailscale/setec/types/api" + "github.com/tink-crypto/tink-go/v2/testutil" "tailscale.com/client/tailscale/apitype" "tailscale.com/tailcfg" ) +func TestNew(t *testing.T) { + ctx := context.Background() + t.Run("NoDB", func(t *testing.T) { + d, err := server.New(ctx, server.Config{}) + if err == nil { + t.Errorf("New with no DB: got %+v, want error", d) + } + }) + t.Run("PathKey", func(t *testing.T) { + path := filepath.Join(t.TempDir(), "test.db") + _, err := server.New(ctx, server.Config{ + DBPath: path, + Key: &testutil.DummyAEAD{Name: t.Name()}, + AuditLog: audit.New(io.Discard), + Mux: http.NewServeMux(), + }) + if err != nil { + t.Errorf("New: unexpected error: %v", err) + } + }) + t.Run("DB", func(t *testing.T) { + path := filepath.Join(t.TempDir(), "test.db") + kdb, err := db.Open(path, &testutil.DummyAEAD{Name: t.Name()}, audit.New(io.Discard)) + if err != nil { + t.Fatalf("Open database: %v", err) + } + if _, err := server.New(ctx, server.Config{ + DB: kdb, + Mux: http.NewServeMux(), + }); err != nil { + t.Errorf("New: unexpected error: %v", err) + } + }) +} + func TestServerGetChanged(t *testing.T) { d := setectest.NewDB(t, nil) v1 := d.MustPut(d.Superuser, "test", "v1") // active diff --git a/setectest/server.go b/setectest/server.go index 9e2c4a5..01c8ff2 100644 --- a/setectest/server.go +++ b/setectest/server.go @@ -75,8 +75,7 @@ func NewServer(t *testing.T, db *DB, opts *ServerOptions) *Server { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) s, err := server.New(ctx, server.Config{ - DBPath: db.Path, - Key: db.Key, + DB: db.Actual, AuditLog: opts.auditLog(), WhoIs: opts.whoIs(), Mux: mux,