diff --git a/pkgerrors.go b/pkgerrors.go index 785a12b..bd1624e 100644 --- a/pkgerrors.go +++ b/pkgerrors.go @@ -1,6 +1,7 @@ package fail import ( + "reflect" "strings" pkgerrors "github.com/pkg/errors" @@ -79,8 +80,15 @@ func extractPkgError(err error) *pkgError { break } - if len(stackTraces) == 0 && rootErr == err { - return nil + if len(stackTraces) == 0 { + ret, et := reflect.TypeOf(rootErr), reflect.TypeOf(err) + if ret != nil && et != nil && ret.Comparable() && et.Comparable() { + if rootErr == err { + return nil + } + } else { + return nil + } } // Extract annotated messages by removing the trailing message. diff --git a/pkgerrors_test.go b/pkgerrors_test.go index 41d41ce..e9dda20 100644 --- a/pkgerrors_test.go +++ b/pkgerrors_test.go @@ -2,6 +2,7 @@ package fail import ( "errors" + "strings" "testing" pkgerrors "github.com/pkg/errors" @@ -20,6 +21,13 @@ func TestExtractPkgError(t *testing.T) { assert.Nil(t, pkgErr) }) + t.Run("slice error", func(t *testing.T) { + err := errorSlice{errors.New("error")} + + pkgErr := extractPkgError(err) + assert.Nil(t, pkgErr) + }) + t.Run("pkg/errors.New", func(t *testing.T) { err := pkgErrorsNew("message") @@ -84,6 +92,16 @@ func TestExtractPkgError(t *testing.T) { assert.NotEmpty(t, pkgErr.StackTrace) assert.Equal(t, "pkgErrorsWrap", pkgErr.StackTrace[0].Func) }) + + t.Run("with slice error", func(t *testing.T) { + err0 := errorSlice{errors.New("error")} + err1 := pkgErrorsWrap(err0, "message") + + pkgErr := extractPkgError(err1) + assert.NotNil(t, pkgErr) + assert.Equal(t, err0, pkgErr.Err) + assert.NotEmpty(t, pkgErr.StackTrace) + }) }) t.Run("pkg/errors.WithMessage", func(t *testing.T) { @@ -156,3 +174,13 @@ func pkgErrorsNew(msg string) error { func pkgErrorsWrap(err error, msg string) error { return pkgerrors.Wrap(err, msg) } + +type errorSlice []error + +func (s errorSlice) Error() string { + msg := make([]string, len(s)) + for i, e := range s { + msg[i] = e.Error() + } + return strings.Join(msg, ": ") +}