From c381928c9f8077c6b61efa8a297264e3e9c88ade Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sun, 22 Sep 2019 23:28:41 -0500 Subject: [PATCH] Implement more of the API for WASM Realized I can at least make the Reader/Writer/SetReadLimit methods work as expected even if they're not perfect. --- conn.go | 22 +++--------- netconn.go => conn_common.go | 15 +++++++- doc.go | 12 ++++++- websocket_js.go | 66 +++++++++++++++++++++++++++++++++-- wsjson/wsjson.go | 6 +--- wsjson/wsjson_js.go | 58 ------------------------------- wspb/wspb.go | 6 +--- wspb/wspb_js.go | 67 ------------------------------------ 8 files changed, 95 insertions(+), 157 deletions(-) rename netconn.go => conn_common.go (91%) delete mode 100644 wsjson/wsjson_js.go delete mode 100644 wspb/wspb_js.go diff --git a/conn.go b/conn.go index e12e1443..20dbece2 100644 --- a/conn.go +++ b/conn.go @@ -59,7 +59,7 @@ type Conn struct { msgReadLimit int64 // Used to ensure a previous writer is not used after being closed. - activeWriter *messageWriter + activeWriter atomic.Value // messageWriter state. writeMsgOpcode opcode writeMsgCtx context.Context @@ -526,16 +526,6 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { return n, err } -// SetReadLimit sets the max number of bytes to read for a single message. -// It applies to the Reader and Read methods. -// -// By default, the connection has a message read limit of 32768 bytes. -// -// When the limit is hit, the connection will be closed with StatusMessageTooBig. -func (c *Conn) SetReadLimit(n int64) { - c.msgReadLimit = n -} - // Read is a convenience method to read a single message from the connection. // // See the Reader method if you want to be able to reuse buffers or want to stream a message. @@ -575,7 +565,7 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err w := &messageWriter{ c: c, } - c.activeWriter = w + c.activeWriter.Store(w) return w, nil } @@ -607,7 +597,7 @@ type messageWriter struct { } func (w *messageWriter) closed() bool { - return w != w.c.activeWriter + return w != w.c.activeWriter.Load() } // Write writes the given bytes to the WebSocket connection. @@ -645,7 +635,7 @@ func (w *messageWriter) close() error { if w.closed() { return fmt.Errorf("cannot use closed writer") } - w.c.activeWriter = nil + w.c.activeWriter.Store((*messageWriter)(nil)) _, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil) if err != nil { @@ -925,7 +915,3 @@ func (c *Conn) extractBufioWriterBuf(w io.Writer) { c.bw.Reset(w) } - -func (c *netConn) netConnReader(ctx context.Context) (MessageType, io.Reader, error) { - return c.c.Reader(c.readContext) -} diff --git a/netconn.go b/conn_common.go similarity index 91% rename from netconn.go rename to conn_common.go index c5c0e17b..771db26b 100644 --- a/netconn.go +++ b/conn_common.go @@ -1,3 +1,6 @@ +// This file contains *Conn symbols relevant to both +// WASM and non WASM builds. + package websocket import ( @@ -99,7 +102,7 @@ func (c *netConn) Read(p []byte) (int, error) { } if c.reader == nil { - typ, r, err := c.netConnReader(c.readContext) + typ, r, err := c.c.Reader(c.readContext) if err != nil { var ce CloseError if errors.As(err, &ce) && (ce.Code == StatusNormalClosure) || (ce.Code == StatusGoingAway) { @@ -189,3 +192,13 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { }() return ctx } + +// SetReadLimit sets the max number of bytes to read for a single message. +// It applies to the Reader and Read methods. +// +// By default, the connection has a message read limit of 32768 bytes. +// +// When the limit is hit, the connection will be closed with StatusMessageTooBig. +func (c *Conn) SetReadLimit(n int64) { + c.msgReadLimit = n +} diff --git a/doc.go b/doc.go index 2a5a0a1a..7753afc7 100644 --- a/doc.go +++ b/doc.go @@ -26,13 +26,23 @@ // See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket // // Thus the unsupported features (not compiled in) for WASM are: +// // - Accept and AcceptOptions -// - Conn's Reader, Writer, SetReadLimit and Ping methods +// - Conn.Ping // - HTTPClient and HTTPHeader fields in DialOptions // // The *http.Response returned by Dial will always either be nil or &http.Response{} as // we do not have access to the handshake response in the browser. // +// The Writer method on the Conn buffers everything in memory and then sends it as a message +// when the writer is closed. +// +// The Reader method also reads the entire response and then returns a reader that +// reads from the byte slice. +// +// SetReadLimit cannot actually limit the number of bytes read from the connection so instead +// when a message beyond the limit is fully read, it throws an error. +// // Writes are also always async so the passed context is no-op. // // Everything else is fully supported. This includes the wsjson and wspb helper packages. diff --git a/websocket_js.go b/websocket_js.go index 123bc8f4..4ed49d97 100644 --- a/websocket_js.go +++ b/websocket_js.go @@ -13,6 +13,7 @@ import ( "sync/atomic" "syscall/js" + "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/wsjs" ) @@ -20,6 +21,8 @@ import ( type Conn struct { ws wsjs.WebSocket + msgReadLimit int64 + readClosed int64 closeOnce sync.Once closed chan struct{} @@ -43,6 +46,7 @@ func (c *Conn) close(err error) { func (c *Conn) init() { c.closed = make(chan struct{}) c.readch = make(chan wsjs.MessageEvent, 1) + c.msgReadLimit = 32768 c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { cerr := CloseError{ @@ -77,6 +81,10 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { if err != nil { return 0, nil, fmt.Errorf("failed to read: %w", err) } + if int64(len(p)) > c.msgReadLimit { + c.Close(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", c.msgReadLimit)) + return 0, nil, c.closeErr + } return typ, p, nil } @@ -106,6 +114,11 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { err := c.write(ctx, typ, p) if err != nil { + // Have to ensure the WebSocket is closed after a write error + // to match the Go API. It can only error if the message type + // is unexpected or the passed bytes contain invalid UTF-8 for + // MessageText. + c.Close(StatusInternalError, "something went wrong") return fmt.Errorf("failed to write: %w", err) } return nil @@ -216,8 +229,10 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp return c, &http.Response{}, nil } -func (c *netConn) netConnReader(ctx context.Context) (MessageType, io.Reader, error) { - typ, p, err := c.c.Read(ctx) +// Reader attempts to read a message from the connection. +// The maximum time spent waiting is bounded by the context. +func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + typ, p, err := c.Read(ctx) if err != nil { return 0, nil, err } @@ -228,3 +243,50 @@ func (c *netConn) netConnReader(ctx context.Context) (MessageType, io.Reader, er func (c *Conn) reader(ctx context.Context) { c.read(ctx) } + +// Writer returns a writer to write a WebSocket data message to the connection. +// It buffers the entire message in memory and then sends it when the writer +// is closed. +func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + return writer{ + c: c, + ctx: ctx, + typ: typ, + b: bpool.Get(), + }, nil +} + +type writer struct { + closed bool + + c *Conn + ctx context.Context + typ MessageType + + b *bytes.Buffer +} + +func (w writer) Write(p []byte) (int, error) { + if w.closed { + return 0, errors.New("cannot write to closed writer") + } + n, err := w.b.Write(p) + if err != nil { + return n, fmt.Errorf("failed to write message: %w", err) + } + return n, nil +} + +func (w writer) Close() error { + if w.closed { + return errors.New("cannot close closed writer") + } + w.closed = true + defer bpool.Put(w.b) + + err := w.c.Write(w.ctx, w.typ, w.b.Bytes()) + if err != nil { + return fmt.Errorf("failed to close writer: %w", err) + } + return nil +} diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index ffdd24ac..fe935fa1 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -1,5 +1,3 @@ -// +build !js - // Package wsjson provides websocket helpers for JSON messages. package wsjson // import "nhooyr.io/websocket/wsjson" @@ -34,9 +32,7 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { } b := bpool.Get() - defer func() { - bpool.Put(b) - }() + defer bpool.Put(b) _, err = b.ReadFrom(r) if err != nil { diff --git a/wsjson/wsjson_js.go b/wsjson/wsjson_js.go deleted file mode 100644 index 5b88ce3b..00000000 --- a/wsjson/wsjson_js.go +++ /dev/null @@ -1,58 +0,0 @@ -// +build js - -package wsjson - -import ( - "context" - "encoding/json" - "fmt" - - "nhooyr.io/websocket" -) - -// Read reads a json message from c into v. -func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { - err := read(ctx, c, v) - if err != nil { - return fmt.Errorf("failed to read json: %w", err) - } - return nil -} - -func read(ctx context.Context, c *websocket.Conn, v interface{}) error { - typ, b, err := c.Read(ctx) - if err != nil { - return err - } - - if typ != websocket.MessageText { - c.Close(websocket.StatusUnsupportedData, "can only accept text messages") - return fmt.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ) - } - - err = json.Unmarshal(b, v) - if err != nil { - c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") - return fmt.Errorf("failed to unmarshal json: %w", err) - } - - return nil -} - -// Write writes the json message v to c. -func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { - err := write(ctx, c, v) - if err != nil { - return fmt.Errorf("failed to write json: %w", err) - } - return nil -} - -func write(ctx context.Context, c *websocket.Conn, v interface{}) error { - b, err := json.Marshal(v) - if err != nil { - return err - } - - return c.Write(ctx, websocket.MessageText, b) -} diff --git a/wspb/wspb.go b/wspb/wspb.go index b32b0c1b..3c9e0f76 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -1,5 +1,3 @@ -// +build !js - // Package wspb provides websocket helpers for protobuf messages. package wspb // import "nhooyr.io/websocket/wspb" @@ -36,9 +34,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { } b := bpool.Get() - defer func() { - bpool.Put(b) - }() + defer bpool.Put(b) _, err = b.ReadFrom(r) if err != nil { diff --git a/wspb/wspb_js.go b/wspb/wspb_js.go deleted file mode 100644 index 6f69eddd..00000000 --- a/wspb/wspb_js.go +++ /dev/null @@ -1,67 +0,0 @@ -// +build js - -package wspb // import "nhooyr.io/websocket/wspb" - -import ( - "bytes" - "context" - "fmt" - - "github.com/golang/protobuf/proto" - - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/bpool" -) - -// Read reads a protobuf message from c into v. -func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { - err := read(ctx, c, v) - if err != nil { - return fmt.Errorf("failed to read protobuf: %w", err) - } - return nil -} - -func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { - typ, p, err := c.Read(ctx) - if err != nil { - return err - } - - if typ != websocket.MessageBinary { - c.Close(websocket.StatusUnsupportedData, "can only accept binary messages") - return fmt.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ) - } - - err = proto.Unmarshal(p, v) - if err != nil { - c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") - return fmt.Errorf("failed to unmarshal protobuf: %w", err) - } - - return nil -} - -// Write writes the protobuf message v to c. -func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { - err := write(ctx, c, v) - if err != nil { - return fmt.Errorf("failed to write protobuf: %w", err) - } - return nil -} - -func write(ctx context.Context, c *websocket.Conn, v proto.Message) error { - b := bpool.Get() - pb := proto.NewBuffer(b.Bytes()) - defer func() { - bpool.Put(bytes.NewBuffer(pb.Bytes())) - }() - - err := pb.Marshal(v) - if err != nil { - return fmt.Errorf("failed to marshal protobuf: %w", err) - } - - return c.Write(ctx, websocket.MessageBinary, pb.Bytes()) -}