diff --git a/server/server.go b/server/server.go index f9698a1..cc2ae05 100644 --- a/server/server.go +++ b/server/server.go @@ -33,23 +33,37 @@ import ( // Config is the configuration for a Server. type Config struct { - // DBPath is the path to the secrets database. + // DB, if set, is used as the secrets database for the server. + // If non-nil, the DBPath and Key fields are ignored. + // If nil, then DBPath, Key, and AuditLog must all be set. + DB *db.DB + + // DBPath, if non-empty, is the path to the secrets database. + // It must be set if DB is nil. DBPath string + // Key is the AEAD used to encrypt/decrypt the database. + // It must be set if DB is nil. Key tink.AEAD + // AuditLog is the writer to use for audit logs. + // It must be set if DB is nil. AuditLog *audit.Writer + // WhoIs is a function that reports an identity for a client IP // address. Outside of tests, it will be the WhoIs of a Tailscale // LocalClient. WhoIs func(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) + // Mux is the http.ServeMux on which the server registers its HTTP - // handlers. + // handlers. It must be non-nil. Mux *http.ServeMux + // BackupBucket is an AWS S3 bucket name to which database // backups should be saved. If empty, the database is not backed // up. BackupBucket string + // BackupBucketRegion is the AWS region that the S3 bucket is in. // // You would think that one could derive this automatically given @@ -91,9 +105,13 @@ var staticFiles embed.FS // New creates a secret server and makes it ready to serve. func New(ctx context.Context, cfg Config) (*Server, error) { - db, err := db.Open(cfg.DBPath, cfg.Key, cfg.AuditLog) - if err != nil { - return nil, fmt.Errorf("opening DB: %w", err) + kdb := cfg.DB + if kdb == nil { + var err error + kdb, err = db.Open(cfg.DBPath, cfg.Key, cfg.AuditLog) + if err != nil { + return nil, fmt.Errorf("opening DB: %w", err) + } } tmpl := template.New("").Funcs(template.FuncMap{ @@ -106,7 +124,7 @@ func New(ctx context.Context, cfg Config) (*Server, error) { } ret := &Server{ - db: db, + db: kdb, whois: cfg.WhoIs, tmpl: tmpl, diff --git a/server/server_test.go b/server/server_test.go index 1419aac..0695d90 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -7,18 +7,59 @@ import ( "context" "encoding/json" "errors" + "io" + "net/http" "net/http/httptest" + "path/filepath" "testing" "github.com/tailscale/setec/acl" + "github.com/tailscale/setec/audit" "github.com/tailscale/setec/client/setec" + "github.com/tailscale/setec/db" "github.com/tailscale/setec/server" "github.com/tailscale/setec/setectest" "github.com/tailscale/setec/types/api" + "github.com/tink-crypto/tink-go/v2/testutil" "tailscale.com/client/tailscale/apitype" "tailscale.com/tailcfg" ) +func TestNew(t *testing.T) { + ctx := context.Background() + t.Run("NoDB", func(t *testing.T) { + d, err := server.New(ctx, server.Config{}) + if err == nil { + t.Errorf("New with no DB: got %+v, want error", d) + } + }) + t.Run("PathKey", func(t *testing.T) { + path := filepath.Join(t.TempDir(), "test.db") + _, err := server.New(ctx, server.Config{ + DBPath: path, + Key: &testutil.DummyAEAD{Name: t.Name()}, + AuditLog: audit.New(io.Discard), + Mux: http.NewServeMux(), + }) + if err != nil { + t.Errorf("New: unexpected error: %v", err) + } + }) + t.Run("DB", func(t *testing.T) { + path := filepath.Join(t.TempDir(), "test.db") + kdb, err := db.Open(path, &testutil.DummyAEAD{Name: t.Name()}, audit.New(io.Discard)) + if err != nil { + t.Fatalf("Open database: %v", err) + } + if _, err := server.New(ctx, server.Config{ + DB: kdb, + Mux: http.NewServeMux(), + }); err != nil { + t.Errorf("New: unexpected error: %v", err) + } + }) +} + func TestServerGetChanged(t *testing.T) { d := setectest.NewDB(t, nil) v1 := d.MustPut(d.Superuser, "test", "v1") // active diff --git a/setectest/server.go b/setectest/server.go index 9e2c4a5..01c8ff2 100644 --- a/setectest/server.go +++ b/setectest/server.go @@ -75,8 +75,7 @@ func NewServer(t *testing.T, db *DB, opts *ServerOptions) *Server { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) s, err := server.New(ctx, server.Config{ - DBPath: db.Path, - Key: db.Key, + DB: db.Actual, AuditLog: opts.auditLog(), WhoIs: opts.whoIs(), Mux: mux,