diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index eb9dc791..a6efcfdf 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -222,7 +222,7 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.write(.extendedQuery(context), cascadingFailureTo: promise) return promise.futureResult } @@ -239,7 +239,8 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.write(.extendedQuery(context), cascadingFailureTo: promise) + return promise.futureResult.map { rowDescription in PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) } @@ -255,7 +256,8 @@ public final class PostgresConnection: @unchecked Sendable { logger: logger, promise: promise) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.write(.extendedQuery(context), cascadingFailureTo: promise) + return promise.futureResult } @@ -263,7 +265,8 @@ public final class PostgresConnection: @unchecked Sendable { let promise = self.channel.eventLoop.makePromise(of: Void.self) let context = CloseCommandContext(target: target, logger: logger, promise: promise) - self.channel.write(HandlerTask.closeCommand(context), promise: nil) + self.write(.closeCommand(context), cascadingFailureTo: promise) + return promise.futureResult } @@ -426,7 +429,7 @@ extension PostgresConnection { promise: promise ) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.write(.extendedQuery(context), cascadingFailureTo: promise) do { return try await promise.futureResult.map({ $0.asyncSequence() }).get() @@ -455,7 +458,11 @@ extension PostgresConnection { let task = HandlerTask.startListening(listener) - self.channel.write(task, promise: nil) + let writePromise = self.channel.eventLoop.makePromise(of: Void.self) + self.channel.write(task, promise: writePromise) + writePromise.futureResult.whenFailure { error in + listener.failed(error) + } } } onCancel: { let task = HandlerTask.cancelListening(channel, id) @@ -480,7 +487,9 @@ extension PostgresConnection { logger: logger, promise: promise )) - self.channel.write(task, promise: nil) + + self.write(task, cascadingFailureTo: promise) + do { return try await promise.futureResult .map { $0.asyncSequence() } @@ -515,7 +524,9 @@ extension PostgresConnection { logger: logger, promise: promise )) - self.channel.write(task, promise: nil) + + self.write(task, cascadingFailureTo: promise) + do { return try await promise.futureResult .map { $0.commandTag } @@ -530,6 +541,12 @@ extension PostgresConnection { throw error // rethrow with more metadata } } + + private func write(_ task: HandlerTask, cascadingFailureTo promise: EventLoopPromise) { + let writePromise = self.channel.eventLoop.makePromise(of: Void.self) + self.channel.write(task, promise: writePromise) + writePromise.futureResult.cascadeFailure(to: promise) + } } // MARK: EventLoopFuture interface @@ -674,7 +691,7 @@ internal enum PostgresCommands: PostgresRequest { /// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. public final class PostgresListenContext: Sendable { - private let promise: EventLoopPromise + let promise: EventLoopPromise var future: EventLoopFuture { self.promise.futureResult @@ -713,8 +730,7 @@ extension PostgresConnection { closure: notificationHandler ) - let task = HandlerTask.startListening(listener) - self.channel.write(task, promise: nil) + self.write(.startListening(listener), cascadingFailureTo: listenContext.promise) listenContext.future.whenComplete { _ in let task = HandlerTask.cancelListening(channel, id) @@ -761,3 +777,4 @@ extension PostgresConnection { #endif } } + diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 57939c06..913d91b2 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -359,5 +359,4 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(obj?.bar, 2) } } - } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 0bc61efd..5c7d4c83 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -224,6 +224,63 @@ class PostgresConnectionTests: XCTestCase { } } + func testSimpleListenFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + do { + _ = try await connection.listen("test_channel") + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testSimpleListenFailsIfConnectionIsClosedWhileListening() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + let events = try await connection.listen("foo") + var iterator = events.makeAsyncIterator() + let first = try await iterator.next() + XCTAssertEqual(first?.payload, "wooohooo") + do { + _ = try await iterator.next() + XCTFail("Did not expect to not throw") + } catch let error as PSQLError { + XCTAssertEqual(error.code, .clientClosedConnection) + } + } + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + + try await connection.close() + + XCTAssertEqual(channel.isActive, false) + + switch await taskGroup.nextResult()! { + case .success: + break + case .failure(let failure): + XCTFail("Unexpected error: \(failure)") + } + } + } + func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in @@ -638,6 +695,118 @@ class PostgresConnectionTests: XCTestCase { } } + func testQueryFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + do { + _ = try await connection.query("SELECT version;", logger: self.logger) + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testPrepareStatementFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + do { + _ = try await connection.prepareStatement("SELECT version;", with: "test_query", logger: .psqlTest).get() + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testExecuteFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + do { + let statement = PSQLExecuteStatement(name: "SELECT version;", binds: .init(), rowDescription: nil) + _ = try await connection.execute(statement, logger: .psqlTest).get() + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testExecutePreparedStatementFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + struct TestPreparedStatement: PostgresPreparedStatement { + static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" + typealias Row = (Int, String) + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.state) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + + do { + let preparedStatement = TestPreparedStatement(state: "active") + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + struct TestPreparedStatement: PostgresPreparedStatement { + static let sql = "SELECT * FROM pg_stat_activity WHERE state = $1" + typealias Row = () + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.state) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + do { + let preparedStatement = TestPreparedStatement(state: "active") + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = await NIOAsyncTestingChannel(handlers: [