diff --git a/Sources/Reachability.swift b/Sources/Reachability.swift index f4c9ce9c..77afa1c2 100644 --- a/Sources/Reachability.swift +++ b/Sources/Reachability.swift @@ -29,11 +29,11 @@ import SystemConfiguration import Foundation public enum ReachabilityError: Error { - case FailedToCreateWithAddress(sockaddr_in) - case FailedToCreateWithHostname(String) - case UnableToSetCallback - case UnableToSetDispatchQueue - case UnableToGetInitialFlags + case failedToCreateWithAddress(sockaddr, Int32) + case failedToCreateWithHostname(String, Int32) + case unableToSetCallback(Int32) + case unableToSetDispatchQueue(Int32) + case unableToGetFlags(Int32) } @available(*, unavailable, renamed: "Notification.Name.reachabilityChanged") @@ -113,35 +113,49 @@ public class Reachability { #endif }() - fileprivate var notifierRunning = false + fileprivate(set) var notifierRunning = false fileprivate let reachabilityRef: SCNetworkReachability fileprivate let reachabilitySerialQueue: DispatchQueue + fileprivate let notificationQueue: DispatchQueue? fileprivate(set) var flags: SCNetworkReachabilityFlags? { didSet { guard flags != oldValue else { return } - reachabilityChanged() + notifyReachabilityChanged() } } - required public init(reachabilityRef: SCNetworkReachability, queueQoS: DispatchQoS = .default, targetQueue: DispatchQueue? = nil) { + required public init(reachabilityRef: SCNetworkReachability, + queueQoS: DispatchQoS = .default, + targetQueue: DispatchQueue? = nil, + notificationQueue: DispatchQueue? = .main) { self.allowsCellularConnection = true self.reachabilityRef = reachabilityRef self.reachabilitySerialQueue = DispatchQueue(label: "uk.co.ashleymills.reachability", qos: queueQoS, target: targetQueue) + self.notificationQueue = notificationQueue } - public convenience init?(hostname: String, queueQoS: DispatchQoS = .default, targetQueue: DispatchQueue? = nil) { - guard let ref = SCNetworkReachabilityCreateWithName(nil, hostname) else { return nil } - self.init(reachabilityRef: ref, queueQoS: queueQoS, targetQueue: targetQueue) + public convenience init(hostname: String, + queueQoS: DispatchQoS = .default, + targetQueue: DispatchQueue? = nil, + notificationQueue: DispatchQueue? = .main) throws { + guard let ref = SCNetworkReachabilityCreateWithName(nil, hostname) else { + throw ReachabilityError.failedToCreateWithHostname(hostname, SCError()) + } + self.init(reachabilityRef: ref, queueQoS: queueQoS, targetQueue: targetQueue, notificationQueue: notificationQueue) } - public convenience init?(queueQoS: DispatchQoS = .default, targetQueue: DispatchQueue? = nil) { + public convenience init(queueQoS: DispatchQoS = .default, + targetQueue: DispatchQueue? = nil, + notificationQueue: DispatchQueue? = .main) throws { var zeroAddress = sockaddr() zeroAddress.sa_len = UInt8(MemoryLayout.size) zeroAddress.sa_family = sa_family_t(AF_INET) - guard let ref = SCNetworkReachabilityCreateWithAddress(nil, &zeroAddress) else { return nil } + guard let ref = SCNetworkReachabilityCreateWithAddress(nil, &zeroAddress) else { + throw ReachabilityError.failedToCreateWithAddress(zeroAddress, SCError()) + } - self.init(reachabilityRef: ref, queueQoS: queueQoS, targetQueue: targetQueue) + self.init(reachabilityRef: ref, queueQoS: queueQoS, targetQueue: targetQueue, notificationQueue: notificationQueue) } deinit { @@ -163,15 +177,16 @@ public extension Reachability { } var context = SCNetworkReachabilityContext(version: 0, info: nil, retain: nil, release: nil, copyDescription: nil) - context.info = UnsafeMutableRawPointer(Unmanaged.passUnretained(self).toOpaque()) + context.info = Unmanaged.passUnretained(self).toOpaque() + if !SCNetworkReachabilitySetCallback(reachabilityRef, callback, &context) { stopNotifier() - throw ReachabilityError.UnableToSetCallback + throw ReachabilityError.unableToSetCallback(SCError()) } if !SCNetworkReachabilitySetDispatchQueue(reachabilityRef, reachabilitySerialQueue) { stopNotifier() - throw ReachabilityError.UnableToSetDispatchQueue + throw ReachabilityError.unableToSetDispatchQueue(SCError()) } // Perform an initial check @@ -205,18 +220,7 @@ public extension Reachability { } var description: String { - guard let flags = flags else { return "unavailable flags" } - let W = isRunningOnDevice ? (flags.isOnWWANFlagSet ? "W" : "-") : "X" - let R = flags.isReachableFlagSet ? "R" : "-" - let c = flags.isConnectionRequiredFlagSet ? "c" : "-" - let t = flags.isTransientConnectionFlagSet ? "t" : "-" - let i = flags.isInterventionRequiredFlagSet ? "i" : "-" - let C = flags.isConnectionOnTrafficFlagSet ? "C" : "-" - let D = flags.isConnectionOnDemandFlagSet ? "D" : "-" - let l = flags.isLocalAddressFlagSet ? "l" : "-" - let d = flags.isDirectFlagSet ? "d" : "-" - - return "\(W)\(R) \(c)\(t)\(i)\(C)\(D)\(l)\(d)" + return flags?.description ?? "unavailable flags" } } @@ -227,21 +231,23 @@ fileprivate extension Reachability { var flags = SCNetworkReachabilityFlags() if !SCNetworkReachabilityGetFlags(self.reachabilityRef, &flags) { self.stopNotifier() - throw ReachabilityError.UnableToGetInitialFlags + throw ReachabilityError.unableToGetFlags(SCError()) } self.flags = flags } } - func reachabilityChanged() { - let block = connection != .none ? whenReachable : whenUnreachable - DispatchQueue.main.async { [weak self] in + func notifyReachabilityChanged() { + let notify = { [weak self] in guard let self = self else { return } - block?(self) + self.connection != .none ? self.whenReachable?(self) : self.whenUnreachable?(self) self.notificationCenter.post(name: .reachabilityChanged, object: self) } + + // notify on the configured `notificationQueue`, or the caller's (i.e. `reachabilitySerialQueue`) + notificationQueue?.async(execute: notify) ?? notify() } } @@ -313,4 +319,18 @@ extension SCNetworkReachabilityFlags { var isConnectionRequiredAndTransientFlagSet: Bool { return intersection([.connectionRequired, .transientConnection]) == [.connectionRequired, .transientConnection] } + + var description: String { + let W = isOnWWANFlagSet ? "W" : "-" + let R = isReachableFlagSet ? "R" : "-" + let c = isConnectionRequiredFlagSet ? "c" : "-" + let t = isTransientConnectionFlagSet ? "t" : "-" + let i = isInterventionRequiredFlagSet ? "i" : "-" + let C = isConnectionOnTrafficFlagSet ? "C" : "-" + let D = isConnectionOnDemandFlagSet ? "D" : "-" + let l = isLocalAddressFlagSet ? "l" : "-" + let d = isDirectFlagSet ? "d" : "-" + + return "\(W)\(R) \(c)\(t)\(i)\(C)\(D)\(l)\(d)" + } } diff --git a/Tests/ReachabilityTests.swift b/Tests/ReachabilityTests.swift index 8200edd4..3bb2cdcf 100644 --- a/Tests/ReachabilityTests.swift +++ b/Tests/ReachabilityTests.swift @@ -14,7 +14,7 @@ class ReachabilityTests: XCTestCase { func testValidHost() { let validHostName = "google.com" - guard let reachability = Reachability(hostname: validHostName) else { + guard let reachability = try? Reachability(hostname: validHostName) else { return XCTFail("Unable to create reachability") } @@ -47,7 +47,7 @@ class ReachabilityTests: XCTestCase { let invalidHostName = "invalidhost" - guard let reachability = Reachability(hostname: invalidHostName) else { + guard let reachability = try? Reachability(hostname: invalidHostName) else { return XCTFail("Unable to create reachability") }