diff --git a/app/lib/backend/http/api/conversation_chat.dart b/app/lib/backend/http/api/conversation_chat.dart new file mode 100644 index 0000000000..10a90da25c --- /dev/null +++ b/app/lib/backend/http/api/conversation_chat.dart @@ -0,0 +1,175 @@ +import 'dart:convert'; + +import 'package:flutter/material.dart'; +import 'package:omi/backend/http/shared.dart'; +import 'package:omi/backend/schema/message.dart'; +import 'package:omi/env/env.dart'; +import 'package:omi/utils/other/string_utils.dart'; + +// Models for conversation chat +class ConversationChatMessage { + final String id; + final String text; + final DateTime createdAt; + final String sender; // 'human' or 'ai' + final String conversationId; + final List memoriesId; + final List actionItemsId; + final bool reported; + + ConversationChatMessage({ + required this.id, + required this.text, + required this.createdAt, + required this.sender, + required this.conversationId, + this.memoriesId = const [], + this.actionItemsId = const [], + this.reported = false, + }); + + factory ConversationChatMessage.fromJson(Map json) { + return ConversationChatMessage( + id: json['id'], + text: json['text'], + createdAt: DateTime.parse(json['created_at']), + sender: json['sender'], + conversationId: json['conversation_id'], + memoriesId: List.from(json['memories_id'] ?? []), + actionItemsId: List.from(json['action_items_id'] ?? []), + reported: json['reported'] ?? false, + ); + } + + bool get isFromUser => sender == 'human'; + bool get isFromAI => sender == 'ai'; +} + +class ConversationChatResponse { + final ConversationChatMessage message; + final bool askForNps; + + ConversationChatResponse({ + required this.message, + required this.askForNps, + }); + + factory ConversationChatResponse.fromJson(Map json) { + return ConversationChatResponse( + message: ConversationChatMessage.fromJson(json), + askForNps: json['ask_for_nps'] ?? false, + ); + } +} + +// API Functions +Future> getConversationMessages(String conversationId) async { + var response = await makeApiCall( + url: '${Env.apiBaseUrl}v2/conversations/$conversationId/chat/messages', + headers: {}, + method: 'GET', + body: '', + ); + + if (response == null) return []; + if (response.statusCode == 200) { + var body = utf8.decode(response.bodyBytes); + var decodedBody = jsonDecode(body) as List; + if (decodedBody.isEmpty) { + return []; + } + var messages = decodedBody.map((messageJson) => ConversationChatMessage.fromJson(messageJson)).toList(); + debugPrint('getConversationMessages length: ${messages.length}'); + return messages; + } + debugPrint('getConversationMessages error ${response.statusCode}'); + return []; +} + +Future clearConversationChat(String conversationId) async { + var response = await makeApiCall( + url: '${Env.apiBaseUrl}v2/conversations/$conversationId/chat/messages', + headers: {}, + method: 'DELETE', + body: '', + ); + + if (response == null) { + return false; + } + + return response.statusCode == 200; +} + +// Parse conversation chat streaming chunks (similar to main chat) +ServerMessageChunk? parseConversationChatChunk(String line, String messageId) { + if (line.startsWith('think: ')) { + return ServerMessageChunk(messageId, line.substring(7).replaceAll("__CRLF__", "\n"), MessageChunkType.think); + } + + if (line.startsWith('data: ')) { + return ServerMessageChunk(messageId, line.substring(6).replaceAll("__CRLF__", "\n"), MessageChunkType.data); + } + + if (line.startsWith('done: ')) { + var text = decodeBase64(line.substring(6)); + var responseJson = json.decode(text); + return ServerMessageChunk( + messageId, + text, + MessageChunkType.done, + message: ServerMessage( + responseJson['id'], + DateTime.parse(responseJson['created_at']).toLocal(), + responseJson['text'], + MessageSender.values.firstWhere((e) => e.toString().split('.').last == responseJson['sender']), + MessageType.text, + null, // appId + false, // fromIntegration + [], // files + [], // filesId + [], // memories + askForNps: responseJson['ask_for_nps'] ?? false, + ), + ); + } + + return null; +} + +Stream sendConversationMessageStream(String conversationId, String text) async* { + var url = '${Env.apiBaseUrl}v2/conversations/$conversationId/chat/messages'; + var messageId = "conv_chat_${DateTime.now().millisecondsSinceEpoch}"; + + await for (var line in makeStreamingApiCall( + url: url, + body: jsonEncode({ + 'text': text, + 'conversation_id': conversationId, + }), + )) { + var messageChunk = parseConversationChatChunk(line, messageId); + if (messageChunk != null) { + yield messageChunk; + } else { + yield ServerMessageChunk.failedMessage(); + return; + } + } +} + +Future?> getConversationContext(String conversationId) async { + var response = await makeApiCall( + url: '${Env.apiBaseUrl}v2/conversations/$conversationId/chat/context', + headers: {}, + method: 'GET', + body: '', + ); + + if (response == null) return null; + if (response.statusCode == 200) { + return jsonDecode(utf8.decode(response.bodyBytes)); + } + debugPrint('getConversationContext error ${response.statusCode}'); + return null; +} diff --git a/app/lib/pages/conversation_detail/conversation_detail_provider.dart b/app/lib/pages/conversation_detail/conversation_detail_provider.dart index 12cc24b47e..3dace93b4c 100644 --- a/app/lib/pages/conversation_detail/conversation_detail_provider.dart +++ b/app/lib/pages/conversation_detail/conversation_detail_provider.dart @@ -32,6 +32,19 @@ class ConversationDetailProvider extends ChangeNotifier with MessageNotifierMixi final scaffoldKey = GlobalKey(); List get appsList => appProvider?.apps ?? []; + // Callback for clearing chat messages in UI + VoidCallback? _clearChatMessagesCallback; + + void registerClearChatCallback(VoidCallback callback) { + _clearChatMessagesCallback = callback; + } + + void clearChatMessages() { + if (_clearChatMessagesCallback != null) { + _clearChatMessagesCallback!(); + } + } + Structured get structured { return conversation.structured; } diff --git a/app/lib/pages/conversation_detail/page.dart b/app/lib/pages/conversation_detail/page.dart index 5c0e21dcc5..df75750f3e 100644 --- a/app/lib/pages/conversation_detail/page.dart +++ b/app/lib/pages/conversation_detail/page.dart @@ -28,6 +28,7 @@ import 'package:pull_down_button/pull_down_button.dart'; import 'conversation_detail_provider.dart'; import 'widgets/name_speaker_sheet.dart'; +import 'widgets/chat_tab.dart'; import 'share.dart'; import 'test_prompts.dart'; import 'package:omi/pages/settings/developer.dart'; @@ -128,8 +129,11 @@ class _ConversationDetailPageState extends State with Ti void initState() { super.initState(); - _controller = TabController(length: 3, vsync: this, initialIndex: 1); // Start with summary tab + _controller = TabController(length: 4, vsync: this, initialIndex: 1); // Start with summary tab _controller!.addListener(() { + // Dismiss keyboard when switching tabs for clean UX + FocusScope.of(context).unfocus(); + setState(() { switch (_controller!.index) { case 0: @@ -141,6 +145,9 @@ class _ConversationDetailPageState extends State with Ti case 2: selectedTab = ConversationTab.actionItems; break; + case 3: + selectedTab = ConversationTab.chat; + break; default: debugPrint('Invalid tab index: ${_controller!.index}'); selectedTab = ConversationTab.summary; @@ -205,6 +212,8 @@ class _ConversationDetailPageState extends State with Ti return 'Conversation'; case ConversationTab.actionItems: return 'Action Items'; + case ConversationTab.chat: + return 'Chat'; } } @@ -300,6 +309,7 @@ class _ConversationDetailPageState extends State with Ti child: Scaffold( key: scaffoldKey, extendBody: true, + resizeToAvoidBottomInset: selectedTab != ConversationTab.chat, // Don't resize on chat tab backgroundColor: Theme.of(context).colorScheme.primary, appBar: AppBar( automaticallyImplyLeading: false, @@ -578,14 +588,15 @@ class _ConversationDetailPageState extends State with Ti child: Column( children: [ Expanded( - child: Padding( - padding: const EdgeInsets.symmetric(horizontal: 16), - child: Builder(builder: (context) { - return TabBarView( - controller: _controller, - physics: const NeverScrollableScrollPhysics(), - children: [ - TranscriptWidgets( + child: Builder(builder: (context) { + return TabBarView( + controller: _controller, + physics: const NeverScrollableScrollPhysics(), + children: [ + // Other tabs with padding + Padding( + padding: const EdgeInsets.symmetric(horizontal: 16), + child: TranscriptWidgets( searchQuery: _searchQuery, currentResultIndex: getCurrentResultIndexForHighlighting(), onTapWhenSearchEmpty: () { @@ -598,7 +609,10 @@ class _ConversationDetailPageState extends State with Ti } }, ), - SummaryTab( + ), + Padding( + padding: const EdgeInsets.symmetric(horizontal: 16), + child: SummaryTab( searchQuery: _searchQuery, currentResultIndex: getCurrentResultIndexForHighlighting(), onTapWhenSearchEmpty: () { @@ -611,11 +625,16 @@ class _ConversationDetailPageState extends State with Ti } }, ), - ActionItemsTab(), - ], - ); - }), - ), + ), + Padding( + padding: const EdgeInsets.symmetric(horizontal: 16), + child: ActionItemsTab(), + ), + // Chat tab with NO padding - 100% width + ChatTab(), + ], + ); + }), ), ], ), @@ -647,6 +666,9 @@ class _ConversationDetailPageState extends State with Ti case ConversationTab.actionItems: index = 2; break; + case ConversationTab.chat: + index = 3; + break; } _controller!.animateTo(index); }, diff --git a/app/lib/pages/conversation_detail/widgets/chat_actions_sheet.dart b/app/lib/pages/conversation_detail/widgets/chat_actions_sheet.dart new file mode 100644 index 0000000000..713f3f7d4d --- /dev/null +++ b/app/lib/pages/conversation_detail/widgets/chat_actions_sheet.dart @@ -0,0 +1,236 @@ +import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; +import 'package:font_awesome_flutter/font_awesome_flutter.dart'; +import 'package:omi/backend/http/api/conversation_chat.dart'; +import 'package:omi/pages/conversation_detail/conversation_detail_provider.dart'; +import 'package:omi/widgets/dialog.dart'; +import 'package:provider/provider.dart'; + +class ChatActionsBottomSheet extends StatelessWidget { + const ChatActionsBottomSheet({super.key}); + + @override + Widget build(BuildContext context) { + return DraggableScrollableSheet( + initialChildSize: 0.3, + minChildSize: 0.2, + maxChildSize: 0.5, + expand: false, + builder: (context, scrollController) { + return Consumer( + builder: (context, provider, _) { + return _SheetContainer( + scrollController: scrollController, + children: [ + const _SheetHeader(), + _ActionsList(provider: provider), + ], + ); + }, + ); + }, + ); + } +} + +class _SheetContainer extends StatelessWidget { + final ScrollController scrollController; + final List children; + + const _SheetContainer({ + required this.scrollController, + required this.children, + }); + + @override + Widget build(BuildContext context) { + return Container( + decoration: const BoxDecoration( + color: Color(0xFF1F1F25), + borderRadius: BorderRadius.only( + topLeft: Radius.circular(24), + topRight: Radius.circular(24), + ), + ), + child: Column( + children: children, + ), + ); + } +} + +class _SheetHeader extends StatelessWidget { + const _SheetHeader(); + + @override + Widget build(BuildContext context) { + return Container( + padding: const EdgeInsets.only(top: 16, bottom: 8), + child: Column( + children: [ + Container( + width: 40, + height: 4, + decoration: BoxDecoration( + color: Colors.grey[600], + borderRadius: BorderRadius.circular(2), + ), + ), + const SizedBox(height: 16), + Text( + 'Chat Actions', + style: Theme.of(context).textTheme.titleLarge!.copyWith( + color: Colors.white, + fontWeight: FontWeight.w600, + ), + ), + ], + ), + ); + } +} + +class _ActionsList extends StatelessWidget { + final ConversationDetailProvider provider; + + const _ActionsList({required this.provider}); + + void _showClearChatDialog(BuildContext context) { + Navigator.pop(context); // Close bottom sheet first + + HapticFeedback.lightImpact(); + + showDialog( + context: context, + builder: (dialogContext) { + return getDialog( + dialogContext, + () { + Navigator.of(dialogContext).pop(); // Cancel - use dialog context + }, + () { + // Close dialog first + Navigator.of(dialogContext).pop(); + + // Clear chat with haptic feedback (no context needed) + HapticFeedback.mediumImpact(); + + // Call API to clear chat + clearConversationChat(provider.conversation.id).then((success) { + if (success) { + // Clear UI messages immediately after API success + provider.clearChatMessages(); + } + }); + }, + "Clear Chat?", + "Are you sure you want to clear this conversation chat? This action cannot be undone.", + ); + }, + ); + } + + @override + Widget build(BuildContext context) { + return Padding( + padding: const EdgeInsets.symmetric(horizontal: 8, vertical: 16), + child: Column( + children: [ + // Clear Chat Action + _ActionItem( + icon: FontAwesomeIcons.trashCan, + title: 'Clear Chat', + subtitle: 'Remove all messages from this conversation chat', + iconColor: Colors.red[400]!, + onTap: () => _showClearChatDialog(context), + ), + + // Future actions can be added here + // _ActionItem( + // icon: FontAwesomeIcons.download, + // title: 'Export Chat', + // subtitle: 'Download chat history as text file', + // iconColor: Colors.blue[400]!, + // onTap: () => _exportChat(context), + // ), + ], + ), + ); + } +} + +class _ActionItem extends StatelessWidget { + final IconData icon; + final String title; + final String subtitle; + final Color iconColor; + final VoidCallback onTap; + + const _ActionItem({ + required this.icon, + required this.title, + required this.subtitle, + required this.iconColor, + required this.onTap, + }); + + @override + Widget build(BuildContext context) { + return GestureDetector( + onTap: onTap, + child: Container( + margin: const EdgeInsets.symmetric(vertical: 4), + padding: const EdgeInsets.symmetric(horizontal: 16, vertical: 12), + decoration: BoxDecoration( + color: Colors.grey[900], + borderRadius: BorderRadius.circular(12), + ), + child: Row( + children: [ + Container( + padding: const EdgeInsets.all(8), + decoration: BoxDecoration( + color: iconColor.withValues(alpha: 0.2), + borderRadius: BorderRadius.circular(8), + ), + child: Icon( + icon, + color: iconColor, + size: 16, + ), + ), + const SizedBox(width: 12), + Expanded( + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + title, + style: const TextStyle( + color: Colors.white, + fontSize: 16, + fontWeight: FontWeight.w500, + ), + ), + const SizedBox(height: 2), + Text( + subtitle, + style: TextStyle( + color: Colors.grey[400], + fontSize: 13, + ), + ), + ], + ), + ), + Icon( + FontAwesomeIcons.chevronRight, + color: Colors.grey[500], + size: 12, + ), + ], + ), + ), + ); + } +} diff --git a/app/lib/pages/conversation_detail/widgets/chat_input_area.dart b/app/lib/pages/conversation_detail/widgets/chat_input_area.dart new file mode 100644 index 0000000000..434cc83ac0 --- /dev/null +++ b/app/lib/pages/conversation_detail/widgets/chat_input_area.dart @@ -0,0 +1,147 @@ +import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; +import 'package:font_awesome_flutter/font_awesome_flutter.dart'; + +class ConversationChatInputArea extends StatelessWidget { + final TextEditingController textController; + final FocusNode textFieldFocusNode; + final bool isSending; + final Function(String) onSendMessage; + final VoidCallback? onVoicePressed; + final bool hideButtons; + + const ConversationChatInputArea({ + super.key, + required this.textController, + required this.textFieldFocusNode, + required this.isSending, + required this.onSendMessage, + this.onVoicePressed, + this.hideButtons = false, + }); + + @override + Widget build(BuildContext context) { + // ONE CLEAN STATE - 100% horizontal width + return Container( + width: double.infinity, // Force 100% width + margin: const EdgeInsets.only(top: 10), // Keep small top margin for spacing + decoration: const BoxDecoration( + color: Color(0xFF1f1f25), + borderRadius: BorderRadius.only( + topLeft: Radius.circular(22), + topRight: Radius.circular(22), + ), + ), + child: Column( + children: [ + Container( + padding: const EdgeInsets.only(left: 8, right: 8, top: 20, bottom: 20), // Reduced horizontal padding + child: Row( + crossAxisAlignment: CrossAxisAlignment.end, + children: [ + // Main input container - ALWAYS the same + Expanded( + child: Container( + padding: const EdgeInsets.only(left: 8, right: 4), // Reduced inner padding for more space + child: Row( + crossAxisAlignment: CrossAxisAlignment.end, + children: [ + Expanded( + child: Container( + alignment: Alignment.centerLeft, + child: TextField( + enabled: true, + controller: textController, + focusNode: textFieldFocusNode, + obscureText: false, + textAlign: TextAlign.start, + textAlignVertical: TextAlignVertical.center, + decoration: const InputDecoration( + hintText: 'Ask about this conversation...', + hintStyle: TextStyle(fontSize: 16.0, color: Colors.white54), + focusedBorder: InputBorder.none, + enabledBorder: InputBorder.none, + contentPadding: EdgeInsets.symmetric(horizontal: 8, vertical: 12), + isDense: true, + ), + minLines: 1, + maxLines: 10, + keyboardType: TextInputType.multiline, + textCapitalization: TextCapitalization.sentences, + style: const TextStyle(fontSize: 16.0, color: Colors.white, height: 1.4), + ), + ), + ), + + // Microphone button - hidden when voice recorder is active + if (!hideButtons) + GestureDetector( + child: Container( + height: 44, + width: 44, + alignment: Alignment.center, + child: const Icon( + FontAwesomeIcons.microphone, + color: Colors.white, + size: 20, + ), + ), + onTap: () { + FocusScope.of(context).unfocus(); + if (onVoicePressed != null) { + onVoicePressed!(); + } + }, + ), + ], + ), + ), + ), + + // Send button - hidden when voice recorder is active + if (!hideButtons) ...[ + const SizedBox(width: 4), // Minimal gap to maximize space usage + + GestureDetector( + onTap: () { + HapticFeedback.mediumImpact(); + String message = textController.text.trim(); + if (message.isEmpty) return; + onSendMessage(message); + }, + child: Container( + height: 44, + width: 44, + decoration: BoxDecoration( + color: Colors.white, + borderRadius: BorderRadius.circular(22), + boxShadow: [ + BoxShadow( + color: Colors.black.withValues(alpha: 0.1), + blurRadius: 8, + offset: const Offset(0, 2), + ), + ], + ), + child: Icon( + FontAwesomeIcons.arrowUp, + color: isSending ? Colors.grey[400] : Colors.black, + size: 20, + ), + ), + ), + ], + ], + ), + ), + + // Smart padding - moves up with keyboard but navbar stays fixed + SizedBox( + height: MediaQuery.of(context).padding.bottom + 64 + MediaQuery.of(context).viewInsets.bottom, + ), + ], + ), + ); + } +} diff --git a/app/lib/pages/conversation_detail/widgets/chat_tab.dart b/app/lib/pages/conversation_detail/widgets/chat_tab.dart new file mode 100644 index 0000000000..50f37275f7 --- /dev/null +++ b/app/lib/pages/conversation_detail/widgets/chat_tab.dart @@ -0,0 +1,475 @@ +import 'dart:async'; + +import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; +import 'package:font_awesome_flutter/font_awesome_flutter.dart'; +import 'package:omi/backend/http/api/conversation_chat.dart'; +import 'package:omi/backend/schema/message.dart'; +import 'package:omi/pages/chat/widgets/typing_indicator.dart'; +import 'package:omi/pages/chat/widgets/voice_recorder_widget.dart'; +import 'package:omi/pages/conversation_detail/conversation_detail_provider.dart'; +import 'package:omi/pages/conversation_detail/widgets/chat_input_area.dart'; +import 'package:omi/utils/platform/platform_service.dart'; +import 'package:provider/provider.dart'; + +// Import desktop voice recorder for desktop platforms +import 'package:omi/desktop/pages/chat/widgets/desktop_voice_recorder_widget.dart' + if (dart.library.html) 'package:omi/pages/chat/widgets/voice_recorder_widget.dart'; + +class ChatTab extends StatefulWidget { + const ChatTab({super.key}); + + @override + State createState() => _ChatTabState(); +} + +class _ChatTabState extends State { + final TextEditingController _messageController = TextEditingController(); + final FocusNode _messageFocusNode = FocusNode(); + final ScrollController _scrollController = ScrollController(); + + List _messages = []; + bool _isLoading = true; + bool _isSending = false; + bool _showVoiceRecorder = false; + bool _showTypingIndicator = false; + + // For streaming messages + ConversationChatMessage? _streamingMessage; + String _streamingTextBuffer = ''; + Timer? _streamingTimer; + + @override + void initState() { + super.initState(); + WidgetsBinding.instance.addPostFrameCallback((_) { + // Register clear messages callback with provider + final provider = Provider.of(context, listen: false); + provider.registerClearChatCallback(clearMessages); + + _loadMessages(); + }); + } + + @override + void dispose() { + _messageController.dispose(); + _messageFocusNode.dispose(); + _scrollController.dispose(); + _streamingTimer?.cancel(); + super.dispose(); + } + + void _loadMessages() async { + final provider = Provider.of(context, listen: false); + try { + final messages = await getConversationMessages(provider.conversation.id); + if (mounted) { + setState(() { + _messages = messages.reversed.toList(); // Show latest at bottom + _isLoading = false; + }); + } + } catch (e) { + debugPrint('Error loading conversation messages: $e'); + if (mounted) { + setState(() { + _isLoading = false; + }); + } + } + } + + // Public method to clear messages (can be called from outside) + void clearMessages() { + if (mounted) { + setState(() { + _messages.clear(); + }); + } + } + + // Voice recorder callbacks + void _onTranscriptReady(String transcript) { + if (mounted) { + setState(() { + _messageController.text = transcript; + _showVoiceRecorder = false; + }); + // Focus text field after transcript is ready + _messageFocusNode.requestFocus(); + } + } + + void _onVoiceRecorderClose() { + if (mounted) { + setState(() { + _showVoiceRecorder = false; + }); + } + } + + void _startVoiceRecording() { + // Hide keyboard when voice recording starts + FocusScope.of(context).unfocus(); + setState(() { + _showVoiceRecorder = true; + }); + } + + void _sendMessage(ConversationDetailProvider provider, String text) async { + if (text.trim().isEmpty || _isSending) return; + + setState(() { + _isSending = true; + }); + + // Add user message immediately to UI + final userMessage = ConversationChatMessage( + id: 'temp_${DateTime.now().millisecondsSinceEpoch}', + text: text.trim(), + createdAt: DateTime.now(), + sender: 'human', + conversationId: provider.conversation.id, + ); + + setState(() { + _messages.add(userMessage); + }); + + _messageController.clear(); + _messageFocusNode.unfocus(); + _scrollToBottom(); + + // Create an empty AI message for streaming + final streamingMessage = ConversationChatMessage( + id: 'streaming_${DateTime.now().millisecondsSinceEpoch}', + text: '', + createdAt: DateTime.now(), + sender: 'ai', + conversationId: provider.conversation.id, + ); + + setState(() { + _streamingMessage = streamingMessage; + _messages.add(streamingMessage); + _showTypingIndicator = true; + _streamingTextBuffer = ''; + }); + + _scrollToBottom(); + + // Flush buffer periodically for smooth streaming + void flushBuffer() { + if (_streamingTextBuffer.isNotEmpty && mounted) { + setState(() { + final index = _messages.indexWhere((m) => m.id == _streamingMessage?.id); + if (index != -1) { + _messages[index] = ConversationChatMessage( + id: _streamingMessage!.id, + text: _streamingMessage!.text + _streamingTextBuffer, + createdAt: _streamingMessage!.createdAt, + sender: 'ai', + conversationId: provider.conversation.id, + ); + _streamingMessage = _messages[index]; + } + _streamingTextBuffer = ''; + }); + HapticFeedback.lightImpact(); + _scrollToBottom(); + } + } + + try { + // Stream the AI response + await for (var chunk in sendConversationMessageStream(provider.conversation.id, text.trim())) { + if (chunk.type == MessageChunkType.data) { + // Add to buffer for batched updates + _streamingTextBuffer += chunk.text; + + // Start timer for periodic flush if not already running + _streamingTimer ??= Timer.periodic(const Duration(milliseconds: 100), (_) { + flushBuffer(); + }); + } else if (chunk.type == MessageChunkType.done && chunk.message != null) { + // Cancel timer and flush any remaining buffer + _streamingTimer?.cancel(); + _streamingTimer = null; + flushBuffer(); + + // Replace streaming message with final message + final aiMessage = ConversationChatMessage( + id: chunk.message!.id, + text: chunk.message!.text, + createdAt: chunk.message!.createdAt, + sender: 'ai', + conversationId: provider.conversation.id, + ); + + setState(() { + final index = _messages.indexWhere((m) => m.id == _streamingMessage?.id); + if (index != -1) { + _messages[index] = aiMessage; + } + _streamingMessage = null; + _showTypingIndicator = false; + _isSending = false; + }); + _scrollToBottom(); + break; + } + } + } catch (e) { + debugPrint('Error sending message: $e'); + _streamingTimer?.cancel(); + _streamingTimer = null; + + setState(() { + // Remove the streaming message on error + if (_streamingMessage != null) { + _messages.removeWhere((m) => m.id == _streamingMessage?.id); + } + _streamingMessage = null; + _showTypingIndicator = false; + _isSending = false; + }); + } + } + + void _scrollToBottom() { + WidgetsBinding.instance.addPostFrameCallback((_) { + if (_scrollController.hasClients) { + _scrollController.animateTo( + _scrollController.position.maxScrollExtent, + duration: const Duration(milliseconds: 300), + curve: Curves.easeOut, + ); + } + }); + } + + @override + Widget build(BuildContext context) { + return Consumer( + builder: (context, provider, child) { + return Column( + children: [ + // Chat messages area + Expanded( + child: GestureDetector( + // Dismiss keyboard when tapping anywhere in messages area + onTap: () { + FocusScope.of(context).unfocus(); + }, + child: _buildMessagesArea(provider), + ), + ), + + // Input area at bottom with voice recorder overlay + Stack( + children: [ + // Input area (always visible) + ConversationChatInputArea( + textController: _messageController, + textFieldFocusNode: _messageFocusNode, + isSending: _isSending, + onSendMessage: (text) => _sendMessage(provider, text), + onVoicePressed: _startVoiceRecording, + hideButtons: _showVoiceRecorder, // Hide buttons when voice recorder is showing + ), + + // Voice recorder overlay (only when recording) - positioned exactly over the Row + if (_showVoiceRecorder) + Positioned( + top: 30, // Top padding (10 margin + 20 padding) + left: 16, // Left padding (8 + 8) + right: 16, // Right padding (8 + 8) + height: 44, // Height of the buttons row + child: _buildVoiceRecorderOverlay(), + ), + ], + ), + ], + ); + }, + ); + } + + Widget _buildMessagesArea(ConversationDetailProvider provider) { + if (_isLoading) { + return const SizedBox.expand( + child: Center( + child: CircularProgressIndicator(color: Colors.white54), + ), + ); + } + + if (_messages.isEmpty) { + return _buildEmptyState(); + } + + return CustomScrollView( + controller: _scrollController, + slivers: [ + // Top padding for breathing room + const SliverToBoxAdapter( + child: SizedBox(height: 16), + ), + + // Chat messages + SliverList( + delegate: SliverChildBuilderDelegate( + (context, index) { + final message = _messages[index]; + return Padding( + padding: const EdgeInsets.symmetric(horizontal: 4, vertical: 4), + child: message.isFromUser ? _buildUserMessage(message) : _buildAIMessage(message), + ); + }, + childCount: _messages.length, + ), + ), + ], + ); + } + + Widget _buildVoiceRecorderOverlay() { + // Voice recorder overlay - positioned exactly over text field + buttons Row + return Container( + decoration: BoxDecoration( + color: Colors.black, + borderRadius: BorderRadius.circular(12), + ), + child: PlatformService.isDesktop + ? DesktopVoiceRecorderWidget( + onTranscriptReady: _onTranscriptReady, + onClose: _onVoiceRecorderClose, + ) + : VoiceRecorderWidget( + onTranscriptReady: _onTranscriptReady, + onClose: _onVoiceRecorderClose, + ), + ); + } + + Widget _buildEmptyState() { + return SizedBox.expand( + child: Center( + child: Padding( + padding: const EdgeInsets.all(32), + child: Column( + mainAxisAlignment: MainAxisAlignment.center, + children: [ + Container( + padding: const EdgeInsets.all(20), + decoration: BoxDecoration( + color: Colors.grey[900], + borderRadius: BorderRadius.circular(20), + ), + child: const Icon( + FontAwesomeIcons.solidComment, + color: Colors.white54, + size: 32, + ), + ), + const SizedBox(height: 16), + Text( + 'Start a conversation', + style: Theme.of(context).textTheme.titleMedium!.copyWith( + color: Colors.white, + fontWeight: FontWeight.w600, + ), + ), + const SizedBox(height: 8), + Text( + 'Ask questions about this conversation', + textAlign: TextAlign.center, + style: Theme.of(context).textTheme.bodyMedium!.copyWith( + color: Colors.grey.shade400, + height: 1.4, + ), + ), + ], + ), + ), + ), + ); + } + + Widget _buildUserMessage(ConversationChatMessage message) { + return Padding( + padding: const EdgeInsets.only(left: 24), + child: Column( + crossAxisAlignment: CrossAxisAlignment.end, + children: [ + Container( + decoration: const BoxDecoration( + color: Color(0xFF1f1f25), + borderRadius: BorderRadius.only( + topLeft: Radius.circular(16.0), + topRight: Radius.circular(16.0), + bottomRight: Radius.circular(4.0), + bottomLeft: Radius.circular(16.0), + ), + ), + padding: const EdgeInsets.symmetric(horizontal: 16, vertical: 14), + child: Text( + message.text, + style: const TextStyle(color: Colors.white, fontSize: 16), + ), + ), + ], + ), + ); + } + + Widget _buildAIMessage(ConversationChatMessage message) { + final isStreaming = _streamingMessage?.id == message.id; + final showTypingIndicator = isStreaming && _showTypingIndicator && message.text.isEmpty; + + return Padding( + padding: const EdgeInsets.only(right: 24), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + showTypingIndicator + ? Container( + decoration: BoxDecoration( + color: Colors.grey[900], + borderRadius: const BorderRadius.only( + topLeft: Radius.circular(4.0), + topRight: Radius.circular(16.0), + bottomRight: Radius.circular(16.0), + bottomLeft: Radius.circular(16.0), + ), + ), + padding: const EdgeInsets.symmetric(horizontal: 16, vertical: 14), + child: const Row( + mainAxisSize: MainAxisSize.min, + children: [ + TypingIndicator(), + ], + ), + ) + : Container( + decoration: BoxDecoration( + color: Colors.grey[900], + borderRadius: const BorderRadius.only( + topLeft: Radius.circular(4.0), + topRight: Radius.circular(16.0), + bottomRight: Radius.circular(16.0), + bottomLeft: Radius.circular(16.0), + ), + ), + padding: const EdgeInsets.symmetric(horizontal: 16, vertical: 14), + child: Text( + message.text, + style: const TextStyle(color: Colors.white, fontSize: 16, height: 1.4), + ), + ), + ], + ), + ); + } +} diff --git a/app/lib/widgets/conversation_bottom_bar.dart b/app/lib/widgets/conversation_bottom_bar.dart index 32f4c173a0..4ac8829081 100644 --- a/app/lib/widgets/conversation_bottom_bar.dart +++ b/app/lib/widgets/conversation_bottom_bar.dart @@ -6,6 +6,7 @@ import 'package:omi/backend/schema/app.dart'; import 'package:omi/gen/assets.gen.dart'; import 'package:omi/pages/conversation_detail/conversation_detail_provider.dart'; import 'package:omi/pages/conversation_detail/widgets/summarized_apps_sheet.dart'; +import 'package:omi/pages/conversation_detail/widgets/chat_actions_sheet.dart'; import 'package:omi/widgets/conversation_bottom_bar/tab_button.dart'; import 'package:provider/provider.dart'; @@ -14,7 +15,7 @@ enum ConversationBottomBarMode { detail // For viewing completed conversations } -enum ConversationTab { transcript, summary, actionItems } +enum ConversationTab { transcript, summary, actionItems, chat } class ConversationBottomBar extends StatelessWidget { final ConversationBottomBarMode mode; @@ -81,8 +82,9 @@ class ConversationBottomBar extends StatelessWidget { _buildSummaryTab(context), const SizedBox(width: 4), _buildActionItemsTab(), + const SizedBox(width: 4), + _buildChatTab(context), ], - _ => [_buildSummaryTab(context)], }, ], ), @@ -191,4 +193,28 @@ class ConversationBottomBar extends StatelessWidget { onTap: () => onTabSelected(ConversationTab.actionItems), ); } + + Widget _buildChatTab(BuildContext context) { + void handleTap() { + if (selectedTab == ConversationTab.chat) { + // Show chat actions bottom sheet when already on chat tab + showModalBottomSheet( + context: context, + isScrollControlled: true, + backgroundColor: Colors.transparent, + builder: (context) => const ChatActionsBottomSheet(), + ); + } else { + onTabSelected(ConversationTab.chat); + } + } + + return TabButton( + icon: FontAwesomeIcons.solidComment, + isSelected: selectedTab == ConversationTab.chat, + onTap: handleTap, + showDropdownArrow: true, // Add dropdown arrow like summary + onDropdownPressed: handleTap, + ); + } } diff --git a/backend/database/chat_convo.py b/backend/database/chat_convo.py new file mode 100644 index 0000000000..67de657b1e --- /dev/null +++ b/backend/database/chat_convo.py @@ -0,0 +1,301 @@ +import copy +import uuid +from datetime import datetime, timezone +from typing import Optional, List, Dict, Any + +from google.cloud import firestore +from google.cloud.firestore_v1 import FieldFilter + +from models.chat_convo import ConversationChatMessage +from utils import encryption +from ._client import db +from .helpers import set_data_protection_level, prepare_for_write, prepare_for_read + + +# ********************************* +# ******* ENCRYPTION HELPERS ****** +# ********************************* + + +def _encrypt_conversation_chat_data(chat_data: Dict[str, Any], uid: str) -> Dict[str, Any]: + """Encrypt conversation chat data for storage""" + data = copy.deepcopy(chat_data) + + if 'text' in data and isinstance(data['text'], str): + data['text'] = encryption.encrypt(data['text'], uid) + return data + + +def _decrypt_conversation_chat_data(chat_data: Dict[str, Any], uid: str) -> Dict[str, Any]: + """Decrypt conversation chat data for reading""" + data = copy.deepcopy(chat_data) + + if 'text' in data and isinstance(data['text'], str): + try: + data['text'] = encryption.decrypt(data['text'], uid) + except Exception: + pass + + return data + + +def _prepare_data_for_write(data: Dict[str, Any], uid: str, level: str) -> Dict[str, Any]: + """Prepare conversation chat data for writing with encryption if needed""" + if level == 'enhanced': + return _encrypt_conversation_chat_data(data, uid) + return data + + +def _prepare_message_for_read(message_data: Optional[Dict[str, Any]], uid: str) -> Optional[Dict[str, Any]]: + """Prepare conversation chat message for reading with decryption if needed""" + if not message_data: + return None + + level = message_data.get('data_protection_level') + if level == 'enhanced': + return _decrypt_conversation_chat_data(message_data, uid) + + return message_data + + +# ***************************** +# ********** CRUD ************* +# ***************************** + + +@set_data_protection_level(data_arg_name='message_data') +@prepare_for_write(data_arg_name='message_data', prepare_func=_prepare_data_for_write) +def add_conversation_message(uid: str, message_data: dict): + """Add a message to a conversation chat""" + # Remove any computed fields that shouldn't be stored + if 'conversation' in message_data: + del message_data['conversation'] + + user_ref = db.collection('users').document(uid) + user_ref.collection('conversation_chats').add(message_data) + return message_data + + +@prepare_for_read(decrypt_func=_prepare_message_for_read) +def get_conversation_messages( + uid: str, + conversation_id: str, + limit: int = 100, + offset: int = 0, + include_references: bool = False, +) -> List[dict]: + """Get all messages for a specific conversation chat""" + print('get_conversation_messages', uid, conversation_id, limit, offset, include_references) + + user_ref = db.collection('users').document(uid) + messages_ref = ( + user_ref.collection('conversation_chats') + .where(filter=FieldFilter('conversation_id', '==', conversation_id)) + .order_by('created_at', direction=firestore.Query.DESCENDING) + .limit(limit) + .offset(offset) + ) + + messages = [] + memories_id = set() + action_items_id = set() + + # Fetch messages and collect reference IDs + for doc in messages_ref.stream(): + message = doc.to_dict() + if message.get('reported') is True: + continue + messages.append(message) + memories_id.update(message.get('memories_id', [])) + action_items_id.update(message.get('action_items_id', [])) + + if not include_references: + return messages + + # Fetch referenced memories and action items + if memories_id: + memories = get_conversation_memories(uid, conversation_id, list(memories_id)) + memories_dict = {memory['id']: memory for memory in memories} + + for message in messages: + message['memories'] = [ + memories_dict[memory_id] for memory_id in message.get('memories_id', []) if memory_id in memories_dict + ] + + if action_items_id: + action_items = get_conversation_action_items(uid, conversation_id, list(action_items_id)) + action_items_dict = {item['id']: item for item in action_items} + + for message in messages: + message['action_items'] = [ + action_items_dict[item_id] + for item_id in message.get('action_items_id', []) + if item_id in action_items_dict + ] + + return messages + + +def get_conversation_message(uid: str, message_id: str) -> tuple[ConversationChatMessage, str] | None: + """Get a specific conversation chat message by ID""" + user_ref = db.collection('users').document(uid) + message_ref = user_ref.collection('conversation_chats').where('id', '==', message_id).limit(1).stream() + message_doc = next(message_ref, None) + if not message_doc: + return None + + message_data = message_doc.to_dict() + if not message_data: + return None + + decrypted_data = _prepare_message_for_read(message_data, uid) + message = ConversationChatMessage(**decrypted_data) + + return message, message_doc.id + + +def report_conversation_message(uid: str, msg_doc_id: str): + """Report a conversation chat message""" + user_ref = db.collection('users').document(uid) + message_ref = user_ref.collection('conversation_chats').document(msg_doc_id) + try: + message_ref.update({'reported': True}) + return {"message": "Message reported"} + except Exception as e: + print("Update failed:", e) + return {"message": f"Update failed: {e}"} + + +def clear_conversation_chat(uid: str, conversation_id: str): + """Clear all messages in a conversation chat""" + try: + user_ref = db.collection('users').document(uid) + print(f"Deleting conversation chat messages for user: {uid}, conversation: {conversation_id}") + if not user_ref.get().exists: + return {"message": "User not found"} + batch_delete_conversation_messages(user_ref, conversation_id) + return None + except Exception as e: + return {"message": str(e)} + + +def batch_delete_conversation_messages(parent_doc_ref, conversation_id: str, batch_size=450): + """Batch delete conversation chat messages""" + messages_ref = parent_doc_ref.collection('conversation_chats').where( + filter=FieldFilter('conversation_id', '==', conversation_id) + ) + print('batch_delete_conversation_messages', conversation_id) + + while True: + docs_stream = messages_ref.limit(batch_size).stream() + docs_list = list(docs_stream) + + if not docs_list: + print("No more conversation chat messages to delete") + break + + batch = db.batch() + for doc in docs_list: + batch.delete(doc.reference) + batch.commit() + + print(f'Deleted {len(docs_list)} conversation chat messages') + + if len(docs_list) < batch_size: + print("Processed all conversation chat messages") + break + + +def get_conversation_memories(uid: str, conversation_id: str, memory_ids: Optional[List[str]] = None) -> List[dict]: + """Get memories associated with a conversation""" + user_ref = db.collection('users').document(uid) + memories_ref = user_ref.collection('memories') + + # Filter by conversation_id + memories_ref = memories_ref.where(filter=FieldFilter('conversation_id', '==', conversation_id)) + + # If specific memory IDs are requested, filter by those too + if memory_ids: + memories_ref = memories_ref.where(filter=FieldFilter('id', 'in', memory_ids)) + + return [doc.to_dict() for doc in memories_ref.stream()] + + +def get_conversation_action_items( + uid: str, conversation_id: str, action_item_ids: Optional[List[str]] = None +) -> List[dict]: + """Get action items associated with a conversation""" + user_ref = db.collection('users').document(uid) + action_items_ref = user_ref.collection('action_items') + + # Filter by conversation_id + action_items_ref = action_items_ref.where(filter=FieldFilter('conversation_id', '==', conversation_id)) + + # If specific action item IDs are requested, filter by those too + if action_item_ids: + action_items_ref = action_items_ref.where(filter=FieldFilter('id', 'in', action_item_ids)) + + return [doc.to_dict() for doc in action_items_ref.stream()] + + +def get_conversation_data(uid: str, conversation_id: str) -> dict: + """Get the base conversation data""" + user_ref = db.collection('users').document(uid) + conversation_ref = user_ref.collection('conversations').document(conversation_id) + conversation_doc = conversation_ref.get() + + if not conversation_doc.exists: + return None + + return conversation_doc.to_dict() + + +# ************************************** +# ********* MIGRATION HELPERS ********** +# ************************************** + + +def get_conversation_chats_to_migrate(uid: str, target_level: str) -> List[dict]: + """Find all conversation chat messages that need protection level migration""" + messages_ref = db.collection('users').document(uid).collection('conversation_chats') + all_messages = messages_ref.select(['data_protection_level']).stream() + + to_migrate = [] + for doc in all_messages: + doc_data = doc.to_dict() + current_level = doc_data.get('data_protection_level', 'standard') + if target_level != current_level: + to_migrate.append({'id': doc.id, 'type': 'conversation_chat'}) + + return to_migrate + + +def migrate_conversation_chats_level_batch(uid: str, message_doc_ids: List[str], target_level: str): + """Migrate a batch of conversation chat messages to the target protection level""" + batch = db.batch() + messages_ref = db.collection('users').document(uid).collection('conversation_chats') + doc_refs = [messages_ref.document(msg_id) for msg_id in message_doc_ids] + doc_snapshots = db.get_all(doc_refs) + + for doc_snapshot in doc_snapshots: + if not doc_snapshot.exists: + print(f"Conversation chat message {doc_snapshot.id} not found, skipping.") + continue + + message_data = doc_snapshot.to_dict() + current_level = message_data.get('data_protection_level', 'standard') + + if current_level == target_level: + continue + + plain_data = _prepare_message_for_read(message_data, uid) + plain_text = plain_data.get('text') + migrated_text = plain_text + if target_level == 'enhanced': + if isinstance(plain_text, str): + migrated_text = encryption.encrypt(plain_text, uid) + + update_data = {'data_protection_level': target_level, 'text': migrated_text} + batch.update(doc_snapshot.reference, update_data) + + batch.commit() diff --git a/backend/database/vector_db_convos.py b/backend/database/vector_db_convos.py new file mode 100644 index 0000000000..7e8917bbc2 --- /dev/null +++ b/backend/database/vector_db_convos.py @@ -0,0 +1,217 @@ +import json +import gzip +import zlib +import base64 +from typing import List, Dict, Any, Optional +from datetime import datetime + +import database.chat_convo as chat_convo_db +from models.conversation import Conversation +from models.transcript_segment import TranscriptSegment + + +def get_conversation_context(uid: str, conversation_id: str) -> Dict[str, Any]: + """ + Get all relevant context for a conversation chat. + Returns transcript, summary, memories, and action items for the specific conversation. + """ + print(f"Getting conversation context for {conversation_id}") + + # Get base conversation data + conversation_data = chat_convo_db.get_conversation_data(uid, conversation_id) + if not conversation_data: + return {'transcript': '', 'summary': '', 'memories': [], 'action_items': [], 'context_text': ''} + + # Extract transcript + transcript_text = _extract_transcript_text(conversation_data) + + # Get summary - check both direct and structured locations + summary = conversation_data.get('overview', '') + if not summary and 'structured' in conversation_data: + summary = conversation_data['structured'].get('overview', '') + + # Get associated memories + memories = chat_convo_db.get_conversation_memories(uid, conversation_id) + + # Get associated action items + action_items = chat_convo_db.get_conversation_action_items(uid, conversation_id) + + # Compile all context into searchable text + context_text = _compile_context_text(transcript_text, summary, memories, action_items) + + return { + 'transcript': transcript_text, + 'summary': summary, + 'memories': memories, + 'action_items': action_items, + 'context_text': context_text, + 'conversation_id': conversation_id, + 'conversation_title': ( + conversation_data.get('title') + or (conversation_data.get('structured', {}).get('title')) + or 'Untitled Conversation' + ), + } + + +def _extract_transcript_text(conversation_data: Dict[str, Any]) -> str: + """Extract and decompress transcript text from conversation data""" + + # Check if transcript_segments exist and are compressed + if conversation_data.get('transcript_segments_compressed', False): + # Handle compressed transcript segments + transcript_segments_data = conversation_data.get('transcript_segments') + if transcript_segments_data: + try: + # Handle both string and bytes cases + if isinstance(transcript_segments_data, bytes): + # If it's bytes, it's already the compressed data (Firebase client decoded base64 for us) + compressed_data = transcript_segments_data + else: + # If it's string, it's base64 that we need to decode first + compressed_data = base64.b64decode(transcript_segments_data) + + # Try zlib first (most common), then gzip as fallback + try: + decompressed_data = zlib.decompress(compressed_data) + except Exception: + try: + decompressed_data = gzip.decompress(compressed_data) + except Exception as e: + print(f"Error decompressing transcript: {e}") + return "" + + segments_data = json.loads(decompressed_data.decode('utf-8')) + + # Convert to TranscriptSegment objects and extract text + if isinstance(segments_data, list): + segments = [TranscriptSegment(**segment) for segment in segments_data] + return TranscriptSegment.segments_as_string(segments) + + except Exception as e: + print(f"Error processing transcript: {e}") + return "" + + # Fall back to regular transcript field if available + transcript = conversation_data.get('transcript', '') + if transcript: + return transcript + + # If no transcript found, try to extract from structured data + transcript_segments = conversation_data.get('transcript_segments', []) + if isinstance(transcript_segments, list): + try: + segments = [TranscriptSegment(**segment) for segment in transcript_segments] + return TranscriptSegment.segments_as_string(segments) + except Exception as e: + print(f"Error processing transcript segments: {e}") + return "" + + return "" + + +def _compile_context_text( + transcript: str, summary: str, memories: List[Dict[str, Any]], action_items: List[Dict[str, Any]] +) -> str: + """Compile all context into a single searchable text""" + + context_parts = [] + + # Add summary + if summary: + context_parts.append(f"CONVERSATION SUMMARY:\n{summary}") + + # Add transcript + if transcript: + context_parts.append(f"CONVERSATION TRANSCRIPT:\n{transcript}") + + # Add memories + if memories: + memories_text = "RELATED MEMORIES:\n" + for memory in memories: + title = memory.get('title', 'Untitled Memory') + overview = memory.get('overview', '') + memories_text += f"- {title}: {overview}\n" + context_parts.append(memories_text) + + # Add action items + if action_items: + action_items_text = "ACTION ITEMS:\n" + for item in action_items: + description = item.get('description', '') + status = "Completed" if item.get('completed', False) else "Pending" + due_at = item.get('due_at') + due_text = f" (Due: {due_at.strftime('%Y-%m-%d')})" if due_at else "" + action_items_text += f"- [{status}] {description}{due_text}\n" + context_parts.append(action_items_text) + + return "\n\n".join(context_parts) + + +def search_conversation_context( + uid: str, conversation_id: str, query: str = "", include_memories: bool = True, include_action_items: bool = True +) -> Dict[str, Any]: + """ + Search within conversation context. + Since we're dealing with a single conversation, we return all relevant context. + The query parameter can be used for future filtering if needed. + """ + + context = get_conversation_context(uid, conversation_id) + + # For now, return all context since it's scoped to one conversation + # Future enhancement: could implement keyword filtering based on query + + result = { + 'context_text': context['context_text'], + 'transcript': context['transcript'], + 'summary': context['summary'], + 'conversation_title': context['conversation_title'], + 'conversation_id': conversation_id, + 'memories_found': context['memories'] if include_memories else [], + 'action_items_found': context['action_items'] if include_action_items else [], + 'total_context_length': len(context['context_text']), + } + + print(f"Conversation context search returned {len(result['context_text'])} characters of context") + return result + + +def get_conversation_summary_for_chat(uid: str, conversation_id: str) -> str: + """Get a formatted summary of the conversation for chat context""" + + context = get_conversation_context(uid, conversation_id) + + summary_parts = [] + + if context['summary']: + summary_parts.append(f"Summary: {context['summary']}") + + if context['memories']: + summary_parts.append(f"Related memories: {len(context['memories'])} items") + + if context['action_items']: + pending_items = sum(1 for item in context['action_items'] if not item.get('completed', False)) + completed_items = len(context['action_items']) - pending_items + summary_parts.append(f"Action items: {pending_items} pending, {completed_items} completed") + + transcript_length = len(context['transcript']) + if transcript_length > 0: + summary_parts.append(f"Transcript: {transcript_length} characters") + + return " | ".join(summary_parts) if summary_parts else "No additional context available" + + +def validate_conversation_context(uid: str, conversation_id: str) -> bool: + """Validate that conversation has sufficient context for chat""" + + context = get_conversation_context(uid, conversation_id) + + # Check if we have at least some content + has_transcript = bool(context['transcript']) + has_summary = bool(context['summary']) + has_memories = bool(context['memories']) + has_action_items = bool(context['action_items']) + + # Conversation should have at least transcript or summary to be chatworthy + return has_transcript or has_summary or has_memories or has_action_items diff --git a/backend/main.py b/backend/main.py index 0bba95aca2..e3da37c7db 100644 --- a/backend/main.py +++ b/backend/main.py @@ -8,6 +8,7 @@ from routers import ( workflow, chat, + chat_convo, firmware, plugins, transcribe, @@ -46,6 +47,7 @@ app.include_router(action_items.router) app.include_router(memories.router) app.include_router(chat.router) +app.include_router(chat_convo.router) app.include_router(plugins.router) app.include_router(speech_profile.router) # app.include_router(screenpipe.router) diff --git a/backend/models/chat_convo.py b/backend/models/chat_convo.py new file mode 100644 index 0000000000..7050e3cf49 --- /dev/null +++ b/backend/models/chat_convo.py @@ -0,0 +1,127 @@ +from datetime import datetime +from enum import Enum +from typing import List, Optional, Any + +from pydantic import BaseModel, model_validator + + +class MessageSender(str, Enum): + ai = 'ai' + human = 'human' + + +class MessageType(str, Enum): + text = 'text' + + +class ConversationReference(BaseModel): + """Reference to the parent conversation this chat belongs to""" + + id: str + title: str + created_at: datetime + + +class ConversationChatMessage(BaseModel): + """Message within a conversation-specific chat""" + + id: str + text: str + created_at: datetime + sender: MessageSender + type: MessageType + conversation_id: str # Always tied to a specific conversation + + # References to memories/action items cited in the response + memories_id: List[str] = [] + action_items_id: List[str] = [] + + # Response metadata + reported: bool = False + report_reason: Optional[str] = None + data_protection_level: Optional[str] = None + + @staticmethod + def get_messages_as_string( + messages: List['ConversationChatMessage'], use_user_name_if_available: bool = False + ) -> str: + """Convert messages to string format for LLM processing""" + sorted_messages = sorted(messages, key=lambda m: m.created_at) + + def get_sender_name(message: ConversationChatMessage) -> str: + if message.sender == 'human': + return 'User' + return 'AI' + + formatted_messages = [ + f"({message.created_at.strftime('%d %b %Y at %H:%M UTC')}) {get_sender_name(message)}: {message.text}" + for message in sorted_messages + ] + + return '\n'.join(formatted_messages) + + @staticmethod + def get_messages_as_xml(messages: List['ConversationChatMessage'], use_user_name_if_available: bool = False) -> str: + """Convert messages to XML format for LLM processing""" + sorted_messages = sorted(messages, key=lambda m: m.created_at) + + def get_sender_name(message: ConversationChatMessage) -> str: + if message.sender == 'human': + return 'User' + return 'AI' + + formatted_messages = [ + f""" + + + {message.created_at.strftime('%d %b %Y at %H:%M UTC')} + + + {get_sender_name(message)} + + + {message.text} + + + """.replace( + ' ', '' + ) + .replace('\n\n\n', '\n\n') + .strip() + for message in sorted_messages + ] + + return '\n'.join(formatted_messages) + + +class SendConversationMessageRequest(BaseModel): + """Request model for sending a message in conversation chat""" + + text: str + conversation_id: str + + +class ConversationChatResponse(ConversationChatMessage): + """Response model with additional metadata""" + + ask_for_nps: Optional[bool] = False + conversation: Optional[ConversationReference] = None + + +class ConversationMemoryReference(BaseModel): + """Referenced memory from the conversation context""" + + id: str + title: str + overview: str + created_at: datetime + + +class ConversationActionItemReference(BaseModel): + """Referenced action item from the conversation context""" + + id: str + description: str + completed: bool + due_at: Optional[datetime] = None + created_at: datetime diff --git a/backend/routers/chat_convo.py b/backend/routers/chat_convo.py new file mode 100644 index 0000000000..6da307cf15 --- /dev/null +++ b/backend/routers/chat_convo.py @@ -0,0 +1,223 @@ +import uuid +import re +import base64 +from datetime import datetime, timezone +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse + +import database.chat_convo as chat_convo_db +import database.conversations as conversations_db +from models.chat_convo import ( + ConversationChatMessage, + SendConversationMessageRequest, + MessageSender, + ConversationChatResponse, + ConversationReference, +) +from utils.other import endpoints as auth +from utils.retrieval.graph_convos import execute_conversation_chat_stream # We'll create this + +router = APIRouter() + + +def _validate_conversation_access(uid: str, conversation_id: str) -> dict: + """Validate that user has access to the conversation""" + conversation = chat_convo_db.get_conversation_data(uid, conversation_id) + if not conversation: + raise HTTPException(status_code=404, detail='Conversation not found') + return conversation + + +@router.post( + '/v2/conversations/{conversation_id}/chat/messages', + tags=['conversation-chat'], + response_model=ConversationChatResponse, +) +def send_conversation_message( + conversation_id: str, + data: SendConversationMessageRequest, + uid: str = Depends(auth.get_current_user_uid), +): + """Send a message in a conversation-specific chat with streaming response""" + print('send_conversation_message', conversation_id, data.text, uid) + + # Validate conversation access + conversation = _validate_conversation_access(uid, conversation_id) + + # Ensure the conversation_id matches between URL and request body + if data.conversation_id != conversation_id: + raise HTTPException(status_code=400, detail='Conversation ID mismatch') + + # Create human message + message = ConversationChatMessage( + id=str(uuid.uuid4()), + text=data.text, + created_at=datetime.now(timezone.utc), + sender='human', + type='text', + conversation_id=conversation_id, + ) + + # Store human message + chat_convo_db.add_conversation_message(uid, message.dict()) + + # Get recent messages for context (last 10 messages) + messages = list( + reversed( + [ + ConversationChatMessage(**msg) + for msg in chat_convo_db.get_conversation_messages(uid, conversation_id, limit=10) + ] + ) + ) + + def process_message(response: str, callback_data: dict): + """Process the AI response and create the AI message""" + memories = callback_data.get('memories_found', []) + action_items = callback_data.get('action_items_found', []) + ask_for_nps = callback_data.get('ask_for_nps', False) + + # Extract cited indices from response + cited_memory_idxs = {int(i) for i in re.findall(r'\[(\d+)\]', response)} + if len(cited_memory_idxs) > 0: + response = re.sub(r'\[\d+\]', '', response) + + # Get referenced memories and action items + memories_id = [] + action_items_id = [] + + if memories and cited_memory_idxs: + referenced_memories = [memories[i - 1] for i in cited_memory_idxs if 0 < i <= len(memories)] + memories_id = [m.get('id') for m in referenced_memories if m.get('id')] + + # Create AI message + ai_message = ConversationChatMessage( + id=str(uuid.uuid4()), + text=response, + created_at=datetime.now(timezone.utc), + sender='ai', + type='text', + conversation_id=conversation_id, + memories_id=memories_id, + action_items_id=action_items_id, + ) + + # Store AI message + chat_convo_db.add_conversation_message(uid, ai_message.dict()) + + return ai_message, ask_for_nps + + async def generate_stream(): + """Generate streaming response""" + callback_data = {} + async for chunk in execute_conversation_chat_stream( + uid, conversation_id, messages, callback_data=callback_data + ): + if chunk: + msg = chunk.replace("\n", "__CRLF__") + yield f'{msg}\n\n' + else: + response = callback_data.get('answer') + if response: + ai_message, ask_for_nps = process_message(response, callback_data) + ai_message_dict = ai_message.dict() + + # Add conversation reference + conversation_ref = ConversationReference( + id=conversation['id'], + title=conversation.get('title', 'Untitled Conversation'), + created_at=conversation['created_at'], + ) + + response_message = ConversationChatResponse(**ai_message_dict) + response_message.ask_for_nps = ask_for_nps + response_message.conversation = conversation_ref + + data = base64.b64encode(bytes(response_message.model_dump_json(), 'utf-8')).decode('utf-8') + yield f"done: {data}\n\n" + + return StreamingResponse(generate_stream(), media_type="text/event-stream") + + +@router.get( + '/v2/conversations/{conversation_id}/chat/messages', + response_model=List[ConversationChatMessage], + tags=['conversation-chat'], +) +def get_conversation_messages( + conversation_id: str, limit: int = 100, offset: int = 0, uid: str = Depends(auth.get_current_user_uid) +): + """Get all messages in a conversation chat""" + # Validate conversation access + _validate_conversation_access(uid, conversation_id) + + messages = chat_convo_db.get_conversation_messages( + uid, conversation_id, limit=limit, offset=offset, include_references=True + ) + print('get_conversation_messages', len(messages), conversation_id) + + return messages + + +@router.post( + '/v2/conversations/{conversation_id}/chat/messages/{message_id}/report', + tags=['conversation-chat'], + response_model=dict, +) +def report_conversation_message(conversation_id: str, message_id: str, uid: str = Depends(auth.get_current_user_uid)): + """Report a message in conversation chat""" + # Validate conversation access + _validate_conversation_access(uid, conversation_id) + + message, msg_doc_id = chat_convo_db.get_conversation_message(uid, message_id) + if message is None: + raise HTTPException(status_code=404, detail='Message not found') + if message.sender != 'ai': + raise HTTPException(status_code=400, detail='Only AI messages can be reported') + if message.reported: + raise HTTPException(status_code=400, detail='Message already reported') + if message.conversation_id != conversation_id: + raise HTTPException(status_code=400, detail='Message does not belong to this conversation') + + chat_convo_db.report_conversation_message(uid, msg_doc_id) + return {'message': 'Message reported'} + + +@router.delete('/v2/conversations/{conversation_id}/chat/messages', tags=['conversation-chat']) +def clear_conversation_chat(conversation_id: str, uid: str = Depends(auth.get_current_user_uid)): + """Clear all messages in a conversation chat""" + # Validate conversation access + _validate_conversation_access(uid, conversation_id) + + err = chat_convo_db.clear_conversation_chat(uid, conversation_id) + if err: + raise HTTPException(status_code=500, detail='Failed to clear conversation chat') + + return {'message': 'Conversation chat cleared successfully'} + + +@router.get('/v2/conversations/{conversation_id}/chat/context', tags=['conversation-chat']) +def get_conversation_context(conversation_id: str, uid: str = Depends(auth.get_current_user_uid)): + """Get context information for a conversation (transcript, memories, action items)""" + # Validate conversation access + conversation = _validate_conversation_access(uid, conversation_id) + + # Use our proper context extraction logic + from database.vector_db_convos import get_conversation_context + + context = get_conversation_context(uid, conversation_id) + + return { + 'conversation': { + 'id': conversation['id'], + 'title': context['conversation_title'], + 'summary': context['summary'], + 'transcript': context['transcript'], + 'created_at': conversation['created_at'], + }, + 'memories': context['memories'], + 'action_items': context['action_items'], + 'context_items_count': len(context['memories']) + len(context['action_items']), + } diff --git a/backend/utils/llm/chat_convos.py b/backend/utils/llm/chat_convos.py new file mode 100644 index 0000000000..1fad7e6520 --- /dev/null +++ b/backend/utils/llm/chat_convos.py @@ -0,0 +1,313 @@ +from datetime import datetime, timezone +from typing import List, Optional + +from pydantic import BaseModel, Field + +from .clients import llm_mini, llm_mini_stream, llm_medium_stream, llm_medium +from database.vector_db_convos import search_conversation_context, get_conversation_summary_for_chat +from models.chat_convo import ConversationChatMessage +from utils.llms.memory import get_prompt_memories + + +# **************************************** +# ************* CONVERSATION CHAT ******** +# **************************************** + + +class ConversationQuestion(BaseModel): + question: str = Field(description='The extracted user question about the conversation.') + + +def extract_question_from_conversation_messages(messages: List[ConversationChatMessage]) -> str: + """Extract the user's question from recent conversation chat messages""" + print("extract_question_from_conversation_messages") + + # Find the last user message + user_message_idx = len(messages) + for i in range(len(messages) - 1, -1, -1): + if messages[i].sender == 'ai': + break + if messages[i].sender == 'human': + user_message_idx = i + + user_last_messages = messages[user_message_idx:] + if len(user_last_messages) == 0: + return "" + + prompt = f''' + You will be given recent messages from a conversation-specific chat where a user is asking questions about a particular conversation. + + Your task is to identify the question the user is asking about this conversation. + + If the user is not asking a question (e.g., just saying "Hi", "Hello", "Thanks"), respond with an empty string. + + If the user is asking a question, extract and rephrase it as a clear, complete question about the conversation. + + Examples: + - "What did we talk about?" → "What topics were discussed in this conversation?" + - "Any action items from this?" → "What action items were generated from this conversation?" + - "Who was speaking?" → "Who were the participants in this conversation?" + - "What was decided?" → "What decisions were made in this conversation?" + + Recent messages: + {ConversationChatMessage.get_messages_as_xml(user_last_messages)} + + Previous context (for reference): + {ConversationChatMessage.get_messages_as_xml(messages)} + '''.replace( + ' ', '' + ).strip() + + question = llm_mini.with_structured_output(ConversationQuestion).invoke(prompt).question + print(f"Extracted question: {question}") + return question + + +class RequiresContext(BaseModel): + value: bool = Field(description="Whether the question requires conversation context to answer") + + +def question_requires_conversation_context(question: str) -> bool: + """Determine if the question needs conversation context or can be answered generally""" + if not question.strip(): + return False + + prompt = f''' + Based on the user's question about a conversation, determine if this requires specific context from that conversation to answer properly. + + Examples requiring context: + - "What did we discuss?" → True + - "Who was in this conversation?" → True + - "What action items were created?" → True + - "What was the main topic?" → True + + Examples NOT requiring context: + - "Hi" → False + - "How are you?" → False + - "Thank you" → False + - "What is artificial intelligence?" (general question) → False + + User's Question: {question} + ''' + + with_parser = llm_mini.with_structured_output(RequiresContext) + response: RequiresContext = with_parser.invoke(prompt) + return response.value + + +def get_simple_conversation_response_prompt( + uid: str, messages: List[ConversationChatMessage], conversation_id: str +) -> str: + """Generate prompt for simple conversation responses that don't need context""" + + user_name, memories_str = get_prompt_memories(uid) # Same as main chat + conversation_history = ConversationChatMessage.get_messages_as_string(messages) + conversation_summary = get_conversation_summary_for_chat(uid, conversation_id) + + return f""" + You are a helpful assistant for {user_name} discussing a specific conversation. + + About {user_name}: {memories_str} + + You are currently in a chat about this conversation: {conversation_summary} + + Respond naturally and helpfully. If the user asks about specific details from the conversation that you don't have context for, let them know you'd be happy to help but need to search through the conversation content. + + Chat History: + {conversation_history} + + Response: + """.replace( + ' ', '' + ).strip() + + +def answer_simple_conversation_message(uid: str, messages: List[ConversationChatMessage], conversation_id: str) -> str: + """Generate a simple response without conversation context""" + prompt = get_simple_conversation_response_prompt(uid, messages, conversation_id) + return llm_mini.invoke(prompt).content + + +def answer_simple_conversation_message_stream( + uid: str, messages: List[ConversationChatMessage], conversation_id: str, callbacks=[] +) -> str: + """Generate a simple streaming response without conversation context""" + prompt = get_simple_conversation_response_prompt(uid, messages, conversation_id) + return llm_mini_stream.invoke(prompt, {'callbacks': callbacks}).content + + +def get_conversation_qa_prompt( + uid: str, + question: str, + conversation_context: str, + messages: List[ConversationChatMessage], + conversation_title: str, + conversation_id: str, +) -> str: + """Generate prompt for conversation-specific Q&A with context""" + + user_name, memories_str = get_prompt_memories(uid) # Same as main chat + memories_str = '\n'.join(memories_str.split('\n')[1:]).strip() # Same processing as main chat + messages_history = ConversationChatMessage.get_messages_as_xml(messages) + + return ( + f""" + + You are an assistant helping {user_name} understand and analyze a specific conversation. + + + + Answer the user's question about the conversation titled "{conversation_title}" using the provided conversation context. + + + + - Use the conversation context (transcript, summary, memories, action items) to answer the question accurately + - Be specific and cite relevant parts of the conversation when possible + - If asked about participants, refer to speakers by their identifiers (Speaker 0, Speaker 1, etc.) or names if provided + - For action items, include completion status and due dates if available + - If the question cannot be answered from the available context, be honest about limitations + - Keep responses conversational and helpful + - You can reference line numbers, timestamps, or specific quotes from the transcript when relevant + + + + [Use the following User Facts if relevant to the conversation analysis] + {memories_str.strip()} + + + + {conversation_context} + + + + {question} + + + + {messages_history} + + + + Current date time in UTC: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')} + + + + """.replace( + ' ', '' + ) + .replace('\n\n\n', '\n\n') + .strip() + ) + + +def answer_conversation_question( + uid: str, + question: str, + conversation_context: str, + messages: List[ConversationChatMessage], + conversation_title: str, + conversation_id: str, +) -> str: + """Answer a question about a conversation using context""" + + prompt = get_conversation_qa_prompt( + uid, question, conversation_context, messages, conversation_title, conversation_id + ) + return llm_medium.invoke(prompt).content + + +def answer_conversation_question_stream( + uid: str, + question: str, + conversation_context: str, + messages: List[ConversationChatMessage], + conversation_title: str, + conversation_id: str, + callbacks=[], +) -> str: + """Answer a question about a conversation using context with streaming""" + + prompt = get_conversation_qa_prompt( + uid, question, conversation_context, messages, conversation_title, conversation_id + ) + return llm_medium_stream.invoke(prompt, {'callbacks': callbacks}).content + + +def get_conversation_context_for_question(uid: str, conversation_id: str, question: str) -> dict: + """Get relevant context from a conversation for answering a question""" + + # For conversation chats, we always return the full context since it's scoped to one conversation + # Future enhancement: could filter context based on question topic + context = search_conversation_context( + uid=uid, conversation_id=conversation_id, query=question, include_memories=True, include_action_items=True + ) + + return context + + +# ************************************************ +# ************* CONVERSATION ANALYSIS ************ +# ************************************************ + + +def analyze_conversation_for_insights(uid: str, conversation_id: str) -> str: + """Generate insights and analysis about a conversation""" + + context = search_conversation_context(uid, conversation_id) + user_name, memories_str = get_prompt_memories(uid) # Same as main chat + + prompt = f""" + As an AI assistant for {user_name}, analyze this conversation and provide helpful insights. + + About {user_name}: {memories_str} + + Conversation Content: + {context['context_text']} + + Please provide: + 1. Key topics discussed + 2. Important decisions made + 3. Action items and next steps + 4. Notable quotes or insights + 5. Overall conversation summary + + Keep the analysis concise and actionable. + """ + + return llm_medium.invoke(prompt).content + + +# TODO: Future Enhancement - Dynamic Suggestions +# def suggest_follow_up_questions(uid: str, conversation_id: str) -> List[str]: +# """ +# Future: Make suggestions dynamic based on chat history progression +# - Include recent chat messages to avoid repeated questions +# - Generate contextual suggestions based on chat evolution +# - Avoid suggesting topics already discussed +# """ + + +def suggest_follow_up_questions(uid: str, conversation_id: str) -> List[str]: + """Suggest follow-up questions the user might want to ask about the conversation""" + + context = search_conversation_context(uid, conversation_id) + + prompt = f""" + Based on this conversation content, suggest 3-5 relevant follow-up questions that someone might want to ask to better understand the conversation. + + Conversation Content: + {context['context_text'][:1000]}... + + Format as a simple list of questions, one per line. + Focus on actionable questions about decisions, action items, key points, or participants. + + Example format: + - What were the main decisions made in this conversation? + - Who was responsible for the action items? + - What are the next steps discussed? + """ + + response = llm_mini.invoke(prompt).content + # Parse the response into a list + questions = [q.strip('- ').strip() for q in response.split('\n') if q.strip() and q.strip().startswith('- ')] + return questions[:5] # Return max 5 questions diff --git a/backend/utils/retrieval/graph_convos.py b/backend/utils/retrieval/graph_convos.py new file mode 100644 index 0000000000..8ebdaa5f6f --- /dev/null +++ b/backend/utils/retrieval/graph_convos.py @@ -0,0 +1,259 @@ +import asyncio +from typing import List, Optional, AsyncGenerator + +from langchain.callbacks.base import BaseCallbackHandler +from langgraph.constants import END +from langgraph.graph import START, StateGraph +from typing_extensions import TypedDict + +from utils.llm.chat_convos import get_conversation_context_for_question +from models.chat_convo import ConversationChatMessage +from utils.llm.chat_convos import ( + extract_question_from_conversation_messages, + question_requires_conversation_context, + answer_simple_conversation_message_stream, + answer_conversation_question_stream, +) +from utils.other.endpoints import timeit + + +class AsyncStreamingCallback(BaseCallbackHandler): + """Callback handler for streaming responses""" + + def __init__(self): + self.queue = asyncio.Queue() + + async def put_data(self, text): + await self.queue.put(f"data: {text}") + + async def put_thought(self, text): + await self.queue.put(f"think: {text}") + + def put_thought_nowait(self, text): + self.queue.put_nowait(f"think: {text}") + + async def end(self): + await self.queue.put(None) + + async def on_llm_new_token(self, token: str, **kwargs) -> None: + await self.put_data(token) + + async def on_llm_end(self, response, **kwargs) -> None: + await self.end() + + async def on_llm_error(self, error: Exception, **kwargs) -> None: + print(f"Error on LLM {error}") + await self.end() + + def put_data_nowait(self, text): + self.queue.put_nowait(f"data: {text}") + + def end_nowait(self): + self.queue.put_nowait(None) + + +class ConversationGraphState(TypedDict): + """State for conversation chat graph""" + + uid: str + conversation_id: str + messages: List[ConversationChatMessage] + + # Processing state + streaming: Optional[bool] = False + callback: Optional[AsyncStreamingCallback] = None + + # Extracted information + parsed_question: Optional[str] + requires_context: Optional[bool] + + # Context and response + conversation_context: Optional[dict] + answer: Optional[str] + ask_for_nps: Optional[bool] + + +def determine_conversation_question(state: ConversationGraphState): + """Extract the user's question from conversation messages""" + print("determine_conversation_question") + + question = extract_question_from_conversation_messages(state.get("messages", [])) + print(f"Extracted question: {question}") + + return {"parsed_question": question} + + +def determine_context_requirement(state: ConversationGraphState): + """Determine if the question requires conversation context""" + print("determine_context_requirement") + + question = state.get("parsed_question", "") + requires_context = question_requires_conversation_context(question) + + print(f"Requires context: {requires_context}") + return {"requires_context": requires_context} + + +def simple_conversation_response(state: ConversationGraphState): + """Handle simple responses that don't need conversation context""" + print("simple_conversation_response") + + uid = state.get("uid") + conversation_id = state.get("conversation_id") + messages = state.get("messages", []) + streaming = state.get("streaming", False) + + if streaming: + answer = answer_simple_conversation_message_stream( + uid, messages, conversation_id, callbacks=[state.get('callback')] + ) + return {"answer": answer, "ask_for_nps": False} + + # TODO: Implement non-streaming version if needed + return {"answer": "Sorry, non-streaming responses not yet implemented.", "ask_for_nps": False} + + +def retrieve_conversation_context(state: ConversationGraphState): + """Retrieve context from the conversation""" + print("retrieve_conversation_context") + + uid = state.get("uid") + conversation_id = state.get("conversation_id") + question = state.get("parsed_question", "") + + # Get conversation context + context = get_conversation_context_for_question(uid, conversation_id, question) + print(f"Retrieved context: {len(context.get('context_text', ''))} characters") + + return {"conversation_context": context} + + +def context_based_response(state: ConversationGraphState): + """Generate response using conversation context""" + print("context_based_response") + + uid = state.get("uid") + conversation_id = state.get("conversation_id") + question = state.get("parsed_question", "") + messages = state.get("messages", []) + context = state.get("conversation_context", {}) + streaming = state.get("streaming", False) + + context_text = context.get('context_text', '') + conversation_title = context.get('conversation_title', 'Conversation') + + if streaming: + answer = answer_conversation_question_stream( + uid=uid, + question=question, + conversation_context=context_text, + messages=messages, + conversation_title=conversation_title, + conversation_id=conversation_id, + callbacks=[state.get('callback')], + ) + return { + "answer": answer, + "ask_for_nps": True, + "memories_found": context.get('memories_found', []), + "action_items_found": context.get('action_items_found', []), + } + + # TODO: Implement non-streaming version if needed + return {"answer": "Sorry, non-streaming responses not yet implemented.", "ask_for_nps": False} + + +def route_conversation_type(state: ConversationGraphState) -> str: + """Route conversation based on whether context is needed""" + requires_context = state.get("requires_context", False) + question = state.get("parsed_question", "") + + # If no question or doesn't require context, use simple response + if not question.strip() or not requires_context: + return "simple_response" + + return "context_response" + + +# Create the conversation chat workflow +workflow = StateGraph(ConversationGraphState) + +# Add nodes +workflow.add_node("determine_question", determine_conversation_question) +workflow.add_node("determine_context", determine_context_requirement) +workflow.add_node("simple_response", simple_conversation_response) +workflow.add_node("retrieve_context", retrieve_conversation_context) +workflow.add_node("context_response", context_based_response) + +# Add edges +workflow.add_edge(START, "determine_question") +workflow.add_edge("determine_question", "determine_context") +workflow.add_conditional_edges( + "determine_context", + route_conversation_type, + {"simple_response": "simple_response", "context_response": "retrieve_context"}, +) +workflow.add_edge("simple_response", END) +workflow.add_edge("retrieve_context", "context_response") +workflow.add_edge("context_response", END) + +# Compile the graph +conversation_graph = workflow.compile() + + +@timeit +def execute_conversation_chat( + uid: str, conversation_id: str, messages: List[ConversationChatMessage] +) -> tuple[str, bool, dict]: + """Execute conversation chat (non-streaming)""" + print(f'execute_conversation_chat for conversation: {conversation_id}') + + result = conversation_graph.invoke( + {"uid": uid, "conversation_id": conversation_id, "messages": messages, "streaming": False} + ) + + return (result.get("answer", ""), result.get('ask_for_nps', False), result.get("conversation_context", {})) + + +async def execute_conversation_chat_stream( + uid: str, conversation_id: str, messages: List[ConversationChatMessage], callback_data: dict = {} +) -> AsyncGenerator[str, None]: + """Execute conversation chat with streaming""" + print(f'execute_conversation_chat_stream for conversation: {conversation_id}') + + callback = AsyncStreamingCallback() + + task = asyncio.create_task( + conversation_graph.ainvoke( + { + "uid": uid, + "conversation_id": conversation_id, + "messages": messages, + "streaming": True, + "callback": callback, + } + ) + ) + + while True: + try: + chunk = await callback.queue.get() + if chunk: + yield chunk + else: + break + except asyncio.CancelledError: + break + + await task + result = task.result() + + # Pass results back to caller + callback_data['answer'] = result.get("answer") + callback_data['memories_found'] = result.get("memories_found", []) + callback_data['action_items_found'] = result.get("action_items_found", []) + callback_data['ask_for_nps'] = result.get('ask_for_nps', False) + + yield None + return +