diff --git a/Sources/OpenAISwift/Models/ChatMessage.swift b/Sources/OpenAISwift/Models/ChatMessage.swift index c987cf9..9cd3daf 100644 --- a/Sources/OpenAISwift/Models/ChatMessage.swift +++ b/Sources/OpenAISwift/Models/ChatMessage.swift @@ -34,6 +34,10 @@ public struct ChatMessage: Codable, Identifiable { self.role = role self.content = content } + + private enum CodingKeys: String, CodingKey { + case role, content + } } /// A structure that represents a chat conversation. diff --git a/Sources/OpenAISwift/OpenAIRequestHandler.swift b/Sources/OpenAISwift/OpenAIRequestHandler.swift new file mode 100644 index 0000000..2e760ac --- /dev/null +++ b/Sources/OpenAISwift/OpenAIRequestHandler.swift @@ -0,0 +1,35 @@ +// +// OpenAIRequestHandler.swift +// LumenateApp +// +// Created by Simon Mitchell on 28/11/2023. +// + +import Foundation + +public protocol OpenAIRequestHandler { + + /// Function which performs the request as required from the user. + /// - Note: However the request is made, it must do a few things + /// 1. Call `completionHandler` with any errors or response data + /// 2. Data returned must be decodable to the model types defined in this library + /// - Parameters: + /// - endpoint: The endpoint to make a request to + /// - body: The body of the request to make + /// - completionHandler: A closure to be called once the request has completed + func makeRequest(_ endpoint: OpenAIEndpointProvider.API, body: BodyType, completionHandler: @escaping (Result) -> Void) + + /// Function which streams the request as required by the user. + /// - Note: Only "chat" api is streamable for now, so this always has return type of `StreamMessageResult` + /// - Parameters: + /// - endpoint: The endpoint to stream the request from. Note: currently this is only for "chat" endpoint + /// - body: The body of the request to make + /// - eventReceived: Called Multiple times, returns an OpenAI Data Model + /// - completion: Triggers when sever complete sending the message + func streamRequest( + _ endpoint: OpenAIEndpointProvider.API, + body: BodyType, + eventReceived: ((Result, OpenAIError>) -> Void)?, + completion: (() -> Void)? + ) +} diff --git a/Sources/OpenAISwift/OpenAISwift.swift b/Sources/OpenAISwift/OpenAISwift.swift index 67090ac..e5a1a8d 100644 --- a/Sources/OpenAISwift/OpenAISwift.swift +++ b/Sources/OpenAISwift/OpenAISwift.swift @@ -11,37 +11,23 @@ public enum OpenAIError: Error { } public class OpenAISwift { - fileprivate let config: Config - fileprivate let handler = ServerSentEventsHandler() - - /// Configuration object for the client - public struct Config { - - /// Initialiser - /// - Parameter session: the session to use for network requests. - public init(baseURL: String, endpointPrivider: OpenAIEndpointProvider, session: URLSession, authorizeRequest: @escaping (inout URLRequest) -> Void) { - self.baseURL = baseURL - self.endpointProvider = endpointPrivider - self.authorizeRequest = authorizeRequest - self.session = session - } - let baseURL: String - let endpointProvider: OpenAIEndpointProvider - let session:URLSession - let authorizeRequest: (inout URLRequest) -> Void - - public static func makeDefaultOpenAI(apiKey: String) -> Self { - .init(baseURL: "https://api.openai.com", - endpointPrivider: OpenAIEndpointProvider(source: .openAI), - session: .shared, - authorizeRequest: { request in - request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") - }) - } + + // Typealias for backward compatibility so allowing custom request makers + // doesn't introduce breaking changes to the public API + public typealias Config = URLSessionRequestHandler + + fileprivate let requestHandler: OpenAIRequestHandler + + /// Initialises OpenAISwift with a given request handler + /// - Parameter requestHandler: The request handler to make requests with + public init(requestHandler: OpenAIRequestHandler) { + self.requestHandler = requestHandler } + /// Deprecated initialiser for backwards API support to remove breaking change when introducing OpenAIRequestHandler protocol + /// - Parameter config: The config to initialise with public init(config: Config) { - self.config = config + self.requestHandler = config } } @@ -55,9 +41,8 @@ extension OpenAISwift { public func sendCompletion(with prompt: String, model: OpenAIModelType = .gpt3(.davinci), maxTokens: Int = 16, temperature: Double = 1, completionHandler: @escaping (Result, OpenAIError>) -> Void) { let endpoint = OpenAIEndpointProvider.API.completions let body = Command(prompt: prompt, model: model.modelName, maxTokens: maxTokens, temperature: temperature) - let request = prepareRequest(endpoint, body: body) - makeRequest(request: request) { result in + requestHandler.makeRequest(endpoint, body: body) { result in switch result { case .success(let success): do { @@ -81,9 +66,8 @@ extension OpenAISwift { public func sendEdits(with instruction: String, model: OpenAIModelType = .feature(.davinci), input: String = "", completionHandler: @escaping (Result, OpenAIError>) -> Void) { let endpoint = OpenAIEndpointProvider.API.edits let body = Instruction(instruction: instruction, model: model.modelName, input: input) - let request = prepareRequest(endpoint, body: body) - makeRequest(request: request) { result in + requestHandler.makeRequest(endpoint, body: body) { result in switch result { case .success(let success): do { @@ -106,9 +90,8 @@ extension OpenAISwift { public func sendModerations(with input: String, model: OpenAIModelType = .moderation(.latest), completionHandler: @escaping (Result, OpenAIError>) -> Void) { let endpoint = OpenAIEndpointProvider.API.moderations let body = Moderation(input: input, model: model.modelName) - let request = prepareRequest(endpoint, body: body) - makeRequest(request: request) { result in + requestHandler.makeRequest(endpoint, body: body) { result in switch result { case .success(let success): do { @@ -162,10 +145,8 @@ extension OpenAISwift { frequencyPenalty: frequencyPenalty, logitBias: logitBias, stream: false) - - let request = prepareRequest(endpoint, body: body) - makeRequest(request: request) { result in + requestHandler.makeRequest(endpoint, body: body) { result in switch result { case .success(let success): if let chatErr = try? JSONDecoder().decode(ChatError.self, from: success) as ChatError { @@ -197,9 +178,8 @@ extension OpenAISwift { let endpoint = OpenAIEndpointProvider.API.embeddings let body = EmbeddingsInput(input: input, model: model.modelName) - - let request = prepareRequest(endpoint, body: body) - makeRequest(request: request) { result in + + requestHandler.makeRequest(endpoint, body: body) { result in switch result { case .success(let success): do { @@ -255,10 +235,8 @@ extension OpenAISwift { frequencyPenalty: frequencyPenalty, logitBias: logitBias, stream: true) - let request = prepareRequest(endpoint, body: body) - handler.onEventReceived = onEventReceived - handler.onComplete = onComplete - handler.connect(with: request) + + requestHandler.streamRequest(endpoint, body: body, eventReceived: onEventReceived, completion: onComplete) } @@ -272,9 +250,8 @@ extension OpenAISwift { public func sendImages(with prompt: String, numImages: Int = 1, size: ImageSize = .size1024, user: String? = nil, completionHandler: @escaping (Result, OpenAIError>) -> Void) { let endpoint = OpenAIEndpointProvider.API.images let body = ImageGeneration(prompt: prompt, n: numImages, size: size, user: user) - let request = prepareRequest(endpoint, body: body) - - makeRequest(request: request) { result in + + requestHandler.makeRequest(endpoint, body: body) { result in switch result { case .success(let success): do { @@ -288,37 +265,6 @@ extension OpenAISwift { } } } - - private func makeRequest(request: URLRequest, completionHandler: @escaping (Result) -> Void) { - let session = config.session - let task = session.dataTask(with: request) { (data, response, error) in - if let error = error { - completionHandler(.failure(error)) - } else if let data = data { - completionHandler(.success(data)) - } - } - - task.resume() - } - - private func prepareRequest(_ endpoint: OpenAIEndpointProvider.API, body: BodyType) -> URLRequest { - var urlComponents = URLComponents(url: URL(string: config.baseURL)!, resolvingAgainstBaseURL: true) - urlComponents?.path = config.endpointProvider.getPath(api: endpoint) - var request = URLRequest(url: urlComponents!.url!) - request.httpMethod = config.endpointProvider.getMethod(api: endpoint) - - config.authorizeRequest(&request) - - request.setValue("application/json", forHTTPHeaderField: "content-type") - - let encoder = JSONEncoder() - if let encoded = try? encoder.encode(body) { - request.httpBody = encoded - } - - return request - } } extension OpenAISwift { diff --git a/Sources/OpenAISwift/Request Handlers/URLSessionRequestHandler.swift b/Sources/OpenAISwift/Request Handlers/URLSessionRequestHandler.swift new file mode 100644 index 0000000..5269ac0 --- /dev/null +++ b/Sources/OpenAISwift/Request Handlers/URLSessionRequestHandler.swift @@ -0,0 +1,144 @@ +// +// URLSessionRequestHandler.swift +// LumenateApp +// +// Created by Simon Mitchell on 28/11/2023. +// + +import Foundation + +public final class URLSessionRequestHandler: NSObject, OpenAIRequestHandler { + + let baseURL: String + + let endpointProvider: OpenAIEndpointProvider + + let session: URLSession + + let authorizeRequest: (inout URLRequest) -> Void + + var onEventReceived: ((Result, OpenAIError>) -> Void)? + + var onComplete: (() -> Void)? + + private lazy var streamingSession: URLSession = URLSession(configuration: .default, delegate: self, delegateQueue: nil) + + private var streamingTask: URLSessionDataTask? + + /// Default memberwise initialiser + /// - Parameters: + /// - baseURL: The base url to load data from + /// - endpointPrivider: An endpoint provider for generating full urls for each request + /// - session: The session to use for network requests + /// - authorizeRequest: A closure to authenticate a specific `URLRequest` + public init(baseURL: String, endpointPrivider: OpenAIEndpointProvider, session: URLSession, authorizeRequest: @escaping (inout URLRequest) -> Void) { + self.session = session + self.endpointProvider = endpointPrivider + self.authorizeRequest = authorizeRequest + self.baseURL = baseURL + } + + // MARK: Protocol Conformance + + public func makeRequest(_ endpoint: OpenAIEndpointProvider.API, body: BodyType, completionHandler: @escaping (Result) -> Void) where BodyType : Encodable { + let request = prepareRequest(endpoint, body: body) + makeRequest(request: request, completionHandler: completionHandler) + } + + public func streamRequest(_ endpoint: OpenAIEndpointProvider.API, body: BodyType, eventReceived: ((Result, OpenAIError>) -> Void)?, completion: (() -> Void)?) where BodyType : Encodable { + + let request = prepareRequest(endpoint, body: body) + self.onEventReceived = eventReceived + self.onComplete = completion + connect(with: request) + } + + private func makeRequest(request: URLRequest, completionHandler: @escaping (Result) -> Void) { + let task = session.dataTask(with: request) { (data, response, error) in + if let error = error { + completionHandler(.failure(error)) + } else if let data = data { + completionHandler(.success(data)) + } + } + + task.resume() + } + + private func prepareRequest(_ endpoint: OpenAIEndpointProvider.API, body: BodyType) -> URLRequest { + var urlComponents = URLComponents(url: URL(string: baseURL)!, resolvingAgainstBaseURL: true) + urlComponents?.path = endpointProvider.getPath(api: endpoint) + var request = URLRequest(url: urlComponents!.url!) + request.httpMethod = endpointProvider.getMethod(api: endpoint) + + authorizeRequest(&request) + + request.setValue("application/json", forHTTPHeaderField: "content-type") + + let encoder = JSONEncoder() + if let encoded = try? encoder.encode(body) { + request.httpBody = encoded + } + + return request + } + + private func connect(with request: URLRequest) { + streamingTask = session.dataTask(with: request) + streamingTask?.resume() + } + + fileprivate func disconnect() { + streamingTask?.cancel() + } + + fileprivate func processEvent(_ eventData: Data) { + do { + let res = try JSONDecoder().decode(OpenAI.self, from: eventData) + onEventReceived?(.success(res)) + } catch { + onEventReceived?(.failure(.decodingError(error: error))) + } + } +} + +extension URLSessionRequestHandler: URLSessionDataDelegate { + /// It will be called several times, each time could return one chunk of data or multiple chunk of data + /// The JSON look liks this: + /// `data: {"id":"chatcmpl-6yVTvD6UAXsE9uG2SmW4Tc2iuFnnT","object":"chat.completion.chunk","created":1679878715,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}` + /// `data: {"id":"chatcmpl-6yVTvD6UAXsE9uG2SmW4Tc2iuFnnT","object":"chat.completion.chunk","created":1679878715,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Once"},"index":0,"finish_reason":null}]}` + public func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) { + if let eventString = String(data: data, encoding: .utf8) { + let lines = eventString.split(separator: "\n") + for line in lines { + if line.hasPrefix("data:") && line != "data: [DONE]" { + if let eventData = String(line.dropFirst(5)).data(using: .utf8) { + processEvent(eventData) + } else { + disconnect() + } + } + } + } + } + + public func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { + if let error = error { + onEventReceived?(.failure(.genericError(error: error))) + } else { + onComplete?() + } + } +} + +public extension OpenAISwift.Config { + + static func makeDefaultOpenAI(apiKey: String) -> URLSessionRequestHandler { + return URLSessionRequestHandler(baseURL: "https://api.openai.com", + endpointPrivider: OpenAIEndpointProvider(source: .openAI), + session: .shared, + authorizeRequest: { request in + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + }) + } +} diff --git a/Sources/OpenAISwift/ServerSentEventsHandler.swift b/Sources/OpenAISwift/ServerSentEventsHandler.swift deleted file mode 100644 index e3f92db..0000000 --- a/Sources/OpenAISwift/ServerSentEventsHandler.swift +++ /dev/null @@ -1,65 +0,0 @@ -// -// ServerSentEventsHandler.swift -// -// -// Created by Vic on 2023-03-25. -// - -import Foundation - -class ServerSentEventsHandler: NSObject { - - var onEventReceived: ((Result, OpenAIError>) -> Void)? - var onComplete: (() -> Void)? - - private lazy var session: URLSession = URLSession(configuration: .default, delegate: self, delegateQueue: nil) - private var task: URLSessionDataTask? - - func connect(with request: URLRequest) { - task = session.dataTask(with: request) - task?.resume() - } - - func disconnect() { - task?.cancel() - } - - func processEvent(_ eventData: Data) { - do { - let res = try JSONDecoder().decode(OpenAI.self, from: eventData) - onEventReceived?(.success(res)) - } catch { - onEventReceived?(.failure(.decodingError(error: error))) - } - } -} - -extension ServerSentEventsHandler: URLSessionDataDelegate { - - /// It will be called several times, each time could return one chunk of data or multiple chunk of data - /// The JSON look liks this: - /// `data: {"id":"chatcmpl-6yVTvD6UAXsE9uG2SmW4Tc2iuFnnT","object":"chat.completion.chunk","created":1679878715,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}` - /// `data: {"id":"chatcmpl-6yVTvD6UAXsE9uG2SmW4Tc2iuFnnT","object":"chat.completion.chunk","created":1679878715,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Once"},"index":0,"finish_reason":null}]}` - func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) { - if let eventString = String(data: data, encoding: .utf8) { - let lines = eventString.split(separator: "\n") - for line in lines { - if line.hasPrefix("data:") && line != "data: [DONE]" { - if let eventData = String(line.dropFirst(5)).data(using: .utf8) { - processEvent(eventData) - } else { - disconnect() - } - } - } - } - } - - func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { - if let error = error { - onEventReceived?(.failure(.genericError(error: error))) - } else { - onComplete?() - } - } -}