Skip to content

Stream client for chat completion, entities update and some junit #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@
<artifactId>reactor-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<distributionManagement>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.github.reactiveclown.openaiwebfluxclient;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.reactiveclown.openaiwebfluxclient.client.audio.AudioService;
import io.github.reactiveclown.openaiwebfluxclient.client.audio.AudioServiceImpl;
import io.github.reactiveclown.openaiwebfluxclient.client.chat.ChatService;
Expand Down Expand Up @@ -78,8 +79,8 @@ public AudioService audioService(@Qualifier("OpenAIClient") WebClient client) {

@Bean
@ConditionalOnMissingBean
public ChatService chatService(@Qualifier("OpenAIClient") WebClient client) {
return new ChatServiceImpl(client);
public ChatService chatService(@Qualifier("OpenAIClient") WebClient client, ObjectMapper objectMapper) {
return new ChatServiceImpl(client, objectMapper);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package io.github.reactiveclown.openaiwebfluxclient.client.chat;

import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.List;

public record ChatCompletionChunk(@JsonProperty("id") String id,
@JsonProperty("object") String object,
@JsonProperty("created") Long created,
@JsonProperty("model") String model,
@JsonProperty("system_fingerprint") String systemFingerprint,
@JsonProperty("choices") List<ChoiceData> choices) {
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.github.reactiveclown.openaiwebfluxclient.client.chat;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public interface ChatService {
Expand All @@ -8,7 +9,16 @@ public interface ChatService {
* Creates a completion for the chat message.
*
* @param request {@link CreateChatCompletionRequest }
* @return A Mono of {@link CreateChatCompletionResponse}
* @return A {@link Mono} of {@link CreateChatCompletionResponse}
*/
Mono<CreateChatCompletionResponse> createChatCompletion(CreateChatCompletionRequest request);

/**
* Creates a completion for the chat message, but with stream.
* The method returns a Flux with chucks of the chat completion response.
*
* @param request {@link CreateChatCompletionRequest }
* @return A {@link Flux} of {@link ChatCompletionChunk}
*/
Flux<ChatCompletionChunk> createStreamChatCompletion(CreateChatCompletionRequest request);
}
Original file line number Diff line number Diff line change
@@ -1,24 +1,60 @@
package io.github.reactiveclown.openaiwebfluxclient.client.chat;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.stereotype.Service;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

@Service
public class ChatServiceImpl implements ChatService{

private final WebClient client;
public ChatServiceImpl(WebClient client){
private final ObjectMapper objectMapper;
private static final String CREATE_CHAT_COMPLETION_URL = "/chat/completions";
public ChatServiceImpl(WebClient client, ObjectMapper objectMapper){
this.client = client;
this.objectMapper = objectMapper;
}

@Override
public Mono<CreateChatCompletionResponse> createChatCompletion(CreateChatCompletionRequest request) {
String createChatCompletionUrl = "/chat/completions";
return client.post()
.uri(createChatCompletionUrl)
.uri(CREATE_CHAT_COMPLETION_URL)
.bodyValue(request)
.retrieve()
.bodyToMono(CreateChatCompletionResponse.class);
}

@Override
public Flux<ChatCompletionChunk> createStreamChatCompletion(CreateChatCompletionRequest request) {
if (request.stream() == null || !request.stream()) {
request = request.withStream();
}
return client.post()
.uri(CREATE_CHAT_COMPLETION_URL)
.accept(MediaType.TEXT_EVENT_STREAM)
.bodyValue(request)
.retrieve()
// transfer to String first to handle the "[DONE]"
.bodyToFlux(new ParameterizedTypeReference<ServerSentEvent<String>>() {
})
.flatMap(serverSentEvent -> {
String data = serverSentEvent.data();
// ignore the done text
if (data == null || data.equals("[DONE]")) {
return Mono.empty();
}
try {
ChatCompletionChunk parsedResponse = objectMapper.readValue(data, ChatCompletionChunk.class);
return Mono.justOrEmpty(parsedResponse);
} catch (JsonProcessingException e) {
return Mono.error(e);
}
});
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package io.github.reactiveclown.openaiwebfluxclient.client.chat;

import com.fasterxml.jackson.annotation.JsonAlias;
import com.fasterxml.jackson.annotation.JsonProperty;

public record ChoiceData(@JsonProperty("index") Integer index,
@JsonProperty("message") MessageData message,
@JsonProperty("logprobs") Logprobs logprobs,
@JsonAlias("delta") @JsonProperty("message") MessageData message,
@JsonProperty("finish_reason") String finishReason) {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.github.reactiveclown.openaiwebfluxclient.client.chat;

import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.List;

public record Content(@JsonProperty("token") String token,
@JsonProperty("logprob") Integer logprob,
@JsonProperty("bytes") List<Integer> bytes,
@JsonProperty("top_logprobs") List<Content> topLogprobs) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ public record CreateChatCompletionRequest(@JsonProperty("model") String model,
@JsonProperty("presence_penalty") Double presencePenalty,
@JsonProperty("frequency_penalty") Double frequencyPenalty,
@JsonProperty("logit_bias") Map<String, Integer> logitBias,
@JsonProperty("user") String user) {
@JsonProperty("user") String user,
@JsonProperty("stream") Boolean stream) {
public CreateChatCompletionRequest {
if (model == null || model.isBlank())
throw new IllegalArgumentException("model value can't be null or blank");
Expand All @@ -67,6 +68,14 @@ public record CreateChatCompletionRequest(@JsonProperty("model") String model,
throw new IllegalArgumentException("messages can't be null or empty");
}

public CreateChatCompletionRequest withStream() {
return new CreateChatCompletionRequest(
model, messages, temperature,
topP, n, stop, maxTokens,
presencePenalty, frequencyPenalty, logitBias,
user, true);
}

public static Builder builder(String model, List<MessageData> messages) {
return new Builder(model, messages);
}
Expand All @@ -83,13 +92,14 @@ public static final class Builder {
private Double frequencyPenalty;
private Map<String, Integer> logitBias;
private String user;
private Boolean stream;

public CreateChatCompletionRequest build() {
return new CreateChatCompletionRequest(
model, messages, temperature,
topP, n, stop, maxTokens,
presencePenalty, frequencyPenalty, logitBias,
user);
user, stream);
}

public Builder(String model, List<MessageData> messages) {
Expand Down Expand Up @@ -147,5 +157,10 @@ public Builder user(String user) {
return this;
}

public Builder stream(Boolean stream) {
this.stream = stream;
return this;
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package io.github.reactiveclown.openaiwebfluxclient.client.chat;

import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.List;

public record Logprobs(@JsonProperty("content") List<Content> content) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package io.github.reactiveclown.openaiwebfluxclient.client.chat;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.reactiveclown.openaiwebfluxclient.client.UsageData;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import java.util.List;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
public class ChatServiceImplTest {


@Mock
WebClient.RequestBodyUriSpec requestBodyUriSpec;
@Mock
WebClient.RequestBodySpec requestBodySpec;
@Mock
WebClient.RequestHeadersSpec requestHeadersSpec;
@Mock
WebClient.ResponseSpec responseSpec;
@Mock
private WebClient webClient;

private ObjectMapper objectMapper;

@InjectMocks
private ChatServiceImpl chatService;

@BeforeEach
void setUp() {
objectMapper = new ObjectMapper().setSerializationInclusion(JsonInclude.Include.NON_NULL);
ReflectionTestUtils.setField(chatService, "objectMapper", objectMapper);
}

@Test
public void createChatCompletion() {
// Arrange
CreateChatCompletionRequest request = CreateChatCompletionRequest
.builder(
"model",
List.of(new MessageData(
"role",
"content")))
.stream(true)
.build();
CreateChatCompletionResponse expectedResponse = new CreateChatCompletionResponse(
"id",
"object",
1L,
"model",
List.of(new ChoiceData(
1,
null,
new MessageData(
"role",
"content"),
"finishReason")),
new UsageData(
1,
1,
2));
when(webClient.post()).thenReturn(requestBodyUriSpec);
when(requestBodyUriSpec.uri(anyString())).thenReturn(requestBodySpec);
when(requestBodySpec.bodyValue(any())).thenReturn(requestHeadersSpec);
when(requestHeadersSpec.retrieve()).thenReturn(responseSpec);
when(responseSpec.bodyToMono(CreateChatCompletionResponse.class)).thenReturn(Mono.just(expectedResponse));

// Assert
StepVerifier.create(chatService.createChatCompletion(request))
.expectNext(expectedResponse)
.verifyComplete();
}

@Test
public void createStreamChatCompletion() {
// Arrange
CreateChatCompletionRequest request = CreateChatCompletionRequest
.builder(
"model",
List.of(new MessageData(
"role",
"content")))
.stream(true)
.build();
when(webClient.post()).thenReturn(requestBodyUriSpec);
when(requestBodyUriSpec.uri(anyString())).thenReturn(requestBodySpec);
when(requestBodySpec.accept(any(MediaType.class))).thenReturn(requestBodySpec);
when(requestBodySpec.bodyValue(any())).thenReturn(requestHeadersSpec);
when(requestHeadersSpec.retrieve()).thenReturn(responseSpec);

// case [DONE]
ServerSentEvent<String> mockEvent = ServerSentEvent.builder("[DONE]").build();
Flux<ServerSentEvent<String>> responseFlux = Flux.just(mockEvent);
when(responseSpec.bodyToFlux(new ParameterizedTypeReference<ServerSentEvent<String>>() {
}))
.thenReturn(responseFlux);

// Assert
StepVerifier.create(chatService.createStreamChatCompletion(request))
.expectNextCount(0) // Since the data is "[DONE]", it should return an empty flux
.verifyComplete();

// case delta
mockEvent = ServerSentEvent.builder("""
{
"id": "chatcmpl-mock",
"object": "chat.completion.chunk",
"created": 1703257917,
"model": "gpt-3.5-turbo-0613",
"system_fingerprint": null,
"choices": [
{
"index": 0,
"delta": {
"content": "mock"
},
"logprobs": null,
"finish_reason": null
}
]
}""").build();
responseFlux = Flux.just(mockEvent);
when(responseSpec.bodyToFlux(new ParameterizedTypeReference<ServerSentEvent<String>>() {
}))
.thenReturn(responseFlux);

// Assert
StepVerifier.create(chatService.createStreamChatCompletion(request))
.expectNextMatches(chatCompletionChunk -> chatCompletionChunk.choices().get(0).message().content().equals("mock"))
.verifyComplete();
}
}