diff --git a/.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata b/.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata new file mode 100644 index 00000000..919434a6 --- /dev/null +++ b/.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata @@ -0,0 +1,7 @@ + + + + + diff --git a/Package.resolved b/Package.resolved new file mode 100644 index 00000000..dc809063 --- /dev/null +++ b/Package.resolved @@ -0,0 +1,16 @@ +{ + "object": { + "pins": [ + { + "package": "swift-argument-parser", + "repositoryURL": "https://github.com/apple/swift-argument-parser", + "state": { + "branch": null, + "revision": "9564d61b08a5335ae0a36f789a7d71493eacadfc", + "version": "0.3.2" + } + } + ] + }, + "version": 1 +} diff --git a/Package.swift b/Package.swift index 4c6e4ab2..179db22b 100644 --- a/Package.swift +++ b/Package.swift @@ -12,6 +12,10 @@ import PackageDescription +let settings: [SwiftSetting]? = [ + .define("SYSTEM_PACKAGE") +] + let targets: [PackageDescription.Target] = [ .target( name: "SystemPackage", @@ -24,19 +28,43 @@ let targets: [PackageDescription.Target] = [ .target( name: "CSystem", dependencies: []), + + .target( + name: "SystemSockets", + dependencies: ["SystemPackage"]), + .testTarget( name: "SystemTests", dependencies: ["SystemPackage"], - swiftSettings: [ - .define("SYSTEM_PACKAGE") - ]), + swiftSettings: settings + ), + + .testTarget( + name: "SystemSocketsTests", + dependencies: ["SystemSockets"], + swiftSettings: settings + ), + + .target( + name: "Samples", + dependencies: [ + "SystemPackage", + "SystemSockets", + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ], + path: "Sources/Samples", + swiftSettings: settings + ), ] let package = Package( name: "swift-system", products: [ - .library(name: "SystemPackage", targets: ["SystemPackage"]), + .library(name: "SystemPackage", targets: ["SystemPackage"]), + .executable(name: "system-samples", targets: ["Samples"]), + ], + dependencies: [ + .package(url: "https://github.com/apple/swift-argument-parser", from: "0.3.0"), ], - dependencies: [], targets: targets ) diff --git a/Sources/Samples/Connect.swift b/Sources/Samples/Connect.swift new file mode 100644 index 00000000..bbda9261 --- /dev/null +++ b/Sources/Samples/Connect.swift @@ -0,0 +1,113 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +import ArgumentParser +#if SYSTEM_PACKAGE +import SystemPackage +import SystemSockets +#else +import System +#error("No socket support") +#endif + +struct Connect: ParsableCommand { + static var configuration = CommandConfiguration( + commandName: "connect", + abstract: "Open a connection and send lines from stdin over it." + ) + + @Argument(help: "The hostname to connect to.") + var hostname: String + + @Argument(help: "The port number (or service name) to connect to.") + var service: String + + @Flag(help: "Use IPv4") + var ipv4: Bool = false + + @Flag(help: "Use IPv6") + var ipv6: Bool = false + + @Flag(help: "Use UDP") + var udp: Bool = false + + @Flag(help: "Send data out-of-band") + var outOfBand: Bool = false + + @Option(help: "TCP connection timeout, in seconds") + var connectionTimeout: CInt? + + @Option(help: "Socket send timeout, in seconds") + var sendTimeout: CInt? + + @Option(help: "Socket receive timeout, in seconds") + var receiveTimeout: CInt? + + @Option(help: "TCP connection keepalive interval, in seconds") + var keepalive: CInt? + + func connect( + to addresses: [SocketAddress.Info] + ) throws -> (SocketDescriptor, SocketAddress)? { + // Only try the first address for now + guard let addressinfo = addresses.first else { return nil } + print(addressinfo) + let socket = try SocketDescriptor.open( + addressinfo.domain, + addressinfo.type, + addressinfo.protocol) + do { + if let connectionTimeout = connectionTimeout { + try socket.setOption(.tcp, .tcpConnectionTimeout, to: connectionTimeout) + } + if let sendTimeout = sendTimeout { + try socket.setOption(.socketOption, .sendTimeout, to: sendTimeout) + } + if let receiveTimeout = receiveTimeout { + try socket.setOption(.socketOption, .receiveTimeout, to: receiveTimeout) + } + if let keepalive = keepalive { + try socket.setOption(.tcp, .tcpKeepAlive, to: keepalive) + } + try socket.connect(to: addressinfo.address) + return (socket, addressinfo.address) + } + catch { + try? socket.close() + throw error + } + } + + func run() throws { + let addresses = try SocketAddress.resolveName( + hostname: hostname, + service: service, + family: ipv6 ? .ipv6 : .ipv4, + type: udp ? .datagram : .stream) + + guard let (socket, address) = try connect(to: addresses) else { + complain("Can't connect to \(hostname)") + throw ExitCode.failure + } + complain("Connected to \(address.niceDescription)") + + let flags: SocketDescriptor.MessageFlags = outOfBand ? .outOfBand : .none + try socket.closeAfter { + while var line = readLine(strippingNewline: false) { + try line.withUTF8 { buffer in + var buffer = UnsafeRawBufferPointer(buffer) + while !buffer.isEmpty { + let c = try socket.send(buffer, flags: flags) + buffer = .init(rebasing: buffer[c...]) + } + } + } + } + } +} diff --git a/Sources/Samples/Listen.swift b/Sources/Samples/Listen.swift new file mode 100644 index 00000000..f5466712 --- /dev/null +++ b/Sources/Samples/Listen.swift @@ -0,0 +1,124 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +import ArgumentParser +#if SYSTEM_PACKAGE +import SystemPackage +import SystemSockets +#else +import System +#error("No socket support") +#endif + +struct Listen: ParsableCommand { + static var configuration = CommandConfiguration( + commandName: "listen", + abstract: "Listen for an incoming connection and print received text to stdout." + ) + + @Argument(help: "The port number (or service name) to listen on.") + var service: String + + @Flag(help: "Use IPv4") + var ipv4: Bool = false + + @Flag(help: "Use IPv6") + var ipv6: Bool = false + + @Flag(help: "Use UDP") + var udp: Bool = false + + func startServer( + on addresses: [SocketAddress.Info] + ) throws -> (SocketDescriptor, SocketAddress.Info)? { + for info in addresses { + do { + let socket = try SocketDescriptor.open( + info.domain, + info.type, + info.protocol) + do { + try socket.bind(to: info.address) + if !info.type.isConnectionless { + try socket.listen(backlog: 10) + } + return (socket, info) + } + catch { + try? socket.close() + throw error + } + } + catch { + continue + } + } + return nil + } + + func prefix( + client: SocketAddress, + flags: SocketDescriptor.MessageFlags + ) -> String { + var prefix: [String] = [] + if client.family != .unspecified { + prefix.append("client: \(client.niceDescription)") + } + if flags != .none { + prefix.append("flags: \(flags)") + } + guard !prefix.isEmpty else { return "" } + return "<\(prefix.joined(separator: ", "))> " + } + + func run() throws { + let addresses = try SocketAddress.resolveName( + hostname: nil, + service: service, + flags: .canonicalName, + family: ipv6 ? .ipv6 : .ipv4, + type: udp ? .datagram : .stream) + + + guard let (socket, address) = try startServer(on: addresses) else { + complain("Can't listen on \(service)") + throw ExitCode.failure + } + complain("Listening on \(address.address.niceDescription)") + + var client = SocketAddress() + let buffer = UnsafeMutableRawBufferPointer.allocate(byteCount: 1024, alignment: 1) + defer { buffer.deallocate() } + + var ancillary = SocketDescriptor.AncillaryMessageBuffer() + try socket.closeAfter { + if udp { + while true { + let (count, flags) = + try socket.receive(into: buffer, sender: &client, ancillary: &ancillary) + print(prefix(client: client, flags: flags), terminator: "") + try FileDescriptor.standardOutput.writeAll(buffer[.. 0 else { break } + print(prefix(client: client, flags: flags), terminator: "") + try FileDescriptor.standardOutput.writeAll(buffer[.. = [] + for info in infos { + // Now try a reverse lookup. + var flags: SocketAddress.AddressResolverFlags = [] + if nofqdn { + flags.insert(.noFullyQualifiedDomain) + } + if numericHost { + flags.insert(.numericHost) + } + if numericService { + flags.insert(.numericService) + } + if datagram { + flags.insert(.datagram) + } + if scopeid { + flags.insert(.scopeIdentifier) + } + let (hostname, service) = try SocketAddress.resolveAddress(info.address, flags: flags) + results.insert("\(hostname) \(service)") + } + for r in results.sorted() { + print(r) + } + } +} diff --git a/Sources/Samples/Util.swift b/Sources/Samples/Util.swift new file mode 100644 index 00000000..a6d8ad76 --- /dev/null +++ b/Sources/Samples/Util.swift @@ -0,0 +1,70 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +#if SYSTEM_PACKAGE +import SystemPackage +import SystemSockets +#else +import System +#error("No socket support") +#endif + +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +import Darwin +#elseif os(Linux) || os(FreeBSD) || os(Android) +import CSystem +import Glibc +#elseif os(Windows) +import CSystem +import ucrt +#else +#error("Unsupported Platform") +#endif + +/// Turn off libc's input/output buffering. +internal func disableBuffering() { + // FIXME: We should probably be able to do this from System. + setbuf(stdin, nil) + setbuf(stdout, nil) + setbuf(stderr, nil) +} + +internal func complain(_ message: String) { + var message = message + "\n" + message.withUTF8 { buffer in + _ = try? FileDescriptor.standardError.writeAll(buffer) + } +} + +extension SocketAddress.Info { + var niceDescription: String { + var proto = "" + switch self.protocol { + case .udp: proto = "udp" + case .tcp: proto = "tcp" + default: proto = "\(self.protocol)" + } + return "\(address.niceDescription) (\(proto))" + } +} + +extension SocketAddress { + var niceDescription: String { + if let ipv4 = self.ipv4 { return ipv4.description } + if let ipv6 = self.ipv6 { return ipv6.description } + if let local = self.local { return local.description } + return self.description + } +} + +extension SocketDescriptor.ConnectionType { + var isConnectionless: Bool { + self == .datagram || self == .reliablyDeliveredMessage + } +} diff --git a/Sources/Samples/main.swift b/Sources/Samples/main.swift new file mode 100644 index 00000000..92014377 --- /dev/null +++ b/Sources/Samples/main.swift @@ -0,0 +1,16 @@ +import ArgumentParser + +internal struct SystemSamples: ParsableCommand { + public static var configuration = CommandConfiguration( + commandName: "system-samples", + abstract: "A collection of little programs exercising some System features.", + subcommands: [ + Resolve.self, + ReverseResolve.self, + Connect.self, + Listen.self, + ]) +} + +disableBuffering() +SystemSamples.main() diff --git a/Sources/System/Internals/CInterop.swift b/Sources/System/Internals/CInterop.swift index 7f3b96d7..1c4b0000 100644 --- a/Sources/System/Internals/CInterop.swift +++ b/Sources/System/Internals/CInterop.swift @@ -1,7 +1,7 @@ /* This source file is part of the Swift System open source project - Copyright (c) 2020 Apple Inc. and the Swift System project authors + Copyright (c) 2020 - 2021 Apple Inc. and the Swift System project authors Licensed under Apache License v2.0 with Runtime Library Exception See https://swift.org/LICENSE.txt for license information @@ -29,12 +29,6 @@ import ucrt /// A namespace for C and platform types // @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) public enum CInterop { -#if os(Windows) - public typealias Mode = CInt -#else - public typealias Mode = mode_t -#endif - /// The C `char` type public typealias Char = CChar @@ -63,4 +57,7 @@ public enum CInterop { /// on API. public typealias PlatformUnicodeEncoding = UTF8 #endif + + public typealias Mode = mode_t } + diff --git a/Sources/System/Internals/Constants.swift b/Sources/System/Internals/Constants.swift index f5a64b0b..b04ae73a 100644 --- a/Sources/System/Internals/Constants.swift +++ b/Sources/System/Internals/Constants.swift @@ -1,7 +1,7 @@ /* This source file is part of the Swift System open source project - Copyright (c) 2020 Apple Inc. and the Swift System project authors + Copyright (c) 2020 - 2021 Apple Inc. and the Swift System project authors Licensed under Apache License v2.0 with Runtime Library Exception See https://swift.org/LICENSE.txt for license information diff --git a/Sources/System/Internals/Exports.swift b/Sources/System/Internals/Exports.swift index 2b4ae3be..0a5a26eb 100644 --- a/Sources/System/Internals/Exports.swift +++ b/Sources/System/Internals/Exports.swift @@ -1,7 +1,7 @@ /* This source file is part of the Swift System open source project - Copyright (c) 2020 Apple Inc. and the Swift System project authors + Copyright (c) 2020 - 2021 Apple Inc. and the Swift System project authors Licensed under Apache License v2.0 with Runtime Library Exception See https://swift.org/LICENSE.txt for license information @@ -60,7 +60,10 @@ internal func system_strerror(_ __errnum: Int32) -> UnsafeMutablePointer! strerror(__errnum) } -internal func system_strlen(_ s: UnsafePointer) -> Int { +internal func system_strlen(_ s: UnsafePointer) -> Int { + strlen(s) +} +internal func system_strlen(_ s: UnsafeMutablePointer) -> Int { strlen(s) } @@ -78,6 +81,15 @@ internal func system_platform_strlen(_ s: UnsafePointer) #endif } +// memset for raw buffers +// FIXME: Do we really not have something like this in the stdlib already? +internal func system_memset( + _ buffer: UnsafeMutableRawBufferPointer, + to byte: UInt8 +) { + memset(buffer.baseAddress, CInt(byte), buffer.count) +} + // Interop between String and platfrom string extension String { internal func _withPlatformString( diff --git a/Sources/System/Internals/Mocking.swift b/Sources/System/Internals/Mocking.swift index 4c74b6a4..764f20b8 100644 --- a/Sources/System/Internals/Mocking.swift +++ b/Sources/System/Internals/Mocking.swift @@ -19,9 +19,10 @@ #if ENABLE_MOCKING internal struct Trace { - internal struct Entry: Hashable { - private var name: String - private var arguments: [AnyHashable] + internal struct Entry { + + internal var name: String + internal var arguments: [AnyHashable] internal init(name: String, _ arguments: [AnyHashable]) { self.name = name diff --git a/Sources/System/Internals/Syscalls.swift b/Sources/System/Internals/Syscalls.swift index ecfdc843..4501d3c5 100644 --- a/Sources/System/Internals/Syscalls.swift +++ b/Sources/System/Internals/Syscalls.swift @@ -1,7 +1,7 @@ /* This source file is part of the Swift System open source project - Copyright (c) 2020 Apple Inc. and the Swift System project authors + Copyright (c) 2020 - 2021 Apple Inc. and the Swift System project authors Licensed under Apache License v2.0 with Runtime Library Exception See https://swift.org/LICENSE.txt for license information @@ -54,7 +54,7 @@ internal func system_close(_ fd: Int32) -> Int32 { // read internal func system_read( - _ fd: Int32, _ buf: UnsafeMutableRawPointer!, _ nbyte: Int + _ fd: Int32, _ buf: UnsafeMutableRawPointer?, _ nbyte: Int ) -> Int { #if ENABLE_MOCKING if mockingEnabled { return _mockInt(fd, buf, nbyte) } @@ -64,7 +64,7 @@ internal func system_read( // pread internal func system_pread( - _ fd: Int32, _ buf: UnsafeMutableRawPointer!, _ nbyte: Int, _ offset: off_t + _ fd: Int32, _ buf: UnsafeMutableRawPointer?, _ nbyte: Int, _ offset: off_t ) -> Int { #if ENABLE_MOCKING if mockingEnabled { return _mockInt(fd, buf, nbyte, offset) } @@ -84,7 +84,7 @@ internal func system_lseek( // write internal func system_write( - _ fd: Int32, _ buf: UnsafeRawPointer!, _ nbyte: Int + _ fd: Int32, _ buf: UnsafeRawPointer?, _ nbyte: Int ) -> Int { #if ENABLE_MOCKING if mockingEnabled { return _mockInt(fd, buf, nbyte) } @@ -94,7 +94,7 @@ internal func system_write( // pwrite internal func system_pwrite( - _ fd: Int32, _ buf: UnsafeRawPointer!, _ nbyte: Int, _ offset: off_t + _ fd: Int32, _ buf: UnsafeRawPointer?, _ nbyte: Int, _ offset: off_t ) -> Int { #if ENABLE_MOCKING if mockingEnabled { return _mockInt(fd, buf, nbyte, offset) } @@ -115,3 +115,4 @@ internal func system_dup2(_ fd: Int32, _ fd2: Int32) -> Int32 { #endif return dup2(fd, fd2) } + diff --git a/Sources/System/Util.swift b/Sources/System/Util.swift index c038d461..3e492d91 100644 --- a/Sources/System/Util.swift +++ b/Sources/System/Util.swift @@ -127,3 +127,14 @@ extension MutableCollection where Element: Equatable { } } } + +internal func _withOptionalUnsafePointerOrNull( + to value: T?, + _ body: (UnsafePointer?) throws -> R +) rethrows -> R { + guard let value = value else { + return try body(nil) + } + return try withUnsafePointer(to: value, body) +} + diff --git a/Sources/SystemSockets/Backcompat.swift b/Sources/SystemSockets/Backcompat.swift new file mode 100644 index 00000000..a376dd2e --- /dev/null +++ b/Sources/SystemSockets/Backcompat.swift @@ -0,0 +1,29 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +extension String { + internal init( + _unsafeUninitializedCapacity capacity: Int, + initializingUTF8With body: (UnsafeMutableBufferPointer) throws -> Int + ) rethrows { + if #available(macOS 11, iOS 14.0, watchOS 7.0, tvOS 14.0, *) { + self = try String( + unsafeUninitializedCapacity: capacity, + initializingUTF8With: body) + return + } + + let array = try Array( + unsafeUninitializedCapacity: capacity + ) { buffer, count in + count = try body(buffer) + } + self = String(decoding: array, as: UTF8.self) + } +} diff --git a/Sources/SystemSockets/CInterop.swift b/Sources/SystemSockets/CInterop.swift new file mode 100644 index 00000000..22a12e30 --- /dev/null +++ b/Sources/SystemSockets/CInterop.swift @@ -0,0 +1,89 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2020 - 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +// MARK: - Public typealiases + +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +import Darwin +#elseif os(Linux) || os(FreeBSD) || os(Android) +import CSystem +import Glibc +#elseif os(Windows) +import CSystem +import ucrt +#else +#error("Unsupported Platform") +#endif + +import SystemPackage + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension CInterop { + public typealias SockAddr = sockaddr + public typealias SockLen = socklen_t + public typealias SAFamily = sa_family_t + + public typealias SockAddrIn = sockaddr_in + public typealias InAddr = in_addr + public typealias InAddrT = in_addr_t + + public typealias In6Addr = in6_addr + + public typealias InPort = in_port_t + + public typealias SockAddrIn6 = sockaddr_in6 + public typealias SockAddrUn = sockaddr_un + + public typealias IOVec = iovec + public typealias MsgHdr = msghdr + public typealias CMsgHdr = cmsghdr // Note: c is for "control", not "C" + + public typealias AddrInfo = addrinfo +} + +// memset for raw buffers +// FIXME: Do we really not have something like this in the stdlib already? +internal func system_memset( + _ buffer: UnsafeMutableRawBufferPointer, + to byte: UInt8 +) { + memset(buffer.baseAddress, CInt(byte), buffer.count) +} + +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +internal var system_errno: CInt { + get { Darwin.errno } + set { Darwin.errno = newValue } +} +#elseif os(Windows) +internal var system_errno: CInt { + get { + var value: CInt = 0 + // TODO(compnerd) handle the error? + _ = ucrt._get_errno(&value) + return value + } + set { + _ = ucrt._set_errno(newValue) + } +} +#else +internal var system_errno: CInt { + get { Glibc.errno } + set { Glibc.errno = newValue } +} +#endif + +internal func system_strlen(_ s: UnsafePointer) -> Int { + strlen(s) +} +internal func system_strlen(_ s: UnsafeMutablePointer) -> Int { + strlen(s) +} + diff --git a/Sources/SystemSockets/NetworkOrder.swift b/Sources/SystemSockets/NetworkOrder.swift new file mode 100644 index 00000000..a76cf7ee --- /dev/null +++ b/Sources/SystemSockets/NetworkOrder.swift @@ -0,0 +1,22 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +extension FixedWidthInteger { + @_alwaysEmitIntoClient + @inline(__always) + internal var _networkOrder: Self { + bigEndian + } + + @_alwaysEmitIntoClient + @inline(__always) + internal init(_networkOrder value: Self) { + self.init(bigEndian: value) + } +} diff --git a/Sources/SystemSockets/RawBuffer.swift b/Sources/SystemSockets/RawBuffer.swift new file mode 100644 index 00000000..7f9461c3 --- /dev/null +++ b/Sources/SystemSockets/RawBuffer.swift @@ -0,0 +1,102 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +// A copy-on-write fixed-size buffer of raw memory. +internal struct _RawBuffer { + internal var _storage: Storage? + + internal init() { + self._storage = nil + } + + internal init(minimumCapacity: Int) { + if minimumCapacity > 0 { + self._storage = Storage.create(minimumCapacity: minimumCapacity) + } else { + self._storage = nil + } + } +} + +extension _RawBuffer { + internal var capacity: Int { + _storage?.header ?? 0 // Note: not capacity! + } + + internal mutating func ensureUnique() { + guard _storage != nil else { return } + let unique = isKnownUniquelyReferenced(&_storage) + if !unique { + _storage = _copy(capacity: capacity) + } + } + + internal func _grow(desired: Int) -> Int { + let next = Int(1.75 * Double(self.capacity)) + return Swift.max(next, desired) + } + + internal mutating func ensureUnique(capacity: Int) { + let unique = isKnownUniquelyReferenced(&_storage) + if !unique || self.capacity < capacity { + _storage = _copy(capacity: _grow(desired: capacity)) + } + } + + internal func withUnsafeBytes( + _ body: (UnsafeRawBufferPointer) throws -> R + ) rethrows -> R { + guard let storage = _storage else { + return try body(UnsafeRawBufferPointer(start: nil, count: 0)) + } + return try storage.withUnsafeMutablePointers { count, bytes in + let buffer = UnsafeRawBufferPointer(start: bytes, count: count.pointee) + return try body(buffer) + } + } + + internal mutating func withUnsafeMutableBytes( + _ body: (UnsafeMutableRawBufferPointer) throws -> R + ) rethrows -> R { + guard _storage != nil else { + return try body(UnsafeMutableRawBufferPointer(start: nil, count: 0)) + } + ensureUnique() + return try _storage!.withUnsafeMutablePointers { count, bytes in + let buffer = UnsafeMutableRawBufferPointer(start: bytes, count: count.pointee) + return try body(buffer) + } + } +} + +extension _RawBuffer { + internal class Storage: ManagedBuffer { + internal static func create(minimumCapacity: Int) -> Storage { + Storage.create( + minimumCapacity: minimumCapacity, + makingHeaderWith: { $0.capacity } + ) as! Storage + } + } + + internal func _copy(capacity: Int) -> Storage { + let copy = Storage.create(minimumCapacity: capacity) + copy.withUnsafeMutablePointers { dstlen, dst in + self.withUnsafeBytes { src in + guard src.count > 0 else { return } + assert(src.count <= dstlen.pointee) + UnsafeMutableRawPointer(dst) + .copyMemory( + from: src.baseAddress!, + byteCount: Swift.min(src.count, dstlen.pointee)) + } + } + return copy + } +} diff --git a/Sources/SystemSockets/Sockets/SocketAddress+Family.swift b/Sources/SystemSockets/Sockets/SocketAddress+Family.swift new file mode 100644 index 00000000..8375d026 --- /dev/null +++ b/Sources/SystemSockets/Sockets/SocketAddress+Family.swift @@ -0,0 +1,104 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +import Darwin +#elseif os(Linux) || os(FreeBSD) || os(Android) +import CSystem +import Glibc +#elseif os(Windows) +import CSystem +import ucrt +#else +#error("Unsupported Platform") +#endif + +import SystemPackage + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + @frozen + /// The address family identifier + public struct Family: RawRepresentable, Hashable { + public let rawValue: CInterop.SAFamily + + @_alwaysEmitIntoClient + public init(rawValue: CInterop.SAFamily) { self.rawValue = rawValue } + + @_alwaysEmitIntoClient + private init(_ rawValue: CInterop.SAFamily) { self.init(rawValue: rawValue) } + + /// Unspecified address family. + /// + /// The corresponding C constant is `AF_UNSPEC`. + @_alwaysEmitIntoClient + public static var unspecified: Family { Family(CInterop.SAFamily(AF_UNSPEC)) } + + /// Local address family. + /// + /// The corresponding C constant is `AF_LOCAL`. + @_alwaysEmitIntoClient + public static var local: Family { Family(CInterop.SAFamily(AF_LOCAL)) } + + /// UNIX address family. (Renamed `local`.) + /// + /// The corresponding C constant is `AF_UNIX`. + @_alwaysEmitIntoClient + @available(*, unavailable, renamed: "local") + public static var unix: Family { Family(CInterop.SAFamily(AF_UNIX)) } + + /// IPv4 address family. + /// + /// The corresponding C constant is `AF_INET`. + @_alwaysEmitIntoClient + public static var ipv4: Family { Family(CInterop.SAFamily(AF_INET)) } + + /// Internal routing address family. + /// + /// The corresponding C constant is `AF_ROUTE`. + @_alwaysEmitIntoClient + public static var routing: Family { Family(CInterop.SAFamily(AF_ROUTE)) } + + /// IPv6 address family. + /// + /// The corresponding C constant is `AF_INET6`. + @_alwaysEmitIntoClient + public static var ipv6: Family { Family(CInterop.SAFamily(AF_INET6)) } + + /// System address family. + /// + /// The corresponding C constant is `AF_SYSTEM`. + @_alwaysEmitIntoClient + public static var system: Family { Family(CInterop.SAFamily(AF_SYSTEM)) } + + /// Raw network device address family. + /// + /// The corresponding C constant is `AF_NDRV` + @_alwaysEmitIntoClient + public static var networkDevice: Family { Family(CInterop.SAFamily(AF_NDRV)) } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.Family: CustomStringConvertible { + public var description: String { + switch self { + case .unspecified: return "unspecified" + case .local: return "local" + //case .unix: return "unix" + case .ipv4: return "ipv4" + case .routing: return "routing" + case .ipv6: return "ipv6" + case .system: return "system" + case .networkDevice: return "networkDevice" + default: + return rawValue.description + } + } +} diff --git a/Sources/SystemSockets/Sockets/SocketAddress+IPv4.swift b/Sources/SystemSockets/Sockets/SocketAddress+IPv4.swift new file mode 100644 index 00000000..71181f4c --- /dev/null +++ b/Sources/SystemSockets/Sockets/SocketAddress+IPv4.swift @@ -0,0 +1,220 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +import SystemPackage + +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +import Darwin +#elseif os(Linux) || os(FreeBSD) || os(Android) +import CSystem +import Glibc +#elseif os(Windows) +import CSystem +import ucrt +#else +#error("Unsupported Platform") +#endif + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + @frozen + /// An IPv4 address and port number. + public struct IPv4: RawRepresentable { + @_alwaysEmitIntoClient + public var rawValue: CInterop.SockAddrIn + + @_alwaysEmitIntoClient + public init(rawValue: CInterop.SockAddrIn) { + self.rawValue = rawValue + self.rawValue.sin_family = Family.ipv4.rawValue + } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + /// Create a SocketAddress from an IPv4 address and port number. + @_alwaysEmitIntoClient + public init(_ ipv4: IPv4) { + self = Swift.withUnsafeBytes(of: ipv4.rawValue) { buffer in + SocketAddress(buffer) + } + } + + /// If `self` holds an IPv4 address, extract it, otherwise return `nil`. + @_alwaysEmitIntoClient + public var ipv4: IPv4? { + guard family == .ipv4 else { return nil } + let value: CInterop.SockAddrIn? = self.withUnsafeBytes { buffer in + guard buffer.count >= MemoryLayout.size else { + return nil + } + return buffer.baseAddress!.load(as: CInterop.SockAddrIn.self) + } + guard let value = value else { return nil } + return IPv4(rawValue: value) + } + + /// Construct a `SocketAddress` holding an IPv4 address and port number. + @_alwaysEmitIntoClient + public init(ipv4 address: IPv4.Address, port: Port) { + self.init(IPv4(address: address, port: port)) + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv4 { + /// Create a socket address from a given Internet address and port number. + @_alwaysEmitIntoClient + public init(address: Address, port: SocketAddress.Port) { + rawValue = CInterop.SockAddrIn() + rawValue.sin_family = CInterop.SAFamily(SocketDescriptor.Domain.ipv4.rawValue); + rawValue.sin_port = port.rawValue._networkOrder + rawValue.sin_addr = CInterop.InAddr(s_addr: address.rawValue._networkOrder) + } + + /// Create a socket address by parsing an IPv4 address from `address` + /// and a port number. + @_alwaysEmitIntoClient + public init?(address: String, port: SocketAddress.Port) { + guard let address = Address(address) else { return nil } + self.init(address: address, port: port) + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv4: Hashable { + @_alwaysEmitIntoClient + public static func ==(left: Self, right: Self) -> Bool { + left.address == right.address && left.port == right.port + } + + @_alwaysEmitIntoClient + public func hash(into hasher: inout Hasher) { + hasher.combine(address) + hasher.combine(port) + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv4: CustomStringConvertible { + public var description: String { "\(address):\(port)" } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv4 { + @frozen + /// A 32-bit IPv4 address. + public struct Address: RawRepresentable, Hashable { + /// The raw internet address value, in host byte order. + @_alwaysEmitIntoClient + public var rawValue: CInterop.InAddrT + + @_alwaysEmitIntoClient + public init(rawValue: CInterop.InAddrT) { + self.rawValue = rawValue + } + } + + /// The 32-bit IPv4 address. + @_alwaysEmitIntoClient + public var address: Address { + get { + let value = CInterop.InAddrT(_networkOrder: rawValue.sin_addr.s_addr) + return Address(rawValue: value) + } + set { + rawValue.sin_addr.s_addr = newValue.rawValue._networkOrder + } + } + + /// The port number associated with this address. + @_alwaysEmitIntoClient + public var port: SocketAddress.Port { + get { SocketAddress.Port(CInterop.InPort(_networkOrder: rawValue.sin_port)) } + set { rawValue.sin_port = newValue.rawValue._networkOrder } + } +} + +extension SocketAddress.IPv4.Address { + /// The IPv4 address 0.0.0.0. + /// + /// This corresponds to the C constant `INADDR_ANY`. + @_alwaysEmitIntoClient + public static var any: Self { Self(rawValue: INADDR_ANY) } + + /// The IPv4 loopback address 127.0.0.1. + /// + /// This corresponds to the C constant `INADDR_LOOPBACK`. + @_alwaysEmitIntoClient + public static var loopback: Self { Self(rawValue: INADDR_LOOPBACK) } + + /// The IPv4 broadcast address 255.255.255.255. + /// + /// This corresponds to the C constant `INADDR_BROADCAST`. + @_alwaysEmitIntoClient + public static var broadcast: Self { Self(rawValue: INADDR_BROADCAST) } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv4.Address: CustomStringConvertible { + public var description: String { _inet_ntop() } + + internal func _inet_ntop() -> String { + let addr = CInterop.InAddr(s_addr: rawValue._networkOrder) + return withUnsafeBytes(of: addr) { src in + String(_unsafeUninitializedCapacity: Int(INET_ADDRSTRLEN)) { dst in + dst.baseAddress!.withMemoryRebound( + to: CChar.self, + capacity: Int(INET_ADDRSTRLEN) + ) { dst in + let res = system_inet_ntop( + PF_INET, + src.baseAddress!, + dst, + CInterop.SockLen(INET_ADDRSTRLEN)) + if res == -1 { + let errno = Errno.current + fatalError("Failed to convert IPv4 address to string: \(errno)") + } + let length = system_strlen(dst) + assert(length <= INET_ADDRSTRLEN) + return length + } + } + } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv4.Address: LosslessStringConvertible { + public init?(_ description: String) { + guard let value = Self._inet_pton(description) else { return nil } + self = value + } + + internal static func _inet_pton(_ string: String) -> Self? { + string.withCString { ptr in + var addr = CInterop.InAddr() + let res = system_inet_pton(PF_INET, ptr, &addr) + guard res == 1 else { return nil } + return Self(rawValue: CInterop.InAddrT(_networkOrder: addr.s_addr)) + } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv4.Address: ExpressibleByStringLiteral { + public init(stringLiteral value: String) { + guard let address = Self(value) else { + preconditionFailure("'\(value)' is not a valid IPv4 address string") + } + self = address + } +} diff --git a/Sources/SystemSockets/Sockets/SocketAddress+IPv6.swift b/Sources/SystemSockets/Sockets/SocketAddress+IPv6.swift new file mode 100644 index 00000000..59a206df --- /dev/null +++ b/Sources/SystemSockets/Sockets/SocketAddress+IPv6.swift @@ -0,0 +1,264 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +import Darwin +#elseif os(Linux) || os(FreeBSD) || os(Android) +import CSystem +import Glibc +#elseif os(Windows) +import CSystem +import ucrt +#else +#error("Unsupported Platform") +#endif + +import SystemPackage + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + /// An IPv6 address and port number. + @frozen + public struct IPv6: RawRepresentable { + @_alwaysEmitIntoClient + public var rawValue: CInterop.SockAddrIn6 + + @_alwaysEmitIntoClient + public init(rawValue: CInterop.SockAddrIn6) { + self.rawValue = rawValue + self.rawValue.sin6_family = Family.ipv6.rawValue + } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + /// Create a SocketAddress from an IPv6 address and port number. + @_alwaysEmitIntoClient + public init(_ address: IPv6) { + self = Swift.withUnsafeBytes(of: address.rawValue) { buffer in + SocketAddress(buffer) + } + } + + /// If `self` holds an IPv6 address, extract it, otherwise return `nil`. + @_alwaysEmitIntoClient + public var ipv6: IPv6? { + guard family == .ipv6 else { return nil } + let value: CInterop.SockAddrIn6? = self.withUnsafeBytes { buffer in + guard buffer.count >= MemoryLayout.size else { + return nil + } + return buffer.baseAddress!.load(as: CInterop.SockAddrIn6.self) + } + guard let value = value else { return nil } + return IPv6(rawValue: value) + } + + /// Construct a `SocketAddress` holding an IPv6 address and port + @_alwaysEmitIntoClient + public init(ipv6 address: IPv6.Address, port: Port) { + self.init(IPv6(address: address, port: port)) + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv6 { + /// Create a socket address from an IPv6 address and port number. + @_alwaysEmitIntoClient + public init(address: Address, port: SocketAddress.Port) { + // FIXME: We aren't modeling flowinfo & scope_id yet. + // If we need to do that, we can add new arguments or define new + // initializers/accessors later. + rawValue = CInterop.SockAddrIn6() + rawValue.sin6_family = SocketAddress.Family.ipv6.rawValue + rawValue.sin6_port = port.rawValue._networkOrder + rawValue.sin6_flowinfo = 0 + rawValue.sin6_addr = address.rawValue + rawValue.sin6_scope_id = 0 + } + + /// Create a socket address by parsing an IPv6 address from `address` and a + /// given port number. + @_alwaysEmitIntoClient + public init?(address: String, port: SocketAddress.Port) { + guard let address = Address(address) else { return nil } + self.init(address: address, port: port) + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv6: Hashable { + @_alwaysEmitIntoClient + public static func ==(left: Self, right: Self) -> Bool { + left.address == right.address + && left.port == right.port + && left.rawValue.sin6_flowinfo == right.rawValue.sin6_flowinfo + && left.rawValue.sin6_scope_id == right.rawValue.sin6_scope_id + } + + @_alwaysEmitIntoClient + public func hash(into hasher: inout Hasher) { + hasher.combine(port) + hasher.combine(rawValue.sin6_flowinfo) + hasher.combine(address) + hasher.combine(rawValue.sin6_scope_id) + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv6: CustomStringConvertible { + public var description: String { + "[\(address)]:\(port)" + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv6 { + /// The port number associated with this address. + @_alwaysEmitIntoClient + public var port: SocketAddress.Port { + get { SocketAddress.Port(CInterop.InPort(_networkOrder: rawValue.sin6_port)) } + set { rawValue.sin6_port = newValue.rawValue._networkOrder } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv6 { + /// A 128-bit IPv6 address. + @frozen + public struct Address: RawRepresentable { + /// The raw IPv6 address value. (16 bytes in network byte order.) + @_alwaysEmitIntoClient + public var rawValue: CInterop.In6Addr + + @_alwaysEmitIntoClient + public init(rawValue: CInterop.In6Addr) { + self.rawValue = rawValue + } + } + + /// The 128-bit IPv6 address. + @_alwaysEmitIntoClient + public var address: Address { + get { Address(rawValue: rawValue.sin6_addr) } + set { rawValue.sin6_addr = newValue.rawValue } + } +} + +extension SocketAddress.IPv6.Address { + /// The IPv6 address "::" (i.e., all zero). + /// + /// This corresponds to the C constant `IN6ADDR_ANY_INIT`. + @_alwaysEmitIntoClient + public static var any: Self { + Self(rawValue: CInterop.In6Addr()) + } + + /// The IPv6 loopback address "::1". + /// + /// This corresponds to the C constant `IN6ADDR_LOOPBACK_INIT`. + @_alwaysEmitIntoClient + public static var loopback: Self { + var addr = CInterop.In6Addr() + addr.__u6_addr.__u6_addr8.15 = 1 + return Self(rawValue: addr) + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv6.Address { + /// Create a 128-bit IPv6 address from raw bytes in memory. + @_alwaysEmitIntoClient + public init(bytes: UnsafeRawBufferPointer) { + precondition(bytes.count == MemoryLayout.size) + var addr = CInterop.In6Addr() + withUnsafeMutableBytes(of: &addr) { target in + target.baseAddress!.copyMemory( + from: bytes.baseAddress!, + byteCount: bytes.count) + } + self.rawValue = addr + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv6.Address: Hashable { + @_alwaysEmitIntoClient + public static func ==(left: Self, right: Self) -> Bool { + let l = left.rawValue.__u6_addr.__u6_addr32 + let r = right.rawValue.__u6_addr.__u6_addr32 + return l.0 == r.0 && l.1 == r.1 && l.2 == r.2 && l.3 == r.3 + } + + @_alwaysEmitIntoClient + public func hash(into hasher: inout Hasher) { + let t = rawValue.__u6_addr.__u6_addr32 + hasher.combine(t.0) + hasher.combine(t.1) + hasher.combine(t.2) + hasher.combine(t.3) + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv6.Address: CustomStringConvertible { + public var description: String { _inet_ntop() } + + internal func _inet_ntop() -> String { + return withUnsafeBytes(of: rawValue) { src in + String(_unsafeUninitializedCapacity: Int(INET6_ADDRSTRLEN)) { dst in + dst.baseAddress!.withMemoryRebound( + to: CChar.self, + capacity: Int(INET6_ADDRSTRLEN) + ) { dst in + let res = system_inet_ntop( + PF_INET6, + src.baseAddress!, + dst, + CInterop.SockLen(INET6_ADDRSTRLEN)) + if res == -1 { + let errno = Errno.current + fatalError("Failed to convert IPv6 address to string: \(errno)") + } + let length = system_strlen(dst) + assert(length <= INET6_ADDRSTRLEN) + return length + } + } + } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv6.Address: LosslessStringConvertible { + public init?(_ description: String) { + guard let value = Self._inet_pton(description) else { return nil } + self = value + } + + internal static func _inet_pton(_ string: String) -> Self? { + string.withCString { ptr in + var addr = CInterop.In6Addr() + let res = system_inet_pton(PF_INET6, ptr, &addr) + guard res == 1 else { return nil } + return Self(rawValue: addr) + } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.IPv6.Address: ExpressibleByStringLiteral { + public init(stringLiteral value: String) { + guard let address = Self(value) else { + preconditionFailure("'\(value)' is not a valid IPv6 address string") + } + self = address + } +} diff --git a/Sources/SystemSockets/Sockets/SocketAddress+Local.swift b/Sources/SystemSockets/Sockets/SocketAddress+Local.swift new file mode 100644 index 00000000..d8c9ab02 --- /dev/null +++ b/Sources/SystemSockets/Sockets/SocketAddress+Local.swift @@ -0,0 +1,87 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +import SystemPackage + +private var _pathOffset: Int { + // FIXME: If this isn't just a constant, use `offsetof` in C. + MemoryLayout.offset(of: \CInterop.SockAddrUn.sun_path)! +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + /// A socket address in the local (a.k.a. UNIX) domain, for inter-process + /// communication on the same machine. + /// + /// The corresponding C type is `sockaddr_un`. + public struct Local: Hashable { + internal let _path: FilePath + + /// Creates a socket address in the local (a.k.a. UNIX) domain, + /// for inter-process communication on the same machine. + /// + /// The corresponding C type is `sockaddr_un`. + public init(_ path: FilePath) { self._path = path } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + /// Create a `SocketAddress` from a local (i.e. UNIX domain) socket address. + public init(_ local: Local) { + let offset = _pathOffset + let length = offset + local._path.length + 1 + self.init(unsafeUninitializedCapacity: length) { target in + let addr = target.baseAddress!.assumingMemoryBound(to: CInterop.SockAddr.self) + addr.pointee.sa_len = UInt8(exactly: length) ?? 255 + addr.pointee.sa_family = CInterop.SAFamily(Family.local.rawValue) + let path = (target.baseAddress! + offset) + .assumingMemoryBound(to: CInterop.PlatformChar.self) + local._path.withPlatformString { + path.initialize(from: $0, count: local._path.length) + } + target[length-1] = 0 + return length + } + } + + /// If `self` holds a local address, extract it, otherwise return `nil`. + public var local: Local? { + guard family == .local else { return nil } + let path: FilePath? = self.withUnsafeBytes { buffer in + guard buffer.count >= _pathOffset + 1 else { + return nil + } + let path = (buffer.baseAddress! + _pathOffset) + .assumingMemoryBound(to: CInterop.PlatformChar.self) + return FilePath(platformString: path) + } + guard path != nil else { return nil } + return Local(path!) + } + + /// Construct an address in the Local domain from the given file path. + @_alwaysEmitIntoClient + public init(local path: FilePath) { + self.init(Local(path)) + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.Local { + /// The path representing the socket name in the local filesystem namespace. + /// + /// The corresponding C struct member is `sun_path`. + public var path: FilePath { _path } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.Local: CustomStringConvertible { + public var description: String { _path.description } +} diff --git a/Sources/SystemSockets/Sockets/SocketAddress+Resolution.swift b/Sources/SystemSockets/Sockets/SocketAddress+Resolution.swift new file mode 100644 index 00000000..bb7524ce --- /dev/null +++ b/Sources/SystemSockets/Sockets/SocketAddress+Resolution.swift @@ -0,0 +1,576 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +import Darwin +#elseif os(Linux) || os(FreeBSD) || os(Android) +import CSystem +import Glibc +#elseif os(Windows) +import CSystem +import ucrt +#else +#error("Unsupported Platform") +#endif + +import SystemPackage + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + /// Information about a resolved address. + /// + /// The members of this struct can be passed directly to + /// `SocketDescriptor.open()`, `SocketDescriptor.connect()` + /// or `SocketDescriptor.bind()` to initiate connections. + /// + /// This loosely corresponds to the C `struct addrinfo`. + public struct Info { + /// Address family. + public var family: Family { Family(rawValue: CInterop.SAFamily(domain.rawValue)) } + + /// Communications domain. + public let domain: SocketDescriptor.Domain + + /// Socket type. + public let type: SocketDescriptor.ConnectionType + /// Protocol for socket. + public let `protocol`: SocketDescriptor.ProtocolID + /// Socket address. + public let address: SocketAddress + /// Canonical name for service location. + public let canonicalName: String? + + internal init( + domain: SocketDescriptor.Domain, + type: SocketDescriptor.ConnectionType, + protocol: SocketDescriptor.ProtocolID, + address: SocketAddress, + canonicalName: String? = nil + ) { + self.domain = domain + self.type = type + self.protocol = `protocol` + self.address = address + self.canonicalName = canonicalName + } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + /// Name resolution flags. + @frozen + public struct NameResolverFlags: + OptionSet, RawRepresentable, CustomStringConvertible + { + @_alwaysEmitIntoClient + public let rawValue: CInt + + @_alwaysEmitIntoClient + public init(rawValue: CInt) { + self.rawValue = rawValue + } + + @_alwaysEmitIntoClient + private init(_ raw: CInt) { + self.init(rawValue: raw) + } + + @_alwaysEmitIntoClient + public init() { + self.rawValue = 0 + } + + /// Return IPv4 (or IPv6) addresses only if the local system is + /// configured with an IPv4 (or IPv6) address of its own. + /// + /// This corresponds to the C constant `AI_ADDRCONFIG`. + @_alwaysEmitIntoClient + public static var configuredAddress: Self { Self(AI_ADDRCONFIG) } + + /// If `.ipv4Mapped` is also present, then return also return all + /// matching IPv4 addresses in addition to IPv6 addresses. + /// + /// If `.ipv4Mapped` is not present, then this flag is ignored. + /// + /// This corresponds to the C constant `AI_ALL`. + @_alwaysEmitIntoClient + public static var all: Self { Self(AI_ALL) } + + /// If this flag is present, then name resolution returns the canonical + /// name of the specified hostname in the `canonicalName` field of the + /// first `Info` structure of the returned array. + /// + /// This corresponds to the C constant `AI_CANONNAME`. + @_alwaysEmitIntoClient + public static var canonicalName: Self { Self(AI_CANONNAME) } + + /// Indicates that the specified hostname string contains an IPv4 or + /// IPv6 address in numeric string representation. No name resolution + /// will be attempted. + /// + /// This corresponds to the C constant `AI_NUMERICHOST`. + @_alwaysEmitIntoClient + public static var numericHost: Self { Self(AI_NUMERICHOST) } + + /// Indicates that the specified service string contains a numerical port + /// value. This prevents having to resolve the port number using a + /// resolution service. + /// + /// This corresponds to the C constant `AI_NUMERICSERV`. + @_alwaysEmitIntoClient + public static var numericService: Self { Self(AI_NUMERICSERV) } + + /// Indicates that the returned address is intended for use in + /// a call to `SocketDescriptor.bind()`. In this case, a + /// `nil` hostname resolves to `SocketAddress.IPv4.Address.any` or + /// `SocketAddress.IPv6.Address.any`. + /// + /// If this flag is not present, the returned socket address will be ready + /// for use as the recipient address in a call to `connect()` or + /// `sendMessage()`. In this case a `nil` hostname resolves to + /// `SocketAddress.IPv4.Address.loopback`, or + /// `SocketAddress.IPv6.Address.loopback`. + /// + /// This corresponds to the C constant `AI_PASSIVE`. + @_alwaysEmitIntoClient + public static var passive: Self { Self(AI_PASSIVE) } + + /// This flag indicates that name resolution should return IPv4-mapped + /// IPv6 addresses if no matching IPv6 addresses are found. + /// + /// This flag is ignored unless resolution is performed with the IPv6 + /// family. + /// + /// This corresponds to the C constant `AI_V4MAPPED`. + @_alwaysEmitIntoClient + public static var ipv4Mapped: Self { Self(AI_V4MAPPED) } + + /// This behaves the same as `.ipv4Mapped` if the kernel supports + /// IPv4-mapped IPv6 addresses. Otherwise this flag is ignored. + /// + /// This corresponds to the C constant `AI_V4MAPPED_CFG`. + @_alwaysEmitIntoClient + public static var ipv4MappedIfSupported: Self { Self(AI_V4MAPPED_CFG) } + + /// This is the combination of the flags + /// `.ipv4MappedIfSupported` and `.configuredAddress`, + /// used by default if no flags are specified. + /// + /// This behavior can be overridden by setting the `.unusable` flag. + /// + /// This corresponds to the C constant `AI_DEFAULT`. + @_alwaysEmitIntoClient + public static var `default`: Self { Self(AI_DEFAULT) } + + /// Adding this flag suppresses the implicit default setting of + /// `.ipv4MappedIfSupported` and `.configuredAddress` for an empty `Flags` + /// value, allowing unusuable addresses to be included in the results. + /// + /// This corresponds to the C constant `AI_UNUSABLE`. + @_alwaysEmitIntoClient + public static var unusable: Self { Self(AI_UNUSABLE) } + + public var description: String { + let descriptions: [(Element, StaticString)] = [ + (.configuredAddress, ".configuredAddress"), + (.all, ".all"), + (.canonicalName, ".canonicalName"), + (.numericHost, ".numericHost"), + (.numericService, ".numericService"), + (.passive, ".passive"), + (.ipv4Mapped, ".ipv4Mapped"), + (.ipv4MappedIfSupported, ".ipv4MappedIfSupported"), + (.default, ".default"), + (.unusable, ".unusable"), + ] + return _buildDescription(descriptions) + } + + } + + /// Address resolution flags. + @frozen + public struct AddressResolverFlags: + OptionSet, RawRepresentable, CustomStringConvertible + { + @_alwaysEmitIntoClient + public let rawValue: CInt + + @_alwaysEmitIntoClient + public init(rawValue: CInt) { + self.rawValue = rawValue + } + + @_alwaysEmitIntoClient + private init(_ raw: CInt) { + self.init(rawValue: raw) + } + + @_alwaysEmitIntoClient + public init() { + self.rawValue = 0 + } + + /// A fully qualified domain name is not required for local hosts. + /// + /// This corresponds to the C constant `NI_NOFQDN`. + @_alwaysEmitIntoClient + public static var noFullyQualifiedDomain: Self { Self(NI_NOFQDN) } + + /// Return the address in numeric form, instead of a host name. + /// + /// This corresponds to the C constant `NI_NUMERICHOST`. + @_alwaysEmitIntoClient + public static var numericHost: Self { Self(NI_NUMERICHOST) } + + /// Indicates that a name is required; if the host name cannot be found, + /// an error will be thrown. If this option is not present, then a + /// numerical address is returned. + /// + /// This corresponds to the C constant `NI_NAMEREQD`. + @_alwaysEmitIntoClient + public static var nameRequired: Self { Self(NI_NAMEREQD) } + + /// The service name is returned as a digit string representing the port + /// number. + /// + /// This corresponds to the C constant `NI_NUMERICSERV`. + @_alwaysEmitIntoClient + public static var numericService: Self { Self(NI_NUMERICSERV) } + + /// Specifies that the service being looked up is a datagram service. + /// This is useful in case a port number is used for different services + /// over TCP & UDP. + /// + /// This corresponds to the C constant `NI_DGRAM`. + @_alwaysEmitIntoClient + public static var datagram: Self { Self(NI_DGRAM) } + + /// Enable IPv6 address notation with scope identifiers. + /// + /// This corresponds to the C constant `NI_WITHSCOPEID`. + @_alwaysEmitIntoClient + public static var scopeIdentifier: Self { Self(NI_WITHSCOPEID) } + + public var description: String { + let descriptions: [(Element, StaticString)] = [ + (.noFullyQualifiedDomain, ".noFullyQualifiedDomain"), + (.numericHost, ".numericHost"), + (.nameRequired, ".nameRequired"), + (.numericService, ".numericService"), + (.datagram, ".datagram"), + (.scopeIdentifier, ".scopeIdentifier"), + ] + return _buildDescription(descriptions) + } + + } +} + +extension SocketAddress { + /// An address resolution failure. + /// + /// This corresponds to the error returned by the C function `getaddrinfo`. + @frozen + public struct ResolverError + : Error, RawRepresentable, Hashable, CustomStringConvertible + { + @_alwaysEmitIntoClient + public var rawValue: CInt + + @_alwaysEmitIntoClient + public init(rawValue: CInt) { + self.rawValue = rawValue + } + + @_alwaysEmitIntoClient + private init(_ raw: CInt) { + self.init(rawValue: raw) + } + + // Use "hidden" entry points for `NSError` bridging + @_alwaysEmitIntoClient + public var _code: Int { Int(rawValue) } + + @_alwaysEmitIntoClient + public var _domain: String { + // FIXME: See if there is an existing domain for these. + "System.SocketAddress.ResolverError" + } + + public var description: String { + String(cString: system_gai_strerror(rawValue)) + } + + @_alwaysEmitIntoClient + public static func ~=(_ lhs: ResolverError, _ rhs: Error) -> Bool { + guard let value = rhs as? ResolverError else { return false } + return lhs == value + } + + /// Address family not supported for the specific hostname. + /// + /// The corresponding C constant is `EAI_ADDRFAMILY`. + @_alwaysEmitIntoClient + public static var unsupportedAddressFamilyForHost: Self { Self(EAI_ADDRFAMILY) } + + /// Temporary failure in name resolution. + /// + /// The corresponding C constant is `EAI_AGAIN`. + @_alwaysEmitIntoClient + public static var temporaryFailure: Self { Self(EAI_AGAIN) } + + /// Invalid resolver flags. + /// + /// The corresponding C constant is `EAI_BADFLAGS`. + @_alwaysEmitIntoClient + public static var badFlags: Self { Self(EAI_BADFLAGS) } + + /// Non-recoverable failure in name resolution. + /// + /// The corresponding C constant is `EAI_FAIL`. + @_alwaysEmitIntoClient + public static var nonrecoverableFailure: Self { Self(EAI_FAIL) } + + /// Unsupported address family. + /// + /// The corresponding C constant is `EAI_FAMILY`. + @_alwaysEmitIntoClient + public static var unsupportedAddressFamily: Self { Self(EAI_FAMILY) } + + /// Memory allocation failure. + /// + /// The corresponding C constant is `EAI_MEMORY`. + @_alwaysEmitIntoClient + public static var memoryAllocation: Self { Self(EAI_MEMORY) } + + /// No data associated with hostname. + /// + /// The corresponding C constant is `EAI_NODATA`. + @_alwaysEmitIntoClient + public static var noData: Self { Self(EAI_NODATA) } + + /// Hostname nor service name provided, or not known. + /// + /// The corresponding C constant is `EAI_NONAME`. + @_alwaysEmitIntoClient + public static var noName: Self { Self(EAI_NONAME) } + + /// Service name not supported for specified socket type. + /// + /// The corresponding C constant is `EAI_SERVICE`. + @_alwaysEmitIntoClient + public static var unsupportedServiceForSocketType: Self { Self(EAI_SERVICE) } + + /// Socket type not supported. + /// + /// The corresponding C constant is `EAI_SOCKTYPE`. + @_alwaysEmitIntoClient + public static var unsupportedSocketType: Self { Self(EAI_SOCKTYPE) } + + /// System error. + /// + /// The corresponding C constant is `EAI_SYSTEM`. + @_alwaysEmitIntoClient + public static var system: Self { Self(EAI_SYSTEM) } + + /// Invalid hints. + /// + /// The corresponding C constant is `EAI_BADHINTS`. + @_alwaysEmitIntoClient + public static var badHints: Self { Self(EAI_BADHINTS) } + + /// Unsupported protocol value. + /// + /// The corresponding C constant is `EAI_PROTOCOL`. + @_alwaysEmitIntoClient + public static var unsupportedProtocol: Self { Self(EAI_PROTOCOL) } + + /// Argument buffer overflow. + /// + /// The corresponding C constant is `EAI_OVERFLOW`. + @_alwaysEmitIntoClient + public static var overflow: Self { Self(EAI_OVERFLOW) } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + /// Get a list of IP addresses and port numbers for a host and service. + /// + /// On failure, this throws either a `ResolverError` or an `Errno`, + /// depending on the error code returned by the underlying `getaddrinfo` + /// function. + /// + /// The method corresponds to the C function `getaddrinfo`. + public static func resolveName( + hostname: String?, + service: String?, + flags: NameResolverFlags? = nil, + family: Family? = nil, + type: SocketDescriptor.ConnectionType? = nil, + protocol: SocketDescriptor.ProtocolID? = nil + ) throws -> [Info] { + // Note: I'm assuming getaddrinfo will never fail with EINTR. + // It it turns out it can, we should add a `retryIfInterrupted` argument. + let result = _resolve( + hostname: hostname, + service: service, + flags: flags, + family: family, + type: type, + protocol: `protocol`) + if let error = result.error { + if let errno = error.1 { throw errno} + throw error.0 + } + return result.results + } + + /// The method corresponds to the C function `getaddrinfo`. + internal static func _resolve( + hostname: String?, + service: String?, + flags: NameResolverFlags?, + family: Family?, + type: SocketDescriptor.ConnectionType?, + protocol: SocketDescriptor.ProtocolID? + ) -> (results: [Info], error: (Error, Errno?)?) { + var hints: CInterop.AddrInfo = CInterop.AddrInfo() + var haveHints = false + if let flags = flags { + hints.ai_flags = flags.rawValue + haveHints = true + } + if let family = family { + hints.ai_family = CInt(family.rawValue) + haveHints = true + } + if let type = type { + hints.ai_socktype = type.rawValue + haveHints = true + } + if let proto = `protocol` { + hints.ai_protocol = proto.rawValue + haveHints = true + } + + var entries: UnsafeMutablePointer? = nil + let error = _withOptionalUnsafePointerOrNull( + to: haveHints ? hints : nil + ) { hints in + _getaddrinfo(hostname, service, hints, &entries) + } + + // Handle errors. + if let error = error { + return ([], error) + } + + // Count number of entries. + var count = 0 + var p: UnsafeMutablePointer? = entries + while let entry = p { + count += 1 + p = entry.pointee.ai_next + } + + // Convert entries to `Info`. + var result: [Info] = [] + result.reserveCapacity(count) + p = entries + while let entry = p { + let info = Info( + domain: SocketDescriptor.Domain(entry.pointee.ai_family), + type: SocketDescriptor.ConnectionType(entry.pointee.ai_socktype), + protocol: SocketDescriptor.ProtocolID(entry.pointee.ai_protocol), + address: SocketAddress(address: entry.pointee.ai_addr, + length: entry.pointee.ai_addrlen), + canonicalName: entry.pointee.ai_canonname.map { String(cString: $0) }) + result.append(info) + p = entry.pointee.ai_next + } + + // Release resources. + system_freeaddrinfo(entries) + + return (result, nil) + } + + internal static func _getaddrinfo( + _ hostname: UnsafePointer?, + _ servname: UnsafePointer?, + _ hints: UnsafePointer?, + _ res: inout UnsafeMutablePointer? + ) -> (ResolverError, Errno?)? { + let r = system_getaddrinfo(hostname, servname, hints, &res) + if r == 0 { return nil } + let error = ResolverError(rawValue: r) + if error == .system { + return (error, Errno.current) + } + return (error, nil) + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + /// Resolve a socket address to hostname and service name. + /// + /// On failure, this throws either a `ResolverError` or an `Errno`, + /// depending on the error code returned by the underlying `getnameinfo` + /// function. + /// + /// This method corresponds to the C function `getnameinfo`. + public static func resolveAddress( + _ address: SocketAddress, + flags: AddressResolverFlags = [] + ) throws -> (hostname: String, service: String) { + let (result, error) = _resolveAddress(address, flags) + if let error = error { + if let errno = error.1 { throw errno } + throw error.0 + } + return result + } + + internal static func _resolveAddress( + _ address: SocketAddress, + _ flags: AddressResolverFlags + ) -> (results: (hostname: String, service: String), error: (ResolverError, Errno?)?) { + address.withUnsafeCInterop { adr, adrlen in + var r: CInt = 0 + var service: String = "" + let host = String(_unsafeUninitializedCapacity: Int(NI_MAXHOST)) { host in + let h = UnsafeMutableRawPointer(host.baseAddress!) + .assumingMemoryBound(to: CChar.self) + service = String(_unsafeUninitializedCapacity: Int(NI_MAXSERV)) { serv in + let s = UnsafeMutableRawPointer(serv.baseAddress!) + .assumingMemoryBound(to: CChar.self) + r = system_getnameinfo( + adr, adrlen, + h, CInterop.SockLen(host.count), + s, CInterop.SockLen(serv.count), + flags.rawValue) + if r != 0 { return 0 } + return system_strlen(s) + } + if r != 0 { return 0 } + return system_strlen(h) + } + var error: (ResolverError, Errno?)? = nil + if r != 0 { + let err = ResolverError(rawValue: r) + error = (err, err == .system ? Errno.current : nil) + } + return ((host, service), error) + } + } +} diff --git a/Sources/SystemSockets/Sockets/SocketAddress.swift b/Sources/SystemSockets/Sockets/SocketAddress.swift new file mode 100644 index 00000000..37cd47c3 --- /dev/null +++ b/Sources/SystemSockets/Sockets/SocketAddress.swift @@ -0,0 +1,363 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +import SystemPackage + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +/// An opaque type representing a socket address in some address family, +/// such as an IP address along with a port number. +/// +/// `SocketAddress` values can be passed directly to `SocketDescriptor.connect` +/// or `.bind` to establish a network connection. +/// +/// We can use the `SocketAddress.resolveName` method resolve a pair of +/// host/service name strings to a list of socket addresses: +/// +/// let results = +/// try SocketAddress.resolveName(hostname: "swift.org", service: "https") +/// for result in results { +/// let try socket = +/// SocketDescriptor.open(result.domain, result.type, result.protocol) +/// do { +/// try socket.connect(to: result.address) +/// } catch { +/// try? socket.close() +/// throw error +/// } +/// return socket +/// } +/// +/// To create an IPv4, IPv6 or Local domain address, we can use convenience +/// initializers that take the corresponding information: +/// +/// let ipv4 = SocketAddress(ipv4: "127.0.0.1", port: 8080)! +/// let ipv6 = SocketAddress(ipv6: "::1", port: 80)! +/// let local = SocketAddress(local: "/var/run/example.sock") +/// +/// (Note that you may prefer to use the concrete address types +/// `SocketAddress.IPv4`, `SocketAddress.IPv6` and `SocketAddress.Local` +/// instead -- they provide easy access to the address parameters.) +/// +/// `SocketAddress` also provides ways to access its underlying contents +/// as a raw unsafe memory buffer. This is useful for dealing with address +/// families that `System` doesn't model, or for passing the socket address +/// to C functions that expect a pointer to a `sockaddr` value. +/// +/// `SocketAddress` stores its contents in a managed storage buffer, and +/// it can serve as a reusable receptacle for addresses that are returned +/// by system calls. You can use the `init(minimumCapacity:)` initializer +/// to create an empty socket address with the specified storage capacity, +/// then you can pass it to functions like `.accept(client:)` to retrieve +/// addresses without repeatedly allocating memory. +/// +/// `SocketAddress` is able to hold any IPv4 or IPv6 address without allocating +/// any memory. For other address families, it may need to heap allocate a +/// storage buffer, depending on the size of the stored value. +/// +/// The corresponding C type is `sockaddr_t`. +public struct SocketAddress { + internal var _variant: _Variant + + /// Create an address from raw bytes + public init( + address: UnsafePointer, + length: CInterop.SockLen + ) { + self.init(UnsafeRawBufferPointer(start: address, count: Int(length))) + } + + /// Create an address from raw bytes + public init(_ buffer: UnsafeRawBufferPointer) { + self.init(unsafeUninitializedCapacity: buffer.count) { target in + target.baseAddress!.copyMemory( + from: buffer.baseAddress!, + byteCount: buffer.count) + return buffer.count + } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + /// Initialize an empty socket address with the specified minimum capacity. + /// The default capacity makes enough storage space to fit any IPv4/IPv6 address. + /// + /// Addresses with storage preallocated this way can be repeatedly passed to + /// `SocketDescriptor.receiveMessage`, eliminating the need for a potential + /// allocation each time it is called. + public init(minimumCapacity: Int = Self.defaultCapacity) { + self.init(unsafeUninitializedCapacity: minimumCapacity) { buffer in + system_memset(buffer, to: 0) + return 0 + } + } + + /// Reserve storage capacity such that `self` is able to store addresses + /// of at least `minimumCapacity` bytes without any additional allocation. + /// + /// Addresses with storage preallocated this way can be repeatedly passed to + /// `SocketDescriptor.receiveMessage`, eliminating the need for a potential + /// allocation each time it is called. + public mutating func reserveCapacity(_ minimumCapacity: Int) { + guard minimumCapacity > _capacity else { return } + let length = _length + var buffer = _RawBuffer(minimumCapacity: minimumCapacity) + buffer.withUnsafeMutableBytes { target in + self.withUnsafeBytes { source in + assert(source.count == length) + assert(target.count > source.count) + if source.count > 0 { + target.baseAddress!.copyMemory( + from: source.baseAddress!, + byteCount: source.count) + } + } + } + self._variant = .large(length: length, bytes: buffer) + } + + /// Reset this address value to an empty address of unspecified family, + /// filling the underlying storage with zero bytes. + public mutating func clear() { + self._withUnsafeMutableBytes(entireCapacity: true) { buffer in + system_memset(buffer, to: 0) + } + self._length = 0 + } + + /// Creates a socket address with the specified capacity, then calls the + /// given closure with a buffer covering the socket address's uninitialized + /// memory. + public init( + unsafeUninitializedCapacity capacity: Int, + initializingWith body: (UnsafeMutableRawBufferPointer) throws -> Int + ) rethrows { + if capacity <= MemoryLayout<_InlineStorage>.size { + var storage = _InlineStorage() + let length: Int = try withUnsafeMutableBytes(of: &storage) { bytes in + let buffer = UnsafeMutableRawBufferPointer(rebasing: bytes[..= 0 && length <= capacity) + self._variant = .small(length: UInt8(length), bytes: storage) + } else { + var buffer = _RawBuffer(minimumCapacity: capacity) + let count = try buffer.withUnsafeMutableBytes { target in + try body(target) + } + precondition(count >= 0 && count <= capacity) + self._variant = .large(length: count, bytes: buffer) + } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + /// A default capacity, with enough storage space to fit any IPv4/IPv6 address. + public static var defaultCapacity: Int { MemoryLayout<_InlineStorage>.size } + + @_alignment(8) // This must be large enough to cover any sockaddr variant + internal struct _InlineStorage { + /// A chunk of 28 bytes worth of integers, treated as inline storage for + /// short `sockaddr` values. + /// + /// Note: 28 bytes is just enough to cover socketaddr_in6 on Darwin. + /// The length of this struct may need to be adjusted on other platforms. + internal let bytes: (UInt32, UInt32, UInt32, UInt32, UInt32, UInt32, UInt32) + + internal init() { + bytes = (0, 0, 0, 0, 0, 0, 0) + } + } +} + +extension SocketAddress { + internal enum _Variant { + case small(length: UInt8, bytes: _InlineStorage) + case large(length: Int, bytes: _RawBuffer) + } + + internal var _length: Int { + get { + switch _variant { + case let .small(length: length, bytes: _): + return Int(length) + case let .large(length: length, bytes: _): + return length + } + } + set { + assert(newValue <= _capacity) + switch _variant { + case let .small(length: _, bytes: bytes): + self._variant = .small(length: UInt8(newValue), bytes: bytes) + case let .large(length: _, bytes: bytes): + self._variant = .large(length: newValue, bytes: bytes) + } + } + } + + internal var _capacity: Int { + switch _variant { + case .small(length: _, bytes: _): + return MemoryLayout<_InlineStorage>.size + case .large(length: _, bytes: let bytes): + return bytes.capacity + } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + /// Calls `body` with an unsafe raw buffer pointer to the raw bytes of this + /// address. This is useful when you need to pass an address to a function + /// that treats socket addresses as untyped raw data. + public func withUnsafeBytes( + _ body: (UnsafeRawBufferPointer) throws -> R + ) rethrows -> R { + switch _variant { + case let .small(length: length, bytes: bytes): + let length = Int(length) + assert(length <= MemoryLayout<_InlineStorage>.size) + return try Swift.withUnsafeBytes(of: bytes) { buffer in + try body(UnsafeRawBufferPointer(rebasing: buffer[..( + _ body: (UnsafePointer?, CInterop.SockLen) throws -> R + ) rethrows -> R { + try withUnsafeBytes { bytes in + let start = bytes.baseAddress?.assumingMemoryBound(to: CInterop.SockAddr.self) + let length = CInterop.SockLen(bytes.count) + if length >= MemoryLayout.size { + return try body(start, length) + } else { + return try body(nil, 0) + } + } + } + + internal mutating func _withUnsafeMutableBytes( + entireCapacity: Bool, + _ body: (UnsafeMutableRawBufferPointer) throws -> R + ) rethrows -> R { + switch _variant { + case .small(length: let length, bytes: var bytes): + assert(length <= MemoryLayout<_InlineStorage>.size) + defer { self._variant = .small(length: length, bytes: bytes) } + return try Swift.withUnsafeMutableBytes(of: &bytes) { buffer in + if entireCapacity { + return try body(buffer) + } else { + return try body(.init(rebasing: buffer[..( + entireCapacity: Bool, + _ body: ( + UnsafeMutablePointer?, + inout CInterop.SockLen + ) throws -> R + ) rethrows -> R { + let (result, length): (R, Int) = try _withUnsafeMutableBytes( + entireCapacity: true + ) { bytes in + let start = bytes.baseAddress?.assumingMemoryBound(to: CInterop.SockAddr.self) + var length = CInterop.SockLen(bytes.count) + let result = try body(start, &length) + precondition(length >= 0 && length <= bytes.count, "\(length) \(bytes.count)") + return (result, Int(length)) + } + self._length = length + return result + } + + + /// The address family identifier of this socket address. + public var family: Family { + withUnsafeCInterop { addr, length in + guard let addr = addr else { return .unspecified } + return Family(rawValue: addr.pointee.sa_family) + } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress: CustomStringConvertible { + public var description: String { + if let address = self.ipv4 { + return "SocketAddress(family: \(family), address: \(address))" + } + if let address = self.ipv6 { + return "SocketAddress(family: \(family), address: \(address))" + } + if let address = self.local { + return "SocketAddress(family: \(family), address: \(address))" + } + return "SocketAddress(family: \(family), \(self._length) bytes)" + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress { + @frozen + /// The port number on which the socket is listening. + public struct Port: RawRepresentable, ExpressibleByIntegerLiteral, Hashable { + /// The port number, in host byte order. + @_alwaysEmitIntoClient + public var rawValue: CInterop.InPort + + @_alwaysEmitIntoClient + public init(_ value: CInterop.InPort) { + self.rawValue = value + } + + @_alwaysEmitIntoClient + public init(rawValue: CInterop.InPort) { + self.init(rawValue) + } + + @_alwaysEmitIntoClient + public init(integerLiteral value: CInterop.InPort) { + self.init(value) + } + } +} +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketAddress.Port: CustomStringConvertible { + public var description: String { + rawValue.description + } +} diff --git a/Sources/SystemSockets/Sockets/SocketDescriptor.swift b/Sources/SystemSockets/Sockets/SocketDescriptor.swift new file mode 100644 index 00000000..b8001df3 --- /dev/null +++ b/Sources/SystemSockets/Sockets/SocketDescriptor.swift @@ -0,0 +1,369 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +import SystemPackage + +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +import Darwin +#elseif os(Linux) || os(FreeBSD) || os(Android) +import Glibc +#elseif os(Windows) +import ucrt +#else +#error("Unsupported Platform") +#endif + +// TODO: @available(...) + +// TODO: Windows uses a SOCKET type; doesn't have file descriptor +// equivalence + +/// TODO +@frozen +public struct SocketDescriptor: RawRepresentable, Hashable { + /// The raw C socket. + @_alwaysEmitIntoClient + public let rawValue: CInt + + /// Creates a strongly-typed socket from a raw C socket. + @_alwaysEmitIntoClient + public init(rawValue: CInt) { self.rawValue = rawValue } +} + +extension SocketDescriptor { + /// The file descriptor for `self`. + @_alwaysEmitIntoClient + public var fileDescriptor: FileDescriptor { + FileDescriptor(rawValue: rawValue) + } + + /// Treat `fd` as a socket descriptor, without checking with the operating + /// system that it actually refers to a socket. + @_alwaysEmitIntoClient + public init(unchecked fd: FileDescriptor) { + self.init(rawValue: fd.rawValue) + } +} + +extension SocketDescriptor { + /// Communications domain, identifying the protocol family that is being used. + @frozen + public struct Domain: RawRepresentable, Hashable, CustomStringConvertible { + @_alwaysEmitIntoClient + public var rawValue: CInt + + @_alwaysEmitIntoClient + public init(rawValue: CInt) { self.rawValue = rawValue } + + @_alwaysEmitIntoClient + internal init(_ rawValue: CInt) { self.init(rawValue: rawValue) } + + /// Unspecified protocol. + /// + /// The corresponding C constant is `PF_UNSPEC`. + @_alwaysEmitIntoClient + public static var unspecified: Domain { Domain(PF_UNSPEC) } + + /// Host-internal protocols, formerly called PF_UNIX. + /// + /// The corresponding C constant is `PF_LOCAL`. + @_alwaysEmitIntoClient + public static var local: Domain { Domain(PF_LOCAL) } + + @_alwaysEmitIntoClient + @available(*, unavailable, renamed: "local") + public static var unix: Domain { Domain(PF_UNIX) } + + /// Internet version 4 protocols. + /// + /// The corresponding C constant is `PF_INET`. + @_alwaysEmitIntoClient + public static var ipv4: Domain { Domain(PF_INET) } + + /// Internal Routing protocol. + /// + /// The corresponding C constant is `PF_ROUTE`. + @_alwaysEmitIntoClient + public static var routing: Domain { Domain(PF_ROUTE) } + + /// Internal key-management function. + /// + /// The corresponding C constant is `PF_KEY`. + @_alwaysEmitIntoClient + public static var keyManagement: Domain { Domain(PF_KEY) } + + /// Internet version 6 protocols. + /// + /// The corresponding C constant is `PF_INET6`. + @_alwaysEmitIntoClient + public static var ipv6: Domain { Domain(PF_INET6) } + + /// System domain. + /// + /// The corresponding C constant is `PF_SYSTEM`. + @_alwaysEmitIntoClient + public static var system: Domain { Domain(PF_SYSTEM) } + + /// Raw access to network device. + /// + /// The corresponding C constant is `PF_NDRV`. + @_alwaysEmitIntoClient + public static var networkDevice: Domain { Domain(PF_NDRV) } + + public var description: String { + switch self { + case .unspecified: return "unspecified" + case .local: return "local" + case .ipv4: return "ipv4" + case .ipv6: return "ipv6" + case .routing: return "routing" + case .keyManagement: return "keyManagement" + case .system: return "system" + case .networkDevice: return "networkDevice" + default: return rawValue.description + } + } + } + + /// The socket type, specifying the semantics of communication. + @frozen + public struct ConnectionType: RawRepresentable, Hashable, CustomStringConvertible { + @_alwaysEmitIntoClient + public var rawValue: CInt + + @_alwaysEmitIntoClient + public init(rawValue: CInt) { self.rawValue = rawValue } + + @_alwaysEmitIntoClient + internal init(_ rawValue: CInt) { self.init(rawValue: rawValue) } + + /// Sequenced, reliable, two-way connection based byte streams. + /// + /// The corresponding C constant is `SOCK_STREAM`. + @_alwaysEmitIntoClient + public static var stream: ConnectionType { ConnectionType(SOCK_STREAM) } + + /// Datagrams (connectionless, unreliable messages of a fixed (typically + /// small) maximum length). + /// + /// The corresponding C constant is `SOCK_DGRAM`. + @_alwaysEmitIntoClient + public static var datagram: ConnectionType { ConnectionType(SOCK_DGRAM) } + + /// Raw protocol interface. Only available to the super user. + /// + /// The corresponding C constant is `SOCK_RAW`. + @_alwaysEmitIntoClient + public static var raw: ConnectionType { ConnectionType(SOCK_RAW) } + + /// Reliably delivered message. + /// + /// The corresponding C constant is `SOCK_RDM`. + @_alwaysEmitIntoClient + public static var reliablyDeliveredMessage: ConnectionType { + ConnectionType(SOCK_RDM) + } + + /// Sequenced packet stream. + /// + /// The corresponding C constant is `SOCK_SEQPACKET`. + @_alwaysEmitIntoClient + public static var sequencedPacketStream: ConnectionType { + ConnectionType(SOCK_SEQPACKET) + } + + public var description: String { + switch self { + case .stream: return "stream" + case .datagram: return "datagram" + case .raw: return "raw" + case .reliablyDeliveredMessage: return "rdm" + case .sequencedPacketStream: return "seqpacket" + default: return rawValue.description + } + } + } + + /// Identifies a particular protocol to be used for communication. + /// + /// Note that protocol numbers are particular to the communication domain + /// that is being used. Accordingly, some of the symbolic names provided + /// here may have the same underlying value -- they are provided merely + /// for convenience. + @frozen + public struct ProtocolID: RawRepresentable, Hashable, CustomStringConvertible { + @_alwaysEmitIntoClient + public var rawValue: CInt + + @_alwaysEmitIntoClient + public init(rawValue: CInt) { self.rawValue = rawValue } + + @_alwaysEmitIntoClient + internal init(_ rawValue: CInt) { self.init(rawValue: rawValue) } + + /// Internet Protocol (IP). + /// + /// This corresponds to the C constant `IPPROTO_IP`. + @_alwaysEmitIntoClient + public static var ip: ProtocolID { Self(IPPROTO_IP) } + + /// Transmission Control Protocol (TCP). + /// + /// This corresponds to the C constant `IPPROTO_TCP`. + @_alwaysEmitIntoClient + public static var tcp: ProtocolID { Self(IPPROTO_TCP) } + + /// User Datagram Protocol (UDP). + /// + /// This corresponds to the C constant `IPPROTO_UDP`. + @_alwaysEmitIntoClient + public static var udp: ProtocolID { Self(IPPROTO_UDP) } + + /// IPv4 encapsulation. + /// + /// This corresponds to the C constant `IPPROTO_IPV4`. + @_alwaysEmitIntoClient + public static var ipv4: ProtocolID { Self(IPPROTO_IPV4) } + + /// IPv6 header. + /// + /// This corresponds to the C constant `IPPROTO_IPV6`. + @_alwaysEmitIntoClient + public static var ipv6: ProtocolID { Self(IPPROTO_IPV6) } + + /// Raw IP packet. + /// + /// This corresponds to the C constant `IPPROTO_RAW`. + @_alwaysEmitIntoClient + public static var raw: ProtocolID { Self(IPPROTO_RAW) } + + /// Special protocol value representing socket-level options. + /// + /// The corresponding C constant is `SOL_SOCKET`. + @_alwaysEmitIntoClient + public static var socketOption: ProtocolID { Self(SOL_SOCKET) } + + public var description: String { + // Note: Can't return symbolic names here -- values have multiple + // meanings based on the domain. + rawValue.description + } + } + + /// Message flags. + @frozen + public struct MessageFlags: OptionSet, CustomStringConvertible { + @_alwaysEmitIntoClient + public var rawValue: CInt + + @_alwaysEmitIntoClient + public init(rawValue: CInt) { self.rawValue = rawValue } + + @_alwaysEmitIntoClient + private init(_ raw: CInt) { self.init(rawValue: raw) } + + @_alwaysEmitIntoClient + public static var none: MessageFlags { MessageFlags(0) } + + /// Process out-of-band data. + /// + /// The corresponding C constant is `MSG_OOB`. + @_alwaysEmitIntoClient + public static var outOfBand: MessageFlags { MessageFlags(MSG_OOB) } + + /// Bypass routing, use direct interface. + /// + /// The corresponding C constant is `MSG_DONTROUTE`. + @_alwaysEmitIntoClient + public static var doNotRoute: MessageFlags { MessageFlags(MSG_DONTROUTE) } + + /// Peek at incoming message. + /// + /// The corresponding C constant is `MSG_PEEK`. + @_alwaysEmitIntoClient + public static var peek: MessageFlags { MessageFlags(MSG_PEEK) } + + /// Wait for full request or error. + /// + /// The corresponding C constant is `MSG_WAITALL`. + @_alwaysEmitIntoClient + public static var waitForAll: MessageFlags { MessageFlags(MSG_WAITALL) } + + /// End-of-record condition -- the associated data completed a + /// full record. + /// + /// The corresponding C constant is `MSG_EOR`. + @_alwaysEmitIntoClient + public static var endOfRecord: MessageFlags { MessageFlags(MSG_EOR) } + + /// Datagram was truncated because it didn't fit in the supplied + /// buffer. + /// + /// The corresponding C constant is `MSG_TRUNC`. + @_alwaysEmitIntoClient + public static var dataTruncated: MessageFlags { MessageFlags(MSG_TRUNC) } + + /// Some ancillary data was discarded because it didn't fit + /// in the supplied buffer. + /// + /// The corresponding C constant is `MSG_CTRUNC`. + @_alwaysEmitIntoClient + public static var ancillaryTruncated: MessageFlags { MessageFlags(MSG_CTRUNC) } + + public var description: String { + let descriptions: [(Element, StaticString)] = [ + (.outOfBand, ".outOfBand"), + (.doNotRoute, ".doNotRoute"), + (.peek, ".peek"), + (.waitForAll, ".waitForAll"), + (.endOfRecord, ".endOfRecord"), + (.dataTruncated, ".dataTruncated"), + (.ancillaryTruncated, ".ancillaryTruncated"), + ] + return _buildDescription(descriptions) + } + } + + /// Specify the part (or all) of a full-duplex connection to shutdown. + @frozen + public struct ShutdownKind: RawRepresentable, Hashable, CustomStringConvertible { + @_alwaysEmitIntoClient + public var rawValue: CInt + + @_alwaysEmitIntoClient + public init(rawValue: CInt) { self.rawValue = rawValue } + + /// Further receives will be disallowed + /// + /// The corresponding C constant is `SHUT_RD`. + @_alwaysEmitIntoClient + public static var read: ShutdownKind { ShutdownKind(rawValue: SHUT_RD) } + + /// Further sends will be disallowed + /// + /// The corresponding C constant is `SHUT_RD`. + @_alwaysEmitIntoClient + public static var write: ShutdownKind { ShutdownKind(rawValue: SHUT_WR) } + + /// Further sends and receives will be disallowed + /// + /// The corresponding C constant is `SHUT_RDWR`. + @_alwaysEmitIntoClient + public static var readWrite: ShutdownKind { ShutdownKind(rawValue: SHUT_RDWR) } + + public var description: String { + switch self { + case .read: return "read" + case .write: return "write" + case .readWrite: return "readWrite" + default: return rawValue.description + } + } + } +} diff --git a/Sources/SystemSockets/Sockets/SocketHelpers.swift b/Sources/SystemSockets/Sockets/SocketHelpers.swift new file mode 100644 index 00000000..e36d034e --- /dev/null +++ b/Sources/SystemSockets/Sockets/SocketHelpers.swift @@ -0,0 +1,43 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +extension SocketDescriptor { + /// Writes a sequence of bytes to the socket + /// + /// This is equivalent to calling `fileDescriptor.writeAll(_:)` + /// + /// - Parameter sequence: The bytes to write. + /// - Returns: The number of bytes written, equal to the number of elements in `sequence`. + @_alwaysEmitIntoClient + @discardableResult + public func writeAll( + _ sequence: S + ) throws -> Int where S.Element == UInt8 { + try fileDescriptor.writeAll(sequence) + } + + /// Runs a closure and then closes the socket, even if an error occurs. + /// + /// This is equivalent to calling `fileDescriptor.closeAfter(_:)` + /// + /// - Parameter body: The closure to run. + /// If the closure throws an error, + /// this method closes the socket before it rethrows that error. + /// + /// - Returns: The value returned by the closure. + /// + /// If `body` throws an error + /// or an error occurs while closing the socket, + /// this method rethrows that error. + @_alwaysEmitIntoClient + public func closeAfter(_ body: () throws -> R) throws -> R { + try fileDescriptor.closeAfter(body) + } +} + diff --git a/Sources/SystemSockets/Sockets/SocketMessages.swift b/Sources/SystemSockets/Sockets/SocketMessages.swift new file mode 100644 index 00000000..f65b916c --- /dev/null +++ b/Sources/SystemSockets/Sockets/SocketMessages.swift @@ -0,0 +1,324 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +import SystemPackage + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketDescriptor { + /// A reusable collection of variable-sized ancillary messages + /// sent or received over a socket. These represent protocol control + /// related messages or other miscellaneous ancillary data. + /// + /// Corresponds to a buffer of `struct cmsghdr` messages in C, as used + /// by `sendmsg` and `recmsg`. + public struct AncillaryMessageBuffer { + internal var _buffer: _RawBuffer + internal var _endOffset: Int + + /// Initialize a new empty ancillary message buffer with no preallocated + /// storage. + public init() { + _buffer = _RawBuffer() + _endOffset = 0 + } + + /// Initialize a new empty ancillary message buffer of the + /// specified minimum capacity (in bytes). + internal init(minimumCapacity: Int) { + let headerSize = MemoryLayout.size + let capacity = Swift.max(headerSize + 1, minimumCapacity) + _buffer = _RawBuffer(minimumCapacity: capacity) + _endOffset = 0 + } + + internal var _headerSize: Int { MemoryLayout.size } + internal var _capacity: Int { _buffer.capacity } + + /// Remove all messages currently in this buffer, preserving storage + /// capacity. + /// + /// This invalidates all indices in the collection. + /// + /// - Complexity: O(1). Does not reallocate the buffer. + public mutating func removeAll() { + _endOffset = 0 + } + + /// Reserve enough storage capacity to hold `minimumCapacity` bytes' worth + /// of messages without having to reallocate storage. + /// + /// This does not invalidate any indices. + /// + /// - Complexity: O(max(`minimumCapacity`, `capacity`)), where `capacity` is + /// the current storage capacity. This potentially needs to reallocate + /// the buffer and copy existing messages. + public mutating func reserveCapacity(_ minimumCapacity: Int) { + _buffer.ensureUnique(capacity: minimumCapacity) + } + + /// Append a message with the specified data to the end of this buffer, + /// resizing it if necessary. + /// + /// This does not invalidate any existing indices, but it updates `endIndex`. + /// + /// - Complexity: Amortized O(`data.count`), when averaged over multiple + /// calls. This method reallocates the buffer if there isn't enough + /// capacity or if the storage is shared with another value. + public mutating func appendMessage( + level: SocketDescriptor.ProtocolID, + type: SocketDescriptor.Option, + bytes: UnsafeRawBufferPointer + ) { + appendMessage( + level: level, + type: type, + unsafeUninitializedCapacity: bytes.count + ) { buffer in + assert(buffer.count >= bytes.count) + if bytes.count > 0 { + buffer.baseAddress!.copyMemory( + from: bytes.baseAddress!, + byteCount: bytes.count) + } + return bytes.count + } + } + + /// Append a message with the supplied data to the end of this buffer, + /// resizing it if necessary. The message payload is initialized with the + /// supplied closure, which needs to return the final message length. + /// + /// This does not invalidate any existing indices, but it updates `endIndex`. + /// + /// - Complexity: Amortized O(`data.count`), when averaged over multiple + /// calls. This method reallocates the buffer if there isn't enough + /// capacity or if the storage is shared with another value. + public mutating func appendMessage( + level: SocketDescriptor.ProtocolID, + type: SocketDescriptor.Option, + unsafeUninitializedCapacity capacity: Int, + initializingWith body: (UnsafeMutableRawBufferPointer) throws -> Int + ) rethrows { + precondition(capacity >= 0) + let headerSize = _headerSize + let delta = _headerSize + capacity + _buffer.ensureUnique(capacity: _endOffset + delta) + let messageLength: Int = try _buffer.withUnsafeMutableBytes { buffer in + assert(buffer.count >= _endOffset + delta) + let p = buffer.baseAddress! + _endOffset + let header = p.bindMemory(to: CInterop.CMsgHdr.self, capacity: 1) + header.pointee = CInterop.CMsgHdr() + header.pointee.cmsg_level = level.rawValue + header.pointee.cmsg_type = type.rawValue + let length = try body( + UnsafeMutableRawBufferPointer(start: p + headerSize, count: capacity)) + precondition(length >= 0 && length <= capacity) + header.pointee.cmsg_len = CInterop.SockLen(headerSize + length) + return headerSize + length + } + _endOffset += messageLength + } + + internal func _withUnsafeBytes( + _ body: (UnsafeRawBufferPointer) throws -> R + ) rethrows -> R { + try _buffer.withUnsafeBytes { buffer in + assert(buffer.count >= _endOffset) + let buffer = UnsafeRawBufferPointer(rebasing: buffer.prefix(_endOffset)) + return try body(buffer) + } + } + + internal mutating func _withUnsafeMutableBytes( + entireCapacity: Bool, + _ body: (UnsafeMutableRawBufferPointer) throws -> R + ) rethrows -> R { + return try _buffer.withUnsafeMutableBytes { buffer in + assert(buffer.count >= _endOffset) + if entireCapacity { + return try body(buffer) + } else { + return try body(.init(rebasing: buffer.prefix(_endOffset))) + } + } + } + + internal mutating func _withMutableCInterop( + entireCapacity: Bool, + _ body: (UnsafeMutableRawPointer?, inout CInterop.SockLen) throws -> R + ) rethrows -> R { + let (result, length): (R, Int) = try _withUnsafeMutableBytes( + entireCapacity: entireCapacity + ) { buffer in + var length = CInterop.SockLen(buffer.count) + let result = try body(buffer.baseAddress, &length) + precondition(length >= 0 && length <= buffer.count) + return (result, Int(length)) + } + _endOffset = length + return result + } + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketDescriptor.AncillaryMessageBuffer: Collection { + /// The index type in an ancillary message buffer. + @frozen + public struct Index: Comparable, Hashable { + @usableFromInline + var _offset: Int + + @inlinable + internal init(_offset: Int) { + self._offset = _offset + } + + @inlinable + public static func == (left: Self, right: Self) -> Bool { + left._offset == right._offset + } + + @inlinable + public static func < (left: Self, right: Self) -> Bool { + left._offset < right._offset + } + + @inlinable + public func hash(into hasher: inout Hasher) { + hasher.combine(_offset) + } + } + + /// An individual message inside an ancillary message buffer. + /// + /// Note that this is merely a reference to a slice of the underlying buffer, + /// so it contains a shared copy of its entire storage. To prevent buffer + /// reallocations due to copy-on-write copies, do not save instances + /// of this type. Instead, immediately copy out any data you need to hold onto + /// into standalone buffers. + public struct Message { + internal var _base: SocketDescriptor.AncillaryMessageBuffer + internal var _offset: Int + + internal init(_base: SocketDescriptor.AncillaryMessageBuffer, offset: Int) { + self._base = _base + self._offset = offset + } + } + + /// The index of the first message in the collection, or `endIndex` if + /// the collection contains no messages. + /// + /// This roughly corresponds to the C macro `CMSG_FIRSTHDR`. + public var startIndex: Index { Index(_offset: 0) } + + /// The index after the last message in the collection. + public var endIndex: Index { Index(_offset: _endOffset) } + + /// True if the collection contains no elements. + public var isEmpty: Bool { _endOffset == 0 } + + /// Return the length (in bytes) of the message at the specified index, or + /// nil if the index isn't valid, or it addresses a corrupt message. + internal func _length(at i: Index) -> Int? { + _withUnsafeBytes { buffer in + guard i._offset >= 0 && i._offset + _headerSize <= buffer.count else { + return nil + } + let p = (buffer.baseAddress! + i._offset) + .assumingMemoryBound(to: CInterop.CMsgHdr.self) + let length = Int(p.pointee.cmsg_len) + + // Cut the list short at the first sign of corrupt data. + // Messages must not be shorter than their header, and they must fit + // entirely in the buffer. + if length < _headerSize || i._offset + length > buffer.count { + return nil + } + return length + } + } + + /// Returns the index immediately following `i` in the collection. + /// + /// This roughly corresponds to the C macro `CMSG_NXTHDR`. + /// + /// - Complexity: O(1) + public func index(after i: Index) -> Index { + precondition(i._offset != _endOffset, "Can't advance past endIndex") + precondition(i._offset >= 0 && i._offset + _headerSize <= _endOffset, + "Invalid index") + guard let length = _length(at: i) else { return endIndex } + return Index(_offset: i._offset + length) + } + + /// Returns the message at the given position, which must be a valid index + /// in this collection. + /// + /// The returned value merely refers to a slice of the entire buffer, so + /// it contains a shared regerence to it. + /// + /// To reduce memory use and to prevent unnecessary copy-on-write copying, do + /// not save `Message` values -- instead, copy out the data you need to hold + /// on to into standalone storage. + public subscript(position: Index) -> Message { + guard let _ = _length(at: position) else { + preconditionFailure("Invalid index") + } + return Element(_base: self, offset: position._offset) + } +} + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension SocketDescriptor.AncillaryMessageBuffer.Message { + internal var _header: CInterop.CMsgHdr { + _base._withUnsafeBytes { buffer in + assert(_offset + _base._headerSize <= buffer.count) + let p = buffer.baseAddress! + _offset + let header = p.assumingMemoryBound(to: CInterop.CMsgHdr.self) + return header.pointee + } + } + + /// The protocol level of the message. Socket-level command messages + /// use the special protocol value `SocketDescriptor.ProtocolID.socketOption`. + public var level: SocketDescriptor.ProtocolID { + .init(rawValue: _header.cmsg_level) + } + + /// The protocol-specific type of the message. + public var type: SocketDescriptor.Option { + .init(rawValue: _header.cmsg_type) + } + + /// Calls `body` with an unsafe raw buffer pointer containing the + /// message payload. + /// + /// This roughly corresponds to the C macro `CMSG_DATA`. + /// + /// - Note: The buffer passed to `body` does not include storage reserved + /// for holding the message header, such as the `level` and `type` values. + /// To access header information, you have to use the corresponding + /// properties. + public func withUnsafeBytes( + _ body: (UnsafeRawBufferPointer) throws -> R + ) rethrows -> R { + try _base._withUnsafeBytes { buffer in + let headerSize = _base._headerSize + assert(_offset + headerSize <= buffer.count) + let p = buffer.baseAddress! + _offset + let header = p.assumingMemoryBound(to: CInterop.CMsgHdr.self) + let data = p + headerSize + let count = Swift.min(Int(header.pointee.cmsg_len) - headerSize, + buffer.count) + return try body(UnsafeRawBufferPointer(start: data, count: count)) + } + } +} diff --git a/Sources/SystemSockets/Sockets/SocketOperations.swift b/Sources/SystemSockets/Sockets/SocketOperations.swift new file mode 100644 index 00000000..00544328 --- /dev/null +++ b/Sources/SystemSockets/Sockets/SocketOperations.swift @@ -0,0 +1,691 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +import SystemPackage + +extension SocketDescriptor { + /// Create an endpoint for communication. + /// + /// - Parameters: + /// - domain: Select the protocol family which should be used for + /// communication + /// - type: Specify the semantics of communication + /// - protocol: Specify a particular protocol to use with the socket. + /// (Zero by default, which often indicates a wildcard value in + /// domain/type combinations that only support a single protocol, + /// such as TCP for IPv4/stream.) + /// - retryOnInterrupt: Whether to retry the open operation + /// if it throws ``Errno/interrupted``. + /// The default is `true`. + /// Pass `false` to try only once and throw an error upon interruption. + /// + /// The corresponding C function is `socket` + @_alwaysEmitIntoClient + public static func open( + _ domain: Domain, + _ type: ConnectionType, + _ protocol: ProtocolID = ProtocolID(rawValue: 0), + retryOnInterrupt: Bool = true + ) throws -> SocketDescriptor { + try SocketDescriptor._open( + domain, type, `protocol`, retryOnInterrupt: retryOnInterrupt + ).get() + } + + @usableFromInline + internal static func _open( + _ domain: Domain, + _ type: ConnectionType, + _ protocol: ProtocolID, + retryOnInterrupt: Bool + ) -> Result { + valueOrErrno(retryOnInterrupt: retryOnInterrupt) { + system_socket(domain.rawValue, type: type.rawValue, protocol: `protocol`.rawValue) + }.map(SocketDescriptor.init(rawValue:)) + } + + /// Bind a socket to an address. + /// + /// The corresponding C function is `bind`. + @_alwaysEmitIntoClient + public func bind(to address: SocketAddress) throws { + try _bind(to: address).get() + } + + /// Bind a socket to an IPv4 address. + /// + /// The corresponding C function is `bind`. + @_alwaysEmitIntoClient + public func bind(to address: SocketAddress.IPv4) throws { + try _bind(to: SocketAddress(address)).get() + } + + /// Bind a socket to an IPv6 address. + /// + /// The corresponding C function is `bind`. + @_alwaysEmitIntoClient + public func bind(to address: SocketAddress.IPv6) throws { + try _bind(to: SocketAddress(address)).get() + } + + /// Bind a socket to an address in the local domain. + /// + /// The corresponding C function is `bind`. + @_alwaysEmitIntoClient + public func bind(to address: SocketAddress.Local) throws { + try _bind(to: SocketAddress(address)).get() + } + + @usableFromInline + internal func _bind(to address: SocketAddress) -> Result<(), Errno> { + nothingOrErrno(retryOnInterrupt: false) { + address.withUnsafeCInterop { addr, len in + system_bind(self.rawValue, addr, len) + } + } + } + + /// Listen for connections on a socket. + /// + /// Only applies to sockets of connection type `.stream`. + /// + /// - Parameters: + /// - backlog: the maximum length for the queue of pending connections + /// + /// The corresponding C function is `listen`. + @_alwaysEmitIntoClient + public func listen(backlog: Int) throws { + try _listen(backlog: backlog).get() + } + + @usableFromInline + internal func _listen(backlog: Int) -> Result<(), Errno> { + nothingOrErrno(retryOnInterrupt: false) { + system_listen(self.rawValue, CInt(backlog)) + } + } + + /// Accept a connection on a socket. + /// + /// The corresponding C function is `accept`. + @_alwaysEmitIntoClient + public func accept(retryOnInterrupt: Bool = true) throws -> SocketDescriptor { + try _accept(retryOnInterrupt: retryOnInterrupt).get() + } + + @usableFromInline + internal func _accept( + retryOnInterrupt: Bool + ) -> Result { + let fd = valueOrErrno(retryOnInterrupt: retryOnInterrupt) { + return system_accept(self.rawValue, nil, nil) + } + return fd.map { SocketDescriptor(rawValue: $0) } + } + + /// Accept a connection on a socket. + /// + /// The corresponding C function is `accept`. + /// + /// - Parameter client: A socket address with enough capacity to hold an + /// address for the current socket domain/type. On return, `accept` + /// overwrites the contents with the address of the remote client. + /// + /// Having this as an inout parameter allows you to reuse the same address + /// value across multiple connections, without reallocating it. + @_alwaysEmitIntoClient + public func accept( + client: inout SocketAddress, + retryOnInterrupt: Bool = true + ) throws -> SocketDescriptor { + try _accept(client: &client, retryOnInterrupt: retryOnInterrupt).get() + } + + @usableFromInline + internal func _accept( + client: inout SocketAddress, + retryOnInterrupt: Bool + ) -> Result { + client._withMutableCInterop(entireCapacity: true) { adr, adrlen in + let fd = valueOrErrno(retryOnInterrupt: retryOnInterrupt) { + return system_accept(self.rawValue, adr, &adrlen) + } + return fd.map { SocketDescriptor(rawValue: $0) } + } + } + + /// Initiate a connection on a socket. + /// + /// The corresponding C function is `connect`. + @_alwaysEmitIntoClient + public func connect(to address: SocketAddress) throws { + try _connect(to: address).get() + } + + /// Initiate a connection to an IPv4 address. + /// + /// The corresponding C function is `connect`. + @_alwaysEmitIntoClient + public func connect(to address: SocketAddress.IPv4) throws { + try _connect(to: SocketAddress(address)).get() + } + + /// Initiate a connection to an IPv6 address. + /// + /// The corresponding C function is `connect`. + @_alwaysEmitIntoClient + public func connect(to address: SocketAddress.IPv6) throws { + try _connect(to: SocketAddress(address)).get() + } + + /// Initiate a connection to an address in the local domain. + /// + /// The corresponding C function is `connect`. + @_alwaysEmitIntoClient + public func connect(to address: SocketAddress.Local) throws { + try _connect(to: SocketAddress(address)).get() + } + + @usableFromInline + internal func _connect(to address: SocketAddress) -> Result<(), Errno> { + nothingOrErrno(retryOnInterrupt: false) { + address.withUnsafeCInterop { addr, len in + system_connect(self.rawValue, addr, len) + } + } + } + + /// Shutdown part of a full-duplex connection + /// + /// The corresponding C function is `shutdown` + @_alwaysEmitIntoClient + public func shutdown(_ how: ShutdownKind) throws { + try _shutdown(how).get() + } + + @usableFromInline + internal func _shutdown(_ how: ShutdownKind) -> Result<(), Errno> { + nothingOrErrno(retryOnInterrupt: false) { + system_shutdown(self.rawValue, how.rawValue) + } + } + + // MARK: - Send and receive + + /// Send a message from a socket. + /// + /// - Parameters: + /// - buffer: The region of memory that contains the data being sent. + /// - flags: see `send(2)` + /// - retryOnInterrupt: Whether to retry the send operation + /// if it throws ``Errno/interrupted``. + /// The default is `true`. + /// Pass `false` to try only once and throw an error upon interruption. + /// - Returns: The number of bytes that were sent. + /// + /// The corresponding C function is `send`. + @_alwaysEmitIntoClient + public func send( + _ buffer: UnsafeRawBufferPointer, + flags: MessageFlags = .none, + retryOnInterrupt: Bool = true + ) throws -> Int { + try _send(buffer, flags: flags, retryOnInterrupt: retryOnInterrupt).get() + } + + @usableFromInline + internal func _send( + _ buffer: UnsafeRawBufferPointer, + flags: MessageFlags, + retryOnInterrupt: Bool + ) -> Result { + valueOrErrno(retryOnInterrupt: retryOnInterrupt) { + system_send(self.rawValue, buffer.baseAddress!, buffer.count, flags.rawValue) + } + } + + /// Send a message from a socket. + /// + /// - Parameters: + /// - buffer: The region of memory that contains the data being sent. + /// - recipient: The socket address of the recipient. + /// - flags: see `send(2)` + /// - retryOnInterrupt: Whether to retry the send operation + /// if it throws ``Errno/interrupted``. + /// The default is `true`. + /// Pass `false` to try only once and throw an error upon interruption. + /// - Returns: The number of bytes that were sent. + /// + /// The corresponding C function is `sendto`. + @_alwaysEmitIntoClient + public func send( + _ buffer: UnsafeRawBufferPointer, + to recipient: SocketAddress, + flags: MessageFlags = .none, + retryOnInterrupt: Bool = true + ) throws -> Int { + try _send( + buffer, + to: recipient, + flags: flags, + retryOnInterrupt: retryOnInterrupt + ).get() + } + + @usableFromInline + internal func _send( + _ buffer: UnsafeRawBufferPointer, + to recipient: SocketAddress, + flags: MessageFlags, + retryOnInterrupt: Bool + ) throws -> Result { + recipient.withUnsafeCInterop { adr, adrlen in + valueOrErrno(retryOnInterrupt: retryOnInterrupt) { + system_sendto( + self.rawValue, + buffer.baseAddress, + buffer.count, + flags.rawValue, + adr, + adrlen) + } + } + } + + /// Send a message from a socket. + /// + /// - Parameters: + /// - buffer: The region of memory that contains the data being sent. + /// - recipient: The socket address of the recipient. + /// - ancillary: A buffer of ancillary/control messages. + /// - flags: see `send(2)` + /// - retryOnInterrupt: Whether to retry the send operation + /// if it throws ``Errno/interrupted``. + /// The default is `true`. + /// Pass `false` to try only once and throw an error upon interruption. + /// - Returns: The number of bytes that were sent. + /// + /// The corresponding C function is `sendmsg`. + @_alwaysEmitIntoClient + public func send( + _ bytes: UnsafeRawBufferPointer, + to recipient: SocketAddress? = nil, + ancillary: AncillaryMessageBuffer, + flags: MessageFlags = .none, + retryOnInterrupt: Bool = true + ) throws -> Int { + try _send( + bytes, + to : recipient, + ancillary: ancillary, + flags: flags, + retryOnInterrupt: retryOnInterrupt + ).get() + } + + @usableFromInline + internal func _send( + _ bytes: UnsafeRawBufferPointer, + to recipient: SocketAddress?, + ancillary: AncillaryMessageBuffer?, + flags: MessageFlags, + retryOnInterrupt: Bool + ) -> Result { + recipient._withUnsafeBytesOrNull { recipient in + ancillary._withUnsafeBytesOrNull { ancillary in + var iov = CInterop.IOVec() + iov.iov_base = UnsafeMutableRawPointer(mutating: bytes.baseAddress) + iov.iov_len = bytes.count + return withUnsafePointer(to: &iov) { iov in + var m = CInterop.MsgHdr() + m.msg_name = UnsafeMutableRawPointer(mutating: recipient.baseAddress) + m.msg_namelen = UInt32(recipient.count) + m.msg_iov = UnsafeMutablePointer(mutating: iov) + m.msg_iovlen = 1 + m.msg_control = UnsafeMutableRawPointer(mutating: ancillary.baseAddress) + m.msg_controllen = CInterop.SockLen(ancillary.count) + m.msg_flags = 0 + return withUnsafePointer(to: &m) { message in + _sendmsg(message, flags.rawValue, + retryOnInterrupt: retryOnInterrupt) + } + } + } + } + } + + private func _sendmsg( + _ message: UnsafePointer, + _ flags: CInt, + retryOnInterrupt: Bool + ) -> Result { + return valueOrErrno(retryOnInterrupt: retryOnInterrupt) { + system_sendmsg(self.rawValue, message, flags) + } + } + + /// Receive a message from a socket. + /// + /// - Parameters: + /// - buffer: The region of memory to receive into. + /// - flags: see `recv(2)` + /// - retryOnInterrupt: Whether to retry the receive operation + /// if it throws ``Errno/interrupted``. + /// The default is `true`. + /// Pass `false` to try only once and throw an error upon interruption. + /// - Returns: The number of bytes that were received. + /// + /// The corresponding C function is `recv`. + @_alwaysEmitIntoClient + public func receive( + into buffer: UnsafeMutableRawBufferPointer, + flags: MessageFlags = .none, + retryOnInterrupt: Bool = true + ) throws -> Int { + try _receive( + into: buffer, flags: flags, retryOnInterrupt: retryOnInterrupt + ).get() + } + + @usableFromInline + internal func _receive( + into buffer: UnsafeMutableRawBufferPointer, + flags: MessageFlags, + retryOnInterrupt: Bool + ) -> Result { + valueOrErrno(retryOnInterrupt: retryOnInterrupt) { + system_recv(self.rawValue, buffer.baseAddress!, buffer.count, flags.rawValue) + } + } + + /// Receive a message from a socket. + /// + /// - Parameters: + /// - buffer: The region of memory to receive into. + /// - flags: see `recv(2)` + /// - sender: A socket address with enough capacity to hold an + /// address for the current socket domain/type. On return, `receive` + /// overwrites the contents with the address of the remote client. + /// - retryOnInterrupt: Whether to retry the receive operation + /// if it throws ``Errno/interrupted``. + /// The default is `true`. + /// Pass `false` to try only once and throw an error upon interruption. + /// - Returns: The number of bytes that were received. + /// + /// The corresponding C function is `recvfrom`. + @_alwaysEmitIntoClient + public func receive( + into buffer: UnsafeMutableRawBufferPointer, + sender: inout SocketAddress, + flags: MessageFlags = .none, + retryOnInterrupt: Bool = true + ) throws -> Int { + try _receive( + into: buffer, + sender: &sender, + flags: flags, + retryOnInterrupt: retryOnInterrupt + ).get() + } + + @usableFromInline + internal func _receive( + into buffer: UnsafeMutableRawBufferPointer, + sender: inout SocketAddress, + flags: MessageFlags, + retryOnInterrupt: Bool + ) throws -> Result { + sender._withMutableCInterop(entireCapacity: true) { adr, adrlen in + valueOrErrno(retryOnInterrupt: retryOnInterrupt) { + system_recvfrom( + self.rawValue, + buffer.baseAddress, + buffer.count, + flags.rawValue, + adr, + &adrlen) + } + } + } + + /// Receive a message from a socket. + /// + /// - Parameters: + /// - buffer: The region of memory to receive into. + /// - flags: see `recv(2)` + /// - ancillary: A buffer of ancillary messages. On return, `receive` + /// overwrites the contents with received ancillary messages (if any). + /// - retryOnInterrupt: Whether to retry the receive operation + /// if it throws ``Errno/interrupted``. + /// The default is `true`. + /// Pass `false` to try only once and throw an error upon interruption. + /// - Returns: The number of bytes that were received, and the flags that + /// describe the received message. + /// + /// The corresponding C function is `recvmsg`. + @_alwaysEmitIntoClient + public func receive( + into bytes: UnsafeMutableRawBufferPointer, + ancillary: inout AncillaryMessageBuffer, + flags: MessageFlags = [], + retryOnInterrupt: Bool = true + ) throws -> (received: Int, flags: MessageFlags) { + return try _receive( + into: bytes, + sender: nil, + ancillary: &ancillary, + flags: flags, + retryOnInterrupt: retryOnInterrupt + ).get() + } + + /// Receive a message from a socket. + /// + /// - Parameters: + /// - buffer: The region of memory to receive into. + /// - flags: see `recv(2)` + /// - sender: A socket address with enough capacity to hold an + /// address for the current socket domain/type. On return, `receive` + /// overwrites the contents with the address of the remote client. + /// - ancillary: A buffer of ancillary messages. On return, `receive` + /// overwrites the contents with received ancillary messages (if any). + /// - retryOnInterrupt: Whether to retry the receive operation + /// if it throws ``Errno/interrupted``. + /// The default is `true`. + /// Pass `false` to try only once and throw an error upon interruption. + /// - Returns: The number of bytes that were received, and the flags that + /// describe the received message. + /// + /// The corresponding C function is `recvmsg`. + @_alwaysEmitIntoClient + public func receive( + into bytes: UnsafeMutableRawBufferPointer, + sender: inout SocketAddress, + ancillary: inout AncillaryMessageBuffer, + flags: MessageFlags = [], + retryOnInterrupt: Bool = true + ) throws -> (received: Int, flags: MessageFlags) { + return try _receive( + into: bytes, + sender: &sender, + ancillary: &ancillary, + flags: flags, + retryOnInterrupt: retryOnInterrupt + ).get() + } + + @usableFromInline + internal func _receive( + into bytes: UnsafeMutableRawBufferPointer, + sender: UnsafeMutablePointer?, + ancillary: UnsafeMutablePointer?, + flags: MessageFlags, + retryOnInterrupt: Bool + ) -> Result<(Int, MessageFlags), Errno> { + let result: Result + let receivedFlags: CInt + (result, receivedFlags) = + sender._withMutableCInteropOrNull(entireCapacity: true) { adr, adrlen in + ancillary._withMutableCInterop(entireCapacity: true) { anc, anclen in + var iov = CInterop.IOVec() + iov.iov_base = bytes.baseAddress + iov.iov_len = bytes.count + return withUnsafePointer(to: &iov) { iov in + var m = CInterop.MsgHdr() + m.msg_name = UnsafeMutableRawPointer(adr) + m.msg_namelen = adrlen + m.msg_iov = UnsafeMutablePointer(mutating: iov) + m.msg_iovlen = 1 + m.msg_control = anc + m.msg_controllen = anclen + m.msg_flags = 0 + let result = withUnsafeMutablePointer(to: &m) { m in + _recvmsg(m, flags.rawValue, retryOnInterrupt: retryOnInterrupt) + } + if case .failure = result { + adrlen = 0 + anclen = 0 + } else { + adrlen = m.msg_namelen + anclen = m.msg_controllen + } + return (result, m.msg_flags) + } + } + } + return result.map { ($0, MessageFlags(rawValue: receivedFlags)) } + } + + private func _recvmsg( + _ message: UnsafeMutablePointer, + _ flags: CInt, + retryOnInterrupt: Bool + ) -> Result { + return valueOrErrno(retryOnInterrupt: retryOnInterrupt) { + system_recvmsg(self.rawValue, message, flags) + } + } + +} + +// MARK: - Forward FileDescriptor methods +extension SocketDescriptor { + /// Deletes a socket's file descriptor. + /// + /// This is equivalent to calling `fileDescriptor.close()` + @_alwaysEmitIntoClient + public func close() throws { try fileDescriptor.close() } + + /// Reads bytes from a socket. + /// + /// This is equivalent to calling `fileDescriptor.read(into:retryOnInterrupt:)` + /// + /// - Parameters: + /// - buffer: The region of memory to read into. + /// - retryOnInterrupt: Whether to retry the read operation + /// if it throws ``Errno/interrupted``. + /// The default is `true`. + /// Pass `false` to try only once and throw an error upon interruption. + /// - Returns: The number of bytes that were read. + /// + /// The corresponding C function is `read`. + @_alwaysEmitIntoClient + public func read( + into buffer: UnsafeMutableRawBufferPointer, retryOnInterrupt: Bool = true + ) throws -> Int { + try fileDescriptor.read(into: buffer, retryOnInterrupt: retryOnInterrupt) + } + + /// Writes the contents of a buffer to the socket. + /// + /// This is equivalent to `fileDescriptor.write(_:retryOnInterrupt:)` + /// + /// - Parameters: + /// - buffer: The region of memory that contains the data being written. + /// - retryOnInterrupt: Whether to retry the write operation + /// if it throws ``Errno/interrupted``. + /// The default is `true`. + /// Pass `false` to try only once and throw an error upon interruption. + /// - Returns: The number of bytes that were written. + /// + /// After writing, + /// this method increments the file's offset by the number of bytes written. + /// To change the file's offset, + /// call the ``seek(offset:from:)`` method. + /// + /// The corresponding C function is `write`. + @_alwaysEmitIntoClient + public func write( + _ buffer: UnsafeRawBufferPointer, retryOnInterrupt: Bool = true + ) throws -> Int { + try fileDescriptor.write(buffer, retryOnInterrupt: retryOnInterrupt) + } +} + +// Optional mapper helpers, for use in setting up message header structs. +extension Optional where Wrapped == SocketDescriptor.AncillaryMessageBuffer { + fileprivate func _withUnsafeBytesOrNull( + _ body: (UnsafeRawBufferPointer) throws -> R + ) rethrows -> R { + guard let buffer = self else { + return try body(UnsafeRawBufferPointer(start: nil, count: 0)) + } + return try buffer._withUnsafeBytes(body) + } +} + +extension Optional where Wrapped == SocketAddress { + fileprivate func _withUnsafeBytesOrNull( + _ body: (UnsafeRawBufferPointer) throws -> R + ) rethrows -> R { + guard let address = self else { + return try body(UnsafeRawBufferPointer(start: nil, count: 0)) + } + return try address.withUnsafeBytes(body) + } +} +extension Optional where Wrapped == UnsafeMutablePointer { + fileprivate func _withMutableCInteropOrNull( + entireCapacity: Bool, + _ body: ( + UnsafeMutablePointer?, + inout CInterop.SockLen + ) throws -> R + ) rethrows -> R { + guard let ptr = self else { + var c: CInterop.SockLen = 0 + let result = try body(nil, &c) + precondition(c == 0) + return result + } + return try ptr.pointee._withMutableCInterop( + entireCapacity: entireCapacity, + body) + } +} + +extension Optional +where Wrapped == UnsafeMutablePointer +{ + internal func _withMutableCInterop( + entireCapacity: Bool, + _ body: (UnsafeMutableRawPointer?, inout CInterop.SockLen) throws -> R + ) rethrows -> R { + guard let buffer = self else { + var length: CInterop.SockLen = 0 + let r = try body(nil, &length) + precondition(length == 0) + return r + } + return try buffer.pointee._withMutableCInterop( + entireCapacity: entireCapacity, + body + ) + } +} diff --git a/Sources/SystemSockets/Sockets/SocketOptions.swift b/Sources/SystemSockets/Sockets/SocketOptions.swift new file mode 100644 index 00000000..59372342 --- /dev/null +++ b/Sources/SystemSockets/Sockets/SocketOptions.swift @@ -0,0 +1,621 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +import Darwin +#elseif os(Linux) || os(FreeBSD) || os(Android) +import CSystem +import Glibc +#elseif os(Windows) +import CSystem +import ucrt +#else +#error("Unsupported Platform") +#endif + +import SystemPackage + +extension SocketDescriptor { + // Options associated with a socket. + @frozen + public struct Option: RawRepresentable, Hashable, CustomStringConvertible { + @_alwaysEmitIntoClient + public var rawValue: CInt + + @_alwaysEmitIntoClient + public init(rawValue: CInt) { self.rawValue = rawValue } + + @_alwaysEmitIntoClient + private init(_ rawValue: CInt) { self.init(rawValue: rawValue) } + + public var description: String { rawValue.description } + + // MARK: - Socket-level + + /// Enables recording of debugging information. + /// + /// The corresponding C constant is `SO_DEBUG`. + @_alwaysEmitIntoClient + public static var debug: Option { Option(SO_DEBUG) } + + /// Enables local address reuse. + /// + /// The corresponding C constant is `SO_REUSEADDR`. + @_alwaysEmitIntoClient + public static var reuseAddress: Option { Option(SO_REUSEADDR) } + + /// Enables duplicate address and port bindings. + /// + /// The corresponding C constant is `SO_REUSEPORT`. + @_alwaysEmitIntoClient + public static var reusePort: Option { Option(SO_REUSEPORT) } + + /// Enables keep connections alive. + /// + /// The corresponding C constant is `SO_KEEPALIVE`. + @_alwaysEmitIntoClient + public static var keepAlive: Option { Option(SO_KEEPALIVE) } + + /// Enables routing bypass for outgoing messages. + /// + /// The corresponding C constant is `SO_DONTROUTE`. + @_alwaysEmitIntoClient + public static var doNotRoute: Option { Option(SO_DONTROUTE) } + + /// linger on close if data present + /// + /// The corresponding C constant is `SO_LINGER`. + @_alwaysEmitIntoClient + public static var linger: Option { Option(SO_LINGER) } + + /// Enables permission to transmit broadcast messages. + /// + /// The corresponding C constant is `SO_BROADCAST`. + @_alwaysEmitIntoClient + public static var broadcast: Option { Option(SO_BROADCAST) } + + /// Enables reception of out-of-band data in band. + /// + /// The corresponding C constant is `SO_OOBINLINE`. + @_alwaysEmitIntoClient + public static var outOfBand: Option { Option(SO_OOBINLINE) } + + /// Set buffer size for output. + /// + /// The corresponding C constant is `SO_SNDBUF`. + @_alwaysEmitIntoClient + public static var sendBufferSize: Option { Option(SO_SNDBUF) } + + /// Set buffer size for input. + /// + /// The corresponding C constant is `SO_RCVBUF`. + @_alwaysEmitIntoClient + public static var receiveBufferSize: Option { Option(SO_RCVBUF) } + + /// Set minimum count for output. + /// + /// The corresponding C constant is `SO_SNDLOWAT`. + @_alwaysEmitIntoClient + public static var sendLowWaterMark: Option { Option(SO_SNDLOWAT) } + + /// Set minimum count for input. + /// + /// The corresponding C constant is `SO_RCVLOWAT`. + @_alwaysEmitIntoClient + public static var receiveLowWaterMark: Option { Option(SO_RCVLOWAT) } + + /// Set timeout value for output. + /// + /// The corresponding C constant is `SO_SNDTIMEO`. + @_alwaysEmitIntoClient + public static var sendTimeout: Option { Option(SO_SNDTIMEO) } + + /// Set timeout value for input. + /// + /// The corresponding C constant is `SO_RCVTIMEO`. + @_alwaysEmitIntoClient + public static var receiveTimeout: Option { Option(SO_RCVTIMEO) } + + /// Get the type of the socket (get only). + /// + /// The corresponding C constant is `SO_TYPE`. + @_alwaysEmitIntoClient + public static var getType: Option { Option(SO_TYPE) } + + /// Get and clear error on the socket (get only). + /// + /// The corresponding C constant is `SO_ERROR`. + @_alwaysEmitIntoClient + public static var getError: Option { Option(SO_ERROR) } + + /// Do not generate SIGPIPE, instead return EPIPE. + /// + /// TODO: better name... + /// + /// The corresponding C constant is `SO_NOSIGPIPE`. + @_alwaysEmitIntoClient + public static var noSignal: Option { Option(SO_NOSIGPIPE) } + + /// Number of bytes to be read (get only). + /// + /// For datagram oriented sockets, returns the size of the first packet. + /// + /// TODO: better name... + /// + /// The corresponding C constant is `SO_NREAD`. + @_alwaysEmitIntoClient + public static var getNumBytesToReceive: Option { Option(SO_NREAD) } + + /// Number of bytes written not yet sent by the protocol (get only). + /// + /// The corresponding C constant is `SO_NWRITE`. + @_alwaysEmitIntoClient + public static var getNumByteToSend: Option { Option(SO_NWRITE) } + + /// Linger on close if data present with timeout in seconds. + /// + /// The corresponding C constant is `SO_LINGER_SEC`. + @_alwaysEmitIntoClient + public static var longerSeconds: Option { Option(SO_LINGER_SEC) } + + // + // MARK: - TCP options + // + + /// Send data before receiving a reply. + /// + /// The corresponding C constant is `TCP_NODELAY`. + @_alwaysEmitIntoClient + public static var tcpNoDelay: Option { Option(TCP_NODELAY) } + + /// Set the maximum segment size. + /// + /// The corresponding C constant is `TCP_MAXSEG`. + @_alwaysEmitIntoClient + public static var tcpMaxSegmentSize: Option { Option(TCP_MAXSEG) } + + /// Disable TCP option use. + /// + /// The corresponding C constant is `TCP_NOOPT`. + @_alwaysEmitIntoClient + public static var tcpNoOptions: Option { Option(TCP_NOOPT) } + + /// Delay sending any data until socket is closed or send buffer is filled. + /// + /// The corresponding C constant is `TCP_NOPUSH`. + @_alwaysEmitIntoClient + public static var tcpNoPush: Option { Option(TCP_NOPUSH) } + + /// Specify the amount of idle time (in seconds) before keepalive probes. + /// + /// The corresponding C constant is `TCP_KEEPALIVE`. + @_alwaysEmitIntoClient + public static var tcpKeepAlive: Option { Option(TCP_KEEPALIVE) } + + /// Specify the timeout (in seconds) for new non-established TCP connections. + /// + /// The corresponding C constant is `TCP_CONNECTIONTIMEOUT`. + @_alwaysEmitIntoClient + public static var tcpConnectionTimeout: Option { Option(TCP_CONNECTIONTIMEOUT) } + + /// Set the amout of time (in seconds) between successive keepalives sent to + /// probe an unresponsive peer. + /// + /// The corresponding C constant is `TCP_KEEPINTVL`. + @_alwaysEmitIntoClient + public static var tcpKeepAliveInterval: Option { Option(TCP_KEEPINTVL) } + + /// Set the number of times keepalive probe should be repeated if peer is not + /// responding. + /// + /// The corresponding C constant is `TCP_KEEPCNT`. + @_alwaysEmitIntoClient + public static var tcpKeepAliveCount: Option { Option(TCP_KEEPCNT) } + + /// Send a TCP acknowledgement for every other data packaet in a stream of + /// received data packets, rather than for every 8. + /// + /// TODO: better name + /// + /// The corresponding C constant is `TCP_SENDMOREACKS`. + @_alwaysEmitIntoClient + public static var tcpSendMoreAcks: Option { Option(TCP_SENDMOREACKS) } + + /// Use Explicit Congestion Notification (ECN). + /// + /// The corresponding C constant is `TCP_ENABLE_ECN`. + @_alwaysEmitIntoClient + public static var tcpUseExplicitCongestionNotification: Option { Option(TCP_ENABLE_ECN) } + + /// Specify the maximum amount of unsent data in the send socket buffer. + /// + /// The corresponding C constant is `TCP_NOTSENT_LOWAT`. + @_alwaysEmitIntoClient + public static var tcpMaxUnsent: Option { Option(TCP_NOTSENT_LOWAT) } + + /// Use TCP Fast Open feature. Accpet may return a socket that is in + /// SYN_RECEIVED state but is readable and writable. + /// + /// The corresponding C constant is `TCP_FASTOPEN`. + @_alwaysEmitIntoClient + public static var tcpFastOpen: Option { Option(TCP_FASTOPEN) } + + /// Optain TCP connection-level statistics. + /// + /// The corresponding C constant is `TCP_CONNECTION_INFO`. + @_alwaysEmitIntoClient + public static var tcpConnectionInfo: Option { Option(TCP_CONNECTION_INFO) } + + // + // MARK: - IP Options + // + + /// Set to null to disable previously specified options. + /// + /// The corresponding C constant is `IP_OPTIONS`. + @_alwaysEmitIntoClient + public static var ipOptions: Option { Option(IP_OPTIONS) } + + /// Set the type-of-service. + /// + /// The corresponding C constant is `IP_TOS`. + @_alwaysEmitIntoClient + public static var ipTypeOfService: Option { Option(IP_TOS) } + + /// Set the time-to-live. + /// + /// The corresponding C constant is `IP_TTL`. + @_alwaysEmitIntoClient + public static var ipTimeToLive: Option { Option(IP_TTL) } + + /// Causes `recvmsg` to return the destination IP address for a UPD + /// datagram. + /// + /// The corresponding C constant is `IP_RECVDSTADDR`. + @_alwaysEmitIntoClient + public static var ipReceiveDestinationAddress: Option { Option(IP_RECVDSTADDR) } + + /// Causes `recvmsg` to return the type-of-service filed of the ip header. + /// + /// The corresponding C constant is `IP_RECVTOS`. + @_alwaysEmitIntoClient + public static var ipReceiveTypeOfService: Option { Option(IP_RECVTOS) } + + /// Change the time-to-live for outgoing multicast datagrams. + /// + /// The corresponding C constant is `IP_MULTICAST_TTL`. + @_alwaysEmitIntoClient + public static var ipMulticastTimeToLive: Option { Option(IP_MULTICAST_TTL) } + + /// Override the default network interface for subsequent transmissions. + /// + /// The corresponding C constant is `IP_MULTICAST_IF`. + @_alwaysEmitIntoClient + public static var ipMulticastInterface: Option { Option(IP_MULTICAST_IF) } + + /// Control whether or not subsequent datagrams are looped back. + /// + /// The corresponding C constant is `IP_MULTICAST_LOOP`. + @_alwaysEmitIntoClient + public static var ipMulticastLoop: Option { Option(IP_MULTICAST_LOOP) } + + /// Join a multicast group. + /// + /// The corresponding C constant is `IP_ADD_MEMBERSHIP`. + @_alwaysEmitIntoClient + public static var ipAddMembership: Option { Option(IP_ADD_MEMBERSHIP) } + + /// Leave a multicast group. + /// + /// The corresponding C constant is `IP_DROP_MEMBERSHIP`. + @_alwaysEmitIntoClient + public static var ipDropMembership: Option { Option(IP_DROP_MEMBERSHIP) } + + /// Indicates the complete IP header is included with the data. + /// + /// Can only be used with `ConnectionType.raw` sockets. + /// + /// The corresponding C constant is `IP_HDRINCL`. + @_alwaysEmitIntoClient + public static var ipHeaderIncluded: Option { Option(IP_HDRINCL) } + + // + // MARK: - IPv6 Options + // + + /// The default hop limit header field for outgoing unicast datagrams. + /// + /// A value of -1 resets to the default value. + /// + /// The corresponding C constant is `IPV6_UNICAST_HOPS`. + @_alwaysEmitIntoClient + public static var ipv6UnicastHops: Option { Option(IPV6_UNICAST_HOPS) } + + /// The interface from which multicast packets will be sent. + /// + /// A value of 0 specifies the default interface. + /// + /// The corresponding C constant is `IPV6_MULTICAST_IF`. + @_alwaysEmitIntoClient + public static var ipv6MulticastInterface: Option { Option(IPV6_MULTICAST_IF) } + + /// The default hop limit header field for outgoing multicast datagrams. + /// + /// The corresponding C constant is `IPV6_MULTICAST_HOPS`. + @_alwaysEmitIntoClient + public static var ipv6MulticastHops: Option { Option(IPV6_MULTICAST_HOPS) } + + /// Whether multicast datagrams will be looped back. + /// + /// The corresponding C constant is `IPV6_MULTICAST_LOOP`. + @_alwaysEmitIntoClient + public static var ipv6MulticastLoop: Option { Option(IPV6_MULTICAST_LOOP) } + + /// Join a multicast group. + /// + /// The corresponding C constant is `IPV6_JOIN_GROUP`. + @_alwaysEmitIntoClient + public static var ipv6JoinGroup: Option { Option(IPV6_JOIN_GROUP) } + + /// Leave a multicast group. + /// + /// The corresponding C constant is `IPV6_LEAVE_GROUP`. + @_alwaysEmitIntoClient + public static var ipv6LeaveGroup: Option { Option(IPV6_LEAVE_GROUP) } + + /// Allocation policy of ephemeral ports for when the kernel automatically + /// binds a local address to this socket. + /// + /// TODO: portrange struct somewhere, with _DEFAULT, _HIGH, _LOW + /// + /// The corresponding C constant is `IPV6_PORTRANGE`. + @_alwaysEmitIntoClient + public static var ipv6PortRange: Option { Option(IPV6_PORTRANGE) } + +// /// Whether additional information about subsequent packets will be +// /// provided in `recvmsg` calls. +// /// +// /// The corresponding C constant is `IPV6_PKTINFO`. +// @_alwaysEmitIntoClient +// public static var ipv6ReceivePacketInfo: Option { Option(IPV6_PKTINFO) } +// +// /// Whether the hop limit header field from subsequent packets will +// /// be provided in `recvmsg` calls. +// /// +// /// The corresponding C constant is `IPV6_HOPLIMIT`. +// @_alwaysEmitIntoClient +// public static var ipv6ReceiveHopLimit: Option { Option(IPV6_HOPLIMIT) } +// +// /// Whether hop-by-hop options from subsequent packets will +// /// be provided in `recvmsg` calls. +// /// +// /// The corresponding C constant is `IPV6_HOPOPTS`. +// @_alwaysEmitIntoClient +// public static var ipv6ReceiveHopOptions: Option { Option(IPV6_HOPOPTS) } +// +// /// Whether destination options from subsequent packets will +// /// be provided in `recvmsg` calls. +// /// +// /// The corresponding C constant is `IPV6_DSTOPTS`. +// @_alwaysEmitIntoClient +// public static var ipv6ReceiveDestinationOptions: Option { Option(IPV6_DSTOPTS) } + + /// The value of the traffic class field for outgoing datagrams. + /// + /// The corresponding C constant is `IPV6_TCLASS`. + @_alwaysEmitIntoClient + public static var ipv6TrafficClass: Option { Option(IPV6_TCLASS) } + + /// Whether traffic class header field from subsequent packets will + /// be provided in `recvmsg` calls. + /// + /// The corresponding C constant is `IPV6_RECVTCLASS`. + @_alwaysEmitIntoClient + public static var ipv6ReceiveTrafficClass: Option { Option(IPV6_RECVTCLASS) } + +// /// Whether the routing header from subsequent packets will +// /// be provided in `recvmsg` calls. +// /// +// /// The corresponding C constant is `IPV6_RTHDR`. +// @_alwaysEmitIntoClient +// public static var ipv6ReceiveRoutingHeader: Option { Option(IPV6_RTHDR) } +// +// /// Get or set all header options and extension headers at one time +// /// on the last packet sent or received. +// /// +// /// The corresponding C constant is `IPV6_PKTOPTIONS`. +// @_alwaysEmitIntoClient +// public static var ipv6PacketOptions: Option { Option(IPV6_PKTOPTIONS) } + + /// The byte offset into a packet where 16-bit checksum is located. + /// + /// The corresponding C constant is `IPV6_CHECKSUM`. + @_alwaysEmitIntoClient + public static var ipv6Checksum: Option { Option(IPV6_CHECKSUM) } + + /// Whether only IPv6 connections can be made to this socket. + /// + /// The corresponding C constant is `IPV6_V6ONLY`. + @_alwaysEmitIntoClient + public static var ipv6Only: Option { Option(IPV6_V6ONLY) } + +// /// Whether the minimal IPv6 maximum transmission unit (MTU) size +// /// will be used to avoid fragmentation for subsequenet outgoing +// /// datagrams. +// /// +// /// The corresponding C constant is `IPV6_USE_MIN_MTU`. +// @_alwaysEmitIntoClient +// public static var ipv6UseMinimalMTU: Option { Option(IPV6_USE_MIN_MTU) } + } +} + +extension SocketDescriptor { + // TODO: Wrappers and convenience overloads for other concrete types + // (timeval, linger) + // For now, clients can use the UMRBP-based variants below. + + /// Copy an option associated with this socket into the specified buffer. + /// + /// The method corresponds to the C function `getsockopt`. + /// + /// - Parameters: + /// - level: The option level. To get a socket-level option, specify `.socketLevel`. + /// Otherwise use the protocol value that defines your desired option. + /// - option: The option identifier within the level. + /// - buffer: The buffer into which to copy the option value. + /// + /// - Returns: The number of bytes copied into the supplied buffer. + @_alwaysEmitIntoClient + public func getOption( + _ level: ProtocolID, + _ option: Option, + into buffer: UnsafeMutableRawBufferPointer + ) throws -> Int { + try _getOption(level, option, into: buffer).get() + } + + /// Return the value of an option associated with this socket as a `CInt` value. + /// + /// The method corresponds to the C function `getsockopt`. + /// + /// - Parameters: + /// - level: The option level. To get a socket-level option, specify `.socketLevel`. + /// Otherwise use the protocol value that defines your desired option. + /// - option: The option identifier within the level. + /// - type: The type to return. Must be set to `CInt.self` (the default). + /// + /// - Returns: The current value of the option. + @_alwaysEmitIntoClient + public func getOption( + _ level: ProtocolID, + _ option: Option, + as type: CInt.Type = CInt.self + ) throws -> CInt { + var value: CInt = 0 + try withUnsafeMutableBytes(of: &value) { buffer in + // Note: return value is intentionally ignored. + _ = try _getOption(level, option, into: buffer).get() + } + return value + } + + /// Return the value of an option associated with this socket as a `Bool` value. + /// + /// The method corresponds to the C function `getsockopt`. + /// + /// - Parameters: + /// - level: The option level. To get a socket-level option, specify `.socketLevel`. + /// Otherwise use the protocol value that defines your desired option. + /// - option: The option identifier within the level. + /// - type: The type to return. Must be set to `Bool.self` (the default). + /// + /// - Returns: True if the current value is not zero; otherwise false. + @_alwaysEmitIntoClient + public func getOption( + _ level: ProtocolID, + _ option: Option, + as type: Bool.Type = Bool.self + ) throws -> Bool { + try 0 != getOption(level, option, as: CInt.self) + } + + @usableFromInline + internal func _getOption( + _ level: ProtocolID, + _ option: Option, + into buffer: UnsafeMutableRawBufferPointer + ) -> Result { + var length = CInterop.SockLen(buffer.count) + return nothingOrErrno(retryOnInterrupt: false) { + system_getsockopt( + self.rawValue, + level.rawValue, + option.rawValue, + buffer.baseAddress, &length) + }.map { _ in Int(length) } + } +} + +extension SocketDescriptor { + /// Set the value of an option associated with this socket to the contents + /// of the specified buffer. + /// + /// The method corresponds to the C function `setsockopt`. + /// + /// - Parameters: + /// - level: The option level. To set a socket-level option, specify `.socketLevel`. + /// Otherwise use the protocol value that defines your desired option. + /// - option: The option identifier within the level. + /// - buffer: The buffer that contains the desired option value. + @_alwaysEmitIntoClient + public func setOption( + _ level: ProtocolID, + _ option: Option, + from buffer: UnsafeRawBufferPointer + ) throws { + try _setOption(level, option, from: buffer).get() + } + + /// Set the value of an option associated with this socket to the supplied + /// `CInt` value. + /// + /// The method corresponds to the C function `setsockopt`. + /// + /// - Parameters: + /// - level: The option level. To set a socket-level option, specify `.socketLevel`. + /// Otherwise use the protocol value that defines your desired option. + /// - option: The option identifier within the level. + /// - value: The desired new value for the option. + @_alwaysEmitIntoClient + public func setOption( + _ level: ProtocolID, + _ option: Option, + to value: CInt + ) throws { + return try withUnsafeBytes(of: value) { buffer in + // Note: return value is intentionally ignored. + _ = try _setOption(level, option, from: buffer).get() + } + } + + /// Set the value of an option associated with this socket to the supplied + /// `Bool` value. + /// + /// The method corresponds to the C function `setsockopt`. + /// + /// - Parameters: + /// - level: The option level. To set a socket-level option, specify `.socketLevel`. + /// Otherwise use the protocol value that defines your desired option. + /// - option: The option identifier within the level. + /// - value: The desired new value for the option. (`true` gets stored + /// as `(1 as CInt)`. `false` is represented by `(0 as CInt)`). + @_alwaysEmitIntoClient + public func setOption( + _ level: ProtocolID, + _ option: Option, + to value: Bool + ) throws { + try setOption(level, option, to: (value ? 1 : 0) as CInt) + } + + @usableFromInline + internal func _setOption( + _ level: ProtocolID, + _ option: Option, + from buffer: UnsafeRawBufferPointer + ) -> Result { + nothingOrErrno(retryOnInterrupt: false) { + system_setsockopt( + self.rawValue, + level.rawValue, + option.rawValue, + buffer.baseAddress, CInterop.SockLen(buffer.count)) + } + } +} diff --git a/Sources/SystemSockets/Syscalls.swift b/Sources/SystemSockets/Syscalls.swift new file mode 100644 index 00000000..275254e0 --- /dev/null +++ b/Sources/SystemSockets/Syscalls.swift @@ -0,0 +1,240 @@ +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +import Darwin +#elseif os(Linux) || os(FreeBSD) || os(Android) +import CSystem +import Glibc +#elseif os(Windows) +import CSystem +import ucrt +#else +#error("Unsupported Platform") +#endif + +import SystemPackage + +internal func system_socket(_ domain: CInt, type: CInt, protocol: CInt) -> CInt { + #if ENABLE_MOCKING + if mockingEnabled { return _mock(domain, type, `protocol`) } + #endif + return socket(domain, type, `protocol`) +} + +internal func system_shutdown(_ socket: CInt, _ how: CInt) -> CInt { + #if ENABLE_MOCKING + if mockingEnabled { return _mock(socket, how) } + #endif + return shutdown(socket, how) +} + +internal func system_listen(_ socket: CInt, _ backlog: CInt) -> CInt { + #if ENABLE_MOCKING + if mockingEnabled { return _mock(socket, backlog) } + #endif + return listen(socket, backlog) +} + +internal func system_send( + _ socket: Int32, _ buffer: UnsafeRawPointer?, _ len: Int, _ flags: Int32 +) -> Int { + #if ENABLE_MOCKING + if mockingEnabled { return _mockInt(socket, buffer, len, flags) } + #endif + return send(socket, buffer, len, flags) +} + +internal func system_recv( + _ socket: Int32, + _ buffer: UnsafeMutableRawPointer?, + _ len: Int, + _ flags: Int32 +) -> Int { + #if ENABLE_MOCKING + if mockingEnabled { return _mockInt(socket, buffer, len, flags) } + #endif + return recv(socket, buffer, len, flags) +} + +internal func system_sendto( + _ socket: CInt, + _ buffer: UnsafeRawPointer?, + _ length: Int, + _ flags: CInt, + _ dest_addr: UnsafePointer?, + _ dest_len: CInterop.SockLen +) -> Int { + #if ENABLE_MOCKING + if mockingEnabled { + return _mockInt(socket, buffer, length, flags, dest_addr, dest_len) + } + #endif + return sendto(socket, buffer, length, flags, dest_addr, dest_len) +} + +internal func system_recvfrom( + _ socket: CInt, + _ buffer: UnsafeMutableRawPointer?, + _ length: Int, + _ flags: CInt, + _ address: UnsafeMutablePointer?, + _ addres_len: UnsafeMutablePointer? +) -> Int { + #if ENABLE_MOCKING + if mockingEnabled { + return _mockInt(socket, buffer, length, flags, address, addres_len) + } + #endif + return recvfrom(socket, buffer, length, flags, address, addres_len) +} + +internal func system_sendmsg( + _ socket: CInt, + _ message: UnsafePointer?, + _ flags: CInt +) -> Int { + #if ENABLE_MOCKING + if mockingEnabled { return _mockInt(socket, message, flags) } + #endif + return sendmsg(socket, message, flags) +} + +internal func system_recvmsg( + _ socket: CInt, + _ message: UnsafeMutablePointer?, + _ flags: CInt +) -> Int { + #if ENABLE_MOCKING + if mockingEnabled { return _mockInt(socket, message, flags) } + #endif + return recvmsg(socket, message, flags) +} + +internal func system_getsockopt( + _ socket: CInt, + _ level: CInt, + _ option: CInt, + _ value: UnsafeMutableRawPointer?, + _ length: UnsafeMutablePointer? +) -> CInt { + #if ENABLE_MOCKING + if mockingEnabled { return _mock(socket, level, option, value, length) } + #endif + return getsockopt(socket, level, option, value, length) +} + +internal func system_setsockopt( + _ socket: CInt, + _ level: CInt, + _ option: CInt, + _ value: UnsafeRawPointer?, + _ length: socklen_t +) -> CInt { + #if ENABLE_MOCKING + if mockingEnabled { return _mock(socket, level, option, value, length) } + #endif + return setsockopt(socket, level, option, value, length) +} + +internal func system_inet_ntop( + _ af: CInt, + _ src: UnsafeRawPointer, + _ dst: UnsafeMutablePointer, + _ size: CInterop.SockLen +) -> CInt { // Note: returns 0 on success, -1 on failure, unlike the original + #if ENABLE_MOCKING + if mockingEnabled { return _mock(af, src, dst, size) } + #endif + let res = inet_ntop(af, src, dst, size) + if Int(bitPattern: res) == 0 { return -1 } + assert(Int(bitPattern: res) == Int(bitPattern: dst)) + return 0 +} + +internal func system_inet_pton( + _ af: CInt, _ src: UnsafePointer, _ dst: UnsafeMutableRawPointer +) -> CInt { + #if ENABLE_MOCKING + if mockingEnabled { return _mock(af, src, dst) } + #endif + return inet_pton(af, src, dst) +} + +internal func system_bind( + _ socket: CInt, _ addr: UnsafePointer?, _ len: socklen_t +) -> CInt { + #if ENABLE_MOCKING + if mockingEnabled { return _mock(socket, addr, len) } + #endif + return bind(socket, addr, len) +} + +internal func system_connect( + _ socket: CInt, _ addr: UnsafePointer?, _ len: socklen_t +) -> CInt { + #if ENABLE_MOCKING + if mockingEnabled { return _mock(socket, addr, len) } + #endif + return connect(socket, addr, len) +} + +internal func system_accept( + _ socket: CInt, + _ addr: UnsafeMutablePointer?, + _ len: UnsafeMutablePointer? +) -> CInt { + #if ENABLE_MOCKING + if mockingEnabled { return _mock(socket, addr, len) } + #endif + return accept(socket, addr, len) +} + +internal func system_getaddrinfo( + _ hostname: UnsafePointer?, + _ servname: UnsafePointer?, + _ hints: UnsafePointer?, + _ res: UnsafeMutablePointer?>? +) -> CInt { + #if ENABLE_MOCKING + if mockingEnabled { + return _mock(hostname.map { String(cString: $0) }, + servname.map { String(cString: $0) }, + hints, res) + } + #endif + return getaddrinfo(hostname, servname, hints, res) +} + +internal func system_getnameinfo( + _ sa: UnsafePointer?, + _ salen: CInterop.SockLen, + _ host: UnsafeMutablePointer?, + _ hostlen: CInterop.SockLen, + _ serv: UnsafeMutablePointer?, + _ servlen: CInterop.SockLen, + _ flags: CInt +) -> CInt { + #if ENABLE_MOCKING + if mockingEnabled { + return _mock(sa, salen, host, hostlen, serv, servlen, flags) + } + #endif + return getnameinfo(sa, salen, host, hostlen, serv, servlen, flags) +} + +internal func system_freeaddrinfo( + _ addrinfo: UnsafeMutablePointer? +) { + #if ENABLE_MOCKING + if mockingEnabled { + _ = _mock(addrinfo) + return + } + #endif + return freeaddrinfo(addrinfo) +} + +internal func system_gai_strerror(_ error: CInt) -> UnsafePointer { + #if ENABLE_MOCKING + // FIXME + #endif + return gai_strerror(error) +} diff --git a/Sources/SystemSockets/Util.swift b/Sources/SystemSockets/Util.swift new file mode 100644 index 00000000..1491a177 --- /dev/null +++ b/Sources/SystemSockets/Util.swift @@ -0,0 +1,100 @@ +import SystemPackage + +extension Errno { + internal static var current: Errno { + get { Errno(rawValue: system_errno) } + set { system_errno = newValue.rawValue } + } +} + +// Results in errno if i == -1 +// @available(macOS 10.16, iOS 14.0, watchOS 7.0, tvOS 14.0, *) +private func valueOrErrno( + _ i: I +) -> Result { + i == -1 ? .failure(Errno.current) : .success(i) +} + +// @available(macOS 10.16, iOS 14.0, watchOS 7.0, tvOS 14.0, *) +private func nothingOrErrno( + _ i: I +) -> Result<(), Errno> { + valueOrErrno(i).map { _ in () } +} + +// @available(macOS 10.16, iOS 14.0, watchOS 7.0, tvOS 14.0, *) +internal func valueOrErrno( + retryOnInterrupt: Bool, _ f: () -> I +) -> Result { + repeat { + switch valueOrErrno(f()) { + case .success(let r): return .success(r) + case .failure(let err): + guard retryOnInterrupt && err == .interrupted else { return .failure(err) } + break + } + } while true +} + +// @available(macOS 10.16, iOS 14.0, watchOS 7.0, tvOS 14.0, *) +internal func nothingOrErrno( + retryOnInterrupt: Bool, _ f: () -> I +) -> Result<(), Errno> { + valueOrErrno(retryOnInterrupt: retryOnInterrupt, f).map { _ in () } +} + +// Run a precondition for debug client builds +internal func _debugPrecondition( + _ condition: @autoclosure () -> Bool, + _ message: StaticString = StaticString(), + file: StaticString = #file, line: UInt = #line +) { + // Only check in debug mode. + if _slowPath(_isDebugAssertConfiguration()) { + precondition( + condition(), String(describing: message), file: file, line: line) + } +} + +extension OptionSet { + // Helper method for building up a comma-separated list of options + // + // Taking an array of descriptions reduces code size vs + // a series of calls due to avoiding register copies. Make sure + // to pass an array literal and not an array built up from a series of + // append calls, else that will massively bloat code size. This takes + // StaticStrings because otherwise we get a warning about getting evicted + // from the shared cache. + @inline(never) + internal func _buildDescription( + _ descriptions: [(Element, StaticString)] + ) -> String { + var copy = self + var result = "[" + + for (option, name) in descriptions { + if _slowPath(copy.contains(option)) { + result += name.description + copy.remove(option) + if !copy.isEmpty { result += ", " } + } + } + + if _slowPath(!copy.isEmpty) { + result += "\(Self.self)(rawValue: \(copy.rawValue))" + } + result += "]" + return result + } +} + +internal func _withOptionalUnsafePointerOrNull( + to value: T?, + _ body: (UnsafePointer?) throws -> R +) rethrows -> R { + guard let value = value else { + return try body(nil) + } + return try withUnsafePointer(to: value, body) +} + diff --git a/Tests/SystemSocketsTests/AncillaryMessageBufferTests.swift b/Tests/SystemSocketsTests/AncillaryMessageBufferTests.swift new file mode 100644 index 00000000..866e011a --- /dev/null +++ b/Tests/SystemSocketsTests/AncillaryMessageBufferTests.swift @@ -0,0 +1,48 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +import XCTest + +#if SYSTEM_PACKAGE +import SystemPackage +@testable import SystemSockets +#else +import System +#error("No socket support") +#endif + +// @available(...) +final class AncillaryMessageBufferTest: XCTestCase { + func testAppend() { + // Create a buffer of 100 messages, with varying payload lengths. + var buffer = SocketDescriptor.AncillaryMessageBuffer(minimumCapacity: 0) + for i in 0 ..< 100 { + let bytes = UnsafeMutableRawBufferPointer.allocate(byteCount: i, alignment: 1) + defer { bytes.deallocate() } + system_memset(bytes, to: UInt8(i)) + buffer.appendMessage(level: .init(rawValue: CInt(100 * i)), + type: .init(rawValue: CInt(1000 * i)), + bytes: UnsafeRawBufferPointer(bytes)) + } + // Check that we can access appended messages. + var i = 0 + for message in buffer { + XCTAssertEqual(Int(message.level.rawValue), 100 * i) + XCTAssertEqual(Int(message.type.rawValue), 1000 * i) + message.withUnsafeBytes { buffer in + XCTAssertEqual(buffer.count, i) + for idx in buffer.indices { + XCTAssertEqual(buffer[idx], UInt8(i), "byte #\(idx)") + } + } + i += 1 + } + XCTAssertEqual(i, 100, "Too many messages in buffer") + } +} diff --git a/Tests/SystemSocketsTests/SocketAddressTest.swift b/Tests/SystemSocketsTests/SocketAddressTest.swift new file mode 100644 index 00000000..44682ba6 --- /dev/null +++ b/Tests/SystemSocketsTests/SocketAddressTest.swift @@ -0,0 +1,228 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +import XCTest + +#if SYSTEM_PACKAGE +import SystemPackage +@testable import SystemSockets +#else +import System +#error("No socket support") +#endif + +// @available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +final class SocketAddressTest: XCTestCase { + func test_addressWithArbitraryData() { + for length in MemoryLayout.size ... 255 { + let range = 0 ..< UInt8(truncatingIfNeeded: length) + let data = Array(range) + data.withUnsafeBytes { source in + let address = SocketAddress(source) + address.withUnsafeBytes { copy in + XCTAssertEqual(copy.count, length) + XCTAssertTrue(range.elementsEqual(copy), "\(length)") + } + } + } + } + + func test_addressWithSockAddr() { + for length in MemoryLayout.size ... 255 { + let range = 0 ..< UInt8(truncatingIfNeeded: length) + let data = Array(range) + data.withUnsafeBytes { source in + let p = source.baseAddress!.assumingMemoryBound(to: CInterop.SockAddr.self) + let address = SocketAddress( + address: p, + length: CInterop.SockLen(source.count)) + address.withUnsafeBytes { copy in + XCTAssertEqual(copy.count, length) + XCTAssertTrue(range.elementsEqual(copy), "\(length)") + } + } + } + } + + func test_description() { + let ipv4 = SocketAddress(SocketAddress.IPv4(address: "1.2.3.4", port: 80)!) + let desc4 = "\(ipv4)" + XCTAssertEqual(desc4, "SocketAddress(family: ipv4, address: 1.2.3.4:80)") + + let ipv6 = SocketAddress(SocketAddress.IPv6(address: "1234::ff", port: 80)!) + let desc6 = "\(ipv6)" + XCTAssertEqual(desc6, "SocketAddress(family: ipv6, address: [1234::ff]:80)") + + let local = SocketAddress(SocketAddress.Local("/tmp/test.sock")) + let descl = "\(local)" + XCTAssertEqual(descl, "SocketAddress(family: local, address: /tmp/test.sock)") + } + + // MARK: IPv4 + + func test_addressWithIPv4Address() { + let ipv4 = SocketAddress.IPv4(address: "1.2.3.4", port: 42)! + let address = SocketAddress(ipv4) + if case .large = address._variant { + XCTFail("IPv4 address in big representation") + } + XCTAssertEqual(address.family, .ipv4) + if let extracted = address.ipv4 { + XCTAssertEqual(extracted, ipv4) + } else { + XCTFail("Cannot extract IPv4 address") + } + } + + func test_ipv4_address_string_conversions() { + typealias Address = SocketAddress.IPv4.Address + + func check( + _ string: String, + _ value: UInt32?, + file: StaticString = #file, + line: UInt = #line + ) { + switch (Address(string), value) { + case let (address?, value?): + XCTAssertEqual(address.rawValue, value, file: file, line: line) + case let (address?, nil): + let s = String(address.rawValue, radix: 16) + XCTFail("Got \(s), expected nil", file: file, line: line) + case let (nil, value?): + let s = String(value, radix: 16) + XCTFail("Got nil, expected \(s), file: file, line: line") + case (nil, nil): + // OK + break + } + + if let value = value { + let address = Address(rawValue: value) + let actual = "\(address)" + XCTAssertEqual( + actual, string, + "Mismatching description. Expected: \(string), actual: \(actual)", + file: file, line: line) + } + } + check("0.0.0.0", 0) + check("0.0.0.1", 1) + check("1.2.3.4", 0x01020304) + check("255.255.255.255", 0xFFFFFFFF) + check("apple.com", nil) + check("256.0.0.0", nil) + } + + func test_ipv4_description() { + let a1 = SocketAddress.IPv4(address: "1.2.3.4", port: 42)! + XCTAssertEqual("\(a1)", "1.2.3.4:42") + + let a2 = SocketAddress.IPv4(address: "192.168.1.1", port: 80)! + XCTAssertEqual("\(a2)", "192.168.1.1:80") + } + + // MARK: IPv6 + + func test_addressWithIPv6Address() { + let ipv6 = SocketAddress.IPv6(address: "2001:db8::", port: 42)! + let address = SocketAddress(ipv6) + if case .large = address._variant { + XCTFail("IPv6 address in big representation") + } + XCTAssertEqual(address.family, .ipv6) + if let extracted = address.ipv6 { + XCTAssertEqual(extracted, ipv6) + } else { + XCTFail("Cannot extract IPv6 address") + } + } + + func test_ipv6_address_string_conversions() { + typealias Address = SocketAddress.IPv6.Address + + func check( + _ string: String, + _ value: [UInt8]?, + file: StaticString = #file, + line: UInt = #line + ) { + let value = value.map { value in + value.withUnsafeBytes { bytes in + Address(bytes: bytes) + } + } + switch (Address(string), value) { + case let (address?, value?): + XCTAssertEqual(address, value, file: file, line: line) + case let (address?, nil): + XCTFail("Got \(address), expected nil", file: file, line: line) + case let (nil, value?): + XCTFail("Got nil, expected \(value), file: file, line: line") + case (nil, nil): + // OK + break + } + } + check( + "::", + [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + check( + "0011:2233:4455:6677:8899:aabb:ccdd:eeff", + [0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]) + check( + "1:203:405:607:809:a0b:c0d:e0f", + [0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f]) + check("1.2.3.4", nil) + check("apple.com", nil) + } + + func test_ipv6_description() { + let a1 = SocketAddress.IPv6(address: "2001:db8:85a3:8d3:1319:8a2e:370:7348", port: 42)! + XCTAssertEqual("\(a1)", "[2001:db8:85a3:8d3:1319:8a2e:370:7348]:42") + + let a2 = SocketAddress.IPv6(address: "2001::42", port: 80)! + XCTAssertEqual("\(a2)", "[2001::42]:80") + } + + // MARK: Local + + func test_addressWithLocalAddress_smol() { + let smolLocal = SocketAddress.Local("/tmp/test.sock") + let smol = SocketAddress(smolLocal) + if case .large = smol._variant { + XCTFail("Local address with short path in big representation") + } + XCTAssertEqual(smol.family, .local) + if let extracted = smol.local { + XCTAssertEqual(extracted, smolLocal) + } else { + XCTFail("Cannot extract Local address") + } + } + + func test_addressWithLocalAddress_large() { + let largeLocal = SocketAddress.Local( + "This is a really long filename, it almost doesn't fit on one line.sock") + let large = SocketAddress(largeLocal) + if case .small = large._variant { + XCTFail("Local address with long path in small representation") + } + XCTAssertEqual(large.family, .local) + if let extracted = large.local { + XCTAssertEqual(extracted, largeLocal) + } else { + XCTFail("Cannot extract Local address") + } + } + +} diff --git a/Tests/SystemSocketsTests/SocketTest.swift b/Tests/SystemSocketsTests/SocketTest.swift new file mode 100644 index 00000000..b4c84c59 --- /dev/null +++ b/Tests/SystemSocketsTests/SocketTest.swift @@ -0,0 +1,109 @@ +/* + This source file is part of the Swift System open source project + + Copyright (c) 2021 Apple Inc. and the Swift System project authors + Licensed under Apache License v2.0 with Runtime Library Exception + + See https://swift.org/LICENSE.txt for license information +*/ + +import XCTest + +#if SYSTEM_PACKAGE +import SystemPackage +import SystemSockets +#else +import System +#error("No socket support") +#endif + +// FIXME: Need collaborative mocking between systempackage and systemsockets +/* +// @available(...) +final class SocketTest: XCTestCase { + + func testSyscalls() { + + let socket = SocketDescriptor(rawValue: 3) + let rawSocket = socket.rawValue + let rawBuf = UnsafeMutableRawBufferPointer.allocate(byteCount: 100, alignment: 4) + defer { rawBuf.deallocate() } + let bufAddr = rawBuf.baseAddress + let bufCount = rawBuf.count + let writeBuf = UnsafeRawBufferPointer(rawBuf) + let writeBufAddr = writeBuf.baseAddress + + let syscallTestCases: Array = [ + MockTestCase(name: "socket", .interruptable, PF_INET6, SOCK_STREAM, 0) { + retryOnInterrupt in + _ = try SocketDescriptor.open(.ipv6, .stream, retryOnInterrupt: retryOnInterrupt) + }, + MockTestCase(name: "shutdown", .noInterrupt, rawSocket, SHUT_RD) { + retryOnInterrupt in + _ = try socket.shutdown(.read) + }, + MockTestCase(name: "listen", .noInterrupt, rawSocket, 999) { + retryOnInterrupt in + _ = try socket.listen(backlog: 999) + }, + MockTestCase( + name: "recv", .interruptable, rawSocket, bufAddr, bufCount, MSG_PEEK + ) { + retryOnInterrupt in + _ = try socket.receive( + into: rawBuf, flags: .peek, retryOnInterrupt: retryOnInterrupt) + }, + MockTestCase( + name: "send", .interruptable, rawSocket, writeBufAddr, bufCount, MSG_DONTROUTE + ) { + retryOnInterrupt in + _ = try socket.send( + writeBuf, flags: .doNotRoute, retryOnInterrupt: retryOnInterrupt) + }, + MockTestCase( + name: "recvfrom", .interruptable, rawSocket, Wildcard(), Wildcard(), 42, Wildcard(), Wildcard() + ) { retryOnInterrupt in + var sender = SocketAddress() + _ = try socket.receive(into: rawBuf, + sender: &sender, + flags: .init(rawValue: 42), + retryOnInterrupt: retryOnInterrupt) + }, + MockTestCase( + name: "sendto", .interruptable, rawSocket, Wildcard(), Wildcard(), 42, Wildcard(), Wildcard() + ) { retryOnInterrupt in + let recipient = SocketAddress(ipv4: .loopback, port: 123) + _ = try socket.send(UnsafeRawBufferPointer(rawBuf), + to: recipient, + flags: .init(rawValue: 42), + retryOnInterrupt: retryOnInterrupt) + }, + MockTestCase( + name: "recvmsg", .interruptable, rawSocket, Wildcard(), 42 + ) { retryOnInterrupt in + var sender = SocketAddress() + var ancillary = SocketDescriptor.AncillaryMessageBuffer() + _ = try socket.receive(into: rawBuf, + sender: &sender, + ancillary: &ancillary, + flags: .init(rawValue: 42), + retryOnInterrupt: retryOnInterrupt) + }, + MockTestCase( + name: "sendmsg", .interruptable, rawSocket, Wildcard(), 42 + ) { retryOnInterrupt in + let recipient = SocketAddress(ipv4: .loopback, port: 123) + let ancillary = SocketDescriptor.AncillaryMessageBuffer() + _ = try socket.send(UnsafeRawBufferPointer(rawBuf), + to: recipient, + ancillary: ancillary, + flags: .init(rawValue: 42), + retryOnInterrupt: retryOnInterrupt) + }, + ] + + syscallTestCases.forEach { $0.runAllTests() } + + } +} +*/ diff --git a/Tests/SystemTests/TestingInfrastructure.swift b/Tests/SystemTests/TestingInfrastructure.swift index 36e90be0..b8905fc4 100644 --- a/Tests/SystemTests/TestingInfrastructure.swift +++ b/Tests/SystemTests/TestingInfrastructure.swift @@ -15,6 +15,25 @@ import XCTest @testable import System #endif +internal struct Wildcard: Hashable {} + +extension Trace.Entry { + /// This implements `==` with wildcard matching. + /// (`Entry` cannot conform to `Equatable`/`Hashable` this way because + /// the wildcard matching `==` relation isn't transitive.) + internal func matches(_ other: Self) -> Bool { + guard self.name == other.name else { return false } + guard self.arguments.count == other.arguments.count else { return false } + for i in self.arguments.indices { + if self.arguments[i] is Wildcard || other.arguments[i] is Wildcard { + continue + } + guard self.arguments[i] == other.arguments[i] else { return false } + } + return true + } +} + // To aid debugging, force failures to fatal error internal var forceFatalFailures = false @@ -40,7 +59,7 @@ extension TestCase { _ message: String? = nil ) where S1.Element: Equatable, S1.Element == S2.Element { if !expected.elementsEqual(actual) { - defer { print("expected: \(expected), actual: \(actual)") } + defer { print("expected: \(expected)\n actual: \(actual)") } fail(message) } } @@ -49,7 +68,7 @@ extension TestCase { _ message: String? = nil ) { if actual != expected { - defer { print("expected: \(expected), actual: \(actual)") } + defer { print("expected: \(expected)\n actual: \(actual)") } fail(message) } } @@ -62,6 +81,27 @@ extension TestCase { fail(message) } } + func expectMatch( + _ expected: Trace.Entry?, _ actual: Trace.Entry?, + _ message: String? = nil + ) { + func check() -> Bool { + switch (expected, actual) { + case let (expected?, actual?): + return expected.matches(actual) + case (nil, nil): + return true + default: + return false + } + } + if !check() { + let e = expected.map { "\($0)" } ?? "nil" + let a = actual.map { "\($0)" } ?? "nil" + defer { print("expected: \(e)\n actual: \(a)") } + fail(message) + } + } func expectNil( _ actual: T?, _ message: String? = nil @@ -149,7 +189,7 @@ internal struct MockTestCase: TestCase { // Test our API mappings to the lower-level syscall invocation do { try body(true) - self.expectEqual(self.expected, mocking.trace.dequeue()) + self.expectMatch(self.expected, mocking.trace.dequeue()) } catch { self.fail() } @@ -158,9 +198,9 @@ internal struct MockTestCase: TestCase { guard interruptBehavior != .noError else { do { try body(interruptable) - self.expectEqual(self.expected, mocking.trace.dequeue()) + self.expectMatch(self.expected, mocking.trace.dequeue()) try body(!interruptable) - self.expectEqual(self.expected, mocking.trace.dequeue()) + self.expectMatch(self.expected, mocking.trace.dequeue()) } catch { self.fail() } @@ -177,7 +217,7 @@ internal struct MockTestCase: TestCase { self.fail() } catch Errno.interrupted { // Success! - self.expectEqual(self.expected, mocking.trace.dequeue()) + self.expectMatch(self.expected, mocking.trace.dequeue()) } catch { self.fail() } @@ -188,13 +228,13 @@ internal struct MockTestCase: TestCase { mocking.forceErrno = .counted(errno: EINTR, count: 3) try body(interruptable) - self.expectEqual(self.expected, mocking.trace.dequeue()) // EINTR - self.expectEqual(self.expected, mocking.trace.dequeue()) // EINTR - self.expectEqual(self.expected, mocking.trace.dequeue()) // EINTR - self.expectEqual(self.expected, mocking.trace.dequeue()) // Success + self.expectMatch(self.expected, mocking.trace.dequeue()) // EINTR + self.expectMatch(self.expected, mocking.trace.dequeue()) // EINTR + self.expectMatch(self.expected, mocking.trace.dequeue()) // EINTR + self.expectMatch(self.expected, mocking.trace.dequeue()) // Success } catch Errno.interrupted { self.expectFalse(interruptable) - self.expectEqual(self.expected, mocking.trace.dequeue()) // EINTR + self.expectMatch(self.expected, mocking.trace.dequeue()) // EINTR } catch { self.fail() }