From 7e490acf219e2fb1f1d565b91edaa5807fd226d7 Mon Sep 17 00:00:00 2001 From: thucpham Date: Wed, 6 Dec 2023 17:24:36 +0100 Subject: [PATCH 01/10] convert opennmt py on the fly and inference with model in memory --- include/ctranslate2/models/language_model.h | 1 + include/ctranslate2/models/model.h | 16 +++ .../ctranslate2/models/sequence_to_sequence.h | 1 + include/ctranslate2/models/transformer.h | 1 + include/ctranslate2/models/wav2vec2.h | 1 + include/ctranslate2/models/whisper.h | 1 + python/cpp/generator.cc | 39 +++++++ python/cpp/replica_pool.h | 68 ++++++++++-- python/cpp/translator.cc | 8 +- python/ctranslate2/__init__.py | 2 +- python/ctranslate2/converters/converter.py | 35 ++++++ python/ctranslate2/onTheFly/__init__.py | 1 + .../onTheFly/generator_on_the_fly.py | 103 ++++++++++++++++++ python/ctranslate2/specs/attention_spec.py | 14 +-- python/ctranslate2/specs/model_spec.py | 25 ++++- python/ctranslate2/specs/transformer_spec.py | 16 +-- src/models/language_model.cc | 16 +++ src/models/model.cc | 56 ++++++++++ src/models/sequence_to_sequence.cc | 53 +++++++++ src/models/transformer.cc | 10 ++ src/models/wav2vec2.cc | 11 ++ src/models/whisper.cc | 11 ++ 22 files changed, 458 insertions(+), 31 deletions(-) create mode 100644 python/ctranslate2/onTheFly/__init__.py create mode 100644 python/ctranslate2/onTheFly/generator_on_the_fly.py diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h index 7532b9a3a..04c083da9 100644 --- a/include/ctranslate2/models/language_model.h +++ b/include/ctranslate2/models/language_model.h @@ -22,6 +22,7 @@ namespace ctranslate2 { protected: void initialize(ModelReader& model_reader) override; + void initialize(std::unordered_map>& vocabularies) override; private: std::shared_ptr _vocabulary; diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index 43a4ea5b9..2e8e9baf2 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -31,6 +31,17 @@ namespace ctranslate2 { Device device = Device::CPU, int device_index = 0, ComputeType compute_type = ComputeType::DEFAULT); + static std::shared_ptr load(const std::string& spec, + const size_t& spec_version, + const size_t& binary_version, + std::unordered_map& alias, + std::unordered_map>& vocabularies, + std::unordered_map& variables, + const std::string& config, + Device device = Device::CPU, + int device_index = 0, + ComputeType compute_type = ComputeType::DEFAULT); + virtual std::unique_ptr as_sequence_to_sequence() const; virtual std::unique_ptr as_sequence_generator() const; @@ -86,6 +97,10 @@ namespace ctranslate2 { return ScopedDeviceSetter(_device, _device_index); } + void set_config(const std::string& config_str); + void set_revision(const size_t revision); + void set_binary_version(const size_t binary_version); + // If the model contains variables, they will be moved to the new device. void set_device(const Device device, const int index = 0); @@ -143,6 +158,7 @@ namespace ctranslate2 { // Runs some initialization after the model is loaded. virtual void initialize(ModelReader&) {} + virtual void initialize(std::unordered_map>&) {} virtual std::unique_ptr clone() const = 0; diff --git a/include/ctranslate2/models/sequence_to_sequence.h b/include/ctranslate2/models/sequence_to_sequence.h index e1d79327f..125d0d77f 100644 --- a/include/ctranslate2/models/sequence_to_sequence.h +++ b/include/ctranslate2/models/sequence_to_sequence.h @@ -32,6 +32,7 @@ namespace ctranslate2 { protected: virtual void initialize(ModelReader& model_reader) override; + virtual void initialize(std::unordered_map>& vocabularies) override; private: std::vector> _source_vocabularies; diff --git a/include/ctranslate2/models/transformer.h b/include/ctranslate2/models/transformer.h index 4e97f85e0..74ee851c5 100644 --- a/include/ctranslate2/models/transformer.h +++ b/include/ctranslate2/models/transformer.h @@ -34,6 +34,7 @@ namespace ctranslate2 { protected: bool is_linear_weight(const std::string& variable_name) const override; void initialize(ModelReader& model_reader) override; + void initialize(std::unordered_map>& vocabularies) override; std::unique_ptr clone() const override; }; diff --git a/include/ctranslate2/models/wav2vec2.h b/include/ctranslate2/models/wav2vec2.h index d1034ef88..5427d2631 100644 --- a/include/ctranslate2/models/wav2vec2.h +++ b/include/ctranslate2/models/wav2vec2.h @@ -41,6 +41,7 @@ namespace ctranslate2 { protected: void initialize(ModelReader& model_reader) override; + void initialize(std::unordered_map>& vocabularies) override; private: std::shared_ptr _vocabulary; }; diff --git a/include/ctranslate2/models/whisper.h b/include/ctranslate2/models/whisper.h index 7ade2bd20..c3ccb873b 100644 --- a/include/ctranslate2/models/whisper.h +++ b/include/ctranslate2/models/whisper.h @@ -90,6 +90,7 @@ namespace ctranslate2 { protected: void initialize(ModelReader& model_reader) override; + void initialize(std::unordered_map>& vocabularies) override; private: std::shared_ptr _vocabulary; diff --git a/python/cpp/generator.cc b/python/cpp/generator.cc index 981c6da68..2f1cc53b3 100644 --- a/python/cpp/generator.cc +++ b/python/cpp/generator.cc @@ -1,6 +1,7 @@ #include "module.h" #include +#include #include "replica_pool.h" @@ -158,6 +159,44 @@ namespace ctranslate2 { :obj:`model_path` acts as an identifier for this model. )pbdoc") + .def(py::init&, + std::unordered_map>&, std::unordered_map&, const std::string&, const std::string&, const std::variant>, const StringOrMap&, size_t, size_t, long>(), + py::arg("spec"), + py::arg("spec_revision"), + py::arg("binary_version"), + py::arg("aliases"), + py::arg("vocabularies"), + py::arg("variables"), + py::arg("config"), + py::arg("device")="cpu", + py::arg("device_index")=0, + py::arg("compute_type")="default", + py::arg("inter_threads")=1, + py::arg("intra_threads")=0, + py::arg("max_queued_batches")=0, + R"pbdoc( + Initializes the generator. + + Arguments: + spec: The name of the model specification. + spec_revision: The model specification revision. + binary_version: The version of binary model + aliases: aliases got in the mode + vocabularies: dictionary of name and list of tokens + variables: dictionary of name of variables and storage view of variable + config: list of config (normally saved in config.json) + device: Device to use (possible values are: cpu, cuda, auto). + device_index: Device IDs where to place this generator on. + compute_type: Model computation type or a dictionary mapping a device name + to the computation type (possible values are: default, auto, int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + inter_threads: Maximum number of parallel generations. + intra_threads: Number of OpenMP threads per generator (0 to use a default value). + max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, + 0 for an automatic value). When the queue is full, future requests will block + until a free slot is available. + )pbdoc") + .def_property_readonly("device", &GeneratorWrapper::device, "Device this generator is running on.") .def_property_readonly("device_index", &GeneratorWrapper::device_index, diff --git a/python/cpp/replica_pool.h b/python/cpp/replica_pool.h index a735ea363..4a26d7459 100644 --- a/python/cpp/replica_pool.h +++ b/python/cpp/replica_pool.h @@ -1,7 +1,11 @@ #pragma once #include +#include +#include +#include +#include #include "utils.h" namespace ctranslate2 { @@ -49,15 +53,53 @@ namespace ctranslate2 { { pybind11::gil_scoped_release nogil; - _model_loader.device = str_to_device(device); - _model_loader.device_indices = std::visit(DeviceIndexResolver(), device_index); - _model_loader.compute_type = std::visit(ComputeTypeResolver(device), compute_type); - _model_loader.num_replicas_per_device = inter_threads; + _model_loader->device = str_to_device(device); + _model_loader->device_indices = std::visit(DeviceIndexResolver(), device_index); + _model_loader->compute_type = std::visit(ComputeTypeResolver(device), compute_type); + _model_loader->num_replicas_per_device = inter_threads; _pool_config.num_threads_per_replica = intra_threads; _pool_config.max_queued_batches = max_queued_batches; - _pool = std::make_unique(_model_loader, _pool_config); + _pool = std::make_unique(_model_loader.value(), _pool_config); + } + + ReplicaPoolHelper(const std::string& spec, + const size_t& spec_version, + const size_t& binary_version, + std::unordered_map& aliases, + std::unordered_map>& vocabularies, + std::unordered_map& variables, + const std::string& config, + const std::string& device, + const std::variant>& device_index, + const StringOrMap& compute_type, + size_t ,//inter_threads + size_t intra_threads, + long max_queued_batches) + { + pybind11::gil_scoped_release nogil; + + // Load the variables. + auto model_device = str_to_device(device); + auto model_device_indices = std::visit(DeviceIndexResolver(), device_index)[0]; + auto model_compute_type = std::visit(ComputeTypeResolver(device), compute_type); + + auto model = models::Model::load(spec, + spec_version, + binary_version, + aliases, + vocabularies, + variables, + config, + model_device, + model_device_indices, + model_compute_type); + + _pool_config.num_threads_per_replica = intra_threads; + _pool_config.max_queued_batches = max_queued_batches; + + _pool = std::make_unique(model, _pool_config); } ~ReplicaPoolHelper() { @@ -66,11 +108,19 @@ namespace ctranslate2 { } std::string device() const { - return device_to_str(_model_loader.device); + if (_model_loader.has_value()) + return device_to_str(_model_loader->device); + if (_device) + return _device.value(); + return ""; } const std::vector& device_index() const { - return _model_loader.device_indices; + if (_model_loader.has_value()) + return _model_loader->device_indices; + if (!_device_index.has_value() || _device_index->empty()) + throw pybind11::type_error("No device index found"); + return _device_index.value(); } std::string compute_type() const { @@ -91,7 +141,9 @@ namespace ctranslate2 { protected: std::unique_ptr _pool; - models::ModelLoader _model_loader; + std::optional _model_loader; + std::optional _device; + std::optional> _device_index; ReplicaPoolConfig _pool_config; const std::shared_ptr& model() const { diff --git a/python/cpp/translator.cc b/python/cpp/translator.cc index d920469fe..a544d6a25 100644 --- a/python/cpp/translator.cc +++ b/python/cpp/translator.cc @@ -42,9 +42,9 @@ namespace ctranslate2 { intra_threads, max_queued_batches, files) - , _device(_model_loader.device) - , _device_index(_model_loader.device_indices) - , _num_replicas_per_device(_model_loader.num_replicas_per_device) + , _device(_model_loader->device) + , _device_index(_model_loader->device_indices) + , _num_replicas_per_device(_model_loader->num_replicas_per_device) , _model_is_loaded(true) { } @@ -324,7 +324,7 @@ namespace ctranslate2 { return; if (_cached_models.empty()) { - _cached_models = _model_loader.load(); + _cached_models = _model_loader->load(); } else { move_cached_models(_device, _device_index, _num_replicas_per_device); } diff --git a/python/ctranslate2/__init__.py b/python/ctranslate2/__init__.py index 9c0efac2a..dfc7650dd 100644 --- a/python/ctranslate2/__init__.py +++ b/python/ctranslate2/__init__.py @@ -50,5 +50,5 @@ else: raise -from ctranslate2 import converters, models, specs +from ctranslate2 import converters, models, specs, onTheFly from ctranslate2.version import __version__ diff --git a/python/ctranslate2/converters/converter.py b/python/ctranslate2/converters/converter.py index ecede044a..05700f709 100644 --- a/python/ctranslate2/converters/converter.py +++ b/python/ctranslate2/converters/converter.py @@ -104,6 +104,41 @@ def convert( model_spec.save(output_dir) return output_dir + def convert_on_the_fly( + self, + vmap: Optional[str] = None, + quantization: Optional[str] = None, + ) -> ModelSpec: + """Converts the model to the CTranslate2 format. + + Arguments: + vmap: Optional path to a vocabulary mapping file that will be included + in the converted model directory. + quantization: Weight quantization scheme (possible values are: int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + + Returns: + Path to the output directory. + + Raises: + RuntimeError: If the output directory already exists and :obj:`force` + is not set. + NotImplementedError: If the converter cannot convert this model to the + CTranslate2 format. + """ + model_spec = self._load() + if model_spec is None: + raise NotImplementedError( + "This model is not supported by CTranslate2 or this converter" + ) + if vmap is not None: + model_spec.register_vocabulary_mapping(vmap) + + model_spec.validate() + model_spec.optimize(quantization=quantization) + + return model_spec + @abc.abstractmethod def _load(self): raise NotImplementedError() diff --git a/python/ctranslate2/onTheFly/__init__.py b/python/ctranslate2/onTheFly/__init__.py new file mode 100644 index 000000000..240316127 --- /dev/null +++ b/python/ctranslate2/onTheFly/__init__.py @@ -0,0 +1 @@ +from ctranslate2.onTheFly.generator_on_the_fly import GeneratorOnTheFly diff --git a/python/ctranslate2/onTheFly/generator_on_the_fly.py b/python/ctranslate2/onTheFly/generator_on_the_fly.py new file mode 100644 index 000000000..965e135b6 --- /dev/null +++ b/python/ctranslate2/onTheFly/generator_on_the_fly.py @@ -0,0 +1,103 @@ +import os +from typing import Optional + +from ctranslate2.converters.opennmt_py import OpenNMTPyConverter +import ctranslate2 +import json + +class GeneratorOnTheFly: + def __init__(self, model_path: str, + device="cpu", + device_index=0, + compute_type="default", + inter_threads=1, + intra_threads=0, + max_queued_batches=0, + model_type="OpenNMTPy", + quantization: Optional[str] = None, + ): + """Initializes the generator on the fly. + + Arguments: + model_path: Path to the CTranslate2 model directory. + device: Device to use (possible values are: cpu, cuda, auto). + device_index: Device IDs where to place this generator on. + compute_type: Model computation type or a dictionary mapping a device name + to the computation type (possible values are: default, auto, int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + inter_threads: Maximum number of parallel generations. + intra_threads: Number of OpenMP threads per generator (0 to use a default value). + max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, + 0 for an automatic value). When the queue is full, future requests will block + until a free slot is available. + model_type: type of converter to convert the model + quantization: quantize the model + """ + converter = _get_converter(model_path=model_path, model_type=model_type) + model_spec = converter.convert_on_the_fly(quantization=quantization) + + variables = model_spec.variables(ordered=True) + vocabularies = model_spec.get_vocabulary() + config = json.dumps(model_spec.config.to_dict()) + aliases = {} + + spec = model_spec.name + spec_revision = model_spec.revision + binary_version = model_spec.binary_version + variables_cpp = dict() + + for key, value in variables: + if isinstance(value, str): + aliases[key] = value + else: + variables_cpp[key] = ctranslate2.StorageView.from_array(value.numpy()) + + self.generator = ctranslate2.Generator(spec=spec, + spec_revision=spec_revision, + binary_revision=binary_version, + aliases=aliases, + vocabularies=vocabularies, + variables=variables_cpp, + config=config, + device=device, + device_index=device_index, + compute_type=compute_type, + inter_threads=inter_threads, + intra_threads=intra_threads, + max_queued_batches=max_queued_batches) + + #self.generator = ctranslate2.Generator(model_path, device=device, files=model_spec.files_memory) + + def generate_iterable(self, start_tokens, *args, **kwargs): + return self.generator.generate_tokens(start_tokens, *args, **kwargs) + + def generate_tokens(self, start_tokens, *args, **kwargs): + return self.generator.generate_tokens(start_tokens, *args, **kwargs) + + def score_iterable(self, tokens, *args, **kwargs): + return self.generator.score_iterable(tokens, *args, **kwargs) + + def async_generate_tokens(self, prompt, *args, **kwargs): + return self.generator.async_generate_tokens(prompt, *args, **kwargs) + + def _get_converter(model_path: str, model_type: str): + if model_type == "OpenNMTPy": + def get_model_file(_model_path: str): + for filename in os.listdir(model_path): + if filename.endswith(".pt"): + _model_file = os.path.join(model_path, filename) + return _model_file + return '' + + model_file = get_model_file(model_path) + if model_file == '': + raise RuntimeError( + "No model opennmt-py found in %s" % model_path + ) + + converter = OpenNMTPyConverter(model_path=model_file) + return converter + else: + raise NotImplementedError( + "Converter on the fly for %s is not implemented." % model_type + ) diff --git a/python/ctranslate2/specs/attention_spec.py b/python/ctranslate2/specs/attention_spec.py index 0b3a44c4b..82cf13395 100644 --- a/python/ctranslate2/specs/attention_spec.py +++ b/python/ctranslate2/specs/attention_spec.py @@ -43,18 +43,16 @@ def __init__( self.relative_attention_max_distance = None if rotary_dim is not None: - self.rotary_dim = np.dtype("int32").type(rotary_dim) + self.rotary_dim = np.array(rotary_dim, dtype="int32") self.rotary_interleave = rotary_interleave - self.rotary_base = np.dtype("float32").type(rotary_base) + self.rotary_base = np.array(rotary_base, dtype="float32") if rotary_scaling_type is not None: - self.rotary_scaling_type = np.dtype("int8").type(rotary_scaling_type) - self.rotary_scaling_factor = np.dtype("float32").type( - rotary_scaling_factor - ) + self.rotary_scaling_type = np.array(rotary_scaling_type, dtype="int8") + self.rotary_scaling_factor = np.array(rotary_scaling_factor, dtype="float32") if num_heads_kv is not None: - self.num_heads_kv = np.dtype("int32").type(num_heads_kv) + self.num_heads_kv = np.array(num_heads_kv, dtype="int32") if sliding_window is not None: - self.sliding_window = np.dtype("int32").type(sliding_window) + self.sliding_window = np.array(sliding_window, dtype="int32") diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 4cb765636..fdcdbffc8 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -114,10 +114,10 @@ def _check(spec, name, value): if value.dtype == np.float64: value = value.astype(np.float32) elif isinstance(value, float): - value = np.dtype("float32").type(value) + value = np.array(value, dtype="float32") elif isinstance(value, bool): # Convert bool to an integer type. - value = np.dtype("int8").type(value) + value = np.array(value, dtype="int8") elif isinstance(value, str): if value != OPTIONAL: value = np.frombuffer(value.encode("utf-8"), dtype=np.int8) @@ -311,6 +311,11 @@ def __init__(self): self._config = self.get_default_config() self._files = {} + @abc.abstractmethod + def get_vocabulary(self): + """Returns the map vocabulary expected by the model.""" + raise NotImplementedError() + @property def name(self): """The name of the model specification.""" @@ -325,6 +330,11 @@ def revision(self): """ return 1 + @property + def binary_version(self): + """The binary version""" + return CURRENT_BINARY_VERSION + @property def config(self): """The model configuration.""" @@ -455,6 +465,14 @@ def __init__(self): "target": [], } + def get_vocabulary(self): + vocabularies = dict(_flatten_vocabularies(self._vocabularies)) + all_vocabularies = list(vocabularies.values()) + if all(vocabulary == all_vocabularies[0] for vocabulary in all_vocabularies): + vocabularies = {"shared": all_vocabularies[0]} + + return vocabularies + def get_default_config(self): return SequenceToSequenceModelConfig() @@ -566,6 +584,9 @@ def __init__(self): super().__init__() self._vocabulary = [] + def get_vocabulary(self): + return {"vocabulary": self._vocabulary} + def get_default_config(self): return LanguageModelConfig() diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index 7208be8a9..a9232652e 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -45,10 +45,10 @@ def __init__( rms_norm: Use the root mean square layer normalization. multi_query_attention: Use multi-query attention. """ - self.num_heads = np.dtype("int16").type(num_heads) + self.num_heads = np.array(num_heads, dtype="int16") self.pre_norm = pre_norm - self.activation = np.dtype("int8").type(activation) - self.embeddings_merge = np.dtype("int8").type(embeddings_merge) + self.activation = np.array(activation, dtype="int8") + self.embeddings_merge = np.array(embeddings_merge, dtype="int8") self.embeddings = [ common_spec.EmbeddingsSpec() for _ in range(num_source_embeddings) ] @@ -160,11 +160,11 @@ def __init__( % num_heads_kv ) - self.num_heads = np.dtype("int16").type(num_heads) + self.num_heads = np.array(num_heads, dtype="int16") self.pre_norm = pre_norm - self.activation = np.dtype("int8").type(activation) - self.alignment_layer = np.dtype("int16").type(alignment_layer) - self.alignment_heads = np.dtype("int16").type(alignment_heads) + self.activation = np.array(activation, dtype="int8") + self.alignment_layer = np.array(alignment_layer, dtype="int16") + self.alignment_heads = np.array(alignment_heads, dtype="int16") self.embeddings = common_spec.EmbeddingsSpec() self.scale_embeddings = True self.scale_outputs = model_spec.OPTIONAL @@ -172,7 +172,7 @@ def __init__( self.alibi_use_positive_positions = alibi_use_positive_positions self.scale_alibi = scale_alibi if sliding_window is not None: - self.sliding_window = np.dtype("int32").type(sliding_window) + self.sliding_window = np.array(sliding_window, dtype="int32") if ( not relative_position and not relative_attention_bias diff --git a/src/models/language_model.cc b/src/models/language_model.cc index 466e42594..b4899fe48 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -35,6 +35,22 @@ namespace ctranslate2 { throw std::runtime_error("Cannot load the vocabulary from the model directory"); } + void LanguageModel::initialize(std::unordered_map>& vocabularies) { + if (binary_version() < 6) { + config["unk_token"] = get_attribute_with_default("unk_token", ""); + config["bos_token"] = get_attribute_with_default("bos_token", ""); + config["eos_token"] = get_attribute_with_default("eos_token", ""); + } + + VocabularyInfo vocab_info; + vocab_info.unk_token = config["unk_token"]; + vocab_info.bos_token = config["bos_token"]; + vocab_info.eos_token = config["eos_token"]; + + _vocabulary = std::make_shared(vocabularies.at("vocabulary")); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); + } std::vector SequenceGeneratorReplica::score(const std::vector>& tokens, diff --git a/src/models/model.cc b/src/models/model.cc index 0672494ff..48ac37bd1 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -185,6 +185,18 @@ namespace ctranslate2 { } } + void Model::set_config(const std::string& config_str) { + config = nlohmann::json::parse(config_str); + } + + void Model::set_revision(const size_t revision) { + _spec_revision = revision; + } + + void Model::set_binary_version(const size_t binary_version) { + _binary_version = binary_version; + } + const StorageView* Model::get_variable_if_exists(const std::string& name) const { auto it = _variable_index.find(name); if (it == _variable_index.end()) @@ -509,6 +521,50 @@ namespace ctranslate2 { return model; } + std::shared_ptr Model::load(const std::string& spec, + const size_t& spec_version, + const size_t& binary_version, + std::unordered_map& aliases, + std::unordered_map>& vocabularies, + std::unordered_map& variables, + const std::string& config, + Device device, + int device_index, + ComputeType compute_type) { + auto model = models::create_model(spec); + + // Load the variables. + for (auto& variable : variables) { + model->register_variable(variable.first, std::move(variable.second)); + } + // Maybe quantize/dequantize/convert the variables to match the requested compute type. + model->set_compute_type(compute_type, device, device_index); + + // Move variables to the target device. + model->set_device(device, device_index); + + model->set_config(config); + + model->set_revision(spec_version); + model->set_binary_version(binary_version); + // Register variable aliases. + if (binary_version >= 3) { + for (auto& alias_pair : aliases) { + const auto alias = alias_pair.first; + const auto variable_name = alias_pair.second; + model->register_variable_alias(alias, variable_name); + // Also alias the quantization scale that could be associated to variable_name. + model->register_variable_alias(alias + "_scale", variable_name + "_scale"); + } + } + + // Run additional model initialization. + const ScopedDeviceSetter scoped_device_setter(device, device_index); + model->process_linear_weights(); + model->initialize(vocabularies); + return model; + } + std::shared_ptr Model::copy_to(Device device, int device_index) const { auto model = clone(); diff --git a/src/models/sequence_to_sequence.cc b/src/models/sequence_to_sequence.cc index a7e64611f..7b239954a 100644 --- a/src/models/sequence_to_sequence.cc +++ b/src/models/sequence_to_sequence.cc @@ -76,6 +76,59 @@ namespace ctranslate2 { load_vocabularies(model_reader); } + void SequenceToSequenceModel::initialize(std::unordered_map>& vocabularies) { + if (binary_version() < 6) { + config["unk_token"] = get_attribute_with_default("unk_token", ""); + config["bos_token"] = get_attribute_with_default("bos_token", ""); + config["eos_token"] = get_attribute_with_default("eos_token", ""); + config["add_source_bos"] = get_flag_with_default("with_source_bos", false); + config["add_source_eos"] = get_flag_with_default("with_source_eos", false); + + if (get_flag_with_default("user_decoder_start_tokens", false)) + config["decoder_start_token"] = nullptr; + else if (get_flag_with_default("with_target_bos", true)) + config["decoder_start_token"] = config["bos_token"]; + else + config["decoder_start_token"] = config["eos_token"]; + } + + VocabularyInfo vocab_info; + vocab_info.unk_token = config["unk_token"]; + vocab_info.bos_token = config["bos_token"]; + vocab_info.eos_token = config["eos_token"]; + + auto shared_vocabulary = std::make_shared(std::move(vocabularies.at("shared_vocabulary"))); + + if (shared_vocabulary) { + _target_vocabulary = shared_vocabulary; + _source_vocabularies = {shared_vocabulary}; + + } else { + _target_vocabulary = std::make_shared(std::move(vocabularies.at("target_vocabulary"))); + if (!_target_vocabulary) + throw std::runtime_error("Cannot load the target vocabulary from the model directory"); + + auto source_vocabulary = std::make_shared(std::move(vocabularies.at("source_vocabulary"))); + + if (source_vocabulary) { + _source_vocabularies = {source_vocabulary}; + } else { + for (size_t i = 1;; i++) { + const std::string name = "source_" + std::to_string(i) + "_vocabulary"; + auto vocabulary = std::make_shared(std::move(vocabularies.at(name))); + + if (!vocabulary) + break; + + _source_vocabularies.emplace_back(vocabulary); + } + } + + if (_source_vocabularies.empty()) + throw std::runtime_error("Cannot load the source vocabulary from the model directory"); + } + } + size_t SequenceToSequenceModel::num_source_vocabularies() const { return _source_vocabularies.size(); } diff --git a/src/models/transformer.cc b/src/models/transformer.cc index f62984b2e..099ece22a 100644 --- a/src/models/transformer.cc +++ b/src/models/transformer.cc @@ -108,6 +108,16 @@ namespace ctranslate2 { } } + void TransformerDecoderModel::initialize(std::unordered_map>& vocabularies) { + LanguageModel::initialize(vocabularies); + + if (spec_revision() < 2) { + register_variable_alias("decoder/num_heads", "num_heads"); + register_variable_alias("decoder/pre_norm", "pre_norm"); + register_variable_alias("decoder/activation", "activation"); + } + } + std::unique_ptr TransformerDecoderModel::as_sequence_generator() const { const auto scoped_device_setter = get_scoped_device_setter(); diff --git a/src/models/wav2vec2.cc b/src/models/wav2vec2.cc index 79a7a40d4..f4b8ed56c 100644 --- a/src/models/wav2vec2.cc +++ b/src/models/wav2vec2.cc @@ -34,6 +34,17 @@ namespace ctranslate2 { throw std::runtime_error("Cannot load the vocabulary from the model directory"); } + void Wav2Vec2Model::initialize(std::unordered_map>& vocabularies) { + VocabularyInfo vocab_info; + vocab_info.unk_token = "[UNK]"; + vocab_info.bos_token = ""; + vocab_info.eos_token = ""; + + _vocabulary = std::make_shared(vocabularies.at("vocabulary")); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); + } + bool Wav2Vec2Model::is_quantizable(const std::string& variable_name) const { return (Model::is_quantizable(variable_name) && variable_name.find("conv") == std::string::npos); diff --git a/src/models/whisper.cc b/src/models/whisper.cc index da12898e9..c5d1e807c 100644 --- a/src/models/whisper.cc +++ b/src/models/whisper.cc @@ -33,6 +33,17 @@ namespace ctranslate2 { throw std::runtime_error("Cannot load the vocabulary from the model directory"); } + void WhisperModel::initialize(std::unordered_map>& vocabularies) { + VocabularyInfo vocab_info; + vocab_info.unk_token = "<|endoftext|>"; + vocab_info.bos_token = "<|startoftranscript|>"; + vocab_info.eos_token = "<|endoftext|>"; + + _vocabulary = std::make_shared(vocabularies.at("vocabulary")); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); + } + bool WhisperModel::is_quantizable(const std::string& variable_name) const { return (Model::is_quantizable(variable_name) && variable_name.find("conv") == std::string::npos); From b0b6c3397ace7bae28134a25331e1216cd13f515 Mon Sep 17 00:00:00 2001 From: thucpham Date: Wed, 6 Dec 2023 18:35:58 +0100 Subject: [PATCH 02/10] fix black and flake --- .../onTheFly/generator_on_the_fly.py | 126 +++++++++++------- 1 file changed, 75 insertions(+), 51 deletions(-) diff --git a/python/ctranslate2/onTheFly/generator_on_the_fly.py b/python/ctranslate2/onTheFly/generator_on_the_fly.py index 965e135b6..9303856f1 100644 --- a/python/ctranslate2/onTheFly/generator_on_the_fly.py +++ b/python/ctranslate2/onTheFly/generator_on_the_fly.py @@ -5,35 +5,65 @@ import ctranslate2 import json + +def _get_converter(model_path: str, model_type: str): + if model_type == "OpenNMTPy": + + def get_model_file(_model_path: str): + for filename in os.listdir(model_path): + if filename.endswith(".pt"): + _model_file = os.path.join(model_path, filename) + return _model_file + return "" + + model_file = get_model_file(model_path) + if model_file == "": + raise RuntimeError("No model opennmt-py found in %s" % model_path) + + converter = OpenNMTPyConverter(model_path=model_file) + return converter + else: + raise NotImplementedError( + "Converter on the fly for %s is not implemented." % model_type + ) + + class GeneratorOnTheFly: - def __init__(self, model_path: str, - device="cpu", - device_index=0, - compute_type="default", - inter_threads=1, - intra_threads=0, - max_queued_batches=0, - model_type="OpenNMTPy", - quantization: Optional[str] = None, - ): - """Initializes the generator on the fly. - - Arguments: - model_path: Path to the CTranslate2 model directory. - device: Device to use (possible values are: cpu, cuda, auto). - device_index: Device IDs where to place this generator on. - compute_type: Model computation type or a dictionary mapping a device name - to the computation type (possible values are: default, auto, int8, int8_float32, - int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). - inter_threads: Maximum number of parallel generations. - intra_threads: Number of OpenMP threads per generator (0 to use a default value). - max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, - 0 for an automatic value). When the queue is full, future requests will block - until a free slot is available. - model_type: type of converter to convert the model - quantization: quantize the model - """ - converter = _get_converter(model_path=model_path, model_type=model_type) + """Initializes the generator on the fly. + + Arguments: + model_path: Path to the CTranslate2 model directory. + device: Device to use (possible values are: cpu, cuda, auto). + device_index: Device IDs where to place this generator on. + compute_type: Model computation type or a dictionary mapping + a device name to the computation type (possible values are: + default, auto, int8, int8_float32, int8_float16, int8_bfloat16, + int16, float16, bfloat16, float32). + inter_threads: Maximum number of parallel generations. + intra_threads: Number of OpenMP threads per generator + (0 to use a default value). + max_queued_batches: Maximum numbers of batches in the queue + (-1 for unlimited, 0 for an automatic value). + When the queue is full, future requests will block + until a free slot is available. + model_type: type of converter to convert the model + quantization: quantize the model + """ + + def __init__( + self, + model_path: str, + device="cpu", + device_index=0, + compute_type="default", + inter_threads=1, + intra_threads=0, + max_queued_batches=0, + model_type="OpenNMTPy", + quantization: Optional[str] = None, + ): + converter = _get_converter(model_path=model_path, + model_type=model_type) model_spec = converter.convert_on_the_fly(quantization=quantization) variables = model_spec.variables(ordered=True) @@ -50,7 +80,8 @@ def __init__(self, model_path: str, if isinstance(value, str): aliases[key] = value else: - variables_cpp[key] = ctranslate2.StorageView.from_array(value.numpy()) + variables_cpp[key] = ctranslate2.StorageView.from_array( + value.numpy()) self.generator = ctranslate2.Generator(spec=spec, spec_revision=spec_revision, @@ -67,6 +98,21 @@ def __init__(self, model_path: str, max_queued_batches=max_queued_batches) #self.generator = ctranslate2.Generator(model_path, device=device, files=model_spec.files_memory) + self.generator = ctranslate2.Generator( + spec=spec, + spec_revision=spec_revision, + binary_revision=binary_version, + aliases=aliases, + vocabularies=vocabularies, + variables=variables_cpp, + config=config, + device=device, + device_index=device_index, + compute_type=compute_type, + inter_threads=inter_threads, + intra_threads=intra_threads, + max_queued_batches=max_queued_batches, + ) def generate_iterable(self, start_tokens, *args, **kwargs): return self.generator.generate_tokens(start_tokens, *args, **kwargs) @@ -79,25 +125,3 @@ def score_iterable(self, tokens, *args, **kwargs): def async_generate_tokens(self, prompt, *args, **kwargs): return self.generator.async_generate_tokens(prompt, *args, **kwargs) - - def _get_converter(model_path: str, model_type: str): - if model_type == "OpenNMTPy": - def get_model_file(_model_path: str): - for filename in os.listdir(model_path): - if filename.endswith(".pt"): - _model_file = os.path.join(model_path, filename) - return _model_file - return '' - - model_file = get_model_file(model_path) - if model_file == '': - raise RuntimeError( - "No model opennmt-py found in %s" % model_path - ) - - converter = OpenNMTPyConverter(model_path=model_file) - return converter - else: - raise NotImplementedError( - "Converter on the fly for %s is not implemented." % model_type - ) From 00f4b2dc095d54b0c8c3b8fb4e30349cf73abf6e Mon Sep 17 00:00:00 2001 From: thucpham Date: Wed, 6 Dec 2023 18:40:47 +0100 Subject: [PATCH 03/10] fix black and flake --- .../ctranslate2/onTheFly/generator_on_the_fly.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/python/ctranslate2/onTheFly/generator_on_the_fly.py b/python/ctranslate2/onTheFly/generator_on_the_fly.py index 9303856f1..9a732b8d1 100644 --- a/python/ctranslate2/onTheFly/generator_on_the_fly.py +++ b/python/ctranslate2/onTheFly/generator_on_the_fly.py @@ -83,21 +83,6 @@ def __init__( variables_cpp[key] = ctranslate2.StorageView.from_array( value.numpy()) - self.generator = ctranslate2.Generator(spec=spec, - spec_revision=spec_revision, - binary_revision=binary_version, - aliases=aliases, - vocabularies=vocabularies, - variables=variables_cpp, - config=config, - device=device, - device_index=device_index, - compute_type=compute_type, - inter_threads=inter_threads, - intra_threads=intra_threads, - max_queued_batches=max_queued_batches) - - #self.generator = ctranslate2.Generator(model_path, device=device, files=model_spec.files_memory) self.generator = ctranslate2.Generator( spec=spec, spec_revision=spec_revision, From b283b03eff597041f9ee8348d292c4e50894f315 Mon Sep 17 00:00:00 2001 From: thucpham Date: Wed, 6 Dec 2023 18:57:52 +0100 Subject: [PATCH 04/10] fix black --- python/ctranslate2/converters/converter.py | 7 ++++--- python/ctranslate2/onTheFly/generator_on_the_fly.py | 6 ++---- python/ctranslate2/specs/attention_spec.py | 4 +++- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/python/ctranslate2/converters/converter.py b/python/ctranslate2/converters/converter.py index 05700f709..acc182efe 100644 --- a/python/ctranslate2/converters/converter.py +++ b/python/ctranslate2/converters/converter.py @@ -105,9 +105,9 @@ def convert( return output_dir def convert_on_the_fly( - self, - vmap: Optional[str] = None, - quantization: Optional[str] = None, + self, + vmap: Optional[str] = None, + quantization: Optional[str] = None, ) -> ModelSpec: """Converts the model to the CTranslate2 format. @@ -136,6 +136,7 @@ def convert_on_the_fly( model_spec.validate() model_spec.optimize(quantization=quantization) + # model_spec.save(output_dir, False) return model_spec diff --git a/python/ctranslate2/onTheFly/generator_on_the_fly.py b/python/ctranslate2/onTheFly/generator_on_the_fly.py index 9a732b8d1..fab804684 100644 --- a/python/ctranslate2/onTheFly/generator_on_the_fly.py +++ b/python/ctranslate2/onTheFly/generator_on_the_fly.py @@ -62,8 +62,7 @@ def __init__( model_type="OpenNMTPy", quantization: Optional[str] = None, ): - converter = _get_converter(model_path=model_path, - model_type=model_type) + converter = _get_converter(model_path=model_path, model_type=model_type) model_spec = converter.convert_on_the_fly(quantization=quantization) variables = model_spec.variables(ordered=True) @@ -80,8 +79,7 @@ def __init__( if isinstance(value, str): aliases[key] = value else: - variables_cpp[key] = ctranslate2.StorageView.from_array( - value.numpy()) + variables_cpp[key] = ctranslate2.StorageView.from_array(value.numpy()) self.generator = ctranslate2.Generator( spec=spec, diff --git a/python/ctranslate2/specs/attention_spec.py b/python/ctranslate2/specs/attention_spec.py index 82cf13395..5c13d4f6a 100644 --- a/python/ctranslate2/specs/attention_spec.py +++ b/python/ctranslate2/specs/attention_spec.py @@ -49,7 +49,9 @@ def __init__( if rotary_scaling_type is not None: self.rotary_scaling_type = np.array(rotary_scaling_type, dtype="int8") - self.rotary_scaling_factor = np.array(rotary_scaling_factor, dtype="float32") + self.rotary_scaling_factor = np.array( + rotary_scaling_factor, dtype="float32" + ) if num_heads_kv is not None: self.num_heads_kv = np.array(num_heads_kv, dtype="int32") From b35f89839670898104576b18118a5055e9b25f17 Mon Sep 17 00:00:00 2001 From: thucpham Date: Wed, 6 Dec 2023 19:07:24 +0100 Subject: [PATCH 05/10] fix isort --- python/ctranslate2/__init__.py | 2 +- python/ctranslate2/onTheFly/generator_on_the_fly.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/ctranslate2/__init__.py b/python/ctranslate2/__init__.py index dfc7650dd..3c8fb8988 100644 --- a/python/ctranslate2/__init__.py +++ b/python/ctranslate2/__init__.py @@ -50,5 +50,5 @@ else: raise -from ctranslate2 import converters, models, specs, onTheFly +from ctranslate2 import converters, models, onTheFly, specs from ctranslate2.version import __version__ diff --git a/python/ctranslate2/onTheFly/generator_on_the_fly.py b/python/ctranslate2/onTheFly/generator_on_the_fly.py index fab804684..d12fb6f1e 100644 --- a/python/ctranslate2/onTheFly/generator_on_the_fly.py +++ b/python/ctranslate2/onTheFly/generator_on_the_fly.py @@ -1,9 +1,11 @@ +import json import os + from typing import Optional -from ctranslate2.converters.opennmt_py import OpenNMTPyConverter import ctranslate2 -import json + +from ctranslate2.converters.opennmt_py import OpenNMTPyConverter def _get_converter(model_path: str, model_type: str): From a429ab34c0c0f010c611c81fd25900cbd5483e23 Mon Sep 17 00:00:00 2001 From: thucpham Date: Wed, 6 Dec 2023 20:37:48 +0100 Subject: [PATCH 06/10] reformat generator on the fly --- python/ctranslate2/__init__.py | 3 +- python/ctranslate2/extensions.py | 240 ++++++++++++++++++ .../{onTheFly => }/generator_on_the_fly.py | 16 +- python/ctranslate2/onTheFly/__init__.py | 1 - 4 files changed, 247 insertions(+), 13 deletions(-) rename python/ctranslate2/{onTheFly => }/generator_on_the_fly.py (85%) delete mode 100644 python/ctranslate2/onTheFly/__init__.py diff --git a/python/ctranslate2/__init__.py b/python/ctranslate2/__init__.py index 3c8fb8988..8d453ee2a 100644 --- a/python/ctranslate2/__init__.py +++ b/python/ctranslate2/__init__.py @@ -39,6 +39,7 @@ set_random_seed, ) from ctranslate2.extensions import register_extensions + from ctranslate2.generator_on_the_fly import GeneratorOnTheFly from ctranslate2.logging import get_log_level, set_log_level register_extensions() @@ -50,5 +51,5 @@ else: raise -from ctranslate2 import converters, models, onTheFly, specs +from ctranslate2 import converters, models, specs from ctranslate2.version import __version__ diff --git a/python/ctranslate2/extensions.py b/python/ctranslate2/extensions.py index b6d9fd4b5..67b45ea2d 100644 --- a/python/ctranslate2/extensions.py +++ b/python/ctranslate2/extensions.py @@ -14,6 +14,7 @@ TranslationResult, Translator, ) +from ctranslate2.generator_on_the_fly import GeneratorOnTheFly def register_extensions(): @@ -25,6 +26,10 @@ def register_extensions(): setattr(Generator, "score_iterable", generator_score_iterable) setattr(Generator, "generate_tokens", generator_generate_tokens) setattr(Generator, "async_generate_tokens", generator_async_generate_tokens) + setattr(GeneratorOnTheFly, "generate_iterable", generator_generate_iterable_on_the_fly) + setattr(GeneratorOnTheFly, "score_iterable", generator_score_iterable_on_the_fly) + setattr(GeneratorOnTheFly, "generate_tokens", generator_generate_tokens_on_the_fly) + setattr(GeneratorOnTheFly, "async_generate_tokens", generator_async_generate_tokens_on_the_fly) def translator_translate_iterable( @@ -430,6 +435,241 @@ async def generator_async_generate_tokens( yield step_result +def generator_generate_tokens_on_the_fly( + generator: GeneratorOnTheFly, + prompt: Union[List[str], List[List[str]]], + max_batch_size: int = 0, + batch_type: str = "examples", + *, + max_length: int = 512, + min_length: int = 0, + sampling_topk: int = 1, + sampling_topp: float = 1, + sampling_temperature: float = 1, + return_log_prob: bool = False, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, + disable_unk: bool = False, + suppress_sequences: Optional[List[List[str]]] = None, + end_token: Optional[Union[str, List[str], List[int]]] = None, + static_prompt: Optional[List[str]] = None, + cache_static_prompt: bool = True, + callback: Callable[[GenerationStepResult], bool] = None, +) -> Iterable[GenerationStepResult]: + """Yields tokens as they are generated by the model. + + Arguments: + prompt: Batch of start tokens. If the decoder starts from a + special start token like , this token should be added to this input. + max_batch_size: The maximum batch size. + batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". + max_length: Maximum generation length. + min_length: Minimum generation length. + sampling_topk: Randomly sample predictions from the top K candidates. + sampling_topp: Keep the most probable tokens whose cumulative probability exceeds this value. + sampling_temperature: Sampling temperature to generate more random samples. + return_log_prob: Include the token log probability in the result. + repetition_penalty: Penalty applied to the score of previously generated tokens + (set > 1 to penalize). + no_repeat_ngram_size: Prevent repetitions of ngrams with this size + (set 0 to disable). + disable_unk: Disable the generation of the unknown token. + suppress_sequences: Disable the generation of some sequences of tokens. + end_token: Stop the decoding on one these tokens (defaults to the model EOS token). + static_prompt: If the model expects a static prompt (a.k.a. system prompt) + it can be set here to simplify the inputs and optionally cache the model + state for this prompt to accelerate future generations. + cache_static_prompt: Cache the model state after the static prompt and + reuse it for future generations using the same static prompt. + callback: Optional function that is called for each generated token when + obj:`beam_size` is 1. If the callback function returns ``True``, the + decoding will stop for this batch index. + + Returns: + A generator iterator over :class:`ctranslate2.GenerationStepResult` instances. + + Note: + This generation method is not compatible with beam search which requires a complete decoding. + """ + if len(prompt) > 0 and isinstance(prompt[0], str): + prompt = [prompt] + + yield from _generate_tokens( + generator.generate_batch, + prompt, + max_batch_size=max_batch_size, + batch_type=batch_type, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + disable_unk=disable_unk, + suppress_sequences=suppress_sequences, + end_token=end_token, + max_length=max_length, + min_length=min_length, + sampling_topk=sampling_topk, + sampling_topp=sampling_topp, + sampling_temperature=sampling_temperature, + return_scores=return_log_prob, + static_prompt=static_prompt, + cache_static_prompt=cache_static_prompt, + include_prompt_in_result=False, + callback=callback, + ) + +def generator_generate_iterable_on_the_fly( + generator: GeneratorOnTheFly, + start_tokens: Iterable[List[str]], + max_batch_size: int = 32, + batch_type: str = "examples", + **kwargs, +) -> Iterable[GenerationResult]: + """Generates from an iterable of tokenized prompts. + + This method is built on top of :meth:`ctranslate2.Generator.generate_batch` + to efficiently run generation on an arbitrarily large stream of data. It enables + the following optimizations: + + * stream processing (the iterable is not fully materialized in memory) + * parallel generations (if the generator has multiple workers) + * asynchronous batch prefetching + * local sorting by length + + Arguments: + start_tokens: An iterable of tokenized prompts. + max_batch_size: The maximum batch size. + batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". + **kwargs: Any generation options accepted by + :meth:`ctranslate2.Generator.generate_batch`. + + Returns: + A generator iterator over :class:`ctranslate2.GenerationResult` instances. + """ + yield from _process_iterable( + generator.generate_batch, + [start_tokens], + max_batch_size, + batch_type, + **kwargs, + ) + +def generator_score_iterable_on_the_fly( + generator: GeneratorOnTheFly, + tokens: Iterable[List[str]], + max_batch_size: int = 64, + batch_type: str = "examples", + **kwargs, +) -> Iterable[ScoringResult]: + """Scores an iterable of tokenized examples. + + This method is built on top of :meth:`ctranslate2.Generator.score_batch` + to efficiently score an arbitrarily large stream of data. It enables + the following optimizations: + + * stream processing (the iterable is not fully materialized in memory) + * parallel scoring (if the generator has multiple workers) + * asynchronous batch prefetching + * local sorting by length + + Arguments: + tokens: An iterable of tokenized examples. + max_batch_size: The maximum batch size. + batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". + **kwargs: Any score options accepted by + :meth:`ctranslate2.Generator.score_batch`. + + Returns: + A generator iterator over :class:`ctranslate2.ScoringResult` instances. + """ + yield from _process_iterable( + generator.score_batch, + [tokens], + max_batch_size, + batch_type, + **kwargs, + ) + +async def generator_async_generate_tokens_on_the_fly( + generator: GeneratorOnTheFly, + prompt: Union[List[str], List[List[str]]], + max_batch_size: int = 0, + batch_type: str = "examples", + *, + max_length: int = 512, + min_length: int = 0, + sampling_topk: int = 1, + sampling_topp: float = 1, + sampling_temperature: float = 1, + return_log_prob: bool = False, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, + disable_unk: bool = False, + suppress_sequences: Optional[List[List[str]]] = None, + end_token: Optional[Union[str, List[str], List[int]]] = None, + static_prompt: Optional[List[str]] = None, + cache_static_prompt: bool = True, + callback: Callable[[GenerationStepResult], bool] = None, +) -> AsyncIterable[GenerationStepResult]: + """Yields tokens asynchronously as they are generated by the model. + + Arguments: + prompt: Batch of start tokens. If the decoder starts from a + special start token like , this token should be added to this input. + max_batch_size: The maximum batch size. + batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". + max_length: Maximum generation length. + min_length: Minimum generation length. + sampling_topk: Randomly sample predictions from the top K candidates. + sampling_topp: Keep the most probable tokens whose cumulative probability exceeds this value. + sampling_temperature: Sampling temperature to generate more random samples. + return_log_prob: Include the token log probability in the result. + repetition_penalty: Penalty applied to the score of previously generated tokens + (set > 1 to penalize). + no_repeat_ngram_size: Prevent repetitions of ngrams with this size + (set 0 to disable). + disable_unk: Disable the generation of the unknown token. + suppress_sequences: Disable the generation of some sequences of tokens. + end_token: Stop the decoding on one of these tokens (defaults to the model EOS token). + static_prompt: If the model expects a static prompt (a.k.a. system prompt) + it can be set here to simplify the inputs and optionally cache the model + state for this prompt to accelerate future generations. + cache_static_prompt: Cache the model state after the static prompt and + reuse it for future generations using the same static prompt. + callback: Optional function that is called for each generated token when + obj:`beam_size` is 1. If the callback function returns ``True``, the + decoding will stop for this batch index. + + Returns: + An async generator iterator over :class:`ctranslate2.GenerationStepResult` instances. + + Note: + This generation method is not compatible with beam search which requires a complete decoding. + """ + if len(prompt) > 0 and isinstance(prompt[0], str): + prompt = [prompt] + async for step_result in AsyncGenerator( + generator.generate_batch, + prompt, + max_batch_size=max_batch_size, + batch_type=batch_type, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + disable_unk=disable_unk, + suppress_sequences=suppress_sequences, + end_token=end_token, + max_length=max_length, + min_length=min_length, + sampling_topk=sampling_topk, + sampling_topp=sampling_topp, + sampling_temperature=sampling_temperature, + return_scores=return_log_prob, + static_prompt=static_prompt, + cache_static_prompt=cache_static_prompt, + include_prompt_in_result=False, + callback=callback, + ): + yield step_result + + class AsyncGenerator: def __init__(self, process_func, *args, **kwargs): self.queue = asyncio.Queue() diff --git a/python/ctranslate2/onTheFly/generator_on_the_fly.py b/python/ctranslate2/generator_on_the_fly.py similarity index 85% rename from python/ctranslate2/onTheFly/generator_on_the_fly.py rename to python/ctranslate2/generator_on_the_fly.py index d12fb6f1e..3dbb5a33c 100644 --- a/python/ctranslate2/onTheFly/generator_on_the_fly.py +++ b/python/ctranslate2/generator_on_the_fly.py @@ -86,7 +86,7 @@ def __init__( self.generator = ctranslate2.Generator( spec=spec, spec_revision=spec_revision, - binary_revision=binary_version, + binary_version=binary_version, aliases=aliases, vocabularies=vocabularies, variables=variables_cpp, @@ -99,14 +99,8 @@ def __init__( max_queued_batches=max_queued_batches, ) - def generate_iterable(self, start_tokens, *args, **kwargs): - return self.generator.generate_tokens(start_tokens, *args, **kwargs) + def generate_batch(self, prompt, *args, **kwargs): + return self.generator.generate_batch(prompt, *args, **kwargs) - def generate_tokens(self, start_tokens, *args, **kwargs): - return self.generator.generate_tokens(start_tokens, *args, **kwargs) - - def score_iterable(self, tokens, *args, **kwargs): - return self.generator.score_iterable(tokens, *args, **kwargs) - - def async_generate_tokens(self, prompt, *args, **kwargs): - return self.generator.async_generate_tokens(prompt, *args, **kwargs) + def score_batch(self, tokens, *args, **kwargs): + return self.generator.score_batch(tokens, *args, **kwargs) diff --git a/python/ctranslate2/onTheFly/__init__.py b/python/ctranslate2/onTheFly/__init__.py deleted file mode 100644 index 240316127..000000000 --- a/python/ctranslate2/onTheFly/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ctranslate2.onTheFly.generator_on_the_fly import GeneratorOnTheFly From fb839f49e3ce2c9b18208e0fa3c580741ebe8816 Mon Sep 17 00:00:00 2001 From: thucpham Date: Wed, 6 Dec 2023 21:15:54 +0100 Subject: [PATCH 07/10] fix black --- python/ctranslate2/extensions.py | 147 ++++++++++++++++--------------- 1 file changed, 78 insertions(+), 69 deletions(-) diff --git a/python/ctranslate2/extensions.py b/python/ctranslate2/extensions.py index 67b45ea2d..e6d9de67a 100644 --- a/python/ctranslate2/extensions.py +++ b/python/ctranslate2/extensions.py @@ -26,10 +26,16 @@ def register_extensions(): setattr(Generator, "score_iterable", generator_score_iterable) setattr(Generator, "generate_tokens", generator_generate_tokens) setattr(Generator, "async_generate_tokens", generator_async_generate_tokens) - setattr(GeneratorOnTheFly, "generate_iterable", generator_generate_iterable_on_the_fly) + setattr( + GeneratorOnTheFly, "generate_iterable", generator_generate_iterable_on_the_fly + ) setattr(GeneratorOnTheFly, "score_iterable", generator_score_iterable_on_the_fly) setattr(GeneratorOnTheFly, "generate_tokens", generator_generate_tokens_on_the_fly) - setattr(GeneratorOnTheFly, "async_generate_tokens", generator_async_generate_tokens_on_the_fly) + setattr( + GeneratorOnTheFly, + "async_generate_tokens", + generator_async_generate_tokens_on_the_fly, + ) def translator_translate_iterable( @@ -436,25 +442,25 @@ async def generator_async_generate_tokens( def generator_generate_tokens_on_the_fly( - generator: GeneratorOnTheFly, - prompt: Union[List[str], List[List[str]]], - max_batch_size: int = 0, - batch_type: str = "examples", - *, - max_length: int = 512, - min_length: int = 0, - sampling_topk: int = 1, - sampling_topp: float = 1, - sampling_temperature: float = 1, - return_log_prob: bool = False, - repetition_penalty: float = 1, - no_repeat_ngram_size: int = 0, - disable_unk: bool = False, - suppress_sequences: Optional[List[List[str]]] = None, - end_token: Optional[Union[str, List[str], List[int]]] = None, - static_prompt: Optional[List[str]] = None, - cache_static_prompt: bool = True, - callback: Callable[[GenerationStepResult], bool] = None, + generator: GeneratorOnTheFly, + prompt: Union[List[str], List[List[str]]], + max_batch_size: int = 0, + batch_type: str = "examples", + *, + max_length: int = 512, + min_length: int = 0, + sampling_topk: int = 1, + sampling_topp: float = 1, + sampling_temperature: float = 1, + return_log_prob: bool = False, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, + disable_unk: bool = False, + suppress_sequences: Optional[List[List[str]]] = None, + end_token: Optional[Union[str, List[str], List[int]]] = None, + static_prompt: Optional[List[str]] = None, + cache_static_prompt: bool = True, + callback: Callable[[GenerationStepResult], bool] = None, ) -> Iterable[GenerationStepResult]: """Yields tokens as they are generated by the model. @@ -516,12 +522,13 @@ def generator_generate_tokens_on_the_fly( callback=callback, ) + def generator_generate_iterable_on_the_fly( - generator: GeneratorOnTheFly, - start_tokens: Iterable[List[str]], - max_batch_size: int = 32, - batch_type: str = "examples", - **kwargs, + generator: GeneratorOnTheFly, + start_tokens: Iterable[List[str]], + max_batch_size: int = 32, + batch_type: str = "examples", + **kwargs, ) -> Iterable[GenerationResult]: """Generates from an iterable of tokenized prompts. @@ -552,12 +559,13 @@ def generator_generate_iterable_on_the_fly( **kwargs, ) + def generator_score_iterable_on_the_fly( - generator: GeneratorOnTheFly, - tokens: Iterable[List[str]], - max_batch_size: int = 64, - batch_type: str = "examples", - **kwargs, + generator: GeneratorOnTheFly, + tokens: Iterable[List[str]], + max_batch_size: int = 64, + batch_type: str = "examples", + **kwargs, ) -> Iterable[ScoringResult]: """Scores an iterable of tokenized examples. @@ -588,26 +596,27 @@ def generator_score_iterable_on_the_fly( **kwargs, ) + async def generator_async_generate_tokens_on_the_fly( - generator: GeneratorOnTheFly, - prompt: Union[List[str], List[List[str]]], - max_batch_size: int = 0, - batch_type: str = "examples", - *, - max_length: int = 512, - min_length: int = 0, - sampling_topk: int = 1, - sampling_topp: float = 1, - sampling_temperature: float = 1, - return_log_prob: bool = False, - repetition_penalty: float = 1, - no_repeat_ngram_size: int = 0, - disable_unk: bool = False, - suppress_sequences: Optional[List[List[str]]] = None, - end_token: Optional[Union[str, List[str], List[int]]] = None, - static_prompt: Optional[List[str]] = None, - cache_static_prompt: bool = True, - callback: Callable[[GenerationStepResult], bool] = None, + generator: GeneratorOnTheFly, + prompt: Union[List[str], List[List[str]]], + max_batch_size: int = 0, + batch_type: str = "examples", + *, + max_length: int = 512, + min_length: int = 0, + sampling_topk: int = 1, + sampling_topp: float = 1, + sampling_temperature: float = 1, + return_log_prob: bool = False, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, + disable_unk: bool = False, + suppress_sequences: Optional[List[List[str]]] = None, + end_token: Optional[Union[str, List[str], List[int]]] = None, + static_prompt: Optional[List[str]] = None, + cache_static_prompt: bool = True, + callback: Callable[[GenerationStepResult], bool] = None, ) -> AsyncIterable[GenerationStepResult]: """Yields tokens asynchronously as they are generated by the model. @@ -647,25 +656,25 @@ async def generator_async_generate_tokens_on_the_fly( if len(prompt) > 0 and isinstance(prompt[0], str): prompt = [prompt] async for step_result in AsyncGenerator( - generator.generate_batch, - prompt, - max_batch_size=max_batch_size, - batch_type=batch_type, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - disable_unk=disable_unk, - suppress_sequences=suppress_sequences, - end_token=end_token, - max_length=max_length, - min_length=min_length, - sampling_topk=sampling_topk, - sampling_topp=sampling_topp, - sampling_temperature=sampling_temperature, - return_scores=return_log_prob, - static_prompt=static_prompt, - cache_static_prompt=cache_static_prompt, - include_prompt_in_result=False, - callback=callback, + generator.generate_batch, + prompt, + max_batch_size=max_batch_size, + batch_type=batch_type, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + disable_unk=disable_unk, + suppress_sequences=suppress_sequences, + end_token=end_token, + max_length=max_length, + min_length=min_length, + sampling_topk=sampling_topk, + sampling_topp=sampling_topp, + sampling_temperature=sampling_temperature, + return_scores=return_log_prob, + static_prompt=static_prompt, + cache_static_prompt=cache_static_prompt, + include_prompt_in_result=False, + callback=callback, ): yield step_result From 141034aa75df358b1ae72fd39f5638ffa62963a6 Mon Sep 17 00:00:00 2001 From: thucpham Date: Thu, 7 Dec 2023 17:28:22 +0100 Subject: [PATCH 08/10] accept directly onmt-py model file --- python/ctranslate2/generator_on_the_fly.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/python/ctranslate2/generator_on_the_fly.py b/python/ctranslate2/generator_on_the_fly.py index 3dbb5a33c..f2796e1fc 100644 --- a/python/ctranslate2/generator_on_the_fly.py +++ b/python/ctranslate2/generator_on_the_fly.py @@ -10,19 +10,10 @@ def _get_converter(model_path: str, model_type: str): if model_type == "OpenNMTPy": - - def get_model_file(_model_path: str): - for filename in os.listdir(model_path): - if filename.endswith(".pt"): - _model_file = os.path.join(model_path, filename) - return _model_file - return "" - - model_file = get_model_file(model_path) - if model_file == "": + if not os.path.exists(model_path): raise RuntimeError("No model opennmt-py found in %s" % model_path) - converter = OpenNMTPyConverter(model_path=model_file) + converter = OpenNMTPyConverter(model_path=model_path) return converter else: raise NotImplementedError( From 86f35657b97a1f23436819dbc31cedceacfe25a2 Mon Sep 17 00:00:00 2001 From: thucpham Date: Thu, 7 Dec 2023 18:08:14 +0100 Subject: [PATCH 09/10] keep vocabs and config in generator to postprocess --- python/ctranslate2/generator_on_the_fly.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ctranslate2/generator_on_the_fly.py b/python/ctranslate2/generator_on_the_fly.py index f2796e1fc..5cd1eeebf 100644 --- a/python/ctranslate2/generator_on_the_fly.py +++ b/python/ctranslate2/generator_on_the_fly.py @@ -59,8 +59,8 @@ def __init__( model_spec = converter.convert_on_the_fly(quantization=quantization) variables = model_spec.variables(ordered=True) - vocabularies = model_spec.get_vocabulary() - config = json.dumps(model_spec.config.to_dict()) + self.vocabularies = model_spec.get_vocabulary() + self.config = json.dumps(model_spec.config.to_dict()) aliases = {} spec = model_spec.name @@ -79,9 +79,9 @@ def __init__( spec_revision=spec_revision, binary_version=binary_version, aliases=aliases, - vocabularies=vocabularies, + vocabularies=self.vocabularies, variables=variables_cpp, - config=config, + config=self.config, device=device, device_index=device_index, compute_type=compute_type, From a21d721f68e1b9074988cafe137c683d8660e42e Mon Sep 17 00:00:00 2001 From: thucpham Date: Thu, 7 Dec 2023 18:32:24 +0100 Subject: [PATCH 10/10] clean params --- python/ctranslate2/generator_on_the_fly.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ctranslate2/generator_on_the_fly.py b/python/ctranslate2/generator_on_the_fly.py index 5cd1eeebf..f408b2ff4 100644 --- a/python/ctranslate2/generator_on_the_fly.py +++ b/python/ctranslate2/generator_on_the_fly.py @@ -90,8 +90,8 @@ def __init__( max_queued_batches=max_queued_batches, ) - def generate_batch(self, prompt, *args, **kwargs): - return self.generator.generate_batch(prompt, *args, **kwargs) + def generate_batch(self, *args, **kwargs): + return self.generator.generate_batch(*args, **kwargs) - def score_batch(self, tokens, *args, **kwargs): - return self.generator.score_batch(tokens, *args, **kwargs) + def score_batch(self, *args, **kwargs): + return self.generator.score_batch(*args, **kwargs)