Skip to content

Commit

Permalink
server,setectest: allow creating a server with a pre-made database
Browse files Browse the repository at this point in the history
Particularly for tests, it is helpful if the database can be given to the
server at construction time, rather than read in off disk.

Add a new DB field to the server.Config. If this field is populated, it takes
precedence over the DBPath and related fields. Otherwise, the DBPath, Key, and
AuditLog fields are all required (as before) and have the same semantics.
  • Loading branch information
creachadair committed Sep 24, 2024
1 parent cd5c757 commit d6e0c84
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 8 deletions.
30 changes: 24 additions & 6 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,37 @@ import (

// Config is the configuration for a Server.
type Config struct {
// DBPath is the path to the secrets database.
// DB, if set, is used as the secrets database for the server.
// If non-nil, the DBPath and Key fields are ignored.
// If nil, then DBPath, Key, and AuditLog must all be set.
DB *db.DB

// DBPath, if non-empty, is the path to the secrets database.
// It must be set if DB is nil.
DBPath string

// Key is the AEAD used to encrypt/decrypt the database.
// It must be set if DB is nil.
Key tink.AEAD

// AuditLog is the writer to use for audit logs.
// It must be set if DB is nil.
AuditLog *audit.Writer

// WhoIs is a function that reports an identity for a client IP
// address. Outside of tests, it will be the WhoIs of a Tailscale
// LocalClient.
WhoIs func(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error)

// Mux is the http.ServeMux on which the server registers its HTTP
// handlers.
// handlers. It must be non-nil.
Mux *http.ServeMux

// BackupBucket is an AWS S3 bucket name to which database
// backups should be saved. If empty, the database is not backed
// up.
BackupBucket string

// BackupBucketRegion is the AWS region that the S3 bucket is in.
//
// You would think that one could derive this automatically given
Expand Down Expand Up @@ -91,9 +105,13 @@ var staticFiles embed.FS

// New creates a secret server and makes it ready to serve.
func New(ctx context.Context, cfg Config) (*Server, error) {
db, err := db.Open(cfg.DBPath, cfg.Key, cfg.AuditLog)
if err != nil {
return nil, fmt.Errorf("opening DB: %w", err)
kdb := cfg.DB
if kdb == nil {
var err error
kdb, err = db.Open(cfg.DBPath, cfg.Key, cfg.AuditLog)
if err != nil {
return nil, fmt.Errorf("opening DB: %w", err)
}
}

tmpl := template.New("").Funcs(template.FuncMap{
Expand All @@ -106,7 +124,7 @@ func New(ctx context.Context, cfg Config) (*Server, error) {
}

ret := &Server{
db: db,
db: kdb,
whois: cfg.WhoIs,
tmpl: tmpl,

Expand Down
41 changes: 41 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,59 @@ import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"

"github.com/tailscale/setec/acl"
"github.com/tailscale/setec/audit"
"github.com/tailscale/setec/client/setec"
"github.com/tailscale/setec/db"
"github.com/tailscale/setec/server"
"github.com/tailscale/setec/setectest"
"github.com/tailscale/setec/types/api"
"github.com/tink-crypto/tink-go/v2/testutil"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/tailcfg"
)

func TestNew(t *testing.T) {
ctx := context.Background()
t.Run("NoDB", func(t *testing.T) {
d, err := server.New(ctx, server.Config{})
if err == nil {
t.Errorf("New with no DB: got %+v, want error", d)
}
})
t.Run("PathKey", func(t *testing.T) {
path := filepath.Join(t.TempDir(), "test.db")
_, err := server.New(ctx, server.Config{
DBPath: path,
Key: &testutil.DummyAEAD{Name: t.Name()},
AuditLog: audit.New(io.Discard),
Mux: http.NewServeMux(),
})
if err != nil {
t.Errorf("New: unexpected error: %v", err)
}
})
t.Run("DB", func(t *testing.T) {
path := filepath.Join(t.TempDir(), "test.db")
kdb, err := db.Open(path, &testutil.DummyAEAD{Name: t.Name()}, audit.New(io.Discard))
if err != nil {
t.Fatalf("Open database: %v", err)
}
if _, err := server.New(ctx, server.Config{
DB: kdb,
Mux: http.NewServeMux(),
}); err != nil {
t.Errorf("New: unexpected error: %v", err)
}
})
}

func TestServerGetChanged(t *testing.T) {
d := setectest.NewDB(t, nil)
v1 := d.MustPut(d.Superuser, "test", "v1") // active
Expand Down
3 changes: 1 addition & 2 deletions setectest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ func NewServer(t *testing.T, db *DB, opts *ServerOptions) *Server {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
s, err := server.New(ctx, server.Config{
DBPath: db.Path,
Key: db.Key,
DB: db.Actual,
AuditLog: opts.auditLog(),
WhoIs: opts.whoIs(),
Mux: mux,
Expand Down

0 comments on commit d6e0c84

Please sign in to comment.