diff --git a/ci/all.sh b/ci/all.sh index efd56b61..1ee7640f 100755 --- a/ci/all.sh +++ b/ci/all.sh @@ -6,7 +6,7 @@ main() { ./ci/fmt.sh ./ci/lint.sh - ./ci/test.sh + ./ci/test.sh "$@" } main "$@" diff --git a/write.go b/write.go index 60a4fba0..2210cf81 100644 --- a/write.go +++ b/write.go @@ -246,13 +246,24 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco if err != nil { return 0, err } - defer func() { - // We leave it locked when writing the close frame to avoid - // any other goroutine writing any other frame. - if opcode != opClose { - c.writeFrameMu.unlock() + defer c.writeFrameMu.unlock() + + // If the state says a close has already been written, we wait until + // the connection is closed and return that error. + // + // However, if the frame being written is a close, that means its the close from + // the state being set so we let it go through. + c.closeMu.Lock() + wroteClose := c.wroteClose + c.closeMu.Unlock() + if wroteClose && opcode != opClose { + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-c.closed: + return 0, c.closeErr } - }() + } select { case <-c.closed: