Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//

import Foundation
import Logging

/// The base protocol which all data stream-related classes conform to.
///
Expand All @@ -14,6 +15,7 @@ import Foundation
/// managing the connection (opening, closing, and reconnecting), creating parameters for allowing
/// and disallowing content, and handling sequences.
public protocol ATEventStreamConfiguration: AnyObject {
var logger: Logger { get }

/// The URL of the relay.
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import Foundation
import SwiftCBOR
import Logging

extension ATEventStreamConfiguration {

/// Connects the client to the event stream.
///
/// Normally, when connecting to the event stream, it will start from the first message the event stream gets. The client will always look at the last successful
Expand All @@ -22,10 +22,13 @@ extension ATEventStreamConfiguration {
///
/// - Parameter cursor: The mark used to indicate the starting point for the next set of results. Optional.
public func connect(cursor: Int64? = nil) async {
logger.trace("In connect()")
self.isConnected = true
self.webSocketTask.resume()
logger.debug("WebSocketTask resumed.", metadata: ["isConnected": "\(self.isConnected)"])

await self.receiveMessages()
logger.trace("Exiting connect()")
}

/// Disconnects the client from the event stream.
Expand All @@ -34,7 +37,10 @@ extension ATEventStreamConfiguration {
/// - closeCode: A code that indicates why the event stream connection closed.
/// - reason: The reason why the client disconnected from the server.
public func disconnect(with closeCode: URLSessionWebSocketTask.CloseCode, reason: Data) {
logger.trace("In disconnect()")
logger.debug("Closing websocket", metadata: ["closeCode": "\(closeCode)", "reason": "\(reason)"])
webSocketTask.cancel(with: closeCode, reason: reason)
logger.trace("Exiting disconnect()")
}

/// Attempts to reconnect the client to the event stream after a disconnect.
Expand All @@ -45,18 +51,19 @@ extension ATEventStreamConfiguration {
/// - cursor: The mark used to indicate the starting point for the next set of results. Optional.
/// - retry: The number of times the connection attempts can be retried.
func reconnect(cursor: Int64?, retry: Int) async {
logger.trace("In reconnect()")
guard isConnected == false else {
print("Already connected. No need to reconnect.")
logger.debug("Already connected. No need to reconnect.")
return
}

let lastCursor: Int64 = sequencePosition ?? 0

if lastCursor > 0 {
logger.debug("Fetching missed messages", metadata: ["lastCursor": "\(lastCursor)"])
await fetchMissedMessages(fromSequence: lastCursor)
}


logger.trace("Exiting reconnect()")
}

/// Receives decoded messages and manages the sequence number.
Expand All @@ -65,26 +72,32 @@ extension ATEventStreamConfiguration {
///
/// [DAG_CBOR]: https://ipld.io/docs/codecs/known/dag-cbor/
public func receiveMessages() async {
logger.trace("In receiveMessages()")
while isConnected {
do {
let message = try await webSocketTask.receive()

switch message {
case .string(let base64String):
logger.debug("Received a string message", metadata: ["length": "\(base64String.count)"])
ATCBORManager().decodeCBOR(from: base64String)
case .data(let data):
logger.debug("Received a data message", metadata: ["length": "\(data.count)"])
let base64String = data.base64EncodedString()
ATCBORManager().decodeCBOR(from: base64String)
@unknown default:
print("Received an unknown type of message.")
logger.warning("Received an unknown type of message.")
}
} catch {
print("Error receiving message: \(error)")
logger.error("Error while receiving message.", metadata: ["error": "\(error)"])
}
}
logger.trace("Exiting receiveMessages()")
}

public func fetchMissedMessages(fromSequence lastCursor: Int64) async {
logger.trace("In fetchMissedMessages()")

logger.trace("Exiting fetchMissedMessages()")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
//

import Foundation
import Logging

/// The base class for the AT Protocol's Firehose event stream.
class ATFirehoseStream: ATEventStreamConfiguration {

internal var logger = Logger(label: "ATFirehoseStream")
/// Indicates whether the event stream is connected. Defaults to `false`.
internal var isConnected: Bool = false

/// The URL of the relay. Defaults to `wss://bsky.network`.
public var relayURL: String = "wss://bsky.network"

Expand Down Expand Up @@ -49,15 +49,24 @@ class ATFirehoseStream: ATEventStreamConfiguration {
/// to `URLSessionConfiguration.default`.
required init(relayURL: String, namespacedIdentifiertURL: String, cursor: Int64?, sequencePosition: Int64?,
urlSessionConfiguration: URLSessionConfiguration = .default, webSocketTask: URLSessionWebSocketTask) async throws {
logger.trace("In init()")
logger.trace("Initializing the ATEventStreamConfiguration")
self.relayURL = relayURL
self.namespacedIdentifiertURL = namespacedIdentifiertURL
self.cursor = cursor
self.sequencePosition = sequencePosition
self.urlSessionConfiguration = urlSessionConfiguration
self.urlSession = URLSession(configuration: urlSessionConfiguration)
self.webSocketTask = webSocketTask

guard let webSocketURL = URL(string: "\(relayURL)/xrpc/\(namespacedIdentifiertURL)") else { throw ATRequestPrepareError.invalidFormat }

logger.debug("Opening a websocket", metadata: ["relayUrl": "\(relayURL)", "namespacedIdentifiertURL": "\(namespacedIdentifiertURL)"])
guard let webSocketURL = URL(string: "\(relayURL)/xrpc/\(namespacedIdentifiertURL)") else {
logger.error("Unable to create the websocket URL due to an invalid format.")
throw ATRequestPrepareError.invalidFormat
}

logger.debug("Creating the websocket task")
self.webSocketTask = urlSession.webSocketTask(with: webSocketURL)
logger.trace("Exiting init()")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,11 @@ public class ATProtocolConfiguration: ProtocolConfiguration {
self.logIdentifier = logIdentifier ?? Bundle.main.bundleIdentifier ?? "com.cjrriley.ATProtoKit"
self.logCategory = logCategory ?? "ATProtoKit"
self.logLevel = logLevel

#if canImport(os)
// Create the logger and bootstrap it for use in the library
LoggingSystem.bootstrap { label in
ATLogHandler(subsystem: label, category: logCategory ?? "ATProtoKit")
}
#else
LoggingSystem.bootstrap(StreamLogHandler.standardOutput)
#endif

logger = Logger(label: logIdentifier ?? "com.cjrriley.ATProtoKit")
logger?.logLevel = logLevel ?? .info
Expand All @@ -94,10 +91,14 @@ public class ATProtocolConfiguration: ProtocolConfiguration {
/// - Throws: An ``ATProtoError``-conforming error type, depending on the issye. Go to
/// ``ATAPIError`` and ``ATRequestPrepareError`` for more details.
public func authenticate(authenticationFactorToken: String? = nil) async throws -> Result<UserSession, Error> {
logger?.trace("In authenticate()")

guard let requestURL = URL(string: "\(self.pdsURL)/xrpc/com.atproto.server.createSession") else {
logger?.error("Error while authenticating with the server", metadata: ["error": "\(ATRequestPrepareError.invalidRequestURL)"])
return .failure(ATRequestPrepareError.invalidRequestURL)
}

logger?.debug("Setting the session credentials")
let credentials = ComAtprotoLexicon.Server.CreateSessionRequestBody(
identifier: handle,
password: appPassword,
Expand All @@ -107,6 +108,8 @@ public class ATProtocolConfiguration: ProtocolConfiguration {
do {
let request = APIClientService.createRequest(forRequest: requestURL,
andMethod: .post)

logger?.debug("Authenticating with the server.", metadata: ["requestURL": "\(requestURL)"])
var response = try await APIClientService.sendRequest(request,
withEncodingBody: credentials,
decodeTo: UserSession.self)
Expand All @@ -115,11 +118,16 @@ public class ATProtocolConfiguration: ProtocolConfiguration {
if self.logger != nil {
response.logger = self.logger
}


logger?.debug("Authentication successful")
logger?.trace("Exiting authenticate()")
return .success(response)
} catch {
logger?.error("Authentication request failed with error.", metadata: ["error": "\(error)"])
logger?.trace("Exiting authenticate()")
return .failure(error)
}

}

/// Creates an a new account for the user.
Expand Down Expand Up @@ -165,11 +173,14 @@ public class ATProtocolConfiguration: ProtocolConfiguration {
recoveryKey: String? = nil,
plcOp: UnknownType? = nil
) async throws -> Result<UserSession, Error> {
logger?.trace("In createAccount()"])
guard let requestURL = URL(string: "\(self.pdsURL)/xrpc/com.atproto.server.createAccount") else {
logger?.error("Error while creating account", metadata: ["error": "\(ATRequestPrepareError.invalidRequestURL)"])
return .failure(ATRequestPrepareError.invalidRequestURL)
}

let requestBody = ComAtprotoLexicon.Server.CreateAccountRequestBody(
develop
email: email,
handle: handle,
existingDID: existingDID,
Expand All @@ -187,6 +198,8 @@ public class ATProtocolConfiguration: ProtocolConfiguration {
acceptValue: nil,
contentTypeValue: nil,
authorizationValue: nil)

logger?.debug("Crreating user account", metadata: ["handle": "\(handle)"])
var response = try await APIClientService.sendRequest(request,
withEncodingBody: requestBody,
decodeTo: UserSession.self)
Expand All @@ -195,9 +208,13 @@ public class ATProtocolConfiguration: ProtocolConfiguration {
if self.logger != nil {
response.logger = self.logger
}


logger?.debug("User account creation successful", metadata: ["handle": "\(handle)"])
logger?.trace("Exiting createAccount()")
return .success(response)
} catch {
logger?.error("Account creation failed with error.", metadata: ["error": "\(error)"])
logger?.trace("Exiting createAccount()")
return .failure(error)
}
}
Expand All @@ -220,20 +237,27 @@ public class ATProtocolConfiguration: ProtocolConfiguration {
by accessToken: String,
pdsURL: String? = nil
) async throws -> Result<SessionResponse, Error> {
logger?.trace("In getSession()")
guard let sessionURL = pdsURL != nil ? pdsURL : self.pdsURL,
let requestURL = URL(string: "\(sessionURL)/xrpc/com.atproto.server.getSession") else {
logger?.error("Error while obtaining session", metadata: ["error": "\(ATRequestPrepareError.invalidRequestURL)"])
return .failure(ATRequestPrepareError.invalidRequestURL)
}

do {
let request = APIClientService.createRequest(forRequest: requestURL,
andMethod: .get,
authorizationValue: "Bearer \(accessToken)")
logger?.debug("Obtaining the session")
let response = try await APIClientService.sendRequest(request,
decodeTo: SessionResponse.self)

logger?.debug("Session obtained successfully")
logger?.trace("Exiting getSession()")
return .success(response)
} catch {
logger?.error("Error while obtaining session", metadata: ["error": "\(error)"])
logger?.trace("Exiting getSession()")
return .failure(error)
}
}
Expand All @@ -256,15 +280,18 @@ public class ATProtocolConfiguration: ProtocolConfiguration {
using refreshToken: String,
pdsURL: String? = nil
) async throws -> Result<UserSession, Error> {
logger?.info("In refreshSession()")
guard let sessionURL = pdsURL != nil ? pdsURL : self.pdsURL,
let requestURL = URL(string: "\(sessionURL)/xrpc/com.atproto.server.refreshSession") else {
logger?.error("Error while refreshing the session", metadata: ["error": "\(ATRequestPrepareError.invalidRequestURL)"])
return .failure(ATRequestPrepareError.invalidRequestURL)
}

do {
let request = APIClientService.createRequest(forRequest: requestURL,
andMethod: .post,
authorizationValue: "Bearer \(refreshToken)")
logger?.debug("Refreshing the session")
var response = try await APIClientService.sendRequest(request,
decodeTo: UserSession.self)
response.pdsURL = self.pdsURL
Expand All @@ -273,8 +300,12 @@ public class ATProtocolConfiguration: ProtocolConfiguration {
response.logger = self.logger
}

logger?.debug("Session refreshed successfully")
logger?.trace("Exiting refreshSession()")
return .success(response)
} catch {
logger?.error("Error while refreshing the session", metadata: ["error": "\(error)"])
logger?.trace("Exiting refreshSession()")
return .failure(error)
}
}
Expand All @@ -295,19 +326,22 @@ public class ATProtocolConfiguration: ProtocolConfiguration {
using accessToken: String,
pdsURL: String? = nil
) async throws {
logger?.trace("In deleteSession()")
guard let sessionURL = pdsURL != nil ? pdsURL : self.pdsURL,
let requestURL = URL(string: "\(sessionURL)/xrpc/com.atproto.server.deleteSession") else {
logger?.error("Error while deleting the session", metadata: ["error": "\(ATRequestPrepareError.invalidRequestURL)"])
throw ATRequestPrepareError.invalidRequestURL
}

do {
let request = APIClientService.createRequest(forRequest: requestURL,
andMethod: .post,
authorizationValue: "Bearer \(accessToken)")

logger?.debug("Deleting the session")
_ = try await APIClientService.sendRequest(request,
withEncodingBody: nil)
} catch {
logger?.error("Error while deleting the session", metadata: ["error": "\(error)"])
throw error
}
}
Expand Down
Loading