diff --git a/docs/decoding.md b/docs/decoding.md index c2107c1f9..a4f82d3db 100644 --- a/docs/decoding.md +++ b/docs/decoding.md @@ -79,7 +79,7 @@ The prefix effectively changes the target context and the rest of the translatio > Dieses Projekt ist auf das effiziente **Servieren** von Standard-Übersetzungsmodellen ausgerichtet, ist aber auch ein Ort für Experimente rund um Modellkompression und Inferenzbeschleunigung. -## Biased decoding +## Biased prefix decoding for translation Instead of using {ref}`decoding:autocompletion` to force a translation to start with a `target_prefix` argument, we can "bias" a translation towards a prefix by setting `prefix_bias_beta` to a value in (0, 1). The higher `prefix_bias_beta` is, the stronger the bias. A translation can diverge from a prefix when `prefix_bias_beta` is low and the translator is confident in decoding tokens that are different from the prefix's tokens. See [section 4.2](https://arxiv.org/abs/1912.03393) for more details on the biasing algorithm. @@ -113,6 +113,10 @@ Lowering the bias by setting `prefix_bias_beta=0.1` results in a divergence in t > Dieses Projekt ist auf **die** effiziente Bedienung von Standard-Übersetzungsmodellen ausgerichtet, ist aber auch ein Ort für Experimente rund um Modellkompression und Inferenzbeschleunigung. +## Shallow biasing for contextual ASR + +Setting `sequence_bias` with tuples of `(sequence, biasing_multiplier)` for Whisper models to boost or diminute the hypotheses hitting words in the biasing list during beam search. See [Ssection 3.3](https://aclanthology.org/2024.lrec-main.328.pdf) for the general concept. See [HuggingFace implementation](https://huggingface.co/docs/transformers/en/internal/generation_utils#transformers.SequenceBiasLogitsProcessor) of an additive version. + ## Alternatives at a position Combining `target_prefix` with the `return_alternatives` flag returns alternative sequences just after the prefix: diff --git a/include/ctranslate2/decoding.h b/include/ctranslate2/decoding.h index 5c1d316dc..961ec43c0 100644 --- a/include/ctranslate2/decoding.h +++ b/include/ctranslate2/decoding.h @@ -161,6 +161,7 @@ namespace ctranslate2 { std::vector disable_ids; std::vector disable_ids_begin; std::vector> disable_sequences; + std::vector, float>> sequence_bias; std::vector> logits_processors; std::function callback = nullptr; }; diff --git a/include/ctranslate2/decoding_utils.h b/include/ctranslate2/decoding_utils.h index 3d146c4a1..0cb51e528 100644 --- a/include/ctranslate2/decoding_utils.h +++ b/include/ctranslate2/decoding_utils.h @@ -70,6 +70,46 @@ namespace ctranslate2 { std::vector _flat_indices; }; + // Helper class to disable tokens in the model output. + class BiasTokens { + public: + BiasTokens(StorageView& logits); + + void add(dim_t batch_id, dim_t token_id, float bias_value) { + const auto flat_index = batch_id * _vocabulary_size + token_id; + + if (_logits_data) { + // On CPU, directly assign the value + _logits_data[flat_index] = _logits_data[flat_index] * bias_value; + } else { + // On GPU, prepare a list of unique indices and values to disable + const auto it = std::lower_bound(_flat_indices.begin(), _flat_indices.end(), flat_index, + [](const auto& a, const auto& b) { return a.first < b; }); + + if (it == _flat_indices.end() || it->first != flat_index) { + _flat_indices.emplace(it, flat_index, bias_value); + } else { + it->second *= bias_value; + } + } + } + + // Disable a token for all batches. + void add(dim_t token_id, float bias_value) { + for (dim_t batch_id = 0; batch_id < _batch_size; ++batch_id) + add(batch_id, token_id, bias_value); + } + + void apply(); + + private: + StorageView& _logits; + float* _logits_data; + const dim_t _batch_size; + const dim_t _vocabulary_size; + std::vector> _flat_indices; + }; + // Base class for processing the output logits. class LogitsProcessor { public: @@ -82,6 +122,7 @@ namespace ctranslate2 { virtual void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, + BiasTokens& bias_tokens, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) = 0; @@ -109,6 +150,7 @@ namespace ctranslate2 { void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, + BiasTokens& bias_tokens, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) override; @@ -124,6 +166,7 @@ namespace ctranslate2 { void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, + BiasTokens& bias_tokens, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) override; @@ -139,6 +182,7 @@ namespace ctranslate2 { void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, + BiasTokens& bias_tokens, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) override; @@ -148,6 +192,23 @@ namespace ctranslate2 { std::vector> _sequences; }; + // Disable the generation of some sequences of tokens. + class BiasSequences : public LogitsProcessor { + public: + BiasSequences(std::vector, float>> sequences); + void apply(dim_t step, + StorageView& logits, + DisableTokens& disable_tokens, + BiasTokens& bias_tokens, + const StorageView& sequences, + const std::vector& batch_offset, + const std::vector>* prefix) override; + + private: + std::vector> _ids; + std::vector, float>> _sequences; + }; + // Disable the generation of some tokens. class SuppressTokens : public LogitsProcessor { public: @@ -155,6 +216,7 @@ namespace ctranslate2 { void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, + BiasTokens& bias_tokens, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) override; @@ -170,6 +232,7 @@ namespace ctranslate2 { void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, + BiasTokens& bias_tokens, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) override; diff --git a/include/ctranslate2/models/whisper.h b/include/ctranslate2/models/whisper.h index e9818cc4e..51a4d2832 100644 --- a/include/ctranslate2/models/whisper.h +++ b/include/ctranslate2/models/whisper.h @@ -56,6 +56,9 @@ namespace ctranslate2 { // List of token IDs to suppress. // -1 will suppress a default set of symbols as defined in the model config.json file. std::vector suppress_tokens = {-1}; + + // List of sequences and a bias factor to contextualize decoding. + std::vector, float>> sequence_bias = {}; }; struct WhisperGenerationResult { diff --git a/include/ctranslate2/primitives.h b/include/ctranslate2/primitives.h index 571121554..816bff4d2 100644 --- a/include/ctranslate2/primitives.h +++ b/include/ctranslate2/primitives.h @@ -19,6 +19,8 @@ namespace ctranslate2 { static void strided_fill(T* x, T a, dim_t inc_x, dim_t size); template static void indexed_fill(T* x, T a, const int32_t* indices, dim_t num_indices); + template + static void indexed_pointwise_multiply(T* x, const T* values, const int32_t* indices, dim_t num_indices); template static void copy(const T* x, T* y, dim_t size); diff --git a/python/cpp/whisper.cc b/python/cpp/whisper.cc index d0156c8c1..706975f6b 100644 --- a/python/cpp/whisper.cc +++ b/python/cpp/whisper.cc @@ -45,6 +45,7 @@ namespace ctranslate2 { size_t max_initial_timestamp_index, bool suppress_blank, const std::optional>& suppress_tokens, + const std::optional, float>>>& sequence_bias, size_t sampling_topk, float sampling_temperature) { std::vector> futures; @@ -69,6 +70,10 @@ namespace ctranslate2 { options.suppress_tokens = suppress_tokens.value(); else options.suppress_tokens.clear(); + if (sequence_bias) + options.sequence_bias = sequence_bias.value(); + else + options.sequence_bias.clear(); std::shared_lock lock(_mutex); assert_model_is_ready(); @@ -254,6 +259,7 @@ namespace ctranslate2 { py::arg("max_initial_timestamp_index")=50, py::arg("suppress_blank")=true, py::arg("suppress_tokens")=std::vector{-1}, + py::arg("sequence_bias")=std::vector>{}, py::arg("sampling_topk")=1, py::arg("sampling_temperature")=1, py::call_guard(), @@ -286,6 +292,8 @@ namespace ctranslate2 { suppress_blank: Suppress blank outputs at the beginning of the sampling. suppress_tokens: List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model ``config.json`` file. + sequence_bias: List of pairs of sequences and a biasing factor to boost or surpass + certain sequences. sampling_topk: Randomly sample predictions from the top K candidates. sampling_temperature: Sampling temperature to generate more random samples. diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index 1fed8196d..8cb0f1160 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -832,6 +832,146 @@ def _get_features(audio): transcription = processor.decode(token_ids) assert transcription == expected_transcription + + @test_utils.only_on_linux + @test_utils.on_available_devices + @pytest.mark.parametrize( + "model_name,prompts,expected_transcriptions,expected_no_speech_probs", + [ + ( + "openai/whisper-tiny", + [ + [ + "<|startoftranscript|>", + "<|en|>", + "<|transcribe|>", + "<|notimestamps|>", + ], + [ + "<|startoftranscript|>", + "<|en|>", + "<|transcribe|>", + "<|notimestamps|>", + "ĠAnd", + "Ġthus", + "Ġmy", + ], + ], + [ + " Mr. Quiltre is the apostle of the middle classes and we are glad" + " to welcome his gospel.", + " And thus my fellow Americans ask not what your country can do for you," + " ask what you can do for your country.", + ], + [ + pytest.approx(0.0022832120303064585, abs=1e-4), + pytest.approx(0.06885894387960434, abs=1e-3), + ], + ), + ( + "openai/whisper-tiny", + [ + ["<|startoftranscript|>", "<|en|>", "<|transcribe|>"], + ["<|startoftranscript|>", "<|en|>", "<|transcribe|>"], + ], + [ + " Mr. Quiltre is the apostle of the middle classes and we are glad" + " to welcome his gospel.", + " And so, my fellow Americans, ask not what your country can do for you," + " ask what you can do for your country.", + ], + [ + pytest.approx(0.0022832120303064585, abs=1e-4), + pytest.approx(0.06885894387960434, abs=1e-3), + ], + ) + ], + ) + def test_transformers_contextually_biased_whisper( + self, + tmp_dir, + device, + model_name, + prompts, + expected_transcriptions, + expected_no_speech_probs, + ): + import transformers + + converter = ctranslate2.converters.TransformersConverter(model_name) + output_dir = str(tmp_dir.join("ctranslate2_model")) + output_dir = converter.convert(output_dir) + print(os.path.join( + os.path.dirname(os.path.realpath(__file__)), "..", "..", "tests", "data" + )) + audio_paths = [ + os.path.join(test_utils.get_data_dir(), "audio", "mr_quilter.npy"), + os.path.join(test_utils.get_data_dir(), "audio", "jfk.npy"), + ] + audio = list(map(np.load, audio_paths)) + + processor = transformers.WhisperProcessor.from_pretrained(model_name) + + def _get_features(audio): + # Pad after computing the log-Mel spectrogram to match the openai/whisper behavior. + inputs = processor(audio, padding=False, sampling_rate=16000) + features = inputs.input_features[0] + features = np.pad(features, [(0, 0), (0, 3000 - features.shape[-1])]) + return features + + features = np.stack(list(map(_get_features, audio))) + features = ctranslate2.StorageView.from_array(features) + + model = ctranslate2.models.Whisper(output_dir, device=device) + + assert model.is_multilingual == (not model_name.endswith(".en")) + + if model.is_multilingual: + for result in model.detect_language(features): + best_lang, best_prob = result[0] + assert best_lang == "<|en|>" + assert best_prob > 0.9 + else: + with pytest.raises(RuntimeError, match="multilingual"): + model.detect_language(features) + + #bias the first two generated words into ("Mr. Quiltre") + results = model.generate( + features, + prompts, + beam_size=2, + num_hypotheses=2, + return_no_speech_prob=True, + sequence_bias=[([2221, 13, 2326, 2352, 265], 1.3), ([2221, 13, 2326, 2352], 1.3), ([2221, 13, 2326], 1.3)], + ) + + timestamp_begin = ( + processor.tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1 + ) + + for prompt, result, expected_transcription, expected_no_speech_prob in zip( + prompts, results, expected_transcriptions, expected_no_speech_probs + ): + assert len(result.sequences_ids) == 2 + assert result.no_speech_prob == expected_no_speech_prob + + for tokens in result.sequences_ids: + if "<|notimestamps|>" in prompt: + assert all(token < timestamp_begin for token in tokens) + else: + assert tokens[0] >= timestamp_begin + assert tokens[-1] >= timestamp_begin + assert tokens[-1] > tokens[0] + + token_ids = list( + filter(lambda token: token < timestamp_begin, result.sequences_ids[0]) + ) + + transcription = processor.decode(token_ids) + print(transcription) + print(expected_transcription) + assert transcription == expected_transcription + @test_utils.only_on_linux @test_utils.on_available_devices @pytest.mark.parametrize( @@ -1025,6 +1165,7 @@ def test_transformers_wav2vec2( assert transcription == expected_transcription[0] + class TestWav2Vec2Bert: @classmethod def teardown_class(cls): diff --git a/python/tests/test_utils.py b/python/tests/test_utils.py index 4b2f28dd0..4c6c1ca91 100644 --- a/python/tests/test_utils.py +++ b/python/tests/test_utils.py @@ -13,7 +13,7 @@ def get_data_dir(): ) # Verify that downloaded files are present. - translit_model = os.path.join(data_dir, "models", "transliteration-aren-all") + translit_model = os.path.join(data_dir, "models", "v1", "aren-transliteration") if not os.path.isdir(translit_model): pytest.skip("Data files are not available") diff --git a/src/cpu/primitives.cc b/src/cpu/primitives.cc index 5e0fd2999..c34fa9652 100644 --- a/src/cpu/primitives.cc +++ b/src/cpu/primitives.cc @@ -62,6 +62,14 @@ namespace ctranslate2 { x[indices[i]] = a; } + template<> + template + void primitives::indexed_pointwise_multiply(T* x, const T* values, const int32_t* indices, dim_t num_indices) { + for (dim_t i = 0; i < num_indices; ++i) { + x[indices[i]] = x[indices[i]] * values[i]; + } + } + template<> template void primitives::copy(const T* x, T* y, dim_t size) { @@ -1153,6 +1161,8 @@ namespace ctranslate2 { template void \ primitives::indexed_fill(T*, T, const int32_t*, dim_t); \ template void \ + primitives::indexed_pointwise_multiply(T* x, const T*, const int32_t*, dim_t); \ + template void \ primitives::copy(const T* x, T* y, dim_t size); \ template T \ primitives::sum(const T* array, dim_t size); \ diff --git a/src/cuda/primitives.cu b/src/cuda/primitives.cu index 9915bb12c..5526c377a 100644 --- a/src/cuda/primitives.cu +++ b/src/cuda/primitives.cu @@ -41,6 +41,19 @@ namespace ctranslate2 { THRUST_CALL(thrust::fill, it, it + num_indices, cuda::device_type(a)); } + template<> + template + void primitives::indexed_pointwise_multiply(T* x, const T* values, const int32_t* indices, dim_t num_indices) { + auto element_it = thrust::device_pointer_cast(cuda::device_cast(x)); + auto index_it = thrust::device_pointer_cast(indices); + auto value_it = thrust::device_pointer_cast(cuda::device_cast(values)); + + auto permutation_it = thrust::make_permutation_iterator(element_it, index_it); + + THRUST_CALL(thrust::transform, permutation_it, permutation_it + num_indices, value_it, permutation_it, + thrust::multiplies>()); + } + template<> template void primitives::copy(const T* x, T* y, dim_t size) { @@ -726,6 +739,8 @@ namespace ctranslate2 { template void \ primitives::indexed_fill(T*, T, const int32_t*, dim_t); \ template void \ + primitives::indexed_pointwise_multiply(T* x, const T*, const int32_t*, dim_t); \ + template void \ primitives::copy(const T* x, T* y, dim_t size); \ template T \ primitives::sum(const T* array, dim_t size); \ @@ -795,7 +810,7 @@ namespace ctranslate2 { template void primitives::gelu(const T*, T*, dim_t); \ template void primitives::gelu_tanh(const T*, T*, dim_t); \ template void primitives::gelu_sigmoid(const T*, T*, dim_t); \ - template void primitives::sigmoid(const T*, T*, dim_t); \ + template void primitives::sigmoid(const T*, T*, dim_t); \ template void primitives::swish(const T*, T*, dim_t); \ template float primitives::logsumexp(const T*, dim_t); \ template void primitives::sin(const T*, T*, dim_t); \ diff --git a/src/decoding.cc b/src/decoding.cc index 55a9d7844..b6ed96057 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -498,6 +498,7 @@ namespace ctranslate2 { const dim_t cur_batch_size = is_expanded ? logits.dim(0) / _beam_size : logits.dim(0); DisableTokens disable_tokens(logits); + BiasTokens bias_tokens(logits); // Prevent the generation of end_ids until the minimum length is reached. apply_min_length(step, @@ -512,12 +513,14 @@ namespace ctranslate2 { if (alive_seq) merge_batch_beam(alive_seq); for (const auto& logits_processor : logits_processors) - logits_processor->apply(step, logits, disable_tokens, alive_seq, batch_offset, prefix_ids); + logits_processor->apply(step, logits, disable_tokens, bias_tokens, alive_seq, batch_offset, prefix_ids); if (alive_seq) split_batch_beam(alive_seq, _beam_size); } disable_tokens.apply(); + bias_tokens.apply(); + std::vector logits_vec; if (return_logits_vocab) logits_vec = build_logits(logits, cur_batch_size); @@ -840,6 +843,7 @@ namespace ctranslate2 { gather_attention ? &attention_step_device : nullptr); DisableTokens disable_tokens(logits); + BiasTokens bias_tokens(logits); // Prevent the generation of end_id until the minimum length is reached. apply_min_length(step, @@ -851,9 +855,10 @@ namespace ctranslate2 { prefix_ids); for (const auto& logits_processor : logits_processors) - logits_processor->apply(step, logits, disable_tokens, alive_seq, batch_offset, prefix_ids); + logits_processor->apply(step, logits, disable_tokens, bias_tokens, alive_seq, batch_offset, prefix_ids); disable_tokens.apply(); + bias_tokens.apply(); std::vector logits_vec; StorageView logits_orig(dtype, device); @@ -1100,6 +1105,9 @@ namespace ctranslate2 { if (!options.disable_sequences.empty()) processors.emplace_back(std::make_shared(options.disable_sequences)); + if (!options.sequence_bias.empty()) + processors.emplace_back(std::make_shared(options.sequence_bias)); + for (const auto& processor : options.logits_processors) { if (!processor->apply_first()) processors.emplace_back(processor); diff --git a/src/decoding_utils.cc b/src/decoding_utils.cc index fed4670d3..360d3debd 100644 --- a/src/decoding_utils.cc +++ b/src/decoding_utils.cc @@ -34,6 +34,41 @@ namespace ctranslate2 { _flat_indices.clear(); } + BiasTokens::BiasTokens(StorageView& logits) + : _logits(logits) + , _logits_data(logits.device() == Device::CPU ? logits.data() : nullptr) + , _batch_size(logits.dim(0)) + , _vocabulary_size(logits.dim(1)) + { + } + + void BiasTokens::apply() { + const dim_t num_indices = _flat_indices.size(); + if (num_indices == 0) + return; + + const Device device = _logits.device(); + const DataType dtype = _logits.dtype(); + + std::vector indices(num_indices); + std::vector values(num_indices); + + std::transform(_flat_indices.begin(), _flat_indices.end(), indices.begin(), + [](const auto& pair) { return pair.first; }); + std::transform(_flat_indices.begin(), _flat_indices.end(), values.begin(), + [](const auto& pair) { return pair.second; }); + + const StorageView flat_indices({num_indices}, indices, device); + const StorageView flat_values({num_indices}, values, device); + + // Apply the disable values on GPU + DEVICE_AND_TYPE_DISPATCH(device, dtype, + primitives::indexed_pointwise_multiply(_logits.data(), + flat_values.data(), + flat_indices.data(), + num_indices)); + _flat_indices.clear(); + } RepetitionPenalty::RepetitionPenalty(const float penalty) : _penalty(penalty) @@ -43,6 +78,7 @@ namespace ctranslate2 { void RepetitionPenalty::apply(dim_t, StorageView& logits, DisableTokens&, + BiasTokens&, const StorageView& sequences, const std::vector&, const std::vector>*) { @@ -75,6 +111,7 @@ namespace ctranslate2 { void NoRepeatNgram::apply(dim_t, StorageView&, DisableTokens& disable_tokens, + BiasTokens&, const StorageView& sequences, const std::vector&, const std::vector>*) { @@ -119,6 +156,7 @@ namespace ctranslate2 { void SuppressSequences::apply(dim_t, StorageView&, DisableTokens& disable_tokens, + BiasTokens&, const StorageView& sequences, const std::vector&, const std::vector>*) { @@ -152,6 +190,54 @@ namespace ctranslate2 { } } + BiasSequences::BiasSequences(std::vector, float>> sequences) { + for (auto& sequence : sequences) { + if (sequence.first.empty()) + continue; + if (sequence.first.size() == 1) // Single tokens are always suppressed. + _ids.emplace_back(std::make_pair(sequence.first[0], sequence.second)); + else + _sequences.emplace_back(std::move(sequence)); + } + } + + void BiasSequences::apply(dim_t, + StorageView& logits, + DisableTokens&, + BiasTokens& bias_tokens, + const StorageView& sequences, + const std::vector&, + const std::vector>*) { + for (const auto token_id : _ids) + bias_tokens.add(token_id.first, token_id.second); + + if (!sequences) + return; + + const dim_t batch_size = sequences.dim(0); + const dim_t length = sequences.dim(1); + + for (dim_t batch_id = 0; batch_id < batch_size; ++batch_id) { + const auto* begin = sequences.index({batch_id, 0}); + const auto* end = begin + length; + + for (const auto& biased_sequence : _sequences) { + const dim_t compare_length = biased_sequence.first.size() - 1; + + if (length < compare_length) + continue; + + const bool bias_last = std::equal(end - compare_length, + end, + biased_sequence.first.begin(), + biased_sequence.first.begin() + compare_length); + + if (bias_last) + bias_tokens.add(batch_id, biased_sequence.first.back(), biased_sequence.second); + } + } + } + SuppressTokens::SuppressTokens(std::vector ids) : _ids(std::move(ids)) @@ -161,6 +247,7 @@ namespace ctranslate2 { void SuppressTokens::apply(dim_t, StorageView&, DisableTokens& disable_tokens, + BiasTokens&, const StorageView&, const std::vector&, const std::vector>*) { @@ -177,6 +264,7 @@ namespace ctranslate2 { void SuppressTokensBegin::apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, + BiasTokens&, const StorageView&, const std::vector& batch_offset, const std::vector>* prefix) { diff --git a/src/models/whisper.cc b/src/models/whisper.cc index 7cdf2dc5b..9d71f436f 100644 --- a/src/models/whisper.cc +++ b/src/models/whisper.cc @@ -210,6 +210,7 @@ namespace ctranslate2 { void apply(dim_t step, StorageView& logits, DisableTokens&, + BiasTokens&, const StorageView&, const std::vector& batch_offset, const std::vector>*) override { @@ -304,6 +305,7 @@ namespace ctranslate2 { decoding_options.return_scores = options.return_scores; decoding_options.return_logits_vocab = options.return_logits_vocab; decoding_options.include_eos_in_hypotheses = false; + decoding_options.sequence_bias = options.sequence_bias; for (const auto& id : options.suppress_tokens) { if (id >= 0) @@ -749,6 +751,7 @@ namespace ctranslate2 { void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, + BiasTokens&, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) override { diff --git a/tests/decoding_test.cc b/tests/decoding_test.cc index a5e9937f1..3a10e77d5 100644 --- a/tests/decoding_test.cc +++ b/tests/decoding_test.cc @@ -14,3 +14,17 @@ TEST(DecodingTest, DisableTokens) { expect_storage_eq(input, expected); } + +TEST(BiasDecodingTest, BiasTokens) { + StorageView input({2, 5}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + StorageView expected({2, 5}, std::vector{1, 3, 6, 4, 5, 6, 7, 48, 27, 10}); + BiasTokens bias_tokens(input); + + bias_tokens.add(2, 2.0); + bias_tokens.add(0, 1, 1.5); + bias_tokens.add(1, 3, 3.0); + bias_tokens.add(1, 2, 3.0); + bias_tokens.apply(); + + expect_storage_eq(input, expected); +} diff --git a/tests/primitives_test.cc b/tests/primitives_test.cc index 9f603de33..c4f7edc02 100644 --- a/tests/primitives_test.cc +++ b/tests/primitives_test.cc @@ -22,6 +22,16 @@ TEST_P(PrimitiveTest, IndexedFill) { expect_storage_eq(x, expected); } +TEST_P(PrimitiveTest, IndexedPointwiseMultiply) { + const Device device = GetParam(); + StorageView x({6}, float(2), device); + StorageView values({3}, std::vector{3.0, 3.0, 3.0}, device); + StorageView ids({3}, std::vector{0, 2, 5}, device); + StorageView expected({6}, std::vector{6, 2, 6, 2, 2, 6}, device); + DEVICE_DISPATCH(device, primitives::indexed_pointwise_multiply(x.data(), values.data(), ids.data(), 3)); + expect_storage_eq(x, expected); +} + TEST_P(PrimitiveTest, LogSumExp) { const Device device = GetParam(); StorageView x({8}, std::vector{0.6, 0.2, -1.2, 0.1, 0.3, 0.5, -1.3, 0.2}, device);