-
Notifications
You must be signed in to change notification settings - Fork 419
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add transcribe_asr extension; optimize bedrock_llm extension; fix pol…
…ly_tts bugs (#174) * add transcribe_asr_python extension * bedrock_llm_extension: add time buffer for model response. * polly_tts_extension: remove on_init, fix sample_rate data type. --------- Co-authored-by: Chen188 <hidden>
- Loading branch information
Showing
16 changed files
with
656 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
## Amazon Transcribe ASR Extension | ||
|
||
### Configurations | ||
|
||
You can config this extension by providing following environments: | ||
|
||
| Env | Required | Default | Notes | | ||
| -- | -- | -- | -- | | ||
| AWS_REGION | No | us-east-1 | The Region of Amazon Transcribe service you want to use. | | ||
| AWS_ACCESS_KEY_ID | No | - | Access Key of your IAM User, make sure you've set proper permissions to [start stream transcription](https://docs.aws.amazon.com/transcribe/latest/APIReference/API_streaming_StartStreamTranscription.html). Will use default credentials provider if not provided. Check [document](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html). | | ||
| AWS_SECRET_ACCESS_KEY | No | - | Secret Key of your IAM User. Will use default credentials provider if not provided. Check [document](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html). | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from . import transcribe_asr_addon | ||
from .extension import EXTENSION_NAME | ||
from .log import logger | ||
|
||
logger.info(f"{EXTENSION_NAME} extension loaded") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
EXTENSION_NAME = "transcribe_asr" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import logging | ||
from .extension import EXTENSION_NAME | ||
|
||
logger = logging.getLogger(EXTENSION_NAME) | ||
logger.setLevel(logging.INFO) | ||
|
||
formatter = logging.Formatter( | ||
"%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s" | ||
) | ||
|
||
console_handler = logging.StreamHandler() | ||
console_handler.setFormatter(formatter) | ||
|
||
logger.addHandler(console_handler) |
76 changes: 76 additions & 0 deletions
76
agents/addon/extension/transcribe_asr_python/manifest.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
{ | ||
"type": "extension", | ||
"name": "transcribe_asr_python", | ||
"version": "0.1.0", | ||
"language": "python", | ||
"dependencies": [ | ||
{ | ||
"type": "system", | ||
"name": "rte_runtime_python", | ||
"version": "0.4.0" | ||
} | ||
], | ||
"api": { | ||
"property": { | ||
"region": { | ||
"type": "string" | ||
}, | ||
"access_key": { | ||
"type": "string" | ||
}, | ||
"secret_key": { | ||
"type": "string" | ||
}, | ||
"sample_rate": { | ||
"type": "string" | ||
}, | ||
"lang_code": { | ||
"type": "string" | ||
} | ||
}, | ||
"pcm_frame_in": [ | ||
{ | ||
"name": "pcm_frame" | ||
} | ||
], | ||
"cmd_in": [ | ||
{ | ||
"name": "on_user_joined" | ||
}, | ||
{ | ||
"name": "on_user_left" | ||
}, | ||
{ | ||
"name": "on_connection_failure" | ||
} | ||
], | ||
"data_out": [ | ||
{ | ||
"name": "text_data", | ||
"property": { | ||
"time": { | ||
"type": "int64" | ||
}, | ||
"duration_ms": { | ||
"type": "int64" | ||
}, | ||
"language": { | ||
"type": "string" | ||
}, | ||
"text": { | ||
"type": "string" | ||
}, | ||
"is_final": { | ||
"type": "bool" | ||
}, | ||
"stream_id": { | ||
"type": "uint32" | ||
}, | ||
"end_of_segment": { | ||
"type": "bool" | ||
} | ||
} | ||
} | ||
] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
amazon-transcribe==0.6.2 |
15 changes: 15 additions & 0 deletions
15
agents/addon/extension/transcribe_asr_python/transcribe_asr_addon.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from rte import ( | ||
Addon, | ||
register_addon_as_extension, | ||
RteEnv, | ||
) | ||
from .extension import EXTENSION_NAME | ||
from .log import logger | ||
from .transcribe_asr_extension import TranscribeAsrExtension | ||
|
||
|
||
@register_addon_as_extension(EXTENSION_NAME) | ||
class TranscribeAsrExtensionAddon(Addon): | ||
def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None: | ||
logger.info("on_create_instance") | ||
rte.on_create_instance_done(TranscribeAsrExtension(addon_name), context) |
90 changes: 90 additions & 0 deletions
90
agents/addon/extension/transcribe_asr_python/transcribe_asr_extension.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from rte import ( | ||
Extension, | ||
RteEnv, | ||
Cmd, | ||
PcmFrame, | ||
StatusCode, | ||
CmdResult, | ||
) | ||
|
||
import asyncio | ||
import threading | ||
|
||
from .log import logger | ||
from .transcribe_wrapper import AsyncTranscribeWrapper, TranscribeConfig | ||
|
||
PROPERTY_REGION = "region" # Optional | ||
PROPERTY_ACCESS_KEY = "access_key" # Optional | ||
PROPERTY_SECRET_KEY = "secret_key" # Optional | ||
PROPERTY_SAMPLE_RATE = 'sample_rate'# Optional | ||
PROPERTY_LANG_CODE = 'lang_code' # Optional | ||
|
||
|
||
class TranscribeAsrExtension(Extension): | ||
def __init__(self, name: str): | ||
super().__init__(name) | ||
|
||
self.stopped = False | ||
self.queue = asyncio.Queue(maxsize=3000) # about 3000 * 10ms = 30s input | ||
self.transcribe = None | ||
self.thread = None | ||
|
||
self.loop = asyncio.new_event_loop() | ||
asyncio.set_event_loop(self.loop) | ||
|
||
def on_start(self, rte: RteEnv) -> None: | ||
logger.info("TranscribeAsrExtension on_start") | ||
|
||
transcribe_config = TranscribeConfig.default_config() | ||
|
||
for optional_param in [PROPERTY_REGION, PROPERTY_SAMPLE_RATE, PROPERTY_LANG_CODE, | ||
PROPERTY_ACCESS_KEY, PROPERTY_SECRET_KEY]: | ||
try: | ||
value = rte.get_property_string(optional_param).strip() | ||
if value: | ||
transcribe_config.__setattr__(optional_param, value) | ||
except Exception as err: | ||
logger.debug(f"GetProperty optional {optional_param} failed, err: {err}. Using default value: {transcribe_config.__getattribute__(optional_param)}") | ||
|
||
self.transcribe = AsyncTranscribeWrapper(transcribe_config, self.queue, rte, self.loop) | ||
|
||
logger.info("Starting async_transcribe_wrapper thread") | ||
self.thread = threading.Thread(target=self.transcribe.run, args=[]) | ||
self.thread.start() | ||
|
||
rte.on_start_done() | ||
|
||
def put_pcm_frame(self, pcm_frame: PcmFrame) -> None: | ||
try: | ||
asyncio.run_coroutine_threadsafe(self.queue.put(pcm_frame), self.loop).result(timeout=0.1) | ||
except asyncio.QueueFull: | ||
logger.exception("Queue is full, dropping frame") | ||
except Exception as e: | ||
logger.exception(f"Error putting frame in queue: {e}") | ||
|
||
def on_pcm_frame(self, rte: RteEnv, pcm_frame: PcmFrame) -> None: | ||
self.put_pcm_frame(pcm_frame=pcm_frame) | ||
|
||
def on_stop(self, rte: RteEnv) -> None: | ||
logger.info("TranscribeAsrExtension on_stop") | ||
|
||
# put an empty frame to stop transcribe_wrapper | ||
self.put_pcm_frame(None) | ||
self.stopped = True | ||
self.thread.join() | ||
self.loop.stop() | ||
self.loop.close() | ||
|
||
rte.on_stop_done() | ||
|
||
def on_cmd(self, rte: RteEnv, cmd: Cmd) -> None: | ||
logger.info("TranscribeAsrExtension on_cmd") | ||
cmd_json = cmd.to_json() | ||
logger.info("TranscribeAsrExtension on_cmd json: " + cmd_json) | ||
|
||
cmdName = cmd.get_name() | ||
logger.info("got cmd %s" % cmdName) | ||
|
||
cmd_result = CmdResult.create(StatusCode.OK) | ||
cmd_result.set_property_string("detail", "success") | ||
rte.return_result(cmd_result, cmd) |
29 changes: 29 additions & 0 deletions
29
agents/addon/extension/transcribe_asr_python/transcribe_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from typing import Union | ||
|
||
class TranscribeConfig: | ||
def __init__(self, | ||
region: str, | ||
access_key: str, | ||
secret_key: str, | ||
sample_rate: Union[str, int], | ||
lang_code: str): | ||
self.region = region | ||
self.access_key = access_key | ||
self.secret_key = secret_key | ||
|
||
self.lang_code = lang_code | ||
self.sample_rate = int(sample_rate) | ||
|
||
self.media_encoding = 'pcm' | ||
self.bytes_per_sample = 2, | ||
self.channel_nums = 1 | ||
|
||
@classmethod | ||
def default_config(cls): | ||
return cls( | ||
region="us-east-1", | ||
access_key="", | ||
secret_key="", | ||
sample_rate=16000, | ||
lang_code='en-US' | ||
) |
Oops, something went wrong.