-
Notifications
You must be signed in to change notification settings - Fork 811
Timestamps in parakeet_runner
#16545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
d402b49
0396189
fa3e404
62774fa
afc3427
6313a49
df2e8e8
3a4a2f1
1a23c14
0c9768d
365896d
08b82fd
5c27d9d
9504e37
349d0b6
4b5b15a
7ada1eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,18 +6,28 @@ | |
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #include <algorithm> | ||
| #include <cmath> | ||
| #include <cstdint> | ||
| #include <cstring> | ||
| #include <exception> | ||
| #include <fstream> | ||
| #include <iostream> | ||
| #include <memory> | ||
| #include <optional> | ||
| #include <string> | ||
| #include <unordered_set> | ||
| #include <vector> | ||
|
|
||
| #include <gflags/gflags.h> | ||
|
|
||
| #include "timestamp_utils.h" | ||
| #include "tokenizer_utils.h" | ||
| #include "types.h" | ||
|
|
||
| #include <executorch/extension/llm/runner/llm_runner_helper.h> | ||
| #include <executorch/extension/llm/runner/wav_loader.h> | ||
| #include <executorch/extension/llm/tokenizers/third-party/llama.cpp-unicode/include/unicode.h> | ||
| #include <executorch/extension/module/module.h> | ||
| #include <executorch/extension/tensor/tensor_ptr_maker.h> | ||
| #include <executorch/runtime/core/evalue.h> | ||
|
|
@@ -27,24 +37,75 @@ DEFINE_string(model_path, "parakeet.pte", "Path to Parakeet model (.pte)."); | |
| DEFINE_string(audio_path, "", "Path to input audio file (.wav)."); | ||
| DEFINE_string( | ||
| tokenizer_path, | ||
| "tokenizer.json", | ||
| "tokenizer.model", | ||
| "Path to SentencePiece tokenizer model file."); | ||
| DEFINE_string( | ||
| data_path, | ||
| "", | ||
| "Path to data file (.ptd) for delegate data (optional, required for CUDA)."); | ||
| DEFINE_string( | ||
| timestamps, | ||
| "none", | ||
| "Timestamp output mode: none|token|word|segment|all"); | ||
|
|
||
| using ::executorch::extension::from_blob; | ||
| using ::executorch::extension::Module; | ||
| using ::executorch::runtime::Error; | ||
| using ::executorch::runtime::EValue; | ||
|
|
||
| namespace { | ||
| using ::parakeet::TextWithOffsets; | ||
| using ::parakeet::Token; | ||
| using ::parakeet::TokenId; | ||
| using ::parakeet::TokenWithTextInfo; | ||
|
|
||
| namespace { | ||
| // TDT duration values | ||
| const std::vector<int> DURATIONS = {0, 1, 2, 3, 4}; | ||
|
|
||
| std::vector<int64_t> greedy_decode_executorch( | ||
| struct TimestampOutputMode { | ||
| bool token = false; | ||
| bool word = false; | ||
| bool segment = false; | ||
|
|
||
| bool enabled() const { | ||
| return token || word || segment; | ||
| } | ||
| }; | ||
|
|
||
| std::string to_lower_ascii(std::string s) { | ||
| for (char& ch : s) { | ||
| ch = static_cast<char>(std::tolower(static_cast<unsigned char>(ch))); | ||
| } | ||
| return s; | ||
| } | ||
|
|
||
| TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) { | ||
| if (raw_arg.empty()) { | ||
| throw std::invalid_argument( | ||
| "Invalid --timestamps value (empty). Expected: token, word, segment, all."); | ||
| } | ||
| const std::string mode = to_lower_ascii(raw_arg); | ||
| if (mode == "none") { | ||
| return {false, false, false}; | ||
| } | ||
| if (mode == "token") { | ||
| return {true, false, false}; | ||
| } | ||
| if (mode == "word") { | ||
| return {false, true, false}; | ||
| } | ||
| if (mode == "segment") { | ||
| return {false, false, true}; | ||
| } | ||
| if (mode == "all") { | ||
| return {true, true, true}; | ||
| } | ||
| throw std::invalid_argument( | ||
| "Invalid --timestamps value '" + raw_arg + | ||
| "'. Expected: token, word, segment, all."); | ||
| } | ||
|
|
||
| std::vector<Token> greedy_decode_executorch( | ||
| Module& model, | ||
| const ::executorch::aten::Tensor& encoder_output, | ||
| int64_t encoder_len, | ||
|
|
@@ -53,7 +114,7 @@ std::vector<int64_t> greedy_decode_executorch( | |
| int64_t num_rnn_layers = 2, | ||
| int64_t pred_hidden = 640, | ||
| int64_t max_symbols_per_step = 10) { | ||
| std::vector<int64_t> hypothesis; | ||
| std::vector<Token> hypothesis; | ||
| int64_t num_token_classes = vocab_size + 1; | ||
|
|
||
| // Transpose encoder output from [1, enc_dim, time] to [1, time, enc_dim] | ||
|
|
@@ -205,10 +266,10 @@ std::vector<int64_t> greedy_decode_executorch( | |
| int64_t dur = DURATIONS[dur_idx]; | ||
|
|
||
| if (k == blank_id) { | ||
| t += std::max(dur, (int64_t)1); | ||
| t += std::max(dur, static_cast<int64_t>(1)); | ||
| symbols_on_frame = 0; | ||
| } else { | ||
| hypothesis.push_back(k); | ||
| hypothesis.push_back({static_cast<TokenId>(k), t, dur}); | ||
|
|
||
| // Update decoder state | ||
| std::vector<int64_t> token_data = {k}; | ||
|
|
@@ -268,29 +329,19 @@ std::vector<int64_t> greedy_decode_executorch( | |
| return hypothesis; | ||
| } | ||
|
|
||
| std::string tokens_to_text( | ||
| const std::vector<int64_t>& tokens, | ||
| tokenizers::Tokenizer* tokenizer) { | ||
| // Decode tokens to text one by one | ||
| std::string result; | ||
| uint64_t prev_token = 0; | ||
| for (size_t i = 0; i < tokens.size(); i++) { | ||
| uint64_t token = static_cast<uint64_t>(tokens[i]); | ||
| auto decode_result = tokenizer->decode(prev_token, token); | ||
| if (decode_result.ok()) { | ||
| result += decode_result.get(); | ||
| } | ||
| prev_token = token; | ||
| } | ||
|
|
||
| return result; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| int main(int argc, char** argv) { | ||
| gflags::ParseCommandLineFlags(&argc, &argv, true); | ||
|
|
||
| TimestampOutputMode timestamp_mode; | ||
| try { | ||
| timestamp_mode = parse_timestamp_output_mode(FLAGS_timestamps); | ||
| } catch (const std::invalid_argument& e) { | ||
| ET_LOG(Error, "%s", e.what()); | ||
| return 1; | ||
| } | ||
|
|
||
| if (FLAGS_audio_path.empty()) { | ||
| ET_LOG(Error, "audio_path flag must be provided."); | ||
| return 1; | ||
|
|
@@ -381,10 +432,14 @@ int main(int argc, char** argv) { | |
| auto vocab_size_result = model->execute("vocab_size", empty_inputs); | ||
| auto blank_id_result = model->execute("blank_id", empty_inputs); | ||
| auto sample_rate_result = model->execute("sample_rate", empty_inputs); | ||
| auto window_stride_result = model->execute("window_stride", empty_inputs); | ||
| auto encoder_subsampling_factor_result = | ||
| model->execute("encoder_subsampling_factor", empty_inputs); | ||
|
|
||
| if (!num_rnn_layers_result.ok() || !pred_hidden_result.ok() || | ||
| !vocab_size_result.ok() || !blank_id_result.ok() || | ||
| !sample_rate_result.ok()) { | ||
| !sample_rate_result.ok() || !window_stride_result.ok() || | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that this will break compat with previously exported parakeet models. I chose to do this b/c early in development and to avoid having to make a separate path that allows everything but timestamps if the new metadata isn't present. Open to doing such a thing if reviewers feel strongly.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, i don't mind bc breaking at this early stage. And it's in examples anyway not in core directory (like /extensions) |
||
| !encoder_subsampling_factor_result.ok()) { | ||
| ET_LOG( | ||
| Error, | ||
| "Failed to query model metadata. Make sure the model was exported with constant_methods."); | ||
|
|
@@ -396,18 +451,23 @@ int main(int argc, char** argv) { | |
| int64_t num_rnn_layers = num_rnn_layers_result.get()[0].toInt(); | ||
| int64_t pred_hidden = pred_hidden_result.get()[0].toInt(); | ||
| int64_t sample_rate = sample_rate_result.get()[0].toInt(); | ||
| double window_stride = window_stride_result.get()[0].toDouble(); | ||
| int64_t encoder_subsampling_factor = | ||
| encoder_subsampling_factor_result.get()[0].toInt(); | ||
|
|
||
| ET_LOG( | ||
| Info, | ||
| "Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld", | ||
| "Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld, window_stride=%.6f, encoder_subsampling_factor=%lld", | ||
| static_cast<long long>(vocab_size), | ||
| static_cast<long long>(blank_id), | ||
| static_cast<long long>(num_rnn_layers), | ||
| static_cast<long long>(pred_hidden), | ||
| static_cast<long long>(sample_rate)); | ||
| static_cast<long long>(sample_rate), | ||
| window_stride, | ||
| encoder_subsampling_factor); | ||
|
|
||
| ET_LOG(Info, "Running TDT greedy decode..."); | ||
| auto tokens = greedy_decode_executorch( | ||
| auto decoded_tokens = greedy_decode_executorch( | ||
| *model, | ||
| encoded, | ||
| encoded_len, | ||
|
|
@@ -416,7 +476,7 @@ int main(int argc, char** argv) { | |
| num_rnn_layers, | ||
| pred_hidden); | ||
|
|
||
| ET_LOG(Info, "Decoded %zu tokens", tokens.size()); | ||
| ET_LOG(Info, "Decoded %zu tokens", decoded_tokens.size()); | ||
|
|
||
| // Load tokenizer | ||
| ET_LOG(Info, "Loading tokenizer from: %s", FLAGS_tokenizer_path.c_str()); | ||
|
|
@@ -431,9 +491,68 @@ int main(int argc, char** argv) { | |
| } | ||
|
|
||
| // Convert tokens to text | ||
| std::string text = tokens_to_text(tokens, tokenizer.get()); | ||
| std::cout << "Transcription tokens: " << text << std::endl; | ||
| std::string text = parakeet::tokenizer_utils::decode_token_sequence( | ||
| decoded_tokens, *tokenizer); | ||
| std::cout << "Transcribed text: " << text << std::endl; | ||
|
|
||
| if (!timestamp_mode.enabled()) { | ||
| return 0; | ||
| } | ||
|
|
||
| ET_LOG(Info, "Computing timestamps..."); | ||
| std::unordered_set<std::string> supported_punctuation = | ||
| parakeet::tokenizer_utils::derive_supported_punctuation(*tokenizer); | ||
| ET_LOG( | ||
| Info, | ||
| "Derived supported_punctuation size=%zu", | ||
| supported_punctuation.size()); | ||
|
|
||
| // for simplicity, compute all levels of timestamps regardless of mode | ||
| std::vector<TokenWithTextInfo> tokens_with_text_info; | ||
| try { | ||
| tokens_with_text_info = | ||
| parakeet::timestamp_utils::get_tokens_with_text_info( | ||
| decoded_tokens, *tokenizer, supported_punctuation); | ||
| } catch (const std::exception& e) { | ||
| ET_LOG(Error, "Failed to get tokens with text info: %s", e.what()); | ||
| return 1; | ||
| } | ||
| const auto word_offsets = parakeet::timestamp_utils::get_words_offsets( | ||
| tokens_with_text_info, *tokenizer, supported_punctuation); | ||
| const auto segment_offsets = | ||
| parakeet::timestamp_utils::get_segment_offsets(word_offsets); | ||
|
|
||
| const double frame_to_seconds = | ||
| window_stride * static_cast<double>(encoder_subsampling_factor); | ||
|
|
||
| if (timestamp_mode.segment) { | ||
| std::cout << "\nSegment timestamps:" << std::endl; | ||
| for (const auto& segment : segment_offsets) { | ||
| const double start = segment.start_offset * frame_to_seconds; | ||
| const double end = segment.end_offset * frame_to_seconds; | ||
| std::cout << start << "s - " << end << "s : " << segment.text | ||
| << std::endl; | ||
| } | ||
| } | ||
|
|
||
| if (timestamp_mode.word) { | ||
| std::cout << "\nWord timestamps:" << std::endl; | ||
| for (const auto& word : word_offsets) { | ||
| const double start = word.start_offset * frame_to_seconds; | ||
| const double end = word.end_offset * frame_to_seconds; | ||
| std::cout << start << "s - " << end << "s : " << word.text << std::endl; | ||
| } | ||
| } | ||
|
|
||
| if (timestamp_mode.token) { | ||
| std::cout << "\nToken timestamps:" << std::endl; | ||
| for (const auto& token : tokens_with_text_info) { | ||
| const double start = token.start_offset * frame_to_seconds; | ||
| const double end = token.end_offset * frame_to_seconds; | ||
| std::cout << start << "s - " << end << "s : " << token.decoded_text | ||
| << std::endl; | ||
| } | ||
| } | ||
|
|
||
| ET_LOG(Info, "Done!"); | ||
| return 0; | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we default to one of the options?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, defaulted to
"segment"because the others can get pretty verbose in outputThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say in README.md that segment is the default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes :)