Skip to content

Commit

Permalink
add transcribe_asr extension; optimize bedrock_llm extension; fix pol…
Browse files Browse the repository at this point in the history
…ly_tts bugs (#174)

* add transcribe_asr_python extension

* bedrock_llm_extension: add time buffer for model response.

* polly_tts_extension: remove on_init, fix sample_rate data type.

---------

Co-authored-by: Chen188 <hidden>
  • Loading branch information
Chen188 authored Aug 9, 2024
1 parent 239dcf0 commit 739b0b5
Show file tree
Hide file tree
Showing 16 changed files with 656 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,9 @@ def converse_stream_worker(start_time, input_text, memory):
first_sentence_sent = False

for event in stream:
if start_time < self.outdate_ts:
logger.info(
f"GetConverseStream recv interrupt and flushing for input text: [{input_text}], startTs: {start_time}, outdateTs: {self.outdate_ts}"
)
# allow 100ms buffer time, in case interruptor's flush cmd comes just after on_data event
if (start_time + 100_000) < self.outdate_ts:
logger.info(f"GetConverseStream recv interrupt and flushing for input text: [{input_text}], startTs: {start_time}, outdateTs: {self.outdate_ts}, delta > 100ms")
break

if "contentBlockDelta" in event:
Expand All @@ -278,8 +277,8 @@ def converse_stream_worker(start_time, input_text, memory):
sentence, content, sentence_is_final = parse_sentence(
sentence, content
)
if len(sentence) == 0 or not sentence_is_final:
logger.info(f"sentence {sentence} is empty or not final")
if not sentence or not sentence_is_final:
logger.info(f"sentence [{sentence}] is empty or not final")
break
logger.info(
f"GetConverseStream recv for input text: [{input_text}] got sentence: [{sentence}]"
Expand Down Expand Up @@ -313,7 +312,10 @@ def converse_stream_worker(start_time, input_text, memory):

if len(full_content.strip()):
# remember response as assistant content in memory
memory.append(
if memory and memory[-1]['role'] == 'assistant':
memory[-1]['content'].append({"text": full_content})
else:
memory.append(
{"role": "assistant", "content": [{"text": full_content}]}
)
else:
Expand Down
4 changes: 2 additions & 2 deletions agents/addon/extension/polly_tts/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"type": "string"
},
"sample_rate": {
"type": "int64"
"type": "string"
},
"lang_code": {
"type": "string"
Expand Down Expand Up @@ -60,4 +60,4 @@
}
]
}
}
}
6 changes: 0 additions & 6 deletions agents/addon/extension/polly_tts/polly_tts_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ def __init__(self, name: str):
self.bytes_per_sample = 2
self.number_of_channels = 1

def on_init(
self, rte: RteEnv, manifest: MetadataInfo, property: MetadataInfo
) -> None:
logger.info("PollyTTSExtension on_init")
rte.on_init_done(manifest, property)

def on_start(self, rte: RteEnv) -> None:
logger.info("PollyTTSExtension on_start")

Expand Down
11 changes: 11 additions & 0 deletions agents/addon/extension/transcribe_asr_python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## Amazon Transcribe ASR Extension

### Configurations

You can config this extension by providing following environments:

| Env | Required | Default | Notes |
| -- | -- | -- | -- |
| AWS_REGION | No | us-east-1 | The Region of Amazon Transcribe service you want to use. |
| AWS_ACCESS_KEY_ID | No | - | Access Key of your IAM User, make sure you've set proper permissions to [start stream transcription](https://docs.aws.amazon.com/transcribe/latest/APIReference/API_streaming_StartStreamTranscription.html). Will use default credentials provider if not provided. Check [document](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html). |
| AWS_SECRET_ACCESS_KEY | No | - | Secret Key of your IAM User. Will use default credentials provider if not provided. Check [document](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html). |
5 changes: 5 additions & 0 deletions agents/addon/extension/transcribe_asr_python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from . import transcribe_asr_addon
from .extension import EXTENSION_NAME
from .log import logger

logger.info(f"{EXTENSION_NAME} extension loaded")
1 change: 1 addition & 0 deletions agents/addon/extension/transcribe_asr_python/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
EXTENSION_NAME = "transcribe_asr"
14 changes: 14 additions & 0 deletions agents/addon/extension/transcribe_asr_python/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import logging
from .extension import EXTENSION_NAME

logger = logging.getLogger(EXTENSION_NAME)
logger.setLevel(logging.INFO)

formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s"
)

console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)

logger.addHandler(console_handler)
76 changes: 76 additions & 0 deletions agents/addon/extension/transcribe_asr_python/manifest.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
{
"type": "extension",
"name": "transcribe_asr_python",
"version": "0.1.0",
"language": "python",
"dependencies": [
{
"type": "system",
"name": "rte_runtime_python",
"version": "0.4.0"
}
],
"api": {
"property": {
"region": {
"type": "string"
},
"access_key": {
"type": "string"
},
"secret_key": {
"type": "string"
},
"sample_rate": {
"type": "string"
},
"lang_code": {
"type": "string"
}
},
"pcm_frame_in": [
{
"name": "pcm_frame"
}
],
"cmd_in": [
{
"name": "on_user_joined"
},
{
"name": "on_user_left"
},
{
"name": "on_connection_failure"
}
],
"data_out": [
{
"name": "text_data",
"property": {
"time": {
"type": "int64"
},
"duration_ms": {
"type": "int64"
},
"language": {
"type": "string"
},
"text": {
"type": "string"
},
"is_final": {
"type": "bool"
},
"stream_id": {
"type": "uint32"
},
"end_of_segment": {
"type": "bool"
}
}
}
]
}
}
1 change: 1 addition & 0 deletions agents/addon/extension/transcribe_asr_python/property.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
amazon-transcribe==0.6.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from rte import (
Addon,
register_addon_as_extension,
RteEnv,
)
from .extension import EXTENSION_NAME
from .log import logger
from .transcribe_asr_extension import TranscribeAsrExtension


@register_addon_as_extension(EXTENSION_NAME)
class TranscribeAsrExtensionAddon(Addon):
def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None:
logger.info("on_create_instance")
rte.on_create_instance_done(TranscribeAsrExtension(addon_name), context)
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from rte import (
Extension,
RteEnv,
Cmd,
PcmFrame,
StatusCode,
CmdResult,
)

import asyncio
import threading

from .log import logger
from .transcribe_wrapper import AsyncTranscribeWrapper, TranscribeConfig

PROPERTY_REGION = "region" # Optional
PROPERTY_ACCESS_KEY = "access_key" # Optional
PROPERTY_SECRET_KEY = "secret_key" # Optional
PROPERTY_SAMPLE_RATE = 'sample_rate'# Optional
PROPERTY_LANG_CODE = 'lang_code' # Optional


class TranscribeAsrExtension(Extension):
def __init__(self, name: str):
super().__init__(name)

self.stopped = False
self.queue = asyncio.Queue(maxsize=3000) # about 3000 * 10ms = 30s input
self.transcribe = None
self.thread = None

self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

def on_start(self, rte: RteEnv) -> None:
logger.info("TranscribeAsrExtension on_start")

transcribe_config = TranscribeConfig.default_config()

for optional_param in [PROPERTY_REGION, PROPERTY_SAMPLE_RATE, PROPERTY_LANG_CODE,
PROPERTY_ACCESS_KEY, PROPERTY_SECRET_KEY]:
try:
value = rte.get_property_string(optional_param).strip()
if value:
transcribe_config.__setattr__(optional_param, value)
except Exception as err:
logger.debug(f"GetProperty optional {optional_param} failed, err: {err}. Using default value: {transcribe_config.__getattribute__(optional_param)}")

self.transcribe = AsyncTranscribeWrapper(transcribe_config, self.queue, rte, self.loop)

logger.info("Starting async_transcribe_wrapper thread")
self.thread = threading.Thread(target=self.transcribe.run, args=[])
self.thread.start()

rte.on_start_done()

def put_pcm_frame(self, pcm_frame: PcmFrame) -> None:
try:
asyncio.run_coroutine_threadsafe(self.queue.put(pcm_frame), self.loop).result(timeout=0.1)
except asyncio.QueueFull:
logger.exception("Queue is full, dropping frame")
except Exception as e:
logger.exception(f"Error putting frame in queue: {e}")

def on_pcm_frame(self, rte: RteEnv, pcm_frame: PcmFrame) -> None:
self.put_pcm_frame(pcm_frame=pcm_frame)

def on_stop(self, rte: RteEnv) -> None:
logger.info("TranscribeAsrExtension on_stop")

# put an empty frame to stop transcribe_wrapper
self.put_pcm_frame(None)
self.stopped = True
self.thread.join()
self.loop.stop()
self.loop.close()

rte.on_stop_done()

def on_cmd(self, rte: RteEnv, cmd: Cmd) -> None:
logger.info("TranscribeAsrExtension on_cmd")
cmd_json = cmd.to_json()
logger.info("TranscribeAsrExtension on_cmd json: " + cmd_json)

cmdName = cmd.get_name()
logger.info("got cmd %s" % cmdName)

cmd_result = CmdResult.create(StatusCode.OK)
cmd_result.set_property_string("detail", "success")
rte.return_result(cmd_result, cmd)
29 changes: 29 additions & 0 deletions agents/addon/extension/transcribe_asr_python/transcribe_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Union

class TranscribeConfig:
def __init__(self,
region: str,
access_key: str,
secret_key: str,
sample_rate: Union[str, int],
lang_code: str):
self.region = region
self.access_key = access_key
self.secret_key = secret_key

self.lang_code = lang_code
self.sample_rate = int(sample_rate)

self.media_encoding = 'pcm'
self.bytes_per_sample = 2,
self.channel_nums = 1

@classmethod
def default_config(cls):
return cls(
region="us-east-1",
access_key="",
secret_key="",
sample_rate=16000,
lang_code='en-US'
)
Loading

0 comments on commit 739b0b5

Please sign in to comment.