Skip to content
This repository has been archived by the owner on Mar 27, 2024. It is now read-only.

refactor: Add handling inbound problem report messages inside legacy-connection protocol #3369

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pkg/didcomm/protocol/legacyconnection/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ type legacyDoc struct {
Proof []interface{} `json:"proof,omitempty"`
}

type problemReport struct {
Type string `json:"@type,omitempty"`
ID string `json:"@id,omitempty"`
Thread *decorator.Thread `json:"~thread,omitempty"`
ProblemCode string `json:"problem-code,omitempty"`
Explain string `json:"explain,omitempty"`
Localization struct {
Locale string `json:"locale,omitempty"`
} `json:"~l10n,omitempty"`
}

// JSONBytes converts Connection to json bytes.
func (con *Connection) toLegacyJSONBytes() ([]byte, error) {
if con.DIDDoc == nil {
Expand Down
13 changes: 10 additions & 3 deletions pkg/didcomm/protocol/legacyconnection/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ const (
// ResponseMsgType defines the legacy-connection response message type.
ResponseMsgType = PIURI + "/response"
// AckMsgType defines the legacy-connection ack message type.
AckMsgType = "https://didcomm.org/notification/1.0/ack"
AckMsgType = "https://didcomm.org/notification/1.0/ack"
// ProblemReportMsgType defines the protocol problem-report message type.
ProblemReportMsgType = PIURI + "/problem-report"
routerConnsMetadataKey = "routerConnections"
)

Expand Down Expand Up @@ -295,7 +297,8 @@ func (s *Service) Accept(msgType string) bool {
return msgType == InvitationMsgType ||
msgType == RequestMsgType ||
msgType == ResponseMsgType ||
msgType == AckMsgType
msgType == AckMsgType ||
msgType == ProblemReportMsgType
}

// HandleOutbound handles outbound connection messages.
Expand All @@ -318,6 +321,10 @@ func (s *Service) nextState(msgType, thID string) (state, error) {

logger.Debugf("retrieved current state [%s] using nsThID [%s]", current.Name(), nsThID)

if msgType == ProblemReportMsgType {
return &responded{}, nil
}

next, err := stateFromMsgType(msgType)
if err != nil {
return nil, err
Expand Down Expand Up @@ -636,7 +643,7 @@ func (s *Service) connectionRecord(msg service.DIDCommMsg, ctx service.DIDCommCo
return s.requestMsgRecord(msg, ctx)
case ResponseMsgType:
return s.responseMsgRecord(msg)
case AckMsgType:
case AckMsgType, ProblemReportMsgType:
return s.fetchConnectionRecord(theirNSPrefix, msg)
}

Expand Down
154 changes: 150 additions & 4 deletions pkg/didcomm/protocol/legacyconnection/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func TestService_Handle_Inviter(t *testing.T) {
require.NoError(t, err)

completedFlag := make(chan struct{})
respondedFlag := make(chan struct{})
respondedFlag := make(chan string)

go msgEventListener(t, statusCh, respondedFlag, completedFlag)

Expand Down Expand Up @@ -247,7 +247,153 @@ func TestService_Handle_Inviter(t *testing.T) {
validateState(t, s, thid, findNamespace(AckMsgType), (&completed{}).Name())
}

func msgEventListener(t *testing.T, statusCh chan service.StateMsg, respondedFlag, completedFlag chan struct{}) {
func TestService_Handle_Inviter_With_ProblemReport(t *testing.T) {
mockStore := &mockstorage.MockStore{Store: make(map[string]mockstorage.DBEntry)}
storeProv := mockstorage.NewCustomMockStoreProvider(mockStore)
k := newKMS(t, storeProv)
prov := &protocol.MockProvider{
StoreProvider: storeProv,
ServiceMap: map[string]interface{}{
mediator.Coordination: &mockroute.MockMediatorSvc{},
},
CustomKMS: k,
KeyTypeValue: kms.ED25519Type,
KeyAgreementTypeValue: kms.X25519ECDHKWType,
}

ctx := &context{
outboundDispatcher: prov.OutboundDispatcher(),
crypto: &tinkcrypto.Crypto{},
kms: k,
keyType: kms.ED25519Type,
keyAgreementType: kms.X25519ECDHKWType,
}

_, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type)
require.NoError(t, err)

ctx.vdRegistry = &mockvdr.MockVDRegistry{CreateValue: createDIDDocWithKey(pubKey)}

connRec, err := connection.NewRecorder(prov)
require.NoError(t, err)
require.NotNil(t, connRec)

ctx.connectionRecorder = connRec

doc, err := ctx.vdRegistry.Create(testMethod, nil)
require.NoError(t, err)

s, err := New(prov)
require.NoError(t, err)

actionCh := make(chan service.DIDCommAction, 10)
err = s.RegisterActionEvent(actionCh)
require.NoError(t, err)

statusCh := make(chan service.StateMsg, 10)
err = s.RegisterMsgEvent(statusCh)
require.NoError(t, err)

completedFlag := make(chan struct{})
respondedFlag := make(chan string)

go msgEventListener(t, statusCh, respondedFlag, completedFlag)
go func() { service.AutoExecuteActionEvent(actionCh) }()

invitation := &Invitation{
Type: InvitationMsgType,
ID: randomString(),
Label: "Bob",
RecipientKeys: []string{base58.Encode(pubKey)},
ServiceEndpoint: "http://alice.agent.example.com:8081",
}

err = ctx.connectionRecorder.SaveInvitation(invitation.ID, invitation)
require.NoError(t, err)

thid := randomString()

// Invitation was previously sent by Alice to Bob.
// Bob now sends a connection Request
connRequest, err := json.Marshal(
&Request{
Type: RequestMsgType,
ID: thid,
Label: "Bob",
Thread: &decorator.Thread{
PID: invitation.ID,
},
Connection: &Connection{
DID: doc.DIDDocument.ID,
DIDDoc: doc.DIDDocument,
},
})
require.NoError(t, err)
requestMsg, err := service.ParseDIDCommMsgMap(connRequest)
require.NoError(t, err)
_, err = s.HandleInbound(requestMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil))
require.NoError(t, err)

var connID string
select {
case connID = <-respondedFlag:
case <-time.After(2 * time.Second):
require.Fail(t, "didn't receive connection ID")
}

connRecord, err := s.connectionRecorder.GetConnectionRecord(connID)
require.NoError(t, err)

// Alice automatically sends connection Response to Bob
// Bob replies with Problem Report
prbRpt, err := json.Marshal(
&problemReport{
ID: randomString(),
Type: ProblemReportMsgType,
Thread: &decorator.Thread{ID: connRecord.ThreadID},
})
require.NoError(t, err)

prbRptMsg, err := service.ParseDIDCommMsgMap(prbRpt)
require.NoError(t, err)

_, err = s.HandleInbound(prbRptMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil))
require.NoError(t, err)

validateState(t, s, thid, findNamespace(RequestMsgType), (&responded{}).Name())

_, err = ctx.connectionRecorder.GetConnectionRecord(connID)
require.ErrorContains(t, err, "data not found")

_, err = s.HandleInbound(requestMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil))
require.NoError(t, err)

// Finally Bob replies with an ACK
ack, err := json.Marshal(
&model.Ack{
Type: AckMsgType,
ID: randomString(),
Status: "OK",
Thread: &decorator.Thread{ID: connRecord.ThreadID},
})
require.NoError(t, err)

ackMsg, err := service.ParseDIDCommMsgMap(ack)
require.NoError(t, err)

_, err = s.HandleInbound(ackMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil))
require.NoError(t, err)

select {
case <-completedFlag:
case <-time.After(4 * time.Second):
require.Fail(t, "didn't receive post event complete")
}

validateState(t, s, connRecord.ThreadID, findNamespace(AckMsgType), (&completed{}).Name())
}

func msgEventListener(t *testing.T, statusCh chan service.StateMsg, respondedFlag chan string, completedFlag chan struct{}) { //nolint: lll
for e := range statusCh {
require.Equal(t, LegacyConnection, e.ProtocolName)

Expand All @@ -272,11 +418,11 @@ func msgEventListener(t *testing.T, statusCh chan service.StateMsg, respondedFla
close(completedFlag)
}

if e.StateID == "responded" {
if e.StateID == "responded" && e.Msg.Type() != ProblemReportMsgType {
// validate connectionID received during state transition with original connectionID
require.NotNil(t, prop.ConnectionID())
require.NotNil(t, prop.InvitationID())
close(respondedFlag)
respondedFlag <- prop.ConnectionID()
}
}
}
Expand Down
9 changes: 8 additions & 1 deletion pkg/didcomm/protocol/legacyconnection/states.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func (s *responded) Name() string {
}

func (s *responded) CanTransitionTo(next state) bool {
return StateIDCompleted == next.Name()
return StateIDCompleted == next.Name() || StateIDRequested == next.Name()
}

func (s *responded) ExecuteInbound(msg *stateMachineMsg, _ string, ctx *context) (*connectionstore.Record,
Expand All @@ -226,6 +226,13 @@ func (s *responded) ExecuteInbound(msg *stateMachineMsg, _ string, ctx *context)
return connRecord, &noOp{}, action, nil
case ResponseMsgType:
return msg.connRecord, &completed{}, func() error { return nil }, nil
case ProblemReportMsgType:
err := ctx.connectionRecorder.RemoveConnection(msg.connRecord.ConnectionID)
if err != nil {
return nil, nil, nil, fmt.Errorf("delete connection record is failed: %w", err)
}

return msg.connRecord, &noOp{}, func() error { return nil }, nil
default:
return nil, nil, nil, fmt.Errorf("illegal msg type %s for state %s", msg.Type(), s.Name())
}
Expand Down
28 changes: 27 additions & 1 deletion pkg/didcomm/protocol/legacyconnection/states_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestRespondedState(t *testing.T) {
require.Equal(t, "responded", res.Name())
require.False(t, res.CanTransitionTo(&null{}))
require.False(t, res.CanTransitionTo(&invited{}))
require.False(t, res.CanTransitionTo(&requested{}))
require.True(t, res.CanTransitionTo(&requested{}))
require.False(t, res.CanTransitionTo(res))
require.True(t, res.CanTransitionTo(&completed{}))
}
Expand Down Expand Up @@ -388,6 +388,32 @@ func TestRespondedState_Execute(t *testing.T) {
require.NotNil(t, connRec)
require.Equal(t, (&completed{}).Name(), followup.Name())
})
t.Run("followup to 'noop' on inbound problem report message", func(t *testing.T) {
connRec := &connection.Record{
State: (&responded{}).Name(),
ThreadID: request.ID,
ConnectionID: "123",
Namespace: findNamespace(ResponseMsgType),
}
err = ctx.connectionRecorder.SaveConnectionRecordWithMappings(connRec)
require.NoError(t, err)

problemReportPayload, err := json.Marshal(&problemReport{Type: ProblemReportMsgType})
require.NoError(t, err)

connRec, followup, _, e := (&responded{}).ExecuteInbound(
&stateMachineMsg{
DIDCommMsg: bytesToDIDCommMsg(t, problemReportPayload),
connRecord: connRec,
}, "", ctx)
require.NoError(t, e)
require.NotNil(t, connRec)

_, e = ctx.connectionRecorder.GetConnectionRecord(connRec.ConnectionID)
require.Error(t, e)
require.ErrorContains(t, e, "data not found")
require.Equal(t, (&noOp{}).Name(), followup.Name())
})

t.Run("handle inbound request unmarshalling error", func(t *testing.T) {
_, followup, _, err := (&responded{}).ExecuteInbound(&stateMachineMsg{
Expand Down