diff --git a/python/ctranslate2/onTheFly/generator_on_the_fly.py b/python/ctranslate2/onTheFly/generator_on_the_fly.py index 965e135b6..9303856f1 100644 --- a/python/ctranslate2/onTheFly/generator_on_the_fly.py +++ b/python/ctranslate2/onTheFly/generator_on_the_fly.py @@ -5,35 +5,65 @@ import ctranslate2 import json + +def _get_converter(model_path: str, model_type: str): + if model_type == "OpenNMTPy": + + def get_model_file(_model_path: str): + for filename in os.listdir(model_path): + if filename.endswith(".pt"): + _model_file = os.path.join(model_path, filename) + return _model_file + return "" + + model_file = get_model_file(model_path) + if model_file == "": + raise RuntimeError("No model opennmt-py found in %s" % model_path) + + converter = OpenNMTPyConverter(model_path=model_file) + return converter + else: + raise NotImplementedError( + "Converter on the fly for %s is not implemented." % model_type + ) + + class GeneratorOnTheFly: - def __init__(self, model_path: str, - device="cpu", - device_index=0, - compute_type="default", - inter_threads=1, - intra_threads=0, - max_queued_batches=0, - model_type="OpenNMTPy", - quantization: Optional[str] = None, - ): - """Initializes the generator on the fly. - - Arguments: - model_path: Path to the CTranslate2 model directory. - device: Device to use (possible values are: cpu, cuda, auto). - device_index: Device IDs where to place this generator on. - compute_type: Model computation type or a dictionary mapping a device name - to the computation type (possible values are: default, auto, int8, int8_float32, - int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). - inter_threads: Maximum number of parallel generations. - intra_threads: Number of OpenMP threads per generator (0 to use a default value). - max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, - 0 for an automatic value). When the queue is full, future requests will block - until a free slot is available. - model_type: type of converter to convert the model - quantization: quantize the model - """ - converter = _get_converter(model_path=model_path, model_type=model_type) + """Initializes the generator on the fly. + + Arguments: + model_path: Path to the CTranslate2 model directory. + device: Device to use (possible values are: cpu, cuda, auto). + device_index: Device IDs where to place this generator on. + compute_type: Model computation type or a dictionary mapping + a device name to the computation type (possible values are: + default, auto, int8, int8_float32, int8_float16, int8_bfloat16, + int16, float16, bfloat16, float32). + inter_threads: Maximum number of parallel generations. + intra_threads: Number of OpenMP threads per generator + (0 to use a default value). + max_queued_batches: Maximum numbers of batches in the queue + (-1 for unlimited, 0 for an automatic value). + When the queue is full, future requests will block + until a free slot is available. + model_type: type of converter to convert the model + quantization: quantize the model + """ + + def __init__( + self, + model_path: str, + device="cpu", + device_index=0, + compute_type="default", + inter_threads=1, + intra_threads=0, + max_queued_batches=0, + model_type="OpenNMTPy", + quantization: Optional[str] = None, + ): + converter = _get_converter(model_path=model_path, + model_type=model_type) model_spec = converter.convert_on_the_fly(quantization=quantization) variables = model_spec.variables(ordered=True) @@ -50,7 +80,8 @@ def __init__(self, model_path: str, if isinstance(value, str): aliases[key] = value else: - variables_cpp[key] = ctranslate2.StorageView.from_array(value.numpy()) + variables_cpp[key] = ctranslate2.StorageView.from_array( + value.numpy()) self.generator = ctranslate2.Generator(spec=spec, spec_revision=spec_revision, @@ -67,6 +98,21 @@ def __init__(self, model_path: str, max_queued_batches=max_queued_batches) #self.generator = ctranslate2.Generator(model_path, device=device, files=model_spec.files_memory) + self.generator = ctranslate2.Generator( + spec=spec, + spec_revision=spec_revision, + binary_revision=binary_version, + aliases=aliases, + vocabularies=vocabularies, + variables=variables_cpp, + config=config, + device=device, + device_index=device_index, + compute_type=compute_type, + inter_threads=inter_threads, + intra_threads=intra_threads, + max_queued_batches=max_queued_batches, + ) def generate_iterable(self, start_tokens, *args, **kwargs): return self.generator.generate_tokens(start_tokens, *args, **kwargs) @@ -79,25 +125,3 @@ def score_iterable(self, tokens, *args, **kwargs): def async_generate_tokens(self, prompt, *args, **kwargs): return self.generator.async_generate_tokens(prompt, *args, **kwargs) - - def _get_converter(model_path: str, model_type: str): - if model_type == "OpenNMTPy": - def get_model_file(_model_path: str): - for filename in os.listdir(model_path): - if filename.endswith(".pt"): - _model_file = os.path.join(model_path, filename) - return _model_file - return '' - - model_file = get_model_file(model_path) - if model_file == '': - raise RuntimeError( - "No model opennmt-py found in %s" % model_path - ) - - converter = OpenNMTPyConverter(model_path=model_file) - return converter - else: - raise NotImplementedError( - "Converter on the fly for %s is not implemented." % model_type - )