Skip to content

Commit

Permalink
gai: Complete callbacks for scripted extensions and add filtering exa…
Browse files Browse the repository at this point in the history
…mple. Misc refactoring and cleanup.
  • Loading branch information
patniemeyer committed Dec 13, 2024
1 parent 01465d0 commit 66f0523
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 127 deletions.
44 changes: 37 additions & 7 deletions gai-frontend/lib/chat/chat.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import 'package:orchid/api/orchid_eth/orchid_account.dart';
import 'package:orchid/api/orchid_eth/orchid_account_detail.dart';
import 'package:orchid/api/orchid_keys.dart';
import 'package:orchid/chat/model.dart';
import 'package:orchid/chat/provider_connection.dart';
import 'package:orchid/chat/scripting/chat_scripting.dart';
import 'package:orchid/common/app_sizes.dart';
import 'package:orchid/chat/chat_settings_button.dart';
Expand Down Expand Up @@ -86,11 +87,12 @@ class _ChatViewState extends State<ChatView> {
log('Error initializing from params: $e, $stack');
}

/*
// Initialize scripting extension
/*
ChatScripting.init(
// url: 'lib/extensions/test.js',
url: 'lib/extensions/party_mode.js',
// url: 'lib/extensions/party_mode.js',
url: 'lib/extensions/filter_example.js',
debugMode: true,
providerManager: _providerManager,
chatHistory: _chatHistory,
Expand Down Expand Up @@ -201,8 +203,8 @@ class _ChatViewState extends State<ChatView> {
String? modelName,
}) {
final message = ChatMessage(
source,
msg,
source: source,
message: msg,
metadata: metadata,
sourceName: sourceName,
modelId: modelId,
Expand All @@ -212,7 +214,7 @@ class _ChatViewState extends State<ChatView> {
}

void _addChatMessage(ChatMessage message) {
log('Adding message: ${message.msg.truncate(64)}');
log('Adding message: ${message.message.truncate(64)}');
setState(() {
_chatHistory.addMessage(message);
});
Expand Down Expand Up @@ -302,6 +304,15 @@ class _ChatViewState extends State<ChatView> {
return;
}

// Debug hack
if (_selectedModelIds.isEmpty &&
ChatScripting.enabled &&
ChatScripting.instance.debugMode) {
setState(() {
_selectedModelIds = ['gpt-4o'];
});
}

// Validate the selected models
if (_selectedModelIds.isEmpty) {
_addMessage(
Expand Down Expand Up @@ -344,15 +355,35 @@ class _ChatViewState extends State<ChatView> {
? _chatHistory.getConversation()
: _chatHistory.getConversation(withModelId: modelId);

await _providerManager.sendMessagesToModel(
final chatResponse = await _providerManager.sendMessagesToModel(
selectedMessages, modelId, _maxTokens);

if (chatResponse != null) {
_handleChatResponseDefaultBehavior(chatResponse);
} else {
// The provider connection should have logged the issue. Do nothing.
}
} catch (e) {
_addMessage(
ChatMessageSource.system, 'Error querying model $modelId: $e');
}
}
}

// The default handler for chat responses from the models (simply adds response to the chat history).
void _handleChatResponseDefaultBehavior(ChatInferenceResponse chatResponse) {
final metadata = chatResponse.metadata;
final modelId = metadata['model_id']; // or request.modelId?
log('Handle response: ${chatResponse.message}, $metadata');
_addMessage(
ChatMessageSource.provider,
chatResponse.message,
metadata: metadata,
modelId: modelId,
modelName: _modelsState.getModelOrDefaultNullable(modelId)?.name,
);
}

void scrollMessagesDown() {
// Dispatching it to the next frame seems to mitigate overlapping scrolls.
Future.delayed(millis(50), () {
Expand Down Expand Up @@ -600,4 +631,3 @@ Future<void> _launchURL(String urlString) async {
enum AuthTokenMethod { manual, walletConnect }

enum OrchataMenuItem { debug }

12 changes: 5 additions & 7 deletions gai-frontend/lib/chat/chat_message.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ enum ChatMessageSource { client, provider, system, internal }
class ChatMessage {
final ChatMessageSource source;
final String sourceName;
final String msg;
final String message;
final Map<String, dynamic>? metadata;

// The modelId of the model that generated this message
Expand All @@ -14,17 +14,15 @@ class ChatMessage {
// The name of the model that generated this message
final String? modelName;

ChatMessage(
this.source,
this.msg, {
ChatMessage({
required this.source,
required this.message,
this.metadata,
this.sourceName = '',
this.modelId,
this.modelName,
});

String get message => msg;

String? get displayName {
if (source == ChatMessageSource.provider && modelName != null) {
return modelName;
Expand Down Expand Up @@ -58,6 +56,6 @@ class ChatMessage {

@override
String toString() {
return 'ChatMessage(source: $source, modelId: $modelId, model: $modelName, msg: ${msg.substring(0, msg.length.clamp(0, 50))}...)';
return 'ChatMessage(source: $source, modelId: $modelId, model: $modelName, msg: ${message.substring(0, message.length.clamp(0, 50))}...)';
}
}
116 changes: 63 additions & 53 deletions gai-frontend/lib/chat/provider_connection.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,52 @@ import 'package:web_socket_channel/web_socket_channel.dart';
import 'package:orchid/api/orchid_crypto.dart';
import 'package:orchid/api/orchid_eth/orchid_ticket.dart';
import 'package:orchid/api/orchid_eth/orchid_account_detail.dart';
import 'inference_client.dart';
import 'chat_message.dart';
import 'package:orchid/api/orchid_log.dart';
import 'inference_client.dart';

typedef MessageCallback = void Function(String message);
typedef ChatCallback = void Function(
String message, Map<String, dynamic> metadata);
typedef VoidCallback = void Function();
typedef ErrorCallback = void Function(String error);
typedef AuthTokenCallback = void Function(String token, String inferenceUrl);

class _PendingRequest {
final String requestId;
class ChatInferenceRequest {
final String modelId;
final List<ChatMessage> messages;
final Map<String, Object>? params;
final List<Map<String, dynamic>> preparedMessages;
final Map<String, Object>? requestParams;
final DateTime timestamp;

_PendingRequest({
required this.requestId,
ChatInferenceRequest({
required this.modelId,
required this.messages,
required this.params,
required this.preparedMessages,
required this.requestParams,
}) : timestamp = DateTime.now();
}

class ChatInferenceResponse {
// Request
final ChatInferenceRequest request;

// Result
final String message;
final Map<String, dynamic> metadata;

ChatInferenceResponse({
required this.request,
required this.message,
required this.metadata,
});

ChatMessage toChatMessage() {
return ChatMessage(
source: ChatMessageSource.provider,
message: message,
// sourceName: request.modelId,
metadata: metadata,
modelId: request.modelId,
);
}
}

class ProviderConnection {
final maxuint256 = BigInt.two.pow(256) - BigInt.one;
final maxuint64 = BigInt.two.pow(64) - BigInt.one;
Expand All @@ -40,7 +60,7 @@ class ProviderConnection {
InferenceClient? get inferenceClient => _inferenceClient;
InferenceClient? _inferenceClient;
final MessageCallback onMessage;
final ChatCallback onChat;

final VoidCallback onConnect;
final ErrorCallback onError;
final VoidCallback onDisconnect;
Expand All @@ -51,8 +71,7 @@ class ProviderConnection {
final String? authToken;
final AccountDetail? accountDetail;
final AuthTokenCallback? onAuthToken;
final Map<String, String> _requestModels = {};
final Map<String, _PendingRequest> _pendingRequests = {};

bool _usingDirectAuth = false;

String _generateRequestId() {
Expand All @@ -62,7 +81,7 @@ class ProviderConnection {
ProviderConnection({
required this.onMessage,
required this.onConnect,
required this.onChat,
// required this.onChat,
required this.onDisconnect,
required this.onError,
required this.onSystemMessage,
Expand Down Expand Up @@ -104,7 +123,7 @@ class ProviderConnection {
AccountDetail? accountDetail,
String? authToken,
required MessageCallback onMessage,
required ChatCallback onChat,
// required ChatCallback onChat,
required VoidCallback onConnect,
required ErrorCallback onError,
required VoidCallback onDisconnect,
Expand All @@ -119,7 +138,7 @@ class ProviderConnection {
final connection = ProviderConnection(
onMessage: onMessage,
onConnect: onConnect,
onChat: onChat,
// onChat: onChat,
onDisconnect: onDisconnect,
onError: onError,
onSystemMessage: onSystemMessage,
Expand Down Expand Up @@ -222,16 +241,6 @@ class ProviderConnection {
onMessage('Provider: $message');

switch (data['type']) {
case 'job_complete':
final requestId = data['request_id'];
final pendingRequest =
requestId != null ? _pendingRequests.remove(requestId) : null;

onChat(data['output'], {
...data,
'model_id': pendingRequest?.modelId,
});
break;
case 'invoice':
payInvoice(data);
break;
Expand All @@ -255,11 +264,16 @@ class ProviderConnection {
_sendProviderMessage(message);
}

Future<void> requestInference(
Future<ChatInferenceResponse?> requestInference(
String modelId,
List<Map<String, dynamic>> preparedMessages, {
Map<String, Object>? params,
}) async {
var request = ChatInferenceRequest(
modelId: modelId,
preparedMessages: preparedMessages,
requestParams: params,
);
/*
Requesting inference for model gpt-4o-mini
Prepared messages: [{role: user, content: Hello!}, {role: assistant, content: Hello! How can I assist you today?}, {role: user, content: How are you?}]
Expand All @@ -271,49 +285,46 @@ class ProviderConnection {

if (_inferenceClient == null) {
onError('No inference connection available');
return;
return null;
}
}

try {
final requestId = _generateRequestId();

_pendingRequests[requestId] = _PendingRequest(
requestId: requestId,
modelId: modelId,
messages: [], // Empty since we're using preparedMessages now
params: params,
);

final allParams = {
...?params,
'request_id': requestId,
};

onInternalMessage('Sending inference request:\n'
'Model: $modelId\n'
'Messages: ${preparedMessages}\n'
'Params: $allParams'
);
'Model: $modelId\n'
'Messages: ${preparedMessages}\n'
'Params: $allParams');

final Map<String, dynamic> result = await _inferenceClient!.inference(
messages: preparedMessages,
model: modelId,
params: allParams,
);

_pendingRequests.remove(requestId);

onChat(result['response'], {
'type': 'job_complete',
'output': result['response'],
'usage': result['usage'],
'model_id': modelId,
'request_id': requestId,
'estimated_prompt_tokens': result['estimated_prompt_tokens'],
});

final chatResult = ChatInferenceResponse(
request: request,
message: result['response'],
metadata: {
'type': 'job_complete',
'output': result['response'],
'usage': result['usage'],
'model_id': modelId,
'request_id': requestId,
'estimated_prompt_tokens': result['estimated_prompt_tokens'],
});

return chatResult;

} catch (e, stack) {
onError('Failed to send inference request: $e\n$stack');
return null;
}
}

Expand All @@ -328,7 +339,6 @@ class ProviderConnection {

void dispose() {
_providerChannel?.sink.close();
_pendingRequests.clear();
onDisconnect();
}

Expand Down
Loading

0 comments on commit 66f0523

Please sign in to comment.