Skip to content

Commit

Permalink
fix: release semaphore on panic
Browse files Browse the repository at this point in the history
  • Loading branch information
costela committed Oct 24, 2023
1 parent 78e4bfe commit 7be9ede
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 26 deletions.
3 changes: 2 additions & 1 deletion breaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ import (
"time"
)

// untypedCircuit is used to avoid type annotations when implementing a breaker.
// untypedCircuit is used to avoid type annotations when implementing breakers.
type untypedCircuit interface {
stateForCall() State
setOpenedAt(int64)
}

// observer is used to observe the result of a single wrapped call through the circuit breaker.
type observer interface {
// observe is called after the wrapped function returns. If [Circuit.Do] returns a non-nil [Observable], it will be
// called exactly once.
Expand Down
7 changes: 4 additions & 3 deletions hoglet.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ type options struct {
// limited (~1) amount of calls are allowed that - if successful - may re-close the breaker.
halfOpenDelay time.Duration

// observerForCall is a function that returns an observer for the next call.
// Usually, this is implemented by the breaker, but it can be overridden for testing purposes.
observerForCall observerFactory
}

// observerFactory is a function that returns one observer for each call going through the circuit.
// It is used analogously to a http.Handler, allowing different plugins to "wrap" each execution.
type observerFactory func(context.Context) (observer, error)

// Breaker is the interface implemented by the different breakers, responsible for actually opening the circuit.
// Each implementation behaves differently when deciding whether to open the breaker upon failure.
type Breaker interface {
Expand All @@ -48,8 +51,6 @@ type Breaker interface {
observerForCall(context.Context) (observer, error)
}

type observerFactory func(context.Context) (observer, error)

// BreakableFunc is the type of the function wrapped by a Breaker.
type BreakableFunc[IN, OUT any] func(context.Context, IN) (OUT, error)

Expand Down
2 changes: 1 addition & 1 deletion limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ func newLimiter(origFactory observerFactory, limit int64, block bool) observerFa
return nil, err
}
return observableCall(func(b bool) {
defer sem.Release(1)
o.observe(b)
sem.Release(1)
}), nil
}

Expand Down
114 changes: 93 additions & 21 deletions limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,124 @@ import (
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type mockPanickingObservable struct{}

func (mo *mockPanickingObservable) observe(shouldPanic bool) {
// abuse the observer interface to signal a panic
if shouldPanic {
panic("mockObservable meant to panic")
}
}

func Test_newLimiter(t *testing.T) {
orig := func(context.Context) (observer, error) {
return &mockObservable{}, nil
orig := func() observerFactory {
return func(context.Context) (observer, error) {
return &mockPanickingObservable{}, nil
}
}

type args struct {
limit int64
block bool
}
tests := []struct {
name string
args args
calls int
cancel bool
wantErr error
name string
args args
calls int
cancel bool
wantPanicOn *int // which call to panic on (if at all)
wantErr error
}{
{"under limit", args{limit: 1, block: false}, 0, false, nil},
{"over limit; non-blocking", args{limit: 1, block: false}, 2, false, ErrConcurrencyLimitReached},
{"over limit; blocking", args{limit: 1, block: true}, 2, false, ErrWaitingForSlot},
{
name: "under limit",
args: args{limit: 1, block: false},
calls: 0,
wantErr: nil,
},
{
name: "over limit; non-blocking",
args: args{limit: 1, block: false},
calls: 1,
wantErr: ErrConcurrencyLimitReached,
},
{
name: "on limit; blocking",
args: args{limit: 1, block: true},
calls: 1,
cancel: true, // cancel simulates a timeout in this case
wantErr: ErrWaitingForSlot,
},
{
name: "cancelation releases with error",
args: args{limit: 1, block: true},
calls: 1,
cancel: true,
wantErr: context.Canceled,
},
{
name: "panic releases",
args: args{limit: 1, block: true},
calls: 1,
cancel: false,
wantPanicOn: ptr(0),
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
ctxCalls, cancelCalls := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancelCalls()

wg := &sync.WaitGroup{}
wgStart := &sync.WaitGroup{}
wgStop := &sync.WaitGroup{}
defer wgStop.Wait()

of := newLimiter(orig, tt.args.limit, tt.args.block)
of := newLimiter(orig(), tt.args.limit, tt.args.block)
for i := 0; i < tt.calls; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = of(ctx)
}()
wantPanic := tt.wantPanicOn != nil && *tt.wantPanicOn == i

f := func() {
defer wgStop.Done()
o, err := of(ctxCalls)
wgStart.Done()
require.NoError(t, err)

<-ctxCalls.Done()

o.observe(wantPanic)
}

wgStart.Add(1)
wgStop.Add(1)
if wantPanic {
go assert.Panics(t, f)
} else {
go f()
}
}

wg.Wait()
ctx, cancel := context.WithCancel(context.Background())

if tt.cancel {
cancel()
} else {
defer cancel()
}

_, err := of(ctx)
wgStart.Wait() // ensure all calls are started

o, err := of(ctx)
assert.ErrorIs(t, err, tt.wantErr)
if tt.wantErr == nil {
assert.NotNil(t, o)
}
})
}
}

func ptr[T any](in T) *T {
return &in
}

0 comments on commit 7be9ede

Please sign in to comment.