diff --git a/neo4j/internal/errorutil/bolt.go b/neo4j/internal/errorutil/bolt.go index c2852d3e..1ba02f57 100644 --- a/neo4j/internal/errorutil/bolt.go +++ b/neo4j/internal/errorutil/bolt.go @@ -45,6 +45,10 @@ func (crt *ConnectionReadTimeout) Error() string { crt.Err) } +func (crt *ConnectionReadTimeout) Unwrap() error { + return crt.Err +} + type ConnectionWriteTimeout struct { UserContext context.Context Err error @@ -58,6 +62,10 @@ func (cwt *ConnectionWriteTimeout) Error() string { return fmt.Sprintf("Timeout while writing to connection [user-provided context deadline: %s]: %s", userDeadline, cwt.Err) } +func (crt *ConnectionWriteTimeout) Unwrap() error { + return crt.Err +} + type ConnectionReadCanceled struct { Err error } @@ -66,6 +74,10 @@ func (crc *ConnectionReadCanceled) Error() string { return fmt.Sprintf("Reading from connection has been canceled: %s", crc.Err) } +func (crt *ConnectionReadCanceled) Unwrap() error { + return crt.Err +} + type ConnectionWriteCanceled struct { Err error } @@ -74,6 +86,10 @@ func (cwc *ConnectionWriteCanceled) Error() string { return fmt.Sprintf("Writing to connection has been canceled: %s", cwc.Err) } +func (crt *ConnectionWriteCanceled) Unwrap() error { + return crt.Err +} + type timeout interface { Timeout() bool } diff --git a/neo4j/internal/errorutil/bolt_test.go b/neo4j/internal/errorutil/bolt_test.go index 5dc2313f..12379928 100644 --- a/neo4j/internal/errorutil/bolt_test.go +++ b/neo4j/internal/errorutil/bolt_test.go @@ -19,10 +19,12 @@ package errorutil_test import ( "context" + "errors" + "testing" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" - "testing" ) func TestIsFatalDuringDiscovery(outer *testing.T) { @@ -138,3 +140,35 @@ func TestIsFatalDuringDiscovery(outer *testing.T) { }) } } + +func TestErrorSupportsUnwrap(outer *testing.T) { + type testCase struct { + description string + err error + } + + inner := errors.New("the inner error") + + testCases := []testCase{ + { + description: "ConnectionReadTimeout support Unwrap", + err: &errorutil.ConnectionReadTimeout{Err: inner}, + }, { + description: "ConnectionWriteTimeout support Unwrap", + err: &errorutil.ConnectionWriteTimeout{Err: inner}, + }, { + description: "ConnectionReadCanceled support Unwrap", + err: &errorutil.ConnectionReadCanceled{Err: inner}, + }, { + description: "ConnectionWriteCanceled support Unwrap", + err: &errorutil.ConnectionWriteCanceled{Err: inner}, + }, + } + + for _, testCase := range testCases { + outer.Run(testCase.description, func(t *testing.T) { + + testutil.AssertBoolEqual(t, errors.Is(testCase.err, inner), true) + }) + } +}