diff --git a/docs/07 - Extensions.md b/docs/07 - Extensions.md index ebcd3c0ed9..3f75e00548 100644 --- a/docs/07 - Extensions.md +++ b/docs/07 - Extensions.md @@ -40,12 +40,13 @@ The extensions framework is based on special functions and variables that you ca | Function | Description | |-------------|-------------| | `def setup()` | Is executed when the extension gets imported. | -| `def ui()` | Creates custom gradio elements when the UI is launched. | +| `def ui()` | Creates custom gradio elements when the UI is launched. | | `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. | | `def custom_js()` | Same as above but for javascript. | | `def input_modifier(string, state, is_chat=False)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | | `def output_modifier(string, state, is_chat=False)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | | `def chat_input_modifier(text, visible_text, state)` | Modifies both the visible and internal inputs in chat mode. Can be used to hijack the chat input with custom content. | +| `def output_stream_modifier(string, state, is_chat=False, is_final=False)` | Overrides the full text mid-stream. Called for each partial token/chunk while the UI is streaming output. Includes the last generated token (is_final). | | `def bot_prefix_modifier(string, state)` | Applied in chat mode to the prefix for the bot's reply. | | `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. | | `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. | @@ -209,6 +210,12 @@ def output_modifier(string, state, is_chat=False): """ return string +def output_stream_modifier(string, state, is_chat=False, is_final=False): + """ + Modifies the text stream of the LLM output in realtime. + """ + return string + def custom_generate_chat_prompt(user_input, state, **kwargs): """ Replaces the function that generates the prompt from the chat history. diff --git a/modules/extensions.py b/modules/extensions.py index e00103124e..4450d23261 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -81,7 +81,7 @@ def iterator(): # Extension functions that map string -> string -def _apply_string_extensions(function_name, text, state, is_chat=False): +def _apply_string_extensions(function_name, text, state, is_chat=False, **extra_kwargs): for extension, _ in iterator(): if hasattr(extension, function_name): func = getattr(extension, function_name) @@ -89,23 +89,22 @@ def _apply_string_extensions(function_name, text, state, is_chat=False): # Handle old extensions without the 'state' arg or # the 'is_chat' kwarg count = 0 - has_chat = False - for k in signature(func).parameters: + func_params = signature(func).parameters + kwargs = {} + + for k in func_params: if k == 'is_chat': - has_chat = True + kwargs['is_chat'] = is_chat + elif k in extra_kwargs: + kwargs[k] = extra_kwargs[k] else: count += 1 - if count == 2: + if count >= 2: args = [text, state] else: args = [text] - if has_chat: - kwargs = {'is_chat': is_chat} - else: - kwargs = {} - text = func(*args, **kwargs) return text @@ -231,6 +230,7 @@ def create_extensions_tabs(): "input": partial(_apply_string_extensions, "input_modifier"), "output": partial(_apply_string_extensions, "output_modifier"), "chat_input": _apply_chat_input_extensions, + "output_stream": partial(_apply_string_extensions, "output_stream_modifier"), "state": _apply_state_modifier_extensions, "history": _apply_history_modifier_extensions, "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), diff --git a/modules/text_generation.py b/modules/text_generation.py index 27c5de7dff..7606fb37c8 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -77,6 +77,16 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap for reply in generate_func(question, original_question, state, stopping_strings, is_chat=is_chat): cur_time = time.monotonic() reply, stop_found = apply_stopping_strings(reply, all_stop_strings) + + try: + reply = apply_extensions('output_stream', reply, state, is_chat=is_chat, is_final=False) + except Exception: + try: + logger.error('Error in streaming extension hook') + except Exception: + pass + traceback.print_exc() + if escape_html: reply = html.escape(reply) @@ -102,6 +112,15 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything): break + try: + reply = apply_extensions('output_stream', reply, state, is_chat=is_chat, is_final=True) + except Exception: + try: + logger.error('Error in streaming extension hook') + except Exception: + pass + traceback.print_exc() + if not is_chat: reply = apply_extensions('output', reply, state)