Skip to content
This repository has been archived by the owner on Mar 27, 2024. It is now read-only.

Commit

Permalink
feat: Add handling inbound problem report messages inside legacy-conn…
Browse files Browse the repository at this point in the history
…ection protocol. Add unit tests

Signed-off-by: Abdulbois <[email protected]>
  • Loading branch information
Abdulbois committed Sep 12, 2022
1 parent 9fc7486 commit 5f500ec
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pkg/didcomm/protocol/didexchange/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ func TestService_Handle_Inviter(t *testing.T) {
validateState(t, s, thid, findNamespace(AckMsgType), (&completed{}).Name())
}

func msgEventListener(t *testing.T, statusCh chan service.StateMsg, respondedFlag, completedFlag chan struct{}) {
func msgEventListener(t *testing.T, statusCh chan service.StateMsg, respondedFlag, completedFlag chan struct{}) { //nolint: lll
for e := range statusCh {
require.Equal(t, DIDExchange, e.ProtocolName)

Expand Down
11 changes: 11 additions & 0 deletions pkg/didcomm/protocol/legacyconnection/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ type legacyDoc struct {
Proof []interface{} `json:"proof,omitempty"`
}

type problemReport struct {
Type string `json:"@type,omitempty"`
ID string `json:"@id,omitempty"`
Thread *decorator.Thread `json:"~thread,omitempty"`
ProblemCode string `json:"problem-code,omitempty"`
Explain string `json:"explain,omitempty"`
Localization struct {
Locale string `json:"locale,omitempty"`
} `json:"~l10n,omitempty"`
}

// JSONBytes converts Connection to json bytes.
func (con *Connection) toLegacyJSONBytes() ([]byte, error) {
if con.DIDDoc == nil {
Expand Down
13 changes: 10 additions & 3 deletions pkg/didcomm/protocol/legacyconnection/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ const (
// ResponseMsgType defines the legacy-connection response message type.
ResponseMsgType = PIURI + "/response"
// AckMsgType defines the legacy-connection ack message type.
AckMsgType = "https://didcomm.org/notification/1.0/ack"
AckMsgType = "https://didcomm.org/notification/1.0/ack"
// ProblemReportMsgType defines the protocol problem-report message type.
ProblemReportMsgType = PIURI + "/problem-report"
routerConnsMetadataKey = "routerConnections"
)

Expand Down Expand Up @@ -295,7 +297,8 @@ func (s *Service) Accept(msgType string) bool {
return msgType == InvitationMsgType ||
msgType == RequestMsgType ||
msgType == ResponseMsgType ||
msgType == AckMsgType
msgType == AckMsgType ||
msgType == ProblemReportMsgType
}

// HandleOutbound handles outbound connection messages.
Expand All @@ -318,6 +321,10 @@ func (s *Service) nextState(msgType, thID string) (state, error) {

logger.Debugf("retrieved current state [%s] using nsThID [%s]", current.Name(), nsThID)

if msgType == ProblemReportMsgType {
return &responded{}, nil
}

next, err := stateFromMsgType(msgType)
if err != nil {
return nil, err
Expand Down Expand Up @@ -636,7 +643,7 @@ func (s *Service) connectionRecord(msg service.DIDCommMsg, ctx service.DIDCommCo
return s.requestMsgRecord(msg, ctx)
case ResponseMsgType:
return s.responseMsgRecord(msg)
case AckMsgType:
case AckMsgType, ProblemReportMsgType:
return s.fetchConnectionRecord(theirNSPrefix, msg)
}

Expand Down
154 changes: 150 additions & 4 deletions pkg/didcomm/protocol/legacyconnection/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func TestService_Handle_Inviter(t *testing.T) {
require.NoError(t, err)

completedFlag := make(chan struct{})
respondedFlag := make(chan struct{})
respondedFlag := make(chan string)

go msgEventListener(t, statusCh, respondedFlag, completedFlag)

Expand Down Expand Up @@ -247,7 +247,153 @@ func TestService_Handle_Inviter(t *testing.T) {
validateState(t, s, thid, findNamespace(AckMsgType), (&completed{}).Name())
}

func msgEventListener(t *testing.T, statusCh chan service.StateMsg, respondedFlag, completedFlag chan struct{}) {
func TestService_Handle_Inviter_With_ProblemReport(t *testing.T) {
mockStore := &mockstorage.MockStore{Store: make(map[string]mockstorage.DBEntry)}
storeProv := mockstorage.NewCustomMockStoreProvider(mockStore)
k := newKMS(t, storeProv)
prov := &protocol.MockProvider{
StoreProvider: storeProv,
ServiceMap: map[string]interface{}{
mediator.Coordination: &mockroute.MockMediatorSvc{},
},
CustomKMS: k,
KeyTypeValue: kms.ED25519Type,
KeyAgreementTypeValue: kms.X25519ECDHKWType,
}

ctx := &context{
outboundDispatcher: prov.OutboundDispatcher(),
crypto: &tinkcrypto.Crypto{},
kms: k,
keyType: kms.ED25519Type,
keyAgreementType: kms.X25519ECDHKWType,
}

_, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type)
require.NoError(t, err)

ctx.vdRegistry = &mockvdr.MockVDRegistry{CreateValue: createDIDDocWithKey(pubKey)}

connRec, err := connection.NewRecorder(prov)
require.NoError(t, err)
require.NotNil(t, connRec)

ctx.connectionRecorder = connRec

doc, err := ctx.vdRegistry.Create(testMethod, nil)
require.NoError(t, err)

s, err := New(prov)
require.NoError(t, err)

actionCh := make(chan service.DIDCommAction, 10)
err = s.RegisterActionEvent(actionCh)
require.NoError(t, err)

statusCh := make(chan service.StateMsg, 10)
err = s.RegisterMsgEvent(statusCh)
require.NoError(t, err)

completedFlag := make(chan struct{})
respondedFlag := make(chan string)

go msgEventListener(t, statusCh, respondedFlag, completedFlag)
go func() { service.AutoExecuteActionEvent(actionCh) }()

invitation := &Invitation{
Type: InvitationMsgType,
ID: randomString(),
Label: "Bob",
RecipientKeys: []string{base58.Encode(pubKey)},
ServiceEndpoint: "http://alice.agent.example.com:8081",
}

err = ctx.connectionRecorder.SaveInvitation(invitation.ID, invitation)
require.NoError(t, err)

thid := randomString()

// Invitation was previously sent by Alice to Bob.
// Bob now sends a connection Request
connRequest, err := json.Marshal(
&Request{
Type: RequestMsgType,
ID: thid,
Label: "Bob",
Thread: &decorator.Thread{
PID: invitation.ID,
},
Connection: &Connection{
DID: doc.DIDDocument.ID,
DIDDoc: doc.DIDDocument,
},
})
require.NoError(t, err)
requestMsg, err := service.ParseDIDCommMsgMap(connRequest)
require.NoError(t, err)
_, err = s.HandleInbound(requestMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil))
require.NoError(t, err)

var connID string
select {
case connID = <-respondedFlag:
case <-time.After(2 * time.Second):
require.Fail(t, "didn't receive connection ID")
}

connRecord, err := s.connectionRecorder.GetConnectionRecord(connID)
require.NoError(t, err)

// Alice automatically sends connection Response to Bob
// Bob replies with Problem Report
prbRpt, err := json.Marshal(
&problemReport{
ID: randomString(),
Type: ProblemReportMsgType,
Thread: &decorator.Thread{ID: connRecord.ThreadID},
})
require.NoError(t, err)

prbRptMsg, err := service.ParseDIDCommMsgMap(prbRpt)
require.NoError(t, err)

_, err = s.HandleInbound(prbRptMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil))
require.NoError(t, err)

validateState(t, s, thid, findNamespace(RequestMsgType), (&responded{}).Name())

_, err = ctx.connectionRecorder.GetConnectionRecord(connID)
require.ErrorContains(t, err, "data not found")

_, err = s.HandleInbound(requestMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil))
require.NoError(t, err)

// Finally Bob replies with an ACK
ack, err := json.Marshal(
&model.Ack{
Type: AckMsgType,
ID: randomString(),
Status: "OK",
Thread: &decorator.Thread{ID: connRecord.ThreadID},
})
require.NoError(t, err)

ackMsg, err := service.ParseDIDCommMsgMap(ack)
require.NoError(t, err)

_, err = s.HandleInbound(ackMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil))
require.NoError(t, err)

select {
case <-completedFlag:
case <-time.After(4 * time.Second):
require.Fail(t, "didn't receive post event complete")
}

validateState(t, s, connRecord.ThreadID, findNamespace(AckMsgType), (&completed{}).Name())
}

func msgEventListener(t *testing.T, statusCh chan service.StateMsg, respondedFlag chan string, completedFlag chan struct{}) {
for e := range statusCh {
require.Equal(t, LegacyConnection, e.ProtocolName)

Expand All @@ -272,11 +418,11 @@ func msgEventListener(t *testing.T, statusCh chan service.StateMsg, respondedFla
close(completedFlag)
}

if e.StateID == "responded" {
if e.StateID == "responded" && e.Msg.Type() != ProblemReportMsgType {
// validate connectionID received during state transition with original connectionID
require.NotNil(t, prop.ConnectionID())
require.NotNil(t, prop.InvitationID())
close(respondedFlag)
respondedFlag <- prop.ConnectionID()
}
}
}
Expand Down
9 changes: 8 additions & 1 deletion pkg/didcomm/protocol/legacyconnection/states.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func (s *responded) Name() string {
}

func (s *responded) CanTransitionTo(next state) bool {
return StateIDCompleted == next.Name()
return StateIDCompleted == next.Name() || StateIDRequested == next.Name()
}

func (s *responded) ExecuteInbound(msg *stateMachineMsg, _ string, ctx *context) (*connectionstore.Record,
Expand All @@ -226,6 +226,13 @@ func (s *responded) ExecuteInbound(msg *stateMachineMsg, _ string, ctx *context)
return connRecord, &noOp{}, action, nil
case ResponseMsgType:
return msg.connRecord, &completed{}, func() error { return nil }, nil
case ProblemReportMsgType:
err := ctx.connectionRecorder.RemoveConnection(msg.connRecord.ConnectionID)
if err != nil {
return nil, nil, nil, fmt.Errorf("delete connection record is failed: %w", err)
}

return msg.connRecord, &noOp{}, func() error { return nil }, nil
default:
return nil, nil, nil, fmt.Errorf("illegal msg type %s for state %s", msg.Type(), s.Name())
}
Expand Down
28 changes: 27 additions & 1 deletion pkg/didcomm/protocol/legacyconnection/states_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestRespondedState(t *testing.T) {
require.Equal(t, "responded", res.Name())
require.False(t, res.CanTransitionTo(&null{}))
require.False(t, res.CanTransitionTo(&invited{}))
require.False(t, res.CanTransitionTo(&requested{}))
require.True(t, res.CanTransitionTo(&requested{}))
require.False(t, res.CanTransitionTo(res))
require.True(t, res.CanTransitionTo(&completed{}))
}
Expand Down Expand Up @@ -388,6 +388,32 @@ func TestRespondedState_Execute(t *testing.T) {
require.NotNil(t, connRec)
require.Equal(t, (&completed{}).Name(), followup.Name())
})
t.Run("followup to 'noop' on inbound problem report message", func(t *testing.T) {
connRec := &connection.Record{
State: (&responded{}).Name(),
ThreadID: request.ID,
ConnectionID: "123",
Namespace: findNamespace(ResponseMsgType),
}
err = ctx.connectionRecorder.SaveConnectionRecordWithMappings(connRec)
require.NoError(t, err)

problemReportPayload, err := json.Marshal(&problemReport{Type: ProblemReportMsgType})
require.NoError(t, err)

connRec, followup, _, e := (&responded{}).ExecuteInbound(
&stateMachineMsg{
DIDCommMsg: bytesToDIDCommMsg(t, problemReportPayload),
connRecord: connRec,
}, "", ctx)
require.NoError(t, e)
require.NotNil(t, connRec)

_, e = ctx.connectionRecorder.GetConnectionRecord(connRec.ConnectionID)
require.Error(t, e)
require.ErrorContains(t, e, "data not found")
require.Equal(t, (&noOp{}).Name(), followup.Name())
})

t.Run("handle inbound request unmarshalling error", func(t *testing.T) {
_, followup, _, err := (&responded{}).ExecuteInbound(&stateMachineMsg{
Expand Down

0 comments on commit 5f500ec

Please sign in to comment.