Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/07 - Extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ The extensions framework is based on special functions and variables that you ca
| Function | Description |
|-------------|-------------|
| `def setup()` | Is executed when the extension gets imported. |
| `def ui()` | Creates custom gradio elements when the UI is launched. |
| `def ui()` | Creates custom gradio elements when the UI is launched. |
| `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. |
| `def custom_js()` | Same as above but for javascript. |
| `def input_modifier(string, state, is_chat=False)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. |
| `def output_modifier(string, state, is_chat=False)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
| `def chat_input_modifier(text, visible_text, state)` | Modifies both the visible and internal inputs in chat mode. Can be used to hijack the chat input with custom content. |
| `def output_stream_modifier(string, state, is_chat=False, is_final=False)` | Overrides the full text mid-stream. Called for each partial token/chunk while the UI is streaming output. Includes the last generated token (is_final). |
| `def bot_prefix_modifier(string, state)` | Applied in chat mode to the prefix for the bot's reply. |
| `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. |
| `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. |
Expand Down Expand Up @@ -209,6 +210,12 @@ def output_modifier(string, state, is_chat=False):
"""
return string

def output_stream_modifier(string, state, is_chat=False, is_final=False):
"""
Modifies the text stream of the LLM output in realtime.
"""
return string

def custom_generate_chat_prompt(user_input, state, **kwargs):
"""
Replaces the function that generates the prompt from the chat history.
Expand Down
20 changes: 10 additions & 10 deletions modules/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,31 +81,30 @@ def iterator():


# Extension functions that map string -> string
def _apply_string_extensions(function_name, text, state, is_chat=False):
def _apply_string_extensions(function_name, text, state, is_chat=False, **extra_kwargs):
for extension, _ in iterator():
if hasattr(extension, function_name):
func = getattr(extension, function_name)

# Handle old extensions without the 'state' arg or
# the 'is_chat' kwarg
count = 0
has_chat = False
for k in signature(func).parameters:
func_params = signature(func).parameters
kwargs = {}

for k in func_params:
if k == 'is_chat':
has_chat = True
kwargs['is_chat'] = is_chat
elif k in extra_kwargs:
kwargs[k] = extra_kwargs[k]
else:
count += 1

if count == 2:
if count >= 2:
args = [text, state]
else:
args = [text]

if has_chat:
kwargs = {'is_chat': is_chat}
else:
kwargs = {}

text = func(*args, **kwargs)

return text
Expand Down Expand Up @@ -231,6 +230,7 @@ def create_extensions_tabs():
"input": partial(_apply_string_extensions, "input_modifier"),
"output": partial(_apply_string_extensions, "output_modifier"),
"chat_input": _apply_chat_input_extensions,
"output_stream": partial(_apply_string_extensions, "output_stream_modifier"),
"state": _apply_state_modifier_extensions,
"history": _apply_history_modifier_extensions,
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
Expand Down
19 changes: 19 additions & 0 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
for reply in generate_func(question, original_question, state, stopping_strings, is_chat=is_chat):
cur_time = time.monotonic()
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)

try:
reply = apply_extensions('output_stream', reply, state, is_chat=is_chat, is_final=False)
except Exception:
try:
logger.error('Error in streaming extension hook')
except Exception:
pass
traceback.print_exc()

if escape_html:
reply = html.escape(reply)

Expand All @@ -102,6 +112,15 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything):
break

try:
reply = apply_extensions('output_stream', reply, state, is_chat=is_chat, is_final=True)
except Exception:
try:
logger.error('Error in streaming extension hook')
except Exception:
pass
traceback.print_exc()

if not is_chat:
reply = apply_extensions('output', reply, state)

Expand Down