From 38b5ec1c178ecbf1e7186ad9f614247c9d6a620f Mon Sep 17 00:00:00 2001 From: Dmitrii Medvedev Date: Tue, 2 Apr 2024 16:37:45 +0300 Subject: [PATCH] Add SSLDelegateProtocol to StreamingSession --- Sources/OpenAI/OpenAI.swift | 14 ++++++++------ Sources/OpenAI/Private/SSLDelegateProtocol.swift | 9 +++++++++ Sources/OpenAI/Private/StreamingSession.swift | 12 +++++++++++- Tests/OpenAITests/OpenAITests.swift | 6 +++--- Tests/OpenAITests/OpenAITestsCombine.swift | 2 +- 5 files changed, 32 insertions(+), 11 deletions(-) create mode 100644 Sources/OpenAI/Private/SSLDelegateProtocol.swift diff --git a/Sources/OpenAI/OpenAI.swift b/Sources/OpenAI/OpenAI.swift index 5ff52833..e6b84552 100644 --- a/Sources/OpenAI/OpenAI.swift +++ b/Sources/OpenAI/OpenAI.swift @@ -38,25 +38,27 @@ final public class OpenAI: OpenAIProtocol { } private let session: URLSessionProtocol + private let sslStreamingDelegate: SSLDelegateProtocol? private var streamingSessions = ArrayWithThreadSafety() public let configuration: Configuration public convenience init(apiToken: String) { - self.init(configuration: Configuration(token: apiToken), session: URLSession.shared) + self.init(configuration: Configuration(token: apiToken), session: URLSession.shared, sslStreamingDelegate: nil) } public convenience init(configuration: Configuration) { - self.init(configuration: configuration, session: URLSession.shared) + self.init(configuration: configuration, session: URLSession.shared, sslStreamingDelegate: nil) } - init(configuration: Configuration, session: URLSessionProtocol) { + init(configuration: Configuration, session: URLSessionProtocol, sslStreamingDelegate: SSLDelegateProtocol?) { self.configuration = configuration self.session = session + self.sslStreamingDelegate = sslStreamingDelegate } - public convenience init(configuration: Configuration, session: URLSession = URLSession.shared) { - self.init(configuration: configuration, session: session as URLSessionProtocol) + public convenience init(configuration: Configuration, session: URLSession = URLSession.shared, sslStreamingDelegate: SSLDelegateProtocol? = nil) { + self.init(configuration: configuration, session: session as URLSessionProtocol, sslStreamingDelegate: sslStreamingDelegate) } public func completions(query: CompletionsQuery, completion: @escaping (Result) -> Void) { @@ -154,7 +156,7 @@ extension OpenAI { let request = try request.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval) - let session = StreamingSession(urlRequest: request) + let session = StreamingSession(urlRequest: request, sslDelegate: sslStreamingDelegate) session.onReceiveContent = {_, object in onResult(.success(object)) } diff --git a/Sources/OpenAI/Private/SSLDelegateProtocol.swift b/Sources/OpenAI/Private/SSLDelegateProtocol.swift new file mode 100644 index 00000000..3c8cdc5e --- /dev/null +++ b/Sources/OpenAI/Private/SSLDelegateProtocol.swift @@ -0,0 +1,9 @@ +import Foundation + +public protocol SSLDelegateProtocol { + func urlSession( + _ session: URLSession, + didReceive challenge: URLAuthenticationChallenge, + completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void + ) +} diff --git a/Sources/OpenAI/Private/StreamingSession.swift b/Sources/OpenAI/Private/StreamingSession.swift index cf56a97a..8548b96a 100644 --- a/Sources/OpenAI/Private/StreamingSession.swift +++ b/Sources/OpenAI/Private/StreamingSession.swift @@ -23,6 +23,7 @@ final class StreamingSession: NSObject, Identifiable, URLSe private let streamingCompletionMarker = "[DONE]" private let urlRequest: URLRequest + private let sslDelegate: SSLDelegateProtocol? private lazy var urlSession: URLSession = { let session = URLSession(configuration: .default, delegate: self, delegateQueue: nil) return session @@ -30,8 +31,9 @@ final class StreamingSession: NSObject, Identifiable, URLSe private var previousChunkBuffer = "" - init(urlRequest: URLRequest) { + init(urlRequest: URLRequest, sslDelegate: SSLDelegateProtocol?) { self.urlRequest = urlRequest + self.sslDelegate = sslDelegate } func perform() { @@ -52,6 +54,14 @@ final class StreamingSession: NSObject, Identifiable, URLSe processJSON(from: stringContent) } + func urlSession( + _ session: URLSession, + didReceive challenge: URLAuthenticationChallenge, + completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void + ) { + guard let sslDelegate else { return completionHandler(.performDefaultHandling, nil) } + sslDelegate.urlSession(session, didReceive: challenge, completionHandler: completionHandler) + } } extension StreamingSession { diff --git a/Tests/OpenAITests/OpenAITests.swift b/Tests/OpenAITests/OpenAITests.swift index bae3e015..de9fd1ce 100644 --- a/Tests/OpenAITests/OpenAITests.swift +++ b/Tests/OpenAITests/OpenAITests.swift @@ -20,7 +20,7 @@ class OpenAITests: XCTestCase { super.setUp() self.urlSession = URLSessionMock() let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14) - self.openAI = OpenAI(configuration: configuration, session: self.urlSession) + self.openAI = OpenAI(configuration: configuration, session: self.urlSession, sslStreamingDelegate: nil) } func testCompletions() async throws { @@ -390,14 +390,14 @@ class OpenAITests: XCTestCase { func testDefaultHostURLBuilt() { let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14) - let openAI = OpenAI(configuration: configuration, session: self.urlSession) + let openAI = OpenAI(configuration: configuration, session: self.urlSession, sslStreamingDelegate: nil) let chatsURL = openAI.buildURL(path: .chats) XCTAssertEqual(chatsURL, URL(string: "https://api.openai.com:443/v1/chat/completions")) } func testCustomURLBuilt() { let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", host: "my.host.com", timeoutInterval: 14) - let openAI = OpenAI(configuration: configuration, session: self.urlSession) + let openAI = OpenAI(configuration: configuration, session: self.urlSession, sslStreamingDelegate: nil) let chatsURL = openAI.buildURL(path: .chats) XCTAssertEqual(chatsURL, URL(string: "https://my.host.com:443/v1/chat/completions")) } diff --git a/Tests/OpenAITests/OpenAITestsCombine.swift b/Tests/OpenAITests/OpenAITestsCombine.swift index e49ab3d7..bca51a58 100644 --- a/Tests/OpenAITests/OpenAITestsCombine.swift +++ b/Tests/OpenAITests/OpenAITestsCombine.swift @@ -22,7 +22,7 @@ final class OpenAITestsCombine: XCTestCase { super.setUp() self.urlSession = URLSessionMock() let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14) - self.openAI = OpenAI(configuration: configuration, session: self.urlSession) + self.openAI = OpenAI(configuration: configuration, session: self.urlSession, sslStreamingDelegate: nil) } func testCompletions() throws {