Skip to content

Commit

Permalink
gai: Work in progress on scripting extensions.
Browse files Browse the repository at this point in the history
  • Loading branch information
patniemeyer committed Dec 9, 2024
1 parent f5661ad commit 826a8c3
Show file tree
Hide file tree
Showing 25 changed files with 750 additions and 182 deletions.
3 changes: 3 additions & 0 deletions gai-backend/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ __pycache__

*~

cfg_server_test.json
venv
env.sh
2 changes: 2 additions & 0 deletions gai-frontend/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ deploy-pat.sh

set_providers.sh

web/lib/extensions/*.js
web/index.html.new
32 changes: 32 additions & 0 deletions gai-frontend/build-scripting.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/sh
set -euo pipefail

base=$(dirname "$(realpath "$0")")
cd "$base" || exit

# if the tsc command is not installed tell the user to install typescript
if ! [ -x "$(command -v tsc)" ]; then
echo 'Error: Typescript is not installed. Install with "npm -g install typescript"' >&2
exit 1
fi

src="lib/chat/scripting"
ext_src="$src/extensions"

# Compile the scripting api typescript to js
mkdir -p ./web/lib
tsc --outFile ./web/lib/extensions/chat.js $src/chat_scripting_api.ts

# Build extensions
cd $ext_src || exit
for file in *.ts; do
echo "Building extension $file"
# rename the file to js
jsfile=$(echo $file | sed 's/\.ts/\.js/')
tsc --outFile "$jsfile" "$file"
mv "$jsfile" "$base/web/lib/extensions/"
done

# Create a declaration file for development time use
#tsc $SRC/chat_scripting_api.ts --declaration --emitDeclarationOnly --outDir $SRC

4 changes: 4 additions & 0 deletions gai-frontend/build.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#!/bin/bash
set -euxo pipefail

# Build the scripting extensions
sh build-scripting.sh

# Set default mode to prod if not specified
MODE=${1:-prod}
Expand Down
93 changes: 50 additions & 43 deletions gai-frontend/lib/chat/chat.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import 'package:orchid/api/orchid_eth/chains.dart';
import 'package:orchid/api/orchid_eth/orchid_account.dart';
import 'package:orchid/api/orchid_eth/orchid_account_detail.dart';
import 'package:orchid/api/orchid_keys.dart';
import 'package:orchid/chat/model.dart';
import 'package:orchid/chat/scripting/chat_scripting.dart';
import 'package:orchid/common/app_sizes.dart';
import 'package:orchid/chat/chat_settings_button.dart';
import 'package:orchid/orchid/field/orchid_labeled_numeric_field.dart';
Expand All @@ -15,9 +17,9 @@ import 'chat_button.dart';
import 'chat_message.dart';
import 'chat_prompt.dart';
import 'chat_model_button.dart';
import 'models.dart';
import 'auth_dialog.dart';
import 'chat_history.dart';
import 'model_manager.dart';
import 'provider_manager.dart';

class ChatView extends StatefulWidget {
Expand Down Expand Up @@ -45,9 +47,12 @@ class _ChatViewState extends State<ChatView> {
late final ProviderManager _providerManager;

// Models
final ModelsState _modelsState = ModelsState();
final ModelManager _modelsState = ModelManager();
List<String> _selectedModelIds = [];

List<ModelInfo> get _selectedModels =>
_modelsState.getModelsOrDefault(_selectedModelIds);

// Account
// This should be wrapped up in a provider. See WIP in vpn app.
EthereumAddress? _funder;
Expand Down Expand Up @@ -80,6 +85,15 @@ class _ChatViewState extends State<ChatView> {
} catch (e, stack) {
log('Error initializing from params: $e, $stack');
}

// Initialize scripting extension
// ChatScripting.init(
// url: 'lib/extensions/test.js',
// debugMode: true,
// providerManager: _providerManager,
// chatHistory: _chatHistory,
// addChatMessageToUI: _addChatMessage,
// );
}

bool get _connected {
Expand Down Expand Up @@ -259,7 +273,7 @@ class _ChatViewState extends State<ChatView> {

void _send() {
if (_canSendMessages()) {
_sendPrompt();
_sendUserPrompt();
} else {
_popAccountDialog();
}
Expand All @@ -270,18 +284,22 @@ class _ChatViewState extends State<ChatView> {
(_authToken != null && _inferenceUrl != null);
}

// Apply the prompt to history and send to selected models
void _sendPrompt() async {
// Validate the prompt, selections, and provider connection and then send the prompt to models.
void _sendUserPrompt() async {
var msg = _promptTextController.text;

// Validate the prompt
if (msg.trim().isEmpty) {
return;
}

// Validate the provider connection
if (!_providerManager.hasProviderConnection) {
_addMessage(ChatMessageSource.system, 'Not connected to provider');
return;
}

// Validate the selected models
if (_selectedModelIds.isEmpty) {
_addMessage(
ChatMessageSource.system,
Expand All @@ -291,52 +309,40 @@ class _ChatViewState extends State<ChatView> {
return;
}

// Manage the prompt UI
_promptTextController.clear();
// FocusManager.instance.primaryFocus?.unfocus(); // ?

// If we have a script selected allow it to handle the prompt
if (ChatScripting.enabled) {
ChatScripting.instance.sendUserPrompt(msg, _selectedModels);
} else {
_sendUserPromptDefaultBehavior(msg);
}
}

// The default behavior for handling the user prompt and selected models.
Future<void> _sendUserPromptDefaultBehavior(String msg) async {
// Add user message immediately to update UI and include in history
_addMessage(ChatMessageSource.client, msg);
_promptTextController.clear();
FocusManager.instance.primaryFocus?.unfocus();

await _sendChatToModels();
// Send the prompt to the selected models
await _sendChatHistoryToSelectedModels();
}

// Send the appropriate chat history to the selected models
Future<void> _sendChatToModels() async {
// The default strategy for sending the next round of the full, potentially multi-model, chat history:
// This strategy selects messages based on the isolated / party mode and sends them sequentially to each
// of the user-selected models allowing each model to see the previous responses.
Future<void> _sendChatHistoryToSelectedModels() async {
for (final modelId in _selectedModelIds) {
try {
final modelInfo = _modelsState.allModels.firstWhere(
(m) => m.id == modelId,
orElse: () => ModelInfo(
id: modelId,
name: modelId,
provider: '',
apiType: '',
),
);

// Prepare messages for this specific model
final preparedMessages = _chatHistory.prepareForModel(
modelId: modelId,
preparationFunction:
_partyMode ? ChatHistory.partyMode : ChatHistory.isolatedMode,
);
Map<String, Object>? params;
if (_maxTokens != null) {
params = {'max_tokens': _maxTokens!};
}
// Filter messages based on conversation mode.
final selectedMessages = _partyMode
? _chatHistory.getConversation()
: _chatHistory.getConversation(withModelId: modelId);

_addMessage(
ChatMessageSource.internal,
'Querying ${modelInfo.name}...',
modelId: modelId,
modelName: modelInfo.name,
);

// TODO: Move request inference to the provider manager?
await _providerManager.providerConnection?.requestInference(
modelId,
preparedMessages,
params: params,
);
await _providerManager.sendMessagesToModel(
selectedMessages, modelId, _maxTokens);
} catch (e) {
_addMessage(
ChatMessageSource.system, 'Error querying model $modelId: $e');
Expand Down Expand Up @@ -591,3 +597,4 @@ Future<void> _launchURL(String urlString) async {
enum AuthTokenMethod { manual, walletConnect }

enum OrchataMenuItem { debug }

86 changes: 11 additions & 75 deletions gai-frontend/lib/chat/chat_history.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import 'package:orchid/chat/chat_message.dart';
/// user-supplied preparation scripts.
class ChatHistory {
final List<ChatMessage> _messages = [];

// Built-in preparation functions
static const String isolatedMode = 'isolated';
static const String partyMode = 'party-mode';
Expand All @@ -21,79 +21,15 @@ class ChatHistory {
_messages.clear();
}

/// Prepare messages for a specific model's inference request.
/// Returns messages formatted for the chat completions API.
List<Map<String, dynamic>> prepareForModel({
required String modelId,
required String preparationFunction,
}) {
switch (preparationFunction) {
case isolatedMode:
return _prepareIsolated(modelId);
case partyMode:
return _preparePartyMode(modelId);
default:
// TODO: Hook up JS engine dispatch
throw UnimplementedError('Custom preparation functions not yet supported');
}
}

/// Default preparation mode where models only see their own history
List<Map<String, dynamic>> _prepareIsolated(String modelId) {
final relevantMessages = _messages.where((msg) =>
// Only include client messages and this model's responses
(msg.source == ChatMessageSource.client) ||
(msg.source == ChatMessageSource.provider && msg.modelId == modelId)
);

return _formatMessages(relevantMessages.toList(), modelId);
}

/// Party mode where models can see and respond to each other
List<Map<String, dynamic>> _preparePartyMode(String modelId) {
// Filter to only include actual conversation messages
final relevantMessages = _messages.where((msg) =>
msg.source == ChatMessageSource.client ||
msg.source == ChatMessageSource.provider
).toList();

return _formatMessages(relevantMessages, modelId);
}

/// Format a single message from the perspective of the target model
List<Map<String, dynamic>> _formatMessages(List<ChatMessage> messages, String modelId) {
return messages.map((msg) {
// Skip internal messages entirely
if (msg.source == ChatMessageSource.system ||
msg.source == ChatMessageSource.internal) {
return null;
}

String role;
String content = msg.message;

// Map conversation messages to appropriate roles
if (msg.source == ChatMessageSource.client) {
role = 'user';
} else if (msg.source == ChatMessageSource.provider) {
if (msg.modelId == modelId) {
role = 'assistant';
} else {
// Another model's message - show as user with identification
role = 'user';
final modelName = msg.modelName ?? msg.modelId;
content = '[$modelName]: $content';
}
} else {
// Should never hit this due to the filter above
log('Error: Unexpected message source: ${msg.source}');
return null;
}

return {
'role': role,
'content': content,
};
}).whereType<Map<String, dynamic>>().toList(); // Remove any nulls from skipped messages
// Return the client and provider messages, optionally limited to the specifid model id.
// System and internal messages are always excluded.
List<ChatMessage> getConversation({String? withModelId}) {
return _messages
.where((msg) =>
// Only include client messages and this model's responses
(msg.source == ChatMessageSource.client) ||
(msg.source == ChatMessageSource.provider &&
(withModelId == null || msg.modelId == withModelId)))
.toList();
}
}
5 changes: 3 additions & 2 deletions gai-frontend/lib/chat/chat_message.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ class ChatMessage {
final String sourceName;
final String msg;
final Map<String, dynamic>? metadata;
// The modelId of the model that generated this message
final String? modelId;
// The name of the model that generated this message
final String? modelName;

ChatMessage(
Expand All @@ -22,7 +24,6 @@ class ChatMessage {
String get message => msg;

String? get displayName {
log('Getting displayName. source: $source, modelName: $modelName, sourceName: $sourceName'); // Debug what we have
if (source == ChatMessageSource.provider && modelName != null) {
return modelName;
}
Expand Down Expand Up @@ -55,7 +56,7 @@ class ChatMessage {

@override
String toString() {
return 'ChatMessage(source: $source, model: $modelName, msg: ${msg.substring(0, msg.length.clamp(0, 50))}...)';
return 'ChatMessage(source: $source, modelId: $modelId, model: $modelName, msg: ${msg.substring(0, msg.length.clamp(0, 50))}...)';
}
}

2 changes: 1 addition & 1 deletion gai-frontend/lib/chat/chat_model_button.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import 'package:orchid/orchid/orchid.dart';
import 'package:orchid/orchid/menu/orchid_popup_menu_button.dart';
import 'models.dart';
import 'model.dart';

class ModelSelectionButton extends StatefulWidget {
final List<ModelInfo> models;
Expand Down
2 changes: 1 addition & 1 deletion gai-frontend/lib/chat/chat_settings_button.dart
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class _ChatSettingsButtonState extends State<ChatSettingsButton> {
const String.fromEnvironment('build_commit', defaultValue: '...');
final githubUrl =
'https://github.com/OrchidTechnologies/orchid/tree/$buildCommit/web-ethereum/dapp2';

return Center(
child: OrchidPopupMenuButton<dynamic>(
width: 30,
Expand Down
Loading

0 comments on commit 826a8c3

Please sign in to comment.