Skip to content

Commit

Permalink
gai: Work in progress on scripting extensions ("party mode" implement…
Browse files Browse the repository at this point in the history
…ed).
  • Loading branch information
patniemeyer committed Dec 9, 2024
1 parent 826a8c3 commit 59b8a45
Show file tree
Hide file tree
Showing 13 changed files with 130 additions and 72 deletions.
3 changes: 2 additions & 1 deletion gai-frontend/build-scripting.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ for file in *.ts; do
echo "Building extension $file"
# rename the file to js
jsfile=$(echo $file | sed 's/\.ts/\.js/')
tsc --outFile "$jsfile" "$file"
#tsc --target es2020 --module es2022 "$file"
tsc --target es2020 "$file"
mv "$jsfile" "$base/web/lib/extensions/"
done

Expand Down
17 changes: 10 additions & 7 deletions gai-frontend/lib/chat/chat.dart
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,17 @@ class _ChatViewState extends State<ChatView> {
log('Error initializing from params: $e, $stack');
}

/*
// Initialize scripting extension
// ChatScripting.init(
// url: 'lib/extensions/test.js',
// debugMode: true,
// providerManager: _providerManager,
// chatHistory: _chatHistory,
// addChatMessageToUI: _addChatMessage,
// );
ChatScripting.init(
// url: 'lib/extensions/test.js',
url: 'lib/extensions/party_mode.js',
debugMode: true,
providerManager: _providerManager,
chatHistory: _chatHistory,
addChatMessageToUI: _addChatMessage,
);
*/
}

bool get _connected {
Expand Down
5 changes: 3 additions & 2 deletions gai-frontend/lib/chat/chat_message.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ class ChatMessage {
final String sourceName;
final String msg;
final Map<String, dynamic>? metadata;

// The modelId of the model that generated this message
final String? modelId;

// The name of the model that generated this message
final String? modelName;

Expand Down Expand Up @@ -38,7 +40,7 @@ class ChatMessage {
if (metadata == null || !metadata!.containsKey('usage')) {
return '';
}

final usage = metadata!['usage'];
if (usage == null) {
return '';
Expand All @@ -59,4 +61,3 @@ class ChatMessage {
return 'ChatMessage(source: $source, modelId: $modelId, model: $modelName, msg: ${msg.substring(0, msg.length.clamp(0, 50))}...)';
}
}

2 changes: 1 addition & 1 deletion gai-frontend/lib/chat/model_manager.dart
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ModelManager extends ChangeNotifier {
List<ModelInfo> get allModels {
final List<ModelInfo> models =
_modelsByProvider.values.expand((models) => models).toList();
log('ModelsState.allModels returning ${models.length} models: ${models.toString().truncate(64)}');
// log('ModelsState.allModels returning ${models.length} models: ${models.toString().truncate(64)}');
return models;
}

Expand Down
2 changes: 1 addition & 1 deletion gai-frontend/lib/chat/provider_manager.dart
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class ProviderManager {
onConnect: () {
_providerConnected(name);
},
onChat: (msg, metadata) {
onChat: (String msg, Map<String, dynamic> metadata) {
log('onChat received metadata: $metadata');
final modelId = metadata['model_id'];
log('Found model_id: $modelId');
Expand Down
3 changes: 2 additions & 1 deletion gai-frontend/lib/chat/scripting/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

// generated by the build script for dev time
// generated by the build script
chat_scripting_api.js
chat_scripting_api.d.ts
12 changes: 7 additions & 5 deletions gai-frontend/lib/chat/scripting/chat_message_js.dart
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ChatMessageJS {

static ChatMessageJS fromChatMessage(ChatMessage chatMessage) {
return ChatMessageJS(
source: chatMessage.source.toString(),
source: chatMessage.source.name, // enum name not toString()
sourceName: chatMessage.sourceName,
msg: chatMessage.msg,
metadata: jsonEncode(chatMessage.metadata).toJS,
Expand All @@ -37,16 +37,18 @@ class ChatMessageJS {
return ChatMessage(
ChatMessageSource.values.byName(chatMessageJS.source),
chatMessageJS.msg,

// TODO:
// metadata: jsonDecode((chatMessageJS.metadata ?? "").toString()),
metadata: {},
// sourceName: '',
// modelId: chatMessageJS.modelId,
// modelName: chatMessageJS.modelName,

sourceName: '',
modelId: chatMessageJS.modelId,
modelName: chatMessageJS.modelName,
);
}

// Map a list of ChatMessageJS to a list of ChatMessage
// Map a list of ChatMessageJS to a list of ChatMessage
static List<ChatMessage> toChatMessages(List<ChatMessageJS> chatMessagesJS) {
return chatMessagesJS.map((msg) => toChatMessage(msg)).toList();
}
Expand Down
76 changes: 52 additions & 24 deletions gai-frontend/lib/chat/scripting/chat_scripting.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ import 'chat_message_js.dart';
import 'model_info_js.dart';
import 'package:http/http.dart' as http;


class ChatScripting {

// Singleton
static ChatScripting? _instance;

static ChatScripting get instance {
Expand All @@ -23,6 +22,7 @@ class ChatScripting {

static bool get enabled => _instance != null;

// Scripting config
late String script;
late ProviderManager providerManager;
late ChatHistory chatHistory;
Expand All @@ -35,7 +35,6 @@ class ChatScripting {

// If debugMode is true, the script will be re-loaded before each invocation
bool debugMode = false,

required ProviderManager providerManager,
required ChatHistory chatHistory,
required Function(ChatMessage) addChatMessageToUI,
Expand All @@ -51,11 +50,11 @@ class ChatScripting {
instance.addChatMessageToUI = addChatMessageToUI;

// Install persistent callback functions
doGlobalSetup();
addGlobalBindings();

await instance.loadExtensionScript(url);
// Do one setup and evaluation of the script now
instance.doPerCallSetup();
instance.updatePerCallBindings();
}

Future<void> loadExtensionScript(String url) async {
Expand All @@ -64,40 +63,43 @@ class ChatScripting {
final response = await http.get(Uri.parse(url));

if (response.statusCode != 200) {
throw Exception("Failed to load script from $url: ${response.statusCode}");
throw Exception(
"Failed to load script from $url: ${response.statusCode}");
}

// If the result is HTML we have failed
if (response.headers['content-type']!.contains('text/html')) {
throw Exception("Failed to load script from $url: HTML response: ${response.body.truncate(64)}");
throw Exception(
"Failed to load script from $url: HTML response: ${response.body.truncate(64)}");
}

script = response.body;
// log("Loaded script: $script");
}

void evalExtensionScript() {
// Wrap the script in an async function to allow top level await without messing with modules.
// final wrappedScript = "(async () => {$script})();";
try {
final result = evaluateJS(script);
log("Evaluated script: $result, ${result.runtimeType}");
evaluateJS(script); // We could get a result back async here if needed
} catch (e, stack) {
log("Failed to evaluate script: $e");
log(stack.toString());
}
}

// Install the persistent callback functions
static void doGlobalSetup() {
static void addGlobalBindings() {
addChatMessageJS = instance.addChatMessageFromJS.toJS;
sendMessagesToModelJS = instance.sendMessagesToModelFromJS.toJS;
}

// Items that need to be copied before each invocation of the JS scripting extension
void doPerCallSetup({List<ModelInfo>? userSelectedModels}) {
chatHistoryJS =
ChatMessageJS.fromChatMessages(chatHistory.messages).jsify() as JSArray;
userSelectedModelsJS =
ModelInfoJS.fromModelInfos(userSelectedModels ?? []).jsify() as JSArray;
void updatePerCallBindings({List<ModelInfo>? userSelectedModels}) {
chatHistoryJS = ChatMessageJS.fromChatMessages(chatHistory.messages).jsify() as JSArray;
if (userSelectedModels != null) {
userSelectedModelsJS = ModelInfoJS.fromModelInfos(userSelectedModels).jsify() as JSArray;
}
if (debugMode) {
evalExtensionScript();
}
Expand All @@ -106,7 +108,7 @@ class ChatScripting {
// Send the user prompt to the JS scripting extension
void sendUserPrompt(String userPrompt, List<ModelInfo> userSelectedModels) {
log("Invoke onUserPrompt on the scripting extension: $userPrompt");
doPerCallSetup(userSelectedModels: userSelectedModels);
updatePerCallBindings(userSelectedModels: userSelectedModels);
onUserPromptJS(userPrompt);
}

Expand All @@ -119,18 +121,44 @@ class ChatScripting {
void addChatMessageFromJS(ChatMessageJS message) {
log("Add chat message: ${message.source}, ${message.msg}");
addChatMessageToUI(ChatMessageJS.toChatMessage(message));
updatePerCallBindings(); // History has changed
}

// Implementation of sendMessagesToModel callback function invoked from JS
// Send a list of ChatMessage to a model for inference
String sendMessagesToModelFromJS(JSArray messagesJS, String modelId, int? maxTokens) {
final List<ChatMessageJS> listJS = (messagesJS.dartify() as List).cast<ChatMessageJS>();
final List<ChatMessage> messages = ChatMessageJS.toChatMessages(listJS);
log("Send messages to model: $modelId, ${messages.length} messages");
if (messages.isNotEmpty) {
log("messages[0] = ${messages[0]}");
}
return "result from dart";
JSPromise sendMessagesToModelFromJS(
JSArray messagesJS, String modelId, int? maxTokens) {
log("dart: Send messages to model called.");
// We must capture the Future and return convert it to a JSPromise
return (() async {
try {
final listJS = (messagesJS.toDart).cast<ChatMessageJS>();
final messages = ChatMessageJS.toChatMessages(listJS);
log("messages = ${messages}");
if (messages.isEmpty) {
return [];
}

// Simulate delay
// log("dart: simulate delay");
// await Future.delayed(const Duration(seconds: 3));
// log("dart: after delay response from sendMessagesToModel sent.");

// Send the messages to the model
await providerManager.sendMessagesToModel(messages, modelId, maxTokens);

// TODO: Fake return
return ["message 1", "message 2"].jsify(); // Don't forget value to JS
} catch (e, stack) {
log("Failed to send messages to model: $e");
log(stack.toString());
return ["error: $e"].jsify();
}
})()
.toJS;
}

///
/// END: callbacks from JS
///
}
9 changes: 6 additions & 3 deletions gai-frontend/lib/chat/scripting/chat_scripting_api.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
// A TS API for the chat scripting environment
// This is compiled to JS by the build script and included in the html for basic structural type info.
// See chat_scripting_bindings_js.dart for the dart bindings.

enum ChatMessageSource {
CLIENT = 'client',
PROVIDER = 'provider',
Expand Down Expand Up @@ -41,9 +45,8 @@ declare let userSelectedModels: ReadonlyArray<ModelInfo>;
declare function sendMessagesToModel(
messages: Array<ChatMessage>,
modelId: string,
maxTokens?: number,
// ) : Promise<string[]>
) : string
maxTokens?: number | null,
): Promise<Array<string>>

// Send a list of formatted messages to a model for inference
declare function sendFormattedMessagesToModel(
Expand Down
32 changes: 22 additions & 10 deletions gai-frontend/lib/chat/scripting/extensions/party_mode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,28 @@
/// <reference path="../chat_scripting_api.ts" />

// Main entry point when the user has hit enter on a new prompt
function onUserPrompt(userPrompt: string): void
{
console.log('Party mode: on user prompt');
addChatMessage(new ChatMessage(ChatMessageSource.SYSTEM, 'Extension: Party mode', {}));
function onUserPrompt(userPrompt: string): void {
(async () => {
console.log(`Party mode: on user prompt: ${userPrompt}`);
addChatMessage(new ChatMessage(ChatMessageSource.SYSTEM, 'Extension: Party mode invoked', {}));
addChatMessage(new ChatMessage(ChatMessageSource.CLIENT, userPrompt, {}));

// Gather messages of source type 'client' or 'provider', irrespective of the model
const filteredMessages = chatHistory.filter(
(message) =>
message.source === ChatMessageSource.CLIENT ||
message.source === ChatMessageSource.PROVIDER
);
console.log("party_mode: chatHistory: ", chatHistory);
// Gather messages of source type 'client' or 'provider', irrespective of the model
const filteredMessages = chatHistory.filter(
(message) =>
message.source === ChatMessageSource.CLIENT ||
message.source === ChatMessageSource.PROVIDER
);
console.log(`party_mode: Filtered messages: ${filteredMessages}`, filteredMessages);

// Send them to all user-selected models
for (const model of userSelectedModels) {
console.log(`party_mode: Sending messages to model: ${model.name}`);
const promise = sendMessagesToModel(filteredMessages, model.id, null);
console.log(`party_mode: promise for model ${model.name}: ${promise}`, promise);
const result = await promise;
}
})();
}

27 changes: 14 additions & 13 deletions gai-frontend/lib/chat/scripting/extensions/test.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
/// Let the IDE see the types from the chat_scripting_api during development.
/// <reference path="../chat_scripting_api.ts" />

console.log('Test Script: Evaluating JavaScript code from Dart...');
(async () => {
console.log('test_script: Evaluating JavaScript code from Dart...');
console.log('test_script: Chat History:', chatHistory);

console.log('Chat History:', chatHistory);
const chatMessage = new ChatMessage(
ChatMessageSource.SYSTEM,
'Extension: Test Script',
{'foo': 'bar'}
);
addChatMessage(chatMessage);

const chatMessage = new ChatMessage(
ChatMessageSource.SYSTEM,
'Extension: Test Script',
{'foo': 'bar'}
);
addChatMessage(chatMessage);
const promise = sendMessagesToModel([chatMessage], 'test-model', null);
const result = await promise;
console.log(`test_script: awaited 2 sendMessagesToModel. Result from dart: ${result}`);

var result = sendMessagesToModel([chatMessage], 'test-model', 999);
console.log('sendMessagesToModel Result from dart:', result);
return 42;
})();

// Return a result
const returnResult = 42;
returnResult;
4 changes: 0 additions & 4 deletions gai-frontend/lib/chat/scripting/extensions/test2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,3 @@ script.onload = () => {
console.error("Lodash is not available!");
}
};

// Define a result variable
const result = 42;
result; // Return this value from eval
10 changes: 10 additions & 0 deletions gai-frontend/lib/chat/scripting/extensions/tsconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// This is just for the IDE to tell it the typescript level
// See build-scripting.sh for the tsc options.
{
"compilerOptions": {
"target": "es2020",
// "module": "es2022",
// "moduleResolution": "node",
"strict": true
}
}

0 comments on commit 59b8a45

Please sign in to comment.