Skip to content
9 changes: 8 additions & 1 deletion docs/07 - Extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ The extensions framework is based on special functions and variables that you ca
| Function | Description |
|-------------|-------------|
| `def setup()` | Is executed when the extension gets imported. |
| `def ui()` | Creates custom gradio elements when the UI is launched. |
| `def ui()` | Creates custom gradio elements when the UI is launched. |
| `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. |
| `def custom_js()` | Same as above but for javascript. |
| `def input_modifier(string, state, is_chat=False)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. |
| `def output_modifier(string, state, is_chat=False)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
| `def chat_input_modifier(text, visible_text, state)` | Modifies both the visible and internal inputs in chat mode. Can be used to hijack the chat input with custom content. |
| `def output_stream_modifier(string, state, is_chat=False, is_final=False)` | Overrides the full text mid-stream. Called for each partial token/chunk while the UI is streaming output. Includes the last generated token (is_final). |
| `def bot_prefix_modifier(string, state)` | Applied in chat mode to the prefix for the bot's reply. |
| `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. |
| `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. |
Expand Down Expand Up @@ -210,6 +211,12 @@ def output_modifier(string, state, is_chat=False):
"""
return string

def output_stream_modifier(string, state, is_chat=False, is_final=False):
"""
Modifies the text stream of the LLM output in realtime.
"""
return string

def custom_generate_chat_prompt(user_input, state, **kwargs):
"""
Replaces the function that generates the prompt from the chat history.
Expand Down
20 changes: 10 additions & 10 deletions modules/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,30 @@ def iterator():


# Extension functions that map string -> string
def _apply_string_extensions(function_name, text, state, is_chat=False):
def _apply_string_extensions(function_name, text, state, is_chat=False, **extra_kwargs):
for extension, _ in iterator():
if hasattr(extension, function_name):
func = getattr(extension, function_name)

# Handle old extensions without the 'state' arg or
# the 'is_chat' kwarg
count = 0
has_chat = False
for k in signature(func).parameters:
func_params = signature(func).parameters
kwargs = {}

for k in func_params:
if k == 'is_chat':
has_chat = True
kwargs['is_chat'] = is_chat
elif k in extra_kwargs:
kwargs[k] = extra_kwargs[k]
else:
count += 1

if count == 2:
if count >= 2:
args = [text, state]
else:
args = [text]

if has_chat:
kwargs = {'is_chat': is_chat}
else:
kwargs = {}

text = func(*args, **kwargs)

return text
Expand Down Expand Up @@ -234,6 +233,7 @@ def create_extensions_tabs():
"input": partial(_apply_string_extensions, "input_modifier"),
"output": partial(_apply_string_extensions, "output_modifier"),
"chat_input": _apply_chat_input_extensions,
"output_stream": partial(_apply_string_extensions, "output_stream_modifier"),
"state": _apply_state_modifier_extensions,
"history": _apply_history_modifier_extensions,
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
Expand Down
19 changes: 19 additions & 0 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
for reply in generate_func(question, original_question, state, stopping_strings, is_chat=is_chat):
cur_time = time.monotonic()
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)

try:
reply = apply_extensions('output_stream', reply, state, is_chat=is_chat, is_final=False)
except Exception:
try:
logger.error('Error in streaming extension hook')
except Exception:
pass
traceback.print_exc()

if escape_html:
reply = html.escape(reply)

Expand All @@ -122,6 +132,15 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
if stop_found or shared.stop_everything or (stop_event and stop_event.is_set()):
break

try:
reply = apply_extensions('output_stream', reply, state, is_chat=is_chat, is_final=True)
except Exception:
try:
logger.error('Error in streaming extension hook')
except Exception:
pass
traceback.print_exc()

if not is_chat:
reply = apply_extensions('output', reply, state)

Expand Down