From 4ca29121fed60a6cf16861c1e110002d2384b7a1 Mon Sep 17 00:00:00 2001 From: David Piegza <697113+davidpiegza@users.noreply.github.com> Date: Mon, 25 Sep 2023 10:53:31 +0000 Subject: [PATCH] Add shutdown state in MySQL server plugin --- go/mysql/conn.go | 29 +++++++++++++++++++++++++++++ go/mysql/server.go | 3 ++- go/vt/vtgate/plugin_mysql_server.go | 11 ++++++++--- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 9fb47da189e..f8e98ab42ec 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -199,6 +199,12 @@ type Conn struct { // enableQueryInfo controls whether we parse the INFO field in QUERY_OK packets // See: ConnParams.EnableQueryInfo enableQueryInfo bool + + // mu protects the fields below + mu sync.Mutex + // this is used to mark the connection to be closed so that the command phase for the connection can be stopped and + // the connection gets closed. + closing bool } // splitStatementFunciton is the function that is used to split the statement in case of a multi-statement query. @@ -897,6 +903,11 @@ func (c *Conn) handleNextCommand(handler Handler) bool { return false } + // before continue to process the packet, check if the connection should be closed or not. + if c.IsMarkedForClose() { + return false + } + switch data[0] { case ComQuit: c.recycleReadPacket() @@ -1632,3 +1643,21 @@ func (c *Conn) IsUnixSocket() bool { func (c *Conn) GetRawConn() net.Conn { return c.conn } + +// MarkForClose marks the connection for close. +func (c *Conn) MarkForClose() { + c.mu.Lock() + defer c.mu.Unlock() + c.closing = true +} + +// IsMarkedForClose return true if the connection should be closed. +func (c *Conn) IsMarkedForClose() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closing +} + +func (c *Conn) IsShuttingDown() bool { + return c.listener.shutdown.Load() +} diff --git a/go/mysql/server.go b/go/mysql/server.go index e17bd82ef90..09ddb95955c 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -525,7 +525,8 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti for { kontinue := c.handleNextCommand(l.handler) - if !kontinue { + // before going for next command check if the connection should be closed or not. + if !kontinue || c.IsMarkedForClose() { return } } diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index c7d4c53785c..9f138b006fd 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -196,6 +196,12 @@ func startSpan(ctx context.Context, query, label string) (trace.Span, context.Co } func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { + session := vh.session(c) + if c.IsShuttingDown() && !session.InTransaction { + c.MarkForClose() + return mysql.NewSQLError(mysql.ERServerShutdown, mysql.SSNetError, "Server shutdown in progress") + } + ctx := context.Background() var cancel context.CancelFunc if mysqlQueryTimeout != 0 { @@ -223,7 +229,6 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq "VTGate MySQL Connector" /* subcomponent: part of the client */) ctx = callerid.NewContext(ctx, ef, im) - session := vh.session(c) if !session.InTransaction { atomic.AddInt32(&busyConnections, 1) } @@ -565,11 +570,11 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys func shutdownMysqlProtocolAndDrain() { if mysqlListener != nil { - mysqlListener.Close() + mysqlListener.Shutdown() mysqlListener = nil } if mysqlUnixListener != nil { - mysqlUnixListener.Close() + mysqlUnixListener.Shutdown() mysqlUnixListener = nil } if sigChan != nil {