Skip to content

Commit

Permalink
Fix Choices Deserialization
Browse files Browse the repository at this point in the history
Some OpenAPI Compatible APIs like Ollama don't always return `lobprobs`
as part of the `choices` JSON. To handle this we need to add a default
argument to `logprobs` so upickle can handle this case.

This PR moves the argument to the end which is breaking.

This PR also adds an ollama deserialization test for completions.
  • Loading branch information
kapunga committed Dec 8, 2024
1 parent ca4b48f commit 29e240e
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ object CompletionsResponseData {
case class Choices(
text: String,
index: Int,
logprobs: Option[String],
finishReason: String
finishReason: String,
logprobs: Option[String] = None
)
object Choices {
implicit val choicesR: SnakePickle.Reader[Choices] = SnakePickle.macroR[Choices]
Expand Down
27 changes: 27 additions & 0 deletions core/src/test/scala/sttp/openai/fixtures/CompletionsFixture.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,33 @@ object CompletionsFixture {
| }
|}""".stripMargin

/**
* Generated from:
* curl http://localhost:11434/v1/completions -d '{
* "model": "llama3.2",
* "prompt": "Say Hello World as a haiku."
* }'
*/
val ollamaPromptResponse: String = """{
| "id": "cmpl-712",
| "object": "text_completion",
| "created": 1733664264,
| "model": "llama3.2",
| "system_fingerprint": "fp_ollama",
| "choices": [
| {
| "text": "Greeting coding dawn\n\"Hello, world!\" echoes bright\nProgramming's start",
| "index": 0,
| "finish_reason": "stop"
| }
| ],
| "usage": {
| "prompt_tokens": 33,
| "completion_tokens": 17,
| "total_tokens": 50
| }
|}""".stripMargin

val jsonMultiplePromptResponse: String = """{
| "id":"cmpl-76D8UlnqOEkhVXu29nY7UPZFDTTlP",
| "object":"text_completion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class CompletionsDataSpec extends AnyFlatSpec with Matchers with EitherValues {
Choices(
text = "\n\nThis is indeed a test.",
index = 0,
logprobs = None,
finishReason = "stop"
finishReason = "stop",
logprobs = None
)
),
usage = Usage(
Expand All @@ -42,6 +42,40 @@ class CompletionsDataSpec extends AnyFlatSpec with Matchers with EitherValues {
givenResponse.value shouldBe expectedResponse
}

"Given ollama completions response as Json" should "be properly deserialized to case class" in {
import sttp.openai.requests.completions.CompletionsResponseData._
import sttp.openai.requests.completions.CompletionsResponseData.CompletionsResponse._
import sttp.openai.requests.completions.CompletionsRequestBody.CompletionModel.CustomCompletionModel

// given
val jsonResponse = fixtures.CompletionsFixture.ollamaPromptResponse
val expectedResponse: CompletionsResponse = CompletionsResponse(
id = "cmpl-712",
`object` = "text_completion",
created = 1733664264,
model = CustomCompletionModel("llama3.2"),
choices = Seq(
Choices(
text = "Greeting coding dawn\n\"Hello, world!\" echoes bright\nProgramming's start",
index = 0,
finishReason = "stop",
logprobs = None
)
),
usage = Usage(
promptTokens = 33,
completionTokens = 17,
totalTokens = 50
)
)

// when
val givenResponse: Either[Exception, CompletionsResponse] = SttpUpickleApiExtension.deserializeJsonSnake.apply(jsonResponse)

// then
givenResponse.value shouldBe expectedResponse
}

"Given completions request as case class" should "be properly serialized to Json" in {
import sttp.openai.requests.completions.CompletionsRequestBody._
import sttp.openai.requests.completions.CompletionsRequestBody.CompletionsBody._
Expand Down Expand Up @@ -84,14 +118,14 @@ class CompletionsDataSpec extends AnyFlatSpec with Matchers with EitherValues {
Choices(
text = "\n\nThis is indeed a test",
index = 0,
logprobs = None,
finishReason = "length"
finishReason = "length",
logprobs = None
),
Choices(
text = "\n\nYes, this is also",
index = 1,
logprobs = None,
finishReason = "length"
finishReason = "length",
logprobs = None
)
),
usage = Usage(
Expand Down

0 comments on commit 29e240e

Please sign in to comment.