Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SSLDelegateProtocol to StreamingSession #191

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,27 @@ final public class OpenAI: OpenAIProtocol {
}

private let session: URLSessionProtocol
private let sslStreamingDelegate: SSLDelegateProtocol?
private var streamingSessions = ArrayWithThreadSafety<NSObject>()

public let configuration: Configuration

public convenience init(apiToken: String) {
self.init(configuration: Configuration(token: apiToken), session: URLSession.shared)
self.init(configuration: Configuration(token: apiToken), session: URLSession.shared, sslStreamingDelegate: nil)
}

public convenience init(configuration: Configuration) {
self.init(configuration: configuration, session: URLSession.shared)
self.init(configuration: configuration, session: URLSession.shared, sslStreamingDelegate: nil)
}

init(configuration: Configuration, session: URLSessionProtocol) {
init(configuration: Configuration, session: URLSessionProtocol, sslStreamingDelegate: SSLDelegateProtocol?) {
self.configuration = configuration
self.session = session
self.sslStreamingDelegate = sslStreamingDelegate
}

public convenience init(configuration: Configuration, session: URLSession = URLSession.shared) {
self.init(configuration: configuration, session: session as URLSessionProtocol)
public convenience init(configuration: Configuration, session: URLSession = URLSession.shared, sslStreamingDelegate: SSLDelegateProtocol? = nil) {
self.init(configuration: configuration, session: session as URLSessionProtocol, sslStreamingDelegate: sslStreamingDelegate)
}

public func completions(query: CompletionsQuery, completion: @escaping (Result<CompletionsResult, Error>) -> Void) {
Expand Down Expand Up @@ -154,7 +156,7 @@ extension OpenAI {
let request = try request.build(token: configuration.token,
organizationIdentifier: configuration.organizationIdentifier,
timeoutInterval: configuration.timeoutInterval)
let session = StreamingSession<ResultType>(urlRequest: request)
let session = StreamingSession<ResultType>(urlRequest: request, sslDelegate: sslStreamingDelegate)
session.onReceiveContent = {_, object in
onResult(.success(object))
}
Expand Down
9 changes: 9 additions & 0 deletions Sources/OpenAI/Private/SSLDelegateProtocol.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import Foundation

public protocol SSLDelegateProtocol {
func urlSession(
_ session: URLSession,
didReceive challenge: URLAuthenticationChallenge,
completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void
)
}
12 changes: 11 additions & 1 deletion Sources/OpenAI/Private/StreamingSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ final class StreamingSession<ResultType: Codable>: NSObject, Identifiable, URLSe

private let streamingCompletionMarker = "[DONE]"
private let urlRequest: URLRequest
private let sslDelegate: SSLDelegateProtocol?
private lazy var urlSession: URLSession = {
let session = URLSession(configuration: .default, delegate: self, delegateQueue: nil)
return session
}()

private var previousChunkBuffer = ""

init(urlRequest: URLRequest) {
init(urlRequest: URLRequest, sslDelegate: SSLDelegateProtocol?) {
self.urlRequest = urlRequest
self.sslDelegate = sslDelegate
}

func perform() {
Expand All @@ -52,6 +54,14 @@ final class StreamingSession<ResultType: Codable>: NSObject, Identifiable, URLSe
processJSON(from: stringContent)
}

func urlSession(
_ session: URLSession,
didReceive challenge: URLAuthenticationChallenge,
completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void
) {
guard let sslDelegate else { return completionHandler(.performDefaultHandling, nil) }
sslDelegate.urlSession(session, didReceive: challenge, completionHandler: completionHandler)
}
}

extension StreamingSession {
Expand Down
6 changes: 3 additions & 3 deletions Tests/OpenAITests/OpenAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class OpenAITests: XCTestCase {
super.setUp()
self.urlSession = URLSessionMock()
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14)
self.openAI = OpenAI(configuration: configuration, session: self.urlSession)
self.openAI = OpenAI(configuration: configuration, session: self.urlSession, sslStreamingDelegate: nil)
}

func testCompletions() async throws {
Expand Down Expand Up @@ -390,14 +390,14 @@ class OpenAITests: XCTestCase {

func testDefaultHostURLBuilt() {
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14)
let openAI = OpenAI(configuration: configuration, session: self.urlSession)
let openAI = OpenAI(configuration: configuration, session: self.urlSession, sslStreamingDelegate: nil)
let chatsURL = openAI.buildURL(path: .chats)
XCTAssertEqual(chatsURL, URL(string: "https://api.openai.com:443/v1/chat/completions"))
}

func testCustomURLBuilt() {
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", host: "my.host.com", timeoutInterval: 14)
let openAI = OpenAI(configuration: configuration, session: self.urlSession)
let openAI = OpenAI(configuration: configuration, session: self.urlSession, sslStreamingDelegate: nil)
let chatsURL = openAI.buildURL(path: .chats)
XCTAssertEqual(chatsURL, URL(string: "https://my.host.com:443/v1/chat/completions"))
}
Expand Down
2 changes: 1 addition & 1 deletion Tests/OpenAITests/OpenAITestsCombine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ final class OpenAITestsCombine: XCTestCase {
super.setUp()
self.urlSession = URLSessionMock()
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14)
self.openAI = OpenAI(configuration: configuration, session: self.urlSession)
self.openAI = OpenAI(configuration: configuration, session: self.urlSession, sslStreamingDelegate: nil)
}

func testCompletions() throws {
Expand Down