Skip to content

Commit

Permalink
gai: Fix type errors in provider messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
patniemeyer committed Dec 3, 2024
1 parent 2eb7c8f commit f83bca8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 30 deletions.
2 changes: 1 addition & 1 deletion gai-frontend/lib/chat/models.dart
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class ModelsState extends ChangeNotifier {

print('ModelsState: Created models list: $modelsList');

_modelsByProvider[providerId] = modelsList;
_modelsByProvider[providerId] = modelsList.cast<ModelInfo>();
print('ModelsState: Updated models for provider $providerId: $modelsList');
} catch (e, stack) {
print('ModelsState: Error fetching models: $e\n$stack');
Expand Down
61 changes: 32 additions & 29 deletions gai-frontend/lib/chat/provider_connection.dart
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import 'inference_client.dart';
import 'chat_message.dart';

typedef MessageCallback = void Function(String message);
typedef ChatCallback = void Function(String message, Map<String, dynamic> metadata);
typedef ChatCallback = void Function(
String message, Map<String, dynamic> metadata);
typedef VoidCallback = void Function();
typedef ErrorCallback = void Function(String error);
typedef AuthTokenCallback = void Function(String token, String inferenceUrl);
Expand All @@ -35,6 +36,7 @@ class ProviderConnection {
final maxuint64 = BigInt.two.pow(64) - BigInt.one;
final wei = BigInt.from(10).pow(18);
WebSocketChannel? _providerChannel;

InferenceClient? get inferenceClient => _inferenceClient;
InferenceClient? _inferenceClient;
final MessageCallback onMessage;
Expand All @@ -49,7 +51,7 @@ class ProviderConnection {
final String? authToken;
final AccountDetail? accountDetail;
final AuthTokenCallback? onAuthToken;
final Map<String, String> _requestModels = {};
final Map<String, String> _requestModels = {};
final Map<String, _PendingRequest> _pendingRequests = {};
bool _usingDirectAuth = false;

Expand All @@ -72,7 +74,7 @@ class ProviderConnection {
this.onAuthToken,
}) {
_usingDirectAuth = authToken != null;

if (!_usingDirectAuth) {
try {
_providerChannel = WebSocketChannel.connect(Uri.parse(url));
Expand Down Expand Up @@ -128,7 +130,7 @@ class ProviderConnection {
authToken: authToken,
onAuthToken: onAuthToken,
);

return connection;
}

Expand All @@ -143,13 +145,14 @@ class ProviderConnection {
_inferenceClient = InferenceClient(baseUrl: inferenceUrl);
_inferenceClient!.setAuthToken(token);
onInternalMessage('Auth token received and inference client initialized');

onAuthToken?.call(token, inferenceUrl);
}
}

bool validInvoice(invoice) {
return invoice.containsKey('amount') && invoice.containsKey('commit') &&
invoice.containsKey('recipient');
return invoice.containsKey('amount') &&
invoice.containsKey('commit') &&
invoice.containsKey('recipient');
}

void payInvoice(Map<String, dynamic> invoice) {
Expand All @@ -163,38 +166,38 @@ class ProviderConnection {
onError('Invalid invoice ${invoice}');
return;
}

assert(accountDetail?.funder != null);
final balance = accountDetail?.lotteryPot?.balance.intValue ?? BigInt.zero;
final deposit = accountDetail?.lotteryPot?.deposit.intValue ?? BigInt.zero;

if (balance <= BigInt.zero || deposit <= BigInt.zero) {
onError('Insufficient funds: balance=$balance, deposit=$deposit');
return;
}

final faceval = _bigIntMin(balance, (wei * deposit) ~/ (wei * BigInt.two));
if (faceval <= BigInt.zero) {
onError('Invalid face value: $faceval');
return;
}

final data = BigInt.zero;
final due = BigInt.parse(invoice['amount']);
final due = BigInt.from(invoice['amount']);
final lotaddr = contract;
final token = EthereumAddress.zero;

BigInt ratio;
try {
ratio = maxuint64 & (maxuint64 * due ~/ faceval);
} catch (e) {
onError('Failed to calculate ratio: $e (due=$due, faceval=$faceval)');
return;
}

final commit = BigInt.parse(invoice['commit'] ?? '0x0');
final recipient = invoice['recipient'];

final ticket = OrchidTicket(
data: data,
lotaddr: lotaddr!,
Expand All @@ -207,7 +210,7 @@ class ProviderConnection {
privateKey: accountDetail!.account.signerKey.private,
millisecondsSinceEpoch: DateTime.now().millisecondsSinceEpoch,
);

payment = '{"type": "payment", "tickets": ["${ticket.serializeTicket()}"]}';
onInternalMessage('Client: $payment');
_sendProviderMessage(payment);
Expand All @@ -221,8 +224,9 @@ class ProviderConnection {
switch (data['type']) {
case 'job_complete':
final requestId = data['request_id'];
final pendingRequest = requestId != null ? _pendingRequests.remove(requestId) : null;

final pendingRequest =
requestId != null ? _pendingRequests.remove(requestId) : null;

onChat(data['output'], {
...data,
'model_id': pendingRequest?.modelId,
Expand All @@ -245,7 +249,7 @@ class ProviderConnection {
onError('Cannot request auth token when using direct auth');
return;
}

final message = '{"type": "request_token"}';
onInternalMessage('Requesting auth token');
_sendProviderMessage(message);
Expand All @@ -259,7 +263,7 @@ class ProviderConnection {
if (!_usingDirectAuth && _inferenceClient == null) {
await requestAuthToken();
await Future.delayed(Duration(milliseconds: 100));

if (_inferenceClient == null) {
onError('No inference connection available');
return;
Expand All @@ -268,7 +272,7 @@ class ProviderConnection {

try {
final requestId = _generateRequestId();

_pendingRequests[requestId] = _PendingRequest(
requestId: requestId,
modelId: modelId,
Expand All @@ -281,20 +285,19 @@ class ProviderConnection {
'request_id': requestId,
};

onInternalMessage('Sending inference request:\n'
'Model: $modelId\n'
'Messages: ${messages.map((m) => "${m.source}: ${m.message}").join("\n")}\n'
'Params: $allParams'
);
onInternalMessage('Sending inference request:\n'
'Model: $modelId\n'
'Messages: ${messages.map((m) => "${m.source}: ${m.message}").join("\n")}\n'
'Params: $allParams');

final result = await _inferenceClient!.inference(
messages: messages,
model: modelId,
params: allParams,
);

final pendingRequest = _pendingRequests.remove(requestId);

onChat(result['response'], {
'type': 'job_complete',
'output': result['response'],
Expand Down

0 comments on commit f83bca8

Please sign in to comment.