Skip to content

Commit 7ef02f8

Browse files
authored
Database interface refactor (#544)
Signed-off-by: Micah Hausler <[email protected]> ## Description This includes 4 changes that are preparatory for future work to migrate to the Kubernetes data model. * Propagate stream context in streaming API. Previously `context.Background()` was used, but the stream provides a context object * Refactored shadowed "context" package name. By naming a method variable "context", no `context` package calls could be made in the method * Added `context.Context` to `GetWorkflowsForWorker()` database API. This plumbs down the context from the API call into the `d.instance.QueryContext()` call. * Refactor the database interface. I added a new interface `WorkerWorkflow` with the methods that get used by APIs the Tink Worker invokes. This is essentially a no-op for now. ## Why is this needed See tinkerbell/proposals#46 ## How Has This Been Tested? Locally ran tests. ## How are existing users impacted? What migration steps/scripts do we need? No impact to existing users ## Checklist: I have: - [ ] updated the documentation and/or roadmap (if required) - [ ] added unit or e2e tests - [ ] provided instructions on how to upgrade
2 parents 3743d31 + d932a47 commit 7ef02f8

File tree

6 files changed

+49
-44
lines changed

6 files changed

+49
-44
lines changed

db/db.go

+10-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ type Database interface {
2222
hardware
2323
template
2424
workflow
25+
WorkerWorkflow
2526
}
2627

2728
type hardware interface {
@@ -43,20 +44,24 @@ type template interface {
4344

4445
type workflow interface {
4546
CreateWorkflow(ctx context.Context, wf Workflow, data string, id uuid.UUID) error
46-
InsertIntoWfDataTable(ctx context.Context, req *pb.UpdateWorkflowDataRequest) error
47-
GetfromWfDataTable(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error)
4847
GetWorkflowMetadata(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error)
4948
GetWorkflowDataVersion(ctx context.Context, workflowID string) (int32, error)
50-
GetWorkflowsForWorker(id string) ([]string, error)
5149
GetWorkflow(ctx context.Context, id string) (Workflow, error)
5250
DeleteWorkflow(ctx context.Context, id string, state int32) error
5351
ListWorkflows(fn func(wf Workflow) error) error
5452
UpdateWorkflow(ctx context.Context, wf Workflow, state int32) error
53+
InsertIntoWorkflowEventTable(ctx context.Context, wfEvent *pb.WorkflowActionStatus, time time.Time) error
54+
ShowWorkflowEvents(wfID string, fn func(wfs *pb.WorkflowActionStatus) error) error
55+
}
56+
57+
// WorkerWorkflow is an interface for methods invoked by APIs that the worker calls
58+
type WorkerWorkflow interface {
59+
InsertIntoWfDataTable(ctx context.Context, req *pb.UpdateWorkflowDataRequest) error
60+
GetfromWfDataTable(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error)
61+
GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error)
5562
UpdateWorkflowState(ctx context.Context, wfContext *pb.WorkflowContext) error
5663
GetWorkflowContexts(ctx context.Context, wfID string) (*pb.WorkflowContext, error)
5764
GetWorkflowActions(ctx context.Context, wfID string) (*pb.WorkflowActionList, error)
58-
InsertIntoWorkflowEventTable(ctx context.Context, wfEvent *pb.WorkflowActionStatus, time time.Time) error
59-
ShowWorkflowEvents(wfID string, fn func(wfs *pb.WorkflowActionStatus) error) error
6065
}
6166

6267
// TinkDB implements the Database interface

db/mock/mock.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ type DB struct {
1919
InsertIntoWfDataTableFunc func(ctx context.Context, req *pb.UpdateWorkflowDataRequest) error
2020
GetWorkflowMetadataFunc func(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error)
2121
GetWorkflowDataVersionFunc func(ctx context.Context, workflowID string) (int32, error)
22-
GetWorkflowsForWorkerFunc func(id string) ([]string, error)
22+
GetWorkflowsForWorkerFunc func(ctx context.Context, id string) ([]string, error)
2323
GetWorkflowContextsFunc func(ctx context.Context, wfID string) (*pb.WorkflowContext, error)
2424
GetWorkflowActionsFunc func(ctx context.Context, wfID string) (*pb.WorkflowActionList, error)
2525
UpdateWorkflowStateFunc func(ctx context.Context, wfContext *pb.WorkflowContext) error

db/mock/workflow.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ func (d DB) GetWorkflowDataVersion(ctx context.Context, workflowID string) (int3
3535
}
3636

3737
// GetWorkflowsForWorker : returns the list of workflows for a particular worker
38-
func (d DB) GetWorkflowsForWorker(id string) ([]string, error) {
39-
return d.GetWorkflowsForWorkerFunc(id)
38+
func (d DB) GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error) {
39+
return d.GetWorkflowsForWorkerFunc(ctx, id)
4040
}
4141

4242
// GetWorkflow returns a workflow

db/workflow.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@ func (d TinkDB) GetWorkflowDataVersion(ctx context.Context, workflowID string) (
304304
}
305305

306306
// GetWorkflowsForWorker : returns the list of workflows for a particular worker
307-
func (d TinkDB) GetWorkflowsForWorker(id string) ([]string, error) {
308-
rows, err := d.instance.Query(`
307+
func (d TinkDB) GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error) {
308+
rows, err := d.instance.QueryContext(ctx, `
309309
SELECT workflow_id
310310
FROM workflow_worker_map
311311
WHERE

grpc-server/tinkerbell.go

+27-27
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,16 @@ const (
3030

3131
// GetWorkflowContexts implements tinkerbell.GetWorkflowContexts
3232
func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.WorkflowService_GetWorkflowContextsServer) error {
33-
wfs, err := getWorkflowsForWorker(s.db, req.WorkerId)
33+
wfs, err := getWorkflowsForWorker(stream.Context(), s.db, req.WorkerId)
3434
if err != nil {
3535
return err
3636
}
3737
for _, wf := range wfs {
38-
wfContext, err := s.db.GetWorkflowContexts(context.Background(), wf)
38+
wfContext, err := s.db.GetWorkflowContexts(stream.Context(), wf)
3939
if err != nil {
4040
return status.Errorf(codes.Aborted, err.Error())
4141
}
42-
if isApplicableToSend(context.Background(), s.logger, wfContext, req.WorkerId, s.db) {
42+
if isApplicableToSend(stream.Context(), s.logger, wfContext, req.WorkerId, s.db) {
4343
if err := stream.Send(wfContext); err != nil {
4444
return err
4545
}
@@ -49,16 +49,16 @@ func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.W
4949
}
5050

5151
// GetWorkflowContextList implements tinkerbell.GetWorkflowContextList
52-
func (s *server) GetWorkflowContextList(context context.Context, req *pb.WorkflowContextRequest) (*pb.WorkflowContextList, error) {
53-
wfs, err := getWorkflowsForWorker(s.db, req.WorkerId)
52+
func (s *server) GetWorkflowContextList(ctx context.Context, req *pb.WorkflowContextRequest) (*pb.WorkflowContextList, error) {
53+
wfs, err := getWorkflowsForWorker(ctx, s.db, req.WorkerId)
5454
if err != nil {
5555
return nil, err
5656
}
5757

5858
if wfs != nil {
5959
wfContexts := []*pb.WorkflowContext{}
6060
for _, wf := range wfs {
61-
wfContext, err := s.db.GetWorkflowContexts(context, wf)
61+
wfContext, err := s.db.GetWorkflowContexts(ctx, wf)
6262
if err != nil {
6363
return nil, status.Errorf(codes.Aborted, err.Error())
6464
}
@@ -72,16 +72,16 @@ func (s *server) GetWorkflowContextList(context context.Context, req *pb.Workflo
7272
}
7373

7474
// GetWorkflowActions implements tinkerbell.GetWorkflowActions
75-
func (s *server) GetWorkflowActions(context context.Context, req *pb.WorkflowActionsRequest) (*pb.WorkflowActionList, error) {
75+
func (s *server) GetWorkflowActions(ctx context.Context, req *pb.WorkflowActionsRequest) (*pb.WorkflowActionList, error) {
7676
wfID := req.GetWorkflowId()
7777
if wfID == "" {
7878
return nil, status.Errorf(codes.InvalidArgument, errInvalidWorkflowId)
7979
}
80-
return getWorkflowActions(context, s.db, wfID)
80+
return getWorkflowActions(ctx, s.db, wfID)
8181
}
8282

8383
// ReportActionStatus implements tinkerbell.ReportActionStatus
84-
func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowActionStatus) (*pb.Empty, error) {
84+
func (s *server) ReportActionStatus(ctx context.Context, req *pb.WorkflowActionStatus) (*pb.Empty, error) {
8585
wfID := req.GetWorkflowId()
8686
if wfID == "" {
8787
return nil, status.Errorf(codes.InvalidArgument, errInvalidWorkflowId)
@@ -96,11 +96,11 @@ func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowAct
9696
l := s.logger.With("actionName", req.GetActionName(), "workflowID", req.GetWorkflowId())
9797
l.Info(fmt.Sprintf(msgReceivedStatus, req.GetActionStatus()))
9898

99-
wfContext, err := s.db.GetWorkflowContexts(context, wfID)
99+
wfContext, err := s.db.GetWorkflowContexts(ctx, wfID)
100100
if err != nil {
101101
return nil, status.Errorf(codes.Aborted, err.Error())
102102
}
103-
wfActions, err := s.db.GetWorkflowActions(context, wfID)
103+
wfActions, err := s.db.GetWorkflowActions(ctx, wfID)
104104
if err != nil {
105105
return nil, status.Errorf(codes.Aborted, err.Error())
106106
}
@@ -123,14 +123,14 @@ func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowAct
123123
wfContext.CurrentAction = req.GetActionName()
124124
wfContext.CurrentActionState = req.GetActionStatus()
125125
wfContext.CurrentActionIndex = actionIndex
126-
err = s.db.UpdateWorkflowState(context, wfContext)
126+
err = s.db.UpdateWorkflowState(ctx, wfContext)
127127
if err != nil {
128128
return &pb.Empty{}, status.Errorf(codes.Aborted, err.Error())
129129
}
130130

131131
// TODO the below "time" would be a part of the request which is coming form worker.
132132
time := time.Now()
133-
err = s.db.InsertIntoWorkflowEventTable(context, req, time)
133+
err = s.db.InsertIntoWorkflowEventTable(ctx, req, time)
134134
if err != nil {
135135
return &pb.Empty{}, status.Error(codes.Aborted, err.Error())
136136
}
@@ -149,7 +149,7 @@ func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowAct
149149
}
150150

151151
// UpdateWorkflowData updates workflow ephemeral data
152-
func (s *server) UpdateWorkflowData(context context.Context, req *pb.UpdateWorkflowDataRequest) (*pb.Empty, error) {
152+
func (s *server) UpdateWorkflowData(ctx context.Context, req *pb.UpdateWorkflowDataRequest) (*pb.Empty, error) {
153153
wfID := req.GetWorkflowId()
154154
if wfID == "" {
155155
return &pb.Empty{}, status.Errorf(codes.InvalidArgument, errInvalidWorkflowId)
@@ -158,57 +158,57 @@ func (s *server) UpdateWorkflowData(context context.Context, req *pb.UpdateWorkf
158158
if !ok {
159159
workflowData[wfID] = 1
160160
}
161-
err := s.db.InsertIntoWfDataTable(context, req)
161+
err := s.db.InsertIntoWfDataTable(ctx, req)
162162
if err != nil {
163163
return &pb.Empty{}, status.Errorf(codes.Aborted, err.Error())
164164
}
165165
return &pb.Empty{}, nil
166166
}
167167

168168
// GetWorkflowData gets the ephemeral data for a workflow
169-
func (s *server) GetWorkflowData(context context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) {
169+
func (s *server) GetWorkflowData(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) {
170170
wfID := req.GetWorkflowId()
171171
if wfID == "" {
172172
return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.InvalidArgument, errInvalidWorkflowId)
173173
}
174-
data, err := s.db.GetfromWfDataTable(context, req)
174+
data, err := s.db.GetfromWfDataTable(ctx, req)
175175
if err != nil {
176176
return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.Aborted, err.Error())
177177
}
178178
return &pb.GetWorkflowDataResponse{Data: data}, nil
179179
}
180180

181181
// GetWorkflowMetadata returns metadata wrt to the ephemeral data of a workflow
182-
func (s *server) GetWorkflowMetadata(context context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) {
183-
data, err := s.db.GetWorkflowMetadata(context, req)
182+
func (s *server) GetWorkflowMetadata(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) {
183+
data, err := s.db.GetWorkflowMetadata(ctx, req)
184184
if err != nil {
185185
return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.Aborted, err.Error())
186186
}
187187
return &pb.GetWorkflowDataResponse{Data: data}, nil
188188
}
189189

190190
// GetWorkflowDataVersion returns the latest version of data for a workflow
191-
func (s *server) GetWorkflowDataVersion(context context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) {
192-
version, err := s.db.GetWorkflowDataVersion(context, req.WorkflowId)
191+
func (s *server) GetWorkflowDataVersion(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) {
192+
version, err := s.db.GetWorkflowDataVersion(ctx, req.WorkflowId)
193193
if err != nil {
194194
return &pb.GetWorkflowDataResponse{Version: version}, status.Errorf(codes.Aborted, err.Error())
195195
}
196196
return &pb.GetWorkflowDataResponse{Version: version}, nil
197197
}
198198

199-
func getWorkflowsForWorker(db db.Database, id string) ([]string, error) {
199+
func getWorkflowsForWorker(ctx context.Context, db db.Database, id string) ([]string, error) {
200200
if id == "" {
201201
return nil, status.Errorf(codes.InvalidArgument, errInvalidWorkerID)
202202
}
203-
wfs, err := db.GetWorkflowsForWorker(id)
203+
wfs, err := db.GetWorkflowsForWorker(ctx, id)
204204
if err != nil {
205205
return nil, status.Errorf(codes.Aborted, err.Error())
206206
}
207207
return wfs, nil
208208
}
209209

210-
func getWorkflowActions(context context.Context, db db.Database, wfID string) (*pb.WorkflowActionList, error) {
211-
actions, err := db.GetWorkflowActions(context, wfID)
210+
func getWorkflowActions(ctx context.Context, db db.Database, wfID string) (*pb.WorkflowActionList, error) {
211+
actions, err := db.GetWorkflowActions(ctx, wfID)
212212
if err != nil {
213213
return nil, status.Errorf(codes.Aborted, errInvalidWorkflowId)
214214
}
@@ -217,12 +217,12 @@ func getWorkflowActions(context context.Context, db db.Database, wfID string) (*
217217

218218
// isApplicableToSend checks if a particular workflow context is applicable or if it is needed to
219219
// be sent to a worker based on the state of the current action and the targeted workerID
220-
func isApplicableToSend(context context.Context, logger log.Logger, wfContext *pb.WorkflowContext, workerID string, db db.Database) bool {
220+
func isApplicableToSend(ctx context.Context, logger log.Logger, wfContext *pb.WorkflowContext, workerID string, db db.Database) bool {
221221
if wfContext.GetCurrentActionState() == pb.State_STATE_FAILED ||
222222
wfContext.GetCurrentActionState() == pb.State_STATE_TIMEOUT {
223223
return false
224224
}
225-
actions, err := getWorkflowActions(context, db, wfContext.GetWorkflowId())
225+
actions, err := getWorkflowActions(ctx, db, wfContext.GetWorkflowId())
226226
if err != nil {
227227
return false
228228
}

grpc-server/tinkerbell_test.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func TestGetWorkflowContextList(t *testing.T) {
6767
"database failure": {
6868
args: args{
6969
db: &mock.DB{
70-
GetWorkflowsForWorkerFunc: func(id string) ([]string, error) {
70+
GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) {
7171
return []string{workflowID}, nil
7272
},
7373
GetWorkflowContextsFunc: func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) {
@@ -83,7 +83,7 @@ func TestGetWorkflowContextList(t *testing.T) {
8383
"no workflows found": {
8484
args: args{
8585
db: &mock.DB{
86-
GetWorkflowsForWorkerFunc: func(id string) ([]string, error) {
86+
GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) {
8787
return nil, nil
8888
},
8989
GetWorkflowContextsFunc: func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) {
@@ -99,7 +99,7 @@ func TestGetWorkflowContextList(t *testing.T) {
9999
"workflows found": {
100100
args: args{
101101
db: &mock.DB{
102-
GetWorkflowsForWorkerFunc: func(id string) ([]string, error) {
102+
GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) {
103103
return []string{workflowID}, nil
104104
},
105105
GetWorkflowContextsFunc: func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) {
@@ -758,7 +758,7 @@ func TestGetWorkflowsForWorker(t *testing.T) {
758758
"database failure": {
759759
args: args{
760760
db: &mock.DB{
761-
GetWorkflowsForWorkerFunc: func(id string) ([]string, error) {
761+
GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) {
762762
return nil, errors.New("database failed")
763763
},
764764
},
@@ -771,7 +771,7 @@ func TestGetWorkflowsForWorker(t *testing.T) {
771771
"no workflows found": {
772772
args: args{
773773
db: &mock.DB{
774-
GetWorkflowsForWorkerFunc: func(id string) ([]string, error) {
774+
GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) {
775775
return nil, nil
776776
},
777777
},
@@ -784,7 +784,7 @@ func TestGetWorkflowsForWorker(t *testing.T) {
784784
"workflows found": {
785785
args: args{
786786
db: &mock.DB{
787-
GetWorkflowsForWorkerFunc: func(id string) ([]string, error) {
787+
GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) {
788788
return []string{workflowID}, nil
789789
},
790790
},
@@ -799,7 +799,7 @@ func TestGetWorkflowsForWorker(t *testing.T) {
799799
for name, tc := range testCases {
800800
t.Run(name, func(t *testing.T) {
801801
s := testServer(t, tc.args.db)
802-
res, err := getWorkflowsForWorker(s.db, tc.args.workerID)
802+
res, err := getWorkflowsForWorker(context.Background(), s.db, tc.args.workerID)
803803
if err != nil {
804804
assert.True(t, tc.want.expectedError)
805805
assert.Error(t, err)

0 commit comments

Comments
 (0)