From 097fc9a21e71dbc6b93b86f024aac5518fcc5f55 Mon Sep 17 00:00:00 2001 From: Dan Montgomery Date: Fri, 8 Nov 2024 16:40:54 -0500 Subject: [PATCH] Front end now uses new inference endpoint. Add multi-model mode. --- gai-frontend/lib/chat/chat.dart | 399 +++++++++++++----- gai-frontend/lib/chat/chat_bubble.dart | 171 +++++--- gai-frontend/lib/chat/chat_button.dart | 31 ++ gai-frontend/lib/chat/chat_message.dart | 51 ++- gai-frontend/lib/chat/chat_model_button.dart | 239 +++++------ gai-frontend/lib/chat/chat_prompt.dart | 30 +- .../lib/chat/chat_settings_button.dart | 196 +++++---- gai-frontend/lib/chat/inference_client.dart | 223 ++++++++++ gai-frontend/lib/chat/models.dart | 94 +++++ .../lib/chat/provider_connection.dart | 301 +++++++++++++ gai-frontend/pubspec.lock | 98 +++-- gai-frontend/pubspec.yaml | 6 +- 12 files changed, 1416 insertions(+), 423 deletions(-) create mode 100644 gai-frontend/lib/chat/inference_client.dart create mode 100644 gai-frontend/lib/chat/models.dart create mode 100644 gai-frontend/lib/chat/provider_connection.dart diff --git a/gai-frontend/lib/chat/chat.dart b/gai-frontend/lib/chat/chat.dart index ce37beba3..30217aa34 100644 --- a/gai-frontend/lib/chat/chat.dart +++ b/gai-frontend/lib/chat/chat.dart @@ -21,9 +21,10 @@ import 'package:url_launcher/url_launcher.dart'; import 'chat_bubble.dart'; import 'chat_button.dart'; import 'chat_message.dart'; -import '../provider.dart'; import 'chat_prompt.dart'; import 'chat_model_button.dart'; +import 'models.dart'; +import 'provider_connection.dart'; import '../config/providers_config.dart'; class ChatView extends StatefulWidget { @@ -35,14 +36,13 @@ class ChatView extends StatefulWidget { class _ChatViewState extends State { List _messages = []; -// List _providers = []; -// Map> _providers = {'gpt4': {'url': 'https://nanogenera.danopato.com/ws/', 'name': 'ChatGPT-4'}}; late final Map> _providers; int _providerIndex = 0; bool _debugMode = false; + bool _multiSelectMode = false; bool _connected = false; - double _bid = 0.00007; ProviderConnection? _providerConnection; + int? _maxTokens; // The active account components EthereumAddress? _funder; @@ -52,17 +52,20 @@ class _ChatViewState extends State { final _funderFieldController = AddressValueFieldController(); final ScrollController messageListController = ScrollController(); final _promptTextController = TextEditingController(); - final _bidController = NumericValueFieldController(); + final _maxTokensController = NumericValueFieldController(); bool _showPromptDetails = false; Chain _selectedChain = Chains.Gnosis; + final ModelsState _modelsState = ModelsState(); + List _selectedModelIds = []; @override void initState() { super.initState(); _providers = ProvidersConfig.getProviders(); - _bidController.value = _bid; try { _initFromParams(); + // If we have providers and an account, connect to first provider + _connectToInitialProvider(); } catch (e, stack) { log('Error initializing from params: $e, $stack'); } @@ -75,6 +78,30 @@ class _ChatViewState extends State { return true; } + void _setMaxTokens(int? value) { + setState(() { + _maxTokens = value; + }); + } + + void _connectToInitialProvider() { + if (_providers.isEmpty) { + log('No providers configured'); + return; + } + + // Get first provider from the list + final firstProviderId = _providers.keys.first; + final firstProvider = _providers[firstProviderId]; + if (firstProvider == null) { + log('Invalid provider configuration'); + return; + } + + log('Connecting to initial provider: ${firstProvider['name']}'); + _connectProvider(firstProviderId); + } + Account? get _account { if (_funder == null || _signerKey == null) { return null; @@ -98,7 +125,16 @@ class _ChatViewState extends State { _accountDetail = null; if (_account != null) { _accountDetail = AccountDetailPoller(account: _account!); - _accountDetail?.pollOnce(); + await _accountDetail?.pollOnce(); + + // Disconnect any existing provider connection + if (_connected) { + _providerConnection?.dispose(); + _connected = false; + } + + // Connect to provider with new account + _connectToInitialProvider(); } _accountDetailNotifier.notifyListeners(); setState(() {}); @@ -127,91 +163,165 @@ class _ChatViewState extends State { } void providerConnected([name = '']) { - String nameTag = ''; _connected = true; - if (!name.isEmpty) { - nameTag = ' ${name}'; - } - addMessage(ChatMessageSource.system, 'Connected to provider${nameTag}.'); + // Only show connection message in debug mode + addMessage( + ChatMessageSource.internal, + 'Connected to provider${name.isEmpty ? '' : ' $name'}.', + ); } void providerDisconnected() { _connected = false; - addMessage(ChatMessageSource.system, 'Provider disconnected'); + // Only show disconnection in debug mode + addMessage( + ChatMessageSource.internal, + 'Provider disconnected', + ); } - - void _connectProvider([provider = '']) { + +void _connectProvider([String provider = '']) async { var account = _accountDetail; - String url; - String name; - String providerId = ''; if (account == null) { + log('_connectProvider() -- No account'); return; } - if (_providers.length == 0) { - log('_connectProvider() -- _providers.length == 0'); + if (_providers.isEmpty) { + log('_connectProvider() -- _providers.isEmpty'); return; } + + // Clean up existing connection if any if (_connected) { _providerConnection?.dispose(); - _providerIndex = (_providerIndex + 1) % _providers.length; _connected = false; } + + // Determine which provider to connect to + String providerId; if (provider.isEmpty) { - _providerIndex += 1; - providerId = _providers.keys.elementAt(_providerIndex); + providerId = _providers.keys.first; } else { providerId = provider; } - url = _providers[providerId]?['url'] ?? ''; - name = _providers[providerId]?['name'] ?? ''; - - log('Connecting to provider: ${name}'); - _providerConnection = ProviderConnection( - onMessage: (msg) { - addMessage(ChatMessageSource.internal, msg); - }, - onConnect: () { providerConnected(name); }, - onChat: (msg, metadata) { - addMessage(ChatMessageSource.provider, msg, metadata: metadata, sourceName: name); - }, - onDisconnect: providerDisconnected, - onError: (msg) { - addMessage(ChatMessageSource.system, 'Provider error: $msg'); - }, - onSystemMessage: (msg) { - addMessage(ChatMessageSource.system, msg); - }, - onInternalMessage: (msg) { - addMessage(ChatMessageSource.internal, msg); - }, - accountDetail: account, - contract: - EthereumAddress.from('0x6dB8381b2B41b74E17F5D4eB82E8d5b04ddA0a82'), - url: url, - ); - log('connected...'); - } - void addMessage(ChatMessageSource source, String msg, - {Map? metadata, String sourceName = ''}) { + final providerInfo = _providers[providerId]; + if (providerInfo == null) { + log('Provider not found: $providerId'); + return; + } + + final wsUrl = providerInfo['url'] ?? ''; + final name = providerInfo['name'] ?? ''; + final httpUrl = wsUrl.replaceFirst('ws:', 'http:').replaceFirst('wss:', 'https:'); + + log('Connecting to provider: $name (ws: $wsUrl, http: $httpUrl)'); + + try { + _providerConnection = await ProviderConnection.connect( + billingUrl: wsUrl, + inferenceUrl: httpUrl, + contract: EthereumAddress.from('0x6dB8381b2B41b74E17F5D4eB82E8d5b04ddA0a82'), + accountDetail: account, + onMessage: (msg) { + addMessage(ChatMessageSource.internal, msg); + }, + onConnect: () { + providerConnected(name); + }, + onChat: (msg, metadata) { + print('onChat received metadata: $metadata'); // See what metadata we get + final modelId = metadata['model_id']; + print('Found model_id: $modelId'); // Verify we extract model_id + + String? modelName; + if (modelId != null) { + print('Available models: ${_modelsState.allModels.map((m) => '${m.id}: ${m.name}')}'); // See what models we have + final model = _modelsState.allModels.firstWhere( + (m) => m.id == modelId, + orElse: () => ModelInfo( + id: modelId, + name: modelId, + provider: '', + apiType: '', + ), + ); + modelName = model.name; + print('Looked up model name: $modelName'); // See what name we found + } + + print('Adding message with modelId: $modelId, modelName: $modelName'); // Verify what we're passing + addMessage( + ChatMessageSource.provider, + msg, + metadata: metadata, + modelId: modelId, + modelName: modelName, + ); + }, + onDisconnect: providerDisconnected, + onError: (msg) { + addMessage(ChatMessageSource.system, 'Provider error: $msg'); + }, + onSystemMessage: (msg) { + addMessage(ChatMessageSource.system, msg); + }, + onInternalMessage: (msg) { + addMessage(ChatMessageSource.internal, msg); + }, + onAuthToken: (token, url) async { + // Fetch models after receiving token + log('Fetching models after auth token receipt'); + if (_providerConnection?.inferenceClient != null) { + await _modelsState.fetchModelsForProvider( + providerId, + _providerConnection!.inferenceClient!, + ); + } + }, + ); + + // Request auth token - model fetch will happen in callback + await _providerConnection?.requestAuthToken(); + + } catch (e, stack) { + log('Error connecting to provider: $e\n$stack'); + addMessage(ChatMessageSource.system, 'Failed to connect to provider: $e'); + } + } + + void addMessage( + ChatMessageSource source, + String msg, { + Map? metadata, + String sourceName = '', + String? modelId, + String? modelName, + }) { log('Adding message: ${msg.truncate(64)}'); setState(() { - if (sourceName.isEmpty) { - _messages.add(ChatMessage(source, msg, metadata: metadata)); - } else { - _messages.add(ChatMessage(source, msg, metadata: metadata, sourceName: sourceName)); - } + _messages.add(ChatMessage( + source, + msg, + metadata: metadata, + sourceName: sourceName, + modelId: modelId, + modelName: modelName, + )); }); - // if (source != ChatMessageSource.internal || _debugMode == true) { scrollMessagesDown(); - // } } - void _setBid(double? value) { + void _updateSelectedModels(List modelIds) { setState(() { - _bid = value ?? _bid; + if (_multiSelectMode) { + _selectedModelIds = modelIds; + } else { + // In single-select mode, only keep the most recently selected model + _selectedModelIds = modelIds.isNotEmpty ? [modelIds.last] : []; + } }); + log('Selected models updated to: $_selectedModelIds'); } // TODO: Break out widget @@ -317,18 +427,67 @@ class _ChatViewState extends State { _account != null ? _sendPrompt() : _popAccountDialog(); } - void _sendPrompt() { +void _sendPrompt() async { var msg = _promptTextController.text; if (msg.trim().isEmpty) { return; } - var message = '{"type": "job", "bid": $_bid, "prompt": "$msg"}'; - _providerConnection?.sendProviderMessage(message); + + if (_providerConnection == null) { + addMessage(ChatMessageSource.system, 'Not connected to provider'); + return; + } + + if (_selectedModelIds.isEmpty) { + addMessage(ChatMessageSource.system, + _multiSelectMode ? 'Please select at least one model' : 'Please select a model' + ); + return; + } + + // Add user message immediately to update UI and include in history + addMessage(ChatMessageSource.client, msg); _promptTextController.clear(); FocusManager.instance.primaryFocus?.unfocus(); - addMessage(ChatMessageSource.client, msg); - addMessage(ChatMessageSource.internal, 'Client: $message'); - log('Sending message to provider $message'); + + for (final modelId in _selectedModelIds) { + try { + final modelInfo = _modelsState.allModels + .firstWhere((m) => m.id == modelId, + orElse: () => ModelInfo( + id: modelId, + name: modelId, + provider: '', + apiType: '', + )); + + // Get messages relevant to this model + final relevantMessages = _messages.where((m) => + (m.source == ChatMessageSource.provider && m.modelId == modelId) || + m.source == ChatMessageSource.client + ).toList(); + + Map? params; + if (_maxTokens != null) { + params = {'max_tokens': _maxTokens!}; + } + + addMessage( + ChatMessageSource.internal, + 'Querying ${modelInfo.name}...', + modelId: modelId, + modelName: modelInfo.name, + ); + + await _providerConnection?.requestInference( + modelId, + relevantMessages, + params: params, + ); + } catch (e) { + addMessage(ChatMessageSource.system, 'Error querying model $modelId: $e'); + } + } } void scrollMessagesDown() { @@ -389,9 +548,9 @@ class _ChatViewState extends State { fit: BoxFit.scaleDown, child: SizedBox( width: minWidth, - child: _buildHeaderRow(showIcons: showIcons, providers: _providers))) + child: _buildHeaderRow(showIcons: showIcons))) else - _buildHeaderRow(showIcons: showIcons, providers: _providers), + _buildHeaderRow(showIcons: showIcons), // Messages area _buildChatPane(), // Prompt row @@ -399,16 +558,12 @@ class _ChatViewState extends State { alignment: Alignment.topCenter, duration: millis(150), child: ChatPromptPanel( - promptTextController: _promptTextController, - onSubmit: _send, - setBid: _setBid, - bidController: _bidController) - .top(8), + promptTextController: _promptTextController, + onSubmit: _send, + setMaxTokens: _setMaxTokens, + maxTokensController: _maxTokensController, + ).top(8), ), - if (!_showPromptDetails) - Text('Your bid is $_bid XDAI per token.', - style: OrchidText.normal_14) - .top(12), ], ), ).top(8).bottom(8), @@ -486,34 +641,72 @@ class _ChatViewState extends State { ); } - Widget _buildHeaderRow({required bool showIcons, required Map> providers}) { + Widget _buildHeaderRow({required bool showIcons}) { + final buttonHeight = 40.0; + final settingsIconSize = buttonHeight * 1.5; + return Row( children: [ - SizedBox(height: 40, child: OrchidAsset.image.logo), + // Logo + SizedBox(height: buttonHeight, child: OrchidAsset.image.logo), + + // Model selector with loading state + ListenableBuilder( + listenable: _modelsState, + builder: (context, _) { + if (_modelsState.isAnyLoading) { + return SizedBox( + width: buttonHeight, + height: buttonHeight, + child: Center( + child: CircularProgressIndicator( + strokeWidth: 2, + ), + ), + ); + } + + return ModelSelectionButton( + models: _modelsState.allModels, + selectedModelIds: _selectedModelIds, + updateModels: _updateSelectedModels, + multiSelectMode: _multiSelectMode, + ); + }, + ).left(24), + const Spacer(), - // Connect button - ChatModelButton( - updateModel: (id) { log(id); _connectProvider(id); }, - providers: providers, - ).left(8), -/* - ChatButton( - text: 'Reroll', - onPressed: _connectProvider, - ).left(8), -*/ - // Clear button - ChatButton(text: 'Clear Chat', onPressed: _clearChat).left(8), + // Account button - ChatButton(text: 'Account', onPressed: _popAccountDialog).left(8), + OutlinedChatButton( + text: 'Account', + onPressed: _popAccountDialog, + height: buttonHeight, + ).left(8), + // Settings button - ChatSettingsButton( - debugMode: _debugMode, - onDebugModeChanged: () { - setState(() { - _debugMode = !_debugMode; - }); - }, + SizedBox( + width: settingsIconSize, + height: buttonHeight, + child: Center( + child: ChatSettingsButton( + debugMode: _debugMode, + multiSelectMode: _multiSelectMode, + onDebugModeChanged: () { + setState(() { + _debugMode = !_debugMode; + }); + }, + onMultiSelectModeChanged: () { + setState(() { + _multiSelectMode = !_multiSelectMode; + // Reset selections when toggling modes + _selectedModelIds = []; + }); + }, + onClearChat: _clearChat, + ), + ), ).left(8), ], ); diff --git a/gai-frontend/lib/chat/chat_bubble.dart b/gai-frontend/lib/chat/chat_bubble.dart index 26c6653ef..735588b2a 100644 --- a/gai-frontend/lib/chat/chat_bubble.dart +++ b/gai-frontend/lib/chat/chat_bubble.dart @@ -15,19 +15,10 @@ class ChatBubble extends StatelessWidget { Widget build(BuildContext context) { ChatMessageSource src = message.source; - List msgBubbleColor(ChatMessageSource src) { - if (src == ChatMessageSource.client) { - return [ - const Color(0xff52319c), - const Color(0xff3b146a), - ]; - } else { - return [ - const Color(0xff005965), - OrchidColors.dark_ff3a3149, - ]; - } - } + // Constants for consistent spacing + const double iconSize = 16.0; + const double iconSpacing = 8.0; + const double iconTotalWidth = iconSize + iconSpacing; if (src == ChatMessageSource.system || src == ChatMessageSource.internal) { if (!debugMode && src == ChatMessageSource.internal) { @@ -39,9 +30,13 @@ class ChatBubble extends StatelessWidget { children: [ Text( message.message, - style: src == ChatMessageSource.system - ? OrchidText.normal_14 - : OrchidText.normal_14.grey, + style: const TextStyle( + fontFamily: 'Baloo2', + fontSize: 14, // 16px equivalent + height: 1.0, + fontWeight: FontWeight.normal, + color: Colors.white, + ), ), const SizedBox(height: 2), ], @@ -53,68 +48,112 @@ class ChatBubble extends StatelessWidget { alignment: src == ChatMessageSource.provider ? Alignment.centerLeft : Alignment.centerRight, - child: SizedBox( - width: 0.6 * 800, //MediaQuery.of(context).size.width * 0.6, + child: Container( + constraints: BoxConstraints(maxWidth: 0.6 * 800), child: Column( - crossAxisAlignment: CrossAxisAlignment.end, + crossAxisAlignment: src == ChatMessageSource.provider + ? CrossAxisAlignment.start + : CrossAxisAlignment.end, children: [ - Align( - alignment: Alignment.centerLeft, - child: _chatSourceText(message), -// child: Text(src == ChatMessageSource.provider ? 'Chat' : 'You', -// style: OrchidText.normal_14), - ), - const SizedBox(height: 2), - ClipRRect( - borderRadius: BorderRadius.circular(10), - child: Stack( - children: [ - Positioned.fill( - child: Container( - width: 0.6 * 800, - // MediaQuery.of(context).size.width * 0.6, - decoration: BoxDecoration( - gradient: LinearGradient( - colors: msgBubbleColor(message.source), - ), - ), + // Header row with icon and name for both provider and user + Row( + mainAxisAlignment: src == ChatMessageSource.provider + ? MainAxisAlignment.start + : MainAxisAlignment.end, + crossAxisAlignment: CrossAxisAlignment.center, + children: [ + if (src == ChatMessageSource.provider) ...[ + Icon( + Icons.stars_rounded, + color: OrchidColors.blue_highlight, + size: iconSize, + ), + SizedBox(width: iconSpacing), + Text( + message.displayName ?? 'Chat', + style: TextStyle( + fontFamily: 'Baloo2', + fontSize: 14, // 16px equivalent + height: 1.0, + fontWeight: FontWeight.w500, + color: OrchidColors.blue_highlight, ), ), - Container( - width: 0.6 * 800, - // MediaQuery.of(context).size.width * 0.6, - padding: const EdgeInsets.all(8.0), - child: Text(message.message, - style: OrchidText.medium_20_050), + ] else ...[ + Text( + 'You', + style: TextStyle( + fontFamily: 'Baloo2', + fontSize: 14, // 16px equivalent + height: 1.0, + fontWeight: FontWeight.w500, + color: OrchidColors.blue_highlight, + ), + ), + SizedBox(width: iconSpacing), + Icon( + Icons.account_circle_rounded, + color: OrchidColors.blue_highlight, + size: iconSize, ), ], - ), + ], ), - const SizedBox(height: 2), - if (src == ChatMessageSource.provider) ...[ - Text( - style: OrchidText.normal_14, - 'model: ${message.metadata?["model"]} usage: ${message.metadata?["usage"]}', + const SizedBox(height: 4), + // Message content with padding for provider messages + if (src == ChatMessageSource.provider) + Padding( + padding: EdgeInsets.only(left: iconTotalWidth), + child: Text( + message.message, + style: const TextStyle( + fontFamily: 'Baloo2', + fontSize: 20, // 20px design spec + height: 1.0, + fontWeight: FontWeight.normal, + color: Colors.white, + ), + ), ) + else + Container( + padding: const EdgeInsets.symmetric(horizontal: 25, vertical: 8), + decoration: BoxDecoration( + color: Colors.black, + borderRadius: BorderRadius.circular(10), + ), + child: Text( + message.message, + style: const TextStyle( + fontFamily: 'Baloo2', + fontSize: 20, // 20px design spec + height: 1.0, + fontWeight: FontWeight.normal, + color: Colors.white, + ), + ), + ), + // Usage metadata for provider messages + if (src == ChatMessageSource.provider) ...[ + const SizedBox(height: 4), + Padding( + padding: EdgeInsets.only(left: iconTotalWidth), + child: Text( + message.formatUsage(), + style: TextStyle( + fontFamily: 'Baloo2', + fontSize: 14, // 16px equivalent + height: 1.0, + fontWeight: FontWeight.normal, + color: OrchidColors.purpleCaption, + ), + ), + ), + const SizedBox(height: 6), ], - const SizedBox(height: 6), ], ), ), ); } - - Widget _chatSourceText(ChatMessage msg) { - final String srcText; - if (msg.sourceName.isEmpty) { - srcText = msg.source == ChatMessageSource.provider ? 'Chat' : 'You'; - } else { - srcText = msg.sourceName; - } - return Text( - srcText, - style: OrchidText.normal_14, - ); - } } - diff --git a/gai-frontend/lib/chat/chat_button.dart b/gai-frontend/lib/chat/chat_button.dart index 18f245bd6..f281397fb 100644 --- a/gai-frontend/lib/chat/chat_button.dart +++ b/gai-frontend/lib/chat/chat_button.dart @@ -31,3 +31,34 @@ class ChatButton extends StatelessWidget { } } +class OutlinedChatButton extends StatelessWidget { + const OutlinedChatButton({ + super.key, + required this.text, + required this.onPressed, + this.width, + this.height = 40, + }); + + final String text; + final VoidCallback onPressed; + final double? width, height; + + @override + Widget build(BuildContext context) { + return SizedBox( + height: height, + width: width, + child: OutlinedButton( + style: OutlinedButton.styleFrom( + side: BorderSide(color: Colors.white), + shape: RoundedRectangleBorder( + borderRadius: BorderRadius.circular(16), + ), + ), + onPressed: onPressed, + child: Text(text).button.white, + ), + ); + } +} diff --git a/gai-frontend/lib/chat/chat_message.dart b/gai-frontend/lib/chat/chat_message.dart index 8f722153b..a6430249f 100644 --- a/gai-frontend/lib/chat/chat_message.dart +++ b/gai-frontend/lib/chat/chat_message.dart @@ -1,4 +1,3 @@ - enum ChatMessageSource { client, provider, system, internal } class ChatMessage { @@ -6,11 +5,55 @@ class ChatMessage { final String sourceName; final String msg; final Map? metadata; + final String? modelId; + final String? modelName; + + ChatMessage( + this.source, + this.msg, { + this.metadata, + this.sourceName = '', + this.modelId, + this.modelName, + }); + + String get message => msg; + + String? get displayName { + print('Getting displayName. source: $source, modelName: $modelName, sourceName: $sourceName'); // Debug what we have + if (source == ChatMessageSource.provider && modelName != null) { + return modelName; + } + if (sourceName.isNotEmpty) { + return sourceName; + } + print('Returning null displayName'); // See when we hit this case + return null; + } - ChatMessage(this.source, this.msg, {this.metadata, this.sourceName = ''}); + String formatUsage() { + if (metadata == null || !metadata!.containsKey('usage')) { + return ''; + } + + final usage = metadata!['usage']; + if (usage == null) { + return ''; + } + + final prompt = usage['prompt_tokens'] ?? 0; + final completion = usage['completion_tokens'] ?? 0; + + if (prompt == 0 && completion == 0) { + return ''; + } + + return 'tokens: $prompt in, $completion out'; + } - String get message { - return msg; + @override + String toString() { + return 'ChatMessage(source: $source, model: $modelName, msg: ${msg.substring(0, msg.length.clamp(0, 50))}...)'; } } diff --git a/gai-frontend/lib/chat/chat_model_button.dart b/gai-frontend/lib/chat/chat_model_button.dart index d232ff70f..00d2356be 100644 --- a/gai-frontend/lib/chat/chat_model_button.dart +++ b/gai-frontend/lib/chat/chat_model_button.dart @@ -1,43 +1,73 @@ import 'package:orchid/orchid/orchid.dart'; -import 'package:orchid/api/orchid_language.dart'; -import 'package:orchid/api/preferences/user_preferences_ui.dart'; -import 'package:orchid/orchid/menu/expanding_popup_menu_item.dart'; -import 'package:orchid/orchid/menu/orchid_popup_menu_item_utils.dart'; -import 'package:orchid/orchid/menu/submenu_popup_menu_item.dart'; -import 'package:url_launcher/url_launcher_string.dart'; -import '../../../orchid/menu/orchid_popup_menu_button.dart'; -import 'chat_button.dart'; +import 'package:orchid/orchid/menu/orchid_popup_menu_button.dart'; +import 'models.dart'; -class ChatModelButton extends StatefulWidget { -// final bool debugMode; -// final VoidCallback onDebugModeChanged; - final updateModel; - final Map> providers; +class ModelSelectionButton extends StatefulWidget { + final List models; + final List selectedModelIds; + final Function(List) updateModels; + final bool multiSelectMode; - const ChatModelButton({ + const ModelSelectionButton({ Key? key, -// required this.debugMode, -// required this.onDebugModeChanged - required this.providers, - required this.updateModel, + required this.models, + required this.selectedModelIds, + required this.updateModels, + required this.multiSelectMode, }) : super(key: key); @override - State createState() => _ChatModelButtonState(); + State createState() => _ModelSelectionButtonState(); } -class _ChatModelButtonState extends State { - final _width = 273.0; - final _height = 50.0; +class _ModelSelectionButtonState extends State { + final _menuWidth = 273.0; + final _menuHeight = 50.0; final _textStyle = OrchidText.medium_16_025.copyWith(height: 2.0); bool _buttonSelected = false; + String get _buttonText { + if (widget.selectedModelIds.isEmpty) { + return widget.multiSelectMode ? 'Select Models' : 'Select Model'; + } + if (!widget.multiSelectMode || widget.selectedModelIds.length == 1) { + final modelId = widget.selectedModelIds.first; + return widget.models + .firstWhere( + (m) => m.id == modelId, + orElse: () => ModelInfo( + id: modelId, + name: modelId, + provider: '', + apiType: '', + ), + ) + .name; + } + return '${widget.selectedModelIds.length} Models'; + } + + void _handleModelSelection(String modelId) { + if (widget.multiSelectMode) { + final newSelection = List.from(widget.selectedModelIds); + if (newSelection.contains(modelId)) { + newSelection.remove(modelId); + } else { + newSelection.add(modelId); + } + widget.updateModels(newSelection); + } else { + widget.updateModels([modelId]); + } + } + @override Widget build(BuildContext context) { - return OrchidPopupMenuButton( - width: 80, + return OrchidPopupMenuButton( + width: null, height: 40, selected: _buttonSelected, + backgroundColor: Colors.transparent, onSelected: (item) { setState(() { _buttonSelected = false; @@ -52,116 +82,73 @@ class _ChatModelButtonState extends State { setState(() { _buttonSelected = true; }); - - return widget.providers.entries.map((entry) { - final providerId = entry.key; - final providerName = entry.value['name'] ?? providerId; + + if (widget.models.isEmpty) { + return [ + PopupMenuItem( + enabled: false, + height: _menuHeight, + child: SizedBox( + width: _menuWidth, + child: Text('No models available', style: _textStyle), + ), + ), + ]; + } + + return widget.models.map((model) { + final isSelected = widget.selectedModelIds.contains(model.id); return PopupMenuItem( - onTap: () { widget.updateModel(providerId); }, - height: _height, + onTap: () => _handleModelSelection(model.id), + height: _menuHeight, child: SizedBox( - width: _width, - child: Text(providerName, style: _textStyle), + width: _menuWidth, + child: Row( + children: [ + Expanded( + child: Text( + model.name, + style: _textStyle.copyWith( + color: isSelected ? Theme.of(context).primaryColor : null, + ), + ), + ), + if (isSelected) + Icon( + Icons.check, + size: 16, + color: Theme.of(context).primaryColor, + ), + ], + ), ), ); }).toList(); }, -/* - itemBuilder: (itemBuilderContext) { - setState(() { - _buttonSelected = true; - }); - - const div = PopupMenuDivider(height: 1.0); - return [ - PopupMenuItem( - onTap: () { widget.updateModel('gpt4'); }, - height: _height, - child: SizedBox( - width: _width, - child: Text('GPT-4', style: _textStyle), - ), - ), - PopupMenuItem( - onTap: () { widget.updateModel('gpt4o'); }, - height: _height, - child: SizedBox( - width: _width, - child: Text('GPT-4o', style: _textStyle), - ), - ), -// div, - PopupMenuItem( - onTap: () { widget.updateModel('mistral'); }, - height: _height, - child: SizedBox( - width: _width, - child: Text('Mistral 7B', style: _textStyle), - ), - ), - PopupMenuItem( - onTap: () { widget.updateModel('mixtral-8x22b'); }, - height: _height, - child: SizedBox( - width: _width, - child: Text('Mixtral 8x22b', style: _textStyle), - ), + child: SizedBox( + height: 40, + child: Padding( + padding: EdgeInsets.symmetric(horizontal: 16), + child: Row( + mainAxisSize: MainAxisSize.min, + children: [ + Flexible( + child: Text( + _buttonText, + textAlign: TextAlign.left, + overflow: TextOverflow.ellipsis, + ).button.white, + ), + Icon( + Icons.arrow_drop_down, + color: Colors.white, + size: 24, + ), + ], ), - PopupMenuItem( - onTap: () { widget.updateModel('gemini'); }, - height: _height, - child: SizedBox( - width: _width, - child: Text('Gemini 1.5', style: _textStyle), - ), - ), - PopupMenuItem( - onTap: () { widget.updateModel('claude-3'); }, - height: _height, - child: SizedBox( - width: _width, - child: Text('Claude 3 Opus', style: _textStyle), - ), - ), - PopupMenuItem( - onTap: () { widget.updateModel('claude35sonnet'); }, - height: _height, - child: SizedBox( - width: _width, - child: Text('Claude 3.5 Sonnet', style: _textStyle), - ), - ), - ]; - }, -*/ - - /* - child: FittedBox( - fit: BoxFit.scaleDown, - child: SizedBox( - width: 80, height: 20, child: Text('Model'))), -*/ -// child: SizedBox( -// width: 120, height: 20, child: Text('Model', style: _textStyle).white), - child: Align( - alignment: Alignment.center, - child: Text('Model', textAlign: TextAlign.center).button.white, - ), - ); - } - - PopupMenuItem _listMenuItem({ - required bool selected, - required String title, - required VoidCallback onTap, - }) { - return OrchidPopupMenuItemUtils.listMenuItem( - context: context, - selected: selected, - title: title, - onTap: onTap, - textStyle: _textStyle, + ), + ), ); } } diff --git a/gai-frontend/lib/chat/chat_prompt.dart b/gai-frontend/lib/chat/chat_prompt.dart index 06470beba..793d6f00e 100644 --- a/gai-frontend/lib/chat/chat_prompt.dart +++ b/gai-frontend/lib/chat/chat_prompt.dart @@ -2,19 +2,18 @@ import 'package:orchid/orchid/field/orchid_labeled_numeric_field.dart'; import 'package:orchid/orchid/field/orchid_text_field.dart'; import 'package:orchid/orchid/orchid.dart'; -// The prompt row and collapsible bid form footer class ChatPromptPanel extends StatefulWidget { final TextEditingController promptTextController; final VoidCallback onSubmit; - final ValueChanged setBid; - final NumericValueFieldController bidController; + final ValueChanged setMaxTokens; + final NumericValueFieldController maxTokensController; const ChatPromptPanel({ super.key, required this.promptTextController, required this.onSubmit, - required this.setBid, - required this.bidController, + required this.setMaxTokens, + required this.maxTokensController, }); @override @@ -63,26 +62,27 @@ class _ChatPromptPanelState extends State { ], ).padx(8), if (_showPromptDetails) - _buildBidForm(widget.setBid, widget.bidController), + _buildPromptParamsForm(widget.setMaxTokens, widget.maxTokensController), ], ); } - Widget _buildBidForm( - ValueChanged setBid, - NumericValueFieldController bidController, + Widget _buildPromptParamsForm( + ValueChanged setMaxTokens, + NumericValueFieldController maxTokensController, ) { return Container( padding: const EdgeInsets.all(10.0), child: Column( children: [ - Text('Your bid is the price per token in/out you will pay.', - style: OrchidText.medium_20_050) - .top(8), + Text( + 'Set the maximum number of tokens for the response.', + style: OrchidText.medium_20_050, + ).top(8), OrchidLabeledNumericField( - label: 'Bid', - onChange: setBid, - controller: bidController, + label: 'Max Tokens', + onChange: (value) => setMaxTokens(value?.toInt()), + controller: maxTokensController, ).top(12) ], ), diff --git a/gai-frontend/lib/chat/chat_settings_button.dart b/gai-frontend/lib/chat/chat_settings_button.dart index b7d46e46a..7d5b4e958 100644 --- a/gai-frontend/lib/chat/chat_settings_button.dart +++ b/gai-frontend/lib/chat/chat_settings_button.dart @@ -9,12 +9,18 @@ import '../../../orchid/menu/orchid_popup_menu_button.dart'; class ChatSettingsButton extends StatefulWidget { final bool debugMode; + final bool multiSelectMode; final VoidCallback onDebugModeChanged; + final VoidCallback onMultiSelectModeChanged; + final VoidCallback onClearChat; const ChatSettingsButton({ Key? key, required this.debugMode, + required this.multiSelectMode, required this.onDebugModeChanged, + required this.onMultiSelectModeChanged, + required this.onClearChat, }) : super(key: key); @override @@ -33,93 +39,128 @@ class _ChatSettingsButtonState extends State { const String.fromEnvironment('build_commit', defaultValue: '...'); final githubUrl = 'https://github.com/OrchidTechnologies/orchid/tree/$buildCommit/web-ethereum/dapp2'; - return OrchidPopupMenuButton( - width: 40, - height: 40, - selected: _buttonSelected, - onSelected: (item) { - setState(() { - _buttonSelected = false; - }); - }, - onCanceled: () { - setState(() { - _buttonSelected = false; - }); - }, - itemBuilder: (itemBuilderContext) { - setState(() { - _buttonSelected = true; - }); + + return Center( + child: OrchidPopupMenuButton( + width: 30, + height: 30, + selected: _buttonSelected, + backgroundColor: Colors.transparent, + onSelected: (item) { + setState(() { + _buttonSelected = false; + }); + }, + onCanceled: () { + setState(() { + _buttonSelected = false; + }); + }, + itemBuilder: (itemBuilderContext) { + setState(() { + _buttonSelected = true; + }); - const div = PopupMenuDivider(height: 1.0); - return [ - // debug mode - PopupMenuItem( - onTap: widget.onDebugModeChanged, - height: _height, - child: SizedBox( - width: _width, - child: Row( - mainAxisAlignment: MainAxisAlignment.spaceBetween, - children: [ - Text("Debug Mode", style: _textStyle), - Icon( - widget.debugMode - ? Icons.check_box_outlined - : Icons.check_box_outline_blank, - color: Colors.white, - ), - ], + const div = PopupMenuDivider(height: 1.0); + return [ + // Clear chat + PopupMenuItem( + onTap: widget.onClearChat, + height: _height, + child: SizedBox( + width: _width, + child: Text("Clear Chat", style: _textStyle), ), ), - ), - div, - SubmenuPopopMenuItemBuilder( - builder: _buildIdenticonsPref, - ), - div, - SubmenuPopopMenuItemBuilder( - builder: _buildLanguagePref, - ), - div, - PopupMenuItem( - onTap: () { - Future.delayed(millis(0), () async { - _openLicensePage(context); - }); - }, - height: _height, - child: SizedBox( - width: _width, - child: Text(s.openSourceLicenses, style: _textStyle), + div, + // debug mode + PopupMenuItem( + onTap: widget.onDebugModeChanged, + height: _height, + child: SizedBox( + width: _width, + child: Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + Text("Debug Mode", style: _textStyle), + Icon( + widget.debugMode + ? Icons.check_box_outlined + : Icons.check_box_outline_blank, + color: Colors.white, + ), + ], + ), + ), ), + div, + // multi-select mode + PopupMenuItem( + onTap: widget.onMultiSelectModeChanged, + height: _height, + child: SizedBox( + width: _width, + child: Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + Text("Multi-Model Mode", style: _textStyle), + Icon( + widget.multiSelectMode + ? Icons.check_box_outlined + : Icons.check_box_outline_blank, + color: Colors.white, + ), + ], + ), + ), + ), + div, + SubmenuPopopMenuItemBuilder( + builder: _buildIdenticonsPref, + ), + div, + SubmenuPopopMenuItemBuilder( + builder: _buildLanguagePref, + ), + div, + PopupMenuItem( + onTap: () { + Future.delayed(millis(0), () async { + _openLicensePage(context); + }); + }, + height: _height, + child: SizedBox( + width: _width, + child: Text(s.openSourceLicenses, style: _textStyle), + ), + ), + div, + // dapp version item + _listMenuItem( + selected: false, + title: 'Version: ' + buildCommit, + onTap: () async { + launchUrlString(githubUrl); + }, + ), + ]; + }, + child: SizedBox( + width: 30, + height: 30, + child: FittedBox( + fit: BoxFit.contain, + child: OrchidAsset.svg.settings_gear, ), - div, - // dapp version item - _listMenuItem( - selected: false, - title: 'Version: ' + buildCommit, - onTap: () async { - launchUrlString(githubUrl); - }, - ), - ]; - }, - child: FittedBox( - fit: BoxFit.scaleDown, - child: SizedBox( - width: 20, height: 20, child: OrchidAsset.svg.settings_gear)), + ), + ), ); } Future _openLicensePage(BuildContext context) { // TODO: return Future.delayed(millis(100), () async {}); - // return Navigator.push(context, - // MaterialPageRoute(builder: (BuildContext context) { - // return OpenSourcePage(); - // })); } Widget _buildLanguagePref(bool expanded) { @@ -147,12 +188,11 @@ class _ChatSettingsButtonState extends State { ), ) .toList() - .cast() // so that we can add the items below + .cast() .separatedWith( PopupMenuDivider(height: 1.0), ); - // Default system language option items.insert( 0, _listMenuItem( diff --git a/gai-frontend/lib/chat/inference_client.dart b/gai-frontend/lib/chat/inference_client.dart new file mode 100644 index 000000000..4ded9a1d7 --- /dev/null +++ b/gai-frontend/lib/chat/inference_client.dart @@ -0,0 +1,223 @@ +import 'dart:convert'; +import 'package:http/http.dart' as http; +import 'chat_message.dart'; + +class InferenceError implements Exception { + final int statusCode; + final String message; + + InferenceError(this.statusCode, this.message); + + @override + String toString() => 'InferenceError($statusCode): $message'; +} + +class TokenUsage { + final int promptTokens; + final int completionTokens; + final int totalTokens; + + TokenUsage({ + required this.promptTokens, + required this.completionTokens, + required this.totalTokens, + }); + + factory TokenUsage.fromJson(Map json) { + return TokenUsage( + promptTokens: json['prompt_tokens'], + completionTokens: json['completion_tokens'], + totalTokens: json['total_tokens'], + ); + } + + Map toJson() => { + 'prompt_tokens': promptTokens, + 'completion_tokens': completionTokens, + 'total_tokens': totalTokens, + }; +} + +class ModelInfo { + final String id; + final String name; + final String apiType; + + ModelInfo({ + required this.id, + required this.name, + required this.apiType, + }); + + factory ModelInfo.fromJson(Map json) { + return ModelInfo( + id: json['id'], + name: json['name'], + apiType: json['api_type'], + ); + } +} + +class InferenceResponse { + final String response; + final TokenUsage usage; + + InferenceResponse({ + required this.response, + required this.usage, + }); + + factory InferenceResponse.fromJson(Map json) { + return InferenceResponse( + response: json['response'], + usage: TokenUsage.fromJson(json['usage']), + ); + } + + Map toMetadata() => { + 'usage': usage.toJson(), + }; +} + +class InferenceClient { + final String baseUrl; + String? _authToken; + + InferenceClient({required String baseUrl}) + : baseUrl = _normalizeBaseUrl(baseUrl); + + static String _normalizeBaseUrl(String url) { + if (url.endsWith('/')) { + url = url.substring(0, url.length - 1); + } + + if (url.endsWith('/v1/inference')) { + url = url.substring(0, url.length - '/v1/inference'.length); + } + + return url; + } + + void setAuthToken(String token) { + _authToken = token; + } + + Future> listModels() async { + if (_authToken == null) { + throw InferenceError(401, 'No auth token'); + } + + final response = await http.get( + Uri.parse('$baseUrl/v1/models'), + headers: {'Authorization': 'Bearer $_authToken'}, + ); + + if (response.statusCode != 200) { + throw InferenceError(response.statusCode, response.body); + } + + final data = json.decode(response.body) as Map; + return data.map((key, value) => MapEntry( + key, + ModelInfo.fromJson(value as Map), + )); + } + + Map _chatMessageToJson(ChatMessage msg) { + String role; + switch (msg.source) { + case ChatMessageSource.client: + role = 'user'; + break; + case ChatMessageSource.provider: + role = 'assistant'; + break; + default: + role = 'system'; + } + + final map = { + 'role': role, + 'content': msg.message, + }; + + if (msg.modelName != null) { + map['name'] = msg.modelName; + } + + return map; + } + + // Simple token estimation + int _estimateTokenCount(String text) { + // Average English word is ~4 characters + space + // Average token is ~4 characters + return (text.length / 4).ceil(); + } + + Future> inference({ + required List messages, + String? model, + Map? params, + }) async { + if (_authToken == null) { + throw InferenceError(401, 'No auth token'); + } + + if (messages.isEmpty) { + throw InferenceError(400, 'No messages provided'); + } + + final estimatedTokens = messages.fold( + 0, (sum, msg) => sum + _estimateTokenCount(msg.message) + ); + + final formattedMessages = messages.map(_chatMessageToJson).toList(); + + final Map payload = { + 'messages': formattedMessages, + 'estimated_prompt_tokens': estimatedTokens, + }; + + if (model != null) { + payload['model'] = model; + } + + if (params != null) { + payload.addAll(params); + } + + print('InferenceClient: Preparing request to $baseUrl/v1/inference'); + print('Payload: ${const JsonEncoder.withIndent(' ').convert(payload)}'); + + final response = await http.post( + Uri.parse('$baseUrl/v1/inference'), + headers: { + 'Authorization': 'Bearer $_authToken', + 'Content-Type': 'application/json', + }, + body: json.encode(payload), + ); + + print('InferenceClient: Received response status ${response.statusCode}'); + print('Response body: ${response.body}'); + + if (response.statusCode == 402) { + throw InferenceError(402, 'Insufficient balance'); + } + + if (response.statusCode != 200) { + throw InferenceError(response.statusCode, response.body); + } + + final inferenceResponse = InferenceResponse.fromJson( + json.decode(response.body) + ); + + return { + 'response': inferenceResponse.response, + 'usage': inferenceResponse.usage.toJson(), + 'estimated_prompt_tokens': estimatedTokens, + }; + } +} diff --git a/gai-frontend/lib/chat/models.dart b/gai-frontend/lib/chat/models.dart new file mode 100644 index 000000000..79cfd8cb2 --- /dev/null +++ b/gai-frontend/lib/chat/models.dart @@ -0,0 +1,94 @@ +import 'package:flutter/foundation.dart'; + +class ModelInfo { + final String id; + final String name; + final String provider; + final String apiType; + + ModelInfo({ + required this.id, + required this.name, + required this.provider, + required this.apiType, + }); + + factory ModelInfo.fromJson(Map json, String providerId) { + return ModelInfo( + id: json['id'], + name: json['name'], + provider: providerId, + apiType: json['api_type'], + ); + } + + @override + String toString() => 'ModelInfo(id: $id, name: $name, provider: $provider)'; +} + +class ModelsState extends ChangeNotifier { + final _modelsByProvider = >{}; + final _loadingProviders = {}; + final _errors = {}; + + bool isLoading(String providerId) => _loadingProviders.contains(providerId); + String? getError(String providerId) => _errors[providerId]; + + List getModelsForProvider(String providerId) { + return _modelsByProvider[providerId] ?? []; + } + + List get allModels { + final models = _modelsByProvider.values.expand((models) => models).toList(); + print('ModelsState.allModels returning ${models.length} models: $models'); + return models; + } + + Future fetchModelsForProvider( + String providerId, + dynamic client, + ) async { + print('ModelsState: Fetching models for provider $providerId'); + _loadingProviders.add(providerId); + _errors.remove(providerId); + notifyListeners(); + + try { + final response = await client.listModels(); + print('ModelsState: Received model data from client: $response'); + + // Convert the response map entries directly to ModelInfo objects + final modelsList = response.entries.map((entry) => ModelInfo( + id: entry.value.id, + name: entry.value.name, + provider: providerId, + apiType: entry.value.apiType, + )).toList(); + + print('ModelsState: Created models list: $modelsList'); + + _modelsByProvider[providerId] = modelsList; + print('ModelsState: Updated models for provider $providerId: $modelsList'); + } catch (e, stack) { + print('ModelsState: Error fetching models: $e\n$stack'); + _errors[providerId] = e.toString(); + } finally { + _loadingProviders.remove(providerId); + notifyListeners(); + print('ModelsState: Notified listeners, current state: \n' + 'Models: ${_modelsByProvider}\n' + 'Loading: $_loadingProviders\n' + 'Errors: $_errors'); + } + } + + void clearProviderModels(String providerId) { + _modelsByProvider.remove(providerId); + _errors.remove(providerId); + _loadingProviders.remove(providerId); + notifyListeners(); + } + + bool get isAnyLoading => _loadingProviders.isNotEmpty; + Set get activeProviders => _modelsByProvider.keys.toSet(); +} diff --git a/gai-frontend/lib/chat/provider_connection.dart b/gai-frontend/lib/chat/provider_connection.dart new file mode 100644 index 000000000..d9f1aa39b --- /dev/null +++ b/gai-frontend/lib/chat/provider_connection.dart @@ -0,0 +1,301 @@ +import 'dart:async'; +import 'dart:convert'; +import 'dart:math'; +import 'package:web_socket_channel/web_socket_channel.dart'; +import 'package:orchid/api/orchid_crypto.dart'; +import 'package:orchid/api/orchid_eth/orchid_ticket.dart'; +import 'package:orchid/api/orchid_eth/orchid_account.dart'; +import 'package:orchid/api/orchid_eth/orchid_account_detail.dart'; +import 'inference_client.dart'; +import 'chat_message.dart'; + +typedef MessageCallback = void Function(String message); +typedef ChatCallback = void Function(String message, Map metadata); +typedef VoidCallback = void Function(); +typedef ErrorCallback = void Function(String error); +typedef AuthTokenCallback = void Function(String token, String inferenceUrl); + +class _PendingRequest { + final String requestId; + final String modelId; + final List messages; + final Map? params; + final DateTime timestamp; + + _PendingRequest({ + required this.requestId, + required this.modelId, + required this.messages, + required this.params, + }) : timestamp = DateTime.now(); +} + +class ProviderConnection { + final maxuint256 = BigInt.two.pow(256) - BigInt.one; + final maxuint64 = BigInt.two.pow(64) - BigInt.one; + final wei = BigInt.from(10).pow(18); + WebSocketChannel? _providerChannel; + InferenceClient? get inferenceClient => _inferenceClient; + InferenceClient? _inferenceClient; + final MessageCallback onMessage; + final ChatCallback onChat; + final VoidCallback onConnect; + final ErrorCallback onError; + final VoidCallback onDisconnect; + final MessageCallback onSystemMessage; + final MessageCallback onInternalMessage; + final EthereumAddress contract; + final String url; + final AccountDetail accountDetail; + final AuthTokenCallback? onAuthToken; + final Map _requestModels = {}; + final Map _pendingRequests = {}; + + String _generateRequestId() { + return '${DateTime.now().millisecondsSinceEpoch}-${Random().nextInt(10000)}'; + } + + ProviderConnection({ + required this.onMessage, + required this.onConnect, + required this.onChat, + required this.onDisconnect, + required this.onError, + required this.onSystemMessage, + required this.onInternalMessage, + required this.contract, + required this.url, + required this.accountDetail, + this.onAuthToken, + }) { + try { + _providerChannel = WebSocketChannel.connect(Uri.parse(url)); + _providerChannel?.ready; + } catch (e) { + onError('Failed on provider connection: $e'); + return; + } + _providerChannel?.stream.listen( + receiveProviderMessage, + onDone: () => onDisconnect(), + onError: (error) => onError('ws error: $error'), + ); + onConnect(); + } + + static Future connect({ + required String billingUrl, + required String inferenceUrl, // This won't be used initially + required EthereumAddress contract, + required AccountDetail accountDetail, + required MessageCallback onMessage, + required ChatCallback onChat, + required VoidCallback onConnect, + required ErrorCallback onError, + required VoidCallback onDisconnect, + required MessageCallback onSystemMessage, + required MessageCallback onInternalMessage, + AuthTokenCallback? onAuthToken, + }) async { + final connection = ProviderConnection( + onMessage: onMessage, + onConnect: onConnect, + onChat: onChat, + onDisconnect: onDisconnect, + onError: onError, + onSystemMessage: onSystemMessage, + onInternalMessage: onInternalMessage, + contract: contract, + url: billingUrl, + accountDetail: accountDetail, + onAuthToken: onAuthToken, + ); + + return connection; + } + + void _handleAuthToken(Map data) { + final token = data['session_id']; + final inferenceUrl = data['inference_url']; + if (token == null || inferenceUrl == null) { + onError('Invalid auth token response'); + return; + } + + // Create new inference client with the URL from the auth token + _inferenceClient = InferenceClient(baseUrl: inferenceUrl); + _inferenceClient!.setAuthToken(token); + onInternalMessage('Auth token received and inference client initialized'); + + onAuthToken?.call(token, inferenceUrl); + } + + bool validInvoice(invoice) { + return invoice.containsKey('amount') && invoice.containsKey('commit') && + invoice.containsKey('recipient'); + } + + void payInvoice(Map invoice) { + var payment; + if (!validInvoice(invoice)) { + onError('Invalid invoice ${invoice}'); + return; + } + + assert(accountDetail.funder != null); + final balance = accountDetail.lotteryPot?.balance.intValue ?? BigInt.zero; + final deposit = accountDetail.lotteryPot?.deposit.intValue ?? BigInt.zero; + + if (balance <= BigInt.zero || deposit <= BigInt.zero) { + onError('Insufficient funds: balance=$balance, deposit=$deposit'); + return; + } + + final faceval = _bigIntMin(balance, (wei * deposit) ~/ (wei * BigInt.two)); + if (faceval <= BigInt.zero) { + onError('Invalid face value: $faceval'); + return; + } + + final data = BigInt.zero; + final due = BigInt.parse(invoice['amount']); + final lotaddr = contract; + final token = EthereumAddress.zero; + + BigInt ratio; + try { + ratio = maxuint64 & (maxuint64 * due ~/ faceval); + } catch (e) { + onError('Failed to calculate ratio: $e (due=$due, faceval=$faceval)'); + return; + } + + final commit = BigInt.parse(invoice['commit'] ?? '0x0'); + final recipient = invoice['recipient']; + + final ticket = OrchidTicket( + data: data, + lotaddr: lotaddr, + token: token, + amount: faceval, + ratio: ratio, + funder: accountDetail.account.funder, + recipient: EthereumAddress.from(recipient), + commitment: commit, + privateKey: accountDetail.account.signerKey.private, + millisecondsSinceEpoch: DateTime.now().millisecondsSinceEpoch, + ); + + payment = '{"type": "payment", "tickets": ["${ticket.serializeTicket()}"]}'; + onInternalMessage('Client: $payment'); + _sendProviderMessage(payment); + } + + void receiveProviderMessage(dynamic message) { + final data = jsonDecode(message) as Map; + print(message); + onMessage('Provider: $message'); + + switch (data['type']) { + case 'job_complete': + final requestId = data['request_id']; + final pendingRequest = requestId != null ? _pendingRequests.remove(requestId) : null; + + onChat(data['output'], { + ...data, + 'model_id': pendingRequest?.modelId, + }); + break; + case 'invoice': + payInvoice(data); + break; + case 'bid_low': + onSystemMessage("Bid below provider's reserve price."); + break; + case 'auth_token': + _handleAuthToken(data); + break; + } + } + + Future requestAuthToken() async { + final message = '{"type": "request_token"}'; + onInternalMessage('Requesting auth token'); + _sendProviderMessage(message); + } + + Future requestInference( + String modelId, + List messages, { + Map? params, + }) async { + if (_inferenceClient == null) { + await requestAuthToken(); + await Future.delayed(Duration(milliseconds: 100)); + + if (_inferenceClient == null) { + onError('No inference connection available'); + return; + } + } + + try { + final requestId = _generateRequestId(); + + _pendingRequests[requestId] = _PendingRequest( + requestId: requestId, + modelId: modelId, + messages: messages, + params: params, + ); + + final allParams = { + ...?params, + 'request_id': requestId, + }; + + onInternalMessage('Sending inference request:\n' + 'Model: $modelId\n' + 'Messages: ${messages.map((m) => "${m.source}: ${m.message}").join("\n")}\n' + 'Params: $allParams' + ); + + final result = await _inferenceClient!.inference( + messages: messages, + model: modelId, + params: allParams, + ); + + final pendingRequest = _pendingRequests.remove(requestId); + + onChat(result['response'], { + 'type': 'job_complete', + 'output': result['response'], + 'usage': result['usage'], + 'model_id': pendingRequest?.modelId, + 'request_id': requestId, + 'estimated_prompt_tokens': result['estimated_prompt_tokens'], + }); + } catch (e, stack) { + onError('Failed to send inference request: $e\n$stack'); + } + } + + void _sendProviderMessage(String message) { + print('Sending message to provider $message'); + _providerChannel?.sink.add(message); + } + + void dispose() { + _providerChannel?.sink.close(); + _pendingRequests.clear(); + onDisconnect(); + } + + BigInt _bigIntMin(BigInt a, BigInt b) { + if (a > b) { + return b; + } + return a; + } +} diff --git a/gai-frontend/pubspec.lock b/gai-frontend/pubspec.lock index 3d744db6b..dc26a272f 100644 --- a/gai-frontend/pubspec.lock +++ b/gai-frontend/pubspec.lock @@ -125,10 +125,10 @@ packages: dependency: transitive description: name: collection - sha256: f092b211a4319e98e5ff58223576de6c2803db36221657b46c82574721240687 + sha256: ee67cb0715911d28db6bf4af1026078bd6f0128b07a5f66fb2ed94ec6783c09a url: "https://pub.dev" source: hosted - version: "1.17.2" + version: "1.18.0" convert: dependency: transitive description: @@ -161,6 +161,14 @@ packages: url: "https://pub.dev" source: hosted version: "2.3.4" + decimal: + dependency: "direct main" + description: + name: decimal + sha256: "4140a688f9e443e2f4de3a1162387bf25e1ac6d51e24c9da263f245210f41440" + url: "https://pub.dev" + source: hosted + version: "3.0.2" fake_async: dependency: transitive description: @@ -254,7 +262,7 @@ packages: source: hosted version: "2.1.2" http: - dependency: transitive + dependency: "direct main" description: name: http sha256: "5895291c13fa8a3bd82e76d5627f69e0d85ca6a30dcac95c4ea19a5d555879c2" @@ -273,10 +281,10 @@ packages: dependency: "direct main" description: name: intl - sha256: "3bc132a9dbce73a7e4a21a17d06e1878839ffbf975568bc875c60537824b0c4d" + sha256: d6f56758b7d3014a48af9701c085700aac781a92a87a62b1333b46d8879661cf url: "https://pub.dev" source: hosted - version: "0.18.1" + version: "0.19.0" jdenticon_dart: dependency: "direct main" description: @@ -309,6 +317,30 @@ packages: url: "https://pub.dev" source: hosted version: "3.0.2" + leak_tracker: + dependency: transitive + description: + name: leak_tracker + sha256: "3f87a60e8c63aecc975dda1ceedbc8f24de75f09e4856ea27daf8958f2f0ce05" + url: "https://pub.dev" + source: hosted + version: "10.0.5" + leak_tracker_flutter_testing: + dependency: transitive + description: + name: leak_tracker_flutter_testing + sha256: "932549fb305594d82d7183ecd9fa93463e9914e1b67cacc34bc40906594a1806" + url: "https://pub.dev" + source: hosted + version: "3.0.5" + leak_tracker_testing: + dependency: transitive + description: + name: leak_tracker_testing + sha256: "6ba465d5d76e67ddf503e1161d1f4a6bc42306f9d66ca1e8f079a47290fb06d3" + url: "https://pub.dev" + source: hosted + version: "3.0.1" lints: dependency: transitive description: @@ -329,26 +361,26 @@ packages: dependency: transitive description: name: matcher - sha256: "1803e76e6653768d64ed8ff2e1e67bea3ad4b923eb5c56a295c3e634bad5960e" + sha256: d2323aa2060500f906aa31a895b4030b6da3ebdcc5619d14ce1aada65cd161cb url: "https://pub.dev" source: hosted - version: "0.12.16" + version: "0.12.16+1" material_color_utilities: dependency: transitive description: name: material_color_utilities - sha256: "9528f2f296073ff54cb9fee677df673ace1218163c3bc7628093e7eed5203d41" + sha256: f7142bb1154231d7ea5f96bc7bde4bda2a0945d2806bb11670e30b850d56bdec url: "https://pub.dev" source: hosted - version: "0.5.0" + version: "0.11.1" meta: dependency: transitive description: name: meta - sha256: "3c74dbf8763d36539f114c799d8a2d87343b5067e9d796ca22b5eb8437090ee3" + sha256: bdb68674043280c3428e9ec998512fb681678676b3c54e773629ffe74419f8c7 url: "https://pub.dev" source: hosted - version: "1.9.1" + version: "1.15.0" package_config: dependency: transitive description: @@ -361,10 +393,10 @@ packages: dependency: transitive description: name: path - sha256: "8829d8a55c13fc0e37127c29fedf290c102f4e40ae94ada574091fe0ff96c917" + sha256: "087ce49c3f0dc39180befefc60fdb4acd8f8620e5682fe2476afd0b3688bb4af" url: "https://pub.dev" source: hosted - version: "1.8.3" + version: "1.9.0" path_drawing: dependency: transitive description: @@ -453,6 +485,14 @@ packages: url: "https://pub.dev" source: hosted version: "1.2.3" + rational: + dependency: transitive + description: + name: rational + sha256: cb808fb6f1a839e6fc5f7d8cb3b0a10e1db48b3be102de73938c627f0b636336 + url: "https://pub.dev" + source: hosted + version: "2.2.3" rxdart: dependency: "direct main" description: @@ -534,18 +574,18 @@ packages: dependency: transitive description: name: stack_trace - sha256: c3c7d8edb15bee7f0f74debd4b9c5f3c2ea86766fe4178eb2a18eb30a0bdaed5 + sha256: "73713990125a6d93122541237550ee3352a2d84baad52d375a4cad2eb9b7ce0b" url: "https://pub.dev" source: hosted - version: "1.11.0" + version: "1.11.1" stream_channel: dependency: "direct main" description: name: stream_channel - sha256: "83615bee9045c1d322bbbd1ba209b7a749c2cbcdcb3fdd1df8eb488b3279c1c8" + sha256: ba2aa5d8cc609d96bbb2899c28934f9e1af5cddbd60a827822ea467161eb54e7 url: "https://pub.dev" source: hosted - version: "2.1.1" + version: "2.1.2" stream_transform: dependency: transitive description: @@ -590,10 +630,10 @@ packages: dependency: transitive description: name: test_api - sha256: "75760ffd7786fffdfb9597c35c5b27eaeec82be8edfb6d71d32651128ed7aab8" + sha256: "5b8a98dafc4d5c4c9c72d8b31ab2b23fc13422348d2997120294d3bac86b4ddb" url: "https://pub.dev" source: hosted - version: "0.6.0" + version: "0.7.2" typed_data: dependency: transitive description: @@ -682,22 +722,22 @@ packages: url: "https://pub.dev" source: hosted version: "2.1.4" - watcher: + vm_service: dependency: transitive description: - name: watcher - sha256: "3d2ad6751b3c16cf07c7fca317a1413b3f26530319181b37e3b9039b84fc01d8" + name: vm_service + sha256: f652077d0bdf60abe4c1f6377448e8655008eef28f128bc023f7b5e8dfeb48fc url: "https://pub.dev" source: hosted - version: "1.1.0" - web: + version: "14.2.4" + watcher: dependency: transitive description: - name: web - sha256: dc8ccd225a2005c1be616fe02951e2e342092edf968cf0844220383757ef8f10 + name: watcher + sha256: "3d2ad6751b3c16cf07c7fca317a1413b3f26530319181b37e3b9039b84fc01d8" url: "https://pub.dev" source: hosted - version: "0.1.4-beta" + version: "1.1.0" web3dart: dependency: "direct main" description: @@ -755,5 +795,5 @@ packages: source: hosted version: "3.1.2" sdks: - dart: ">=3.1.0 <4.0.0" - flutter: ">=3.13.0" + dart: ">=3.3.0 <4.0.0" + flutter: ">=3.18.0-18.0.pre.54" diff --git a/gai-frontend/pubspec.yaml b/gai-frontend/pubspec.yaml index 2d554748a..1a299ff65 100644 --- a/gai-frontend/pubspec.yaml +++ b/gai-frontend/pubspec.yaml @@ -21,10 +21,10 @@ dependencies: # flutter_lints: ^1.0.0 flutter_svg: 1.0.3 flutter_web3: 2.1.6 - intl: 0.18.1 + intl: 0.19.0 pointycastle: 3.5.0 rxdart: 0.27.7 - shared_preferences: ^2.2.2 + shared_preferences: ^2.0.5 styled_text: 4.0.0 uuid: 3.0.5 url_launcher: 6.1.3 @@ -34,6 +34,8 @@ dependencies: browser_detector: ^2.0.0 badges: 3.1.1 jdenticon_dart: 2.0.0 + decimal: ^3.0.2 + http: ^0.13.4 dev_dependencies: flutter_test: