Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/inference_using_cross_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unitxt.text_utils import print_dict

if __name__ == "__main__":
for provider in ["watsonx", "rits", "watsonx-sdk", "hf-local"]:
for provider in ["vllm", "watsonx", "rits", "watsonx-sdk", "hf-local"]:
print()
print("------------------------------------------------ ")
print("PROVIDER:", provider)
Expand Down
21 changes: 14 additions & 7 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3014,26 +3014,30 @@ class VLLMParamsMixin(Artifact):
model: str
n: int = 1
best_of: Optional[int] = None
_real_n: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
temperature: float = 0.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
top_k: int = 0
min_p: float = 0.0
seed: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
bad_words: Optional[List[str]] = None
include_stop_str_in_output: bool = False
ignore_eos: bool = False
max_tokens: Optional[int] = 16
min_tokens: int = 0
logprobs: Optional[int] = None
prompt_logprobs: Optional[int] = None
detokenize: bool = True
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True


class VLLMInferenceEngine(InferenceEngine, PackageRequirementsMixin, VLLMParamsMixin):
_requirements_list: list = ["vllm"]
label = "vllm"

def get_engine_id(self):
Expand All @@ -3047,7 +3051,6 @@ def prepare_engine(self):
self.sampling_params = SamplingParams(**args)
self.llm = LLM(
model=self.model,
device="auto",
trust_remote_code=True,
max_num_batched_tokens=4096,
gpu_memory_utilization=0.7,
Expand Down Expand Up @@ -3231,6 +3234,7 @@ def get_return_object(self, responses, return_meta_data):
"vertex-ai",
"replicate",
"hf-local",
"vllm",
]


Expand Down Expand Up @@ -3477,6 +3481,7 @@ class CrossProviderInferenceEngine(
provider_model_map["watsonx"] = {
k: f"watsonx/{v}" for k, v in provider_model_map["watsonx-sdk"].items()
}
provider_model_map["vllm"] = provider_model_map["hf-local"]

_provider_to_base_class = {
"watsonx": LiteLLMInferenceEngine,
Expand All @@ -3490,12 +3495,14 @@ class CrossProviderInferenceEngine(
"vertex-ai": LiteLLMInferenceEngine,
"replicate": LiteLLMInferenceEngine,
"hf-local": HFAutoModelInferenceEngine,
"vllm": VLLMInferenceEngine,
}

_provider_param_renaming = {
"watsonx-sdk": {"model": "model_name"},
"rits": {"model": "model_name"},
"hf-local": {"model": "model_name", "max_tokens": "max_new_tokens"},
"vllm": {"top_logprobs": "logprobs", "logprobs": "prompt_logprobs"},
}

def get_return_object(self, **kwargs):
Expand Down
15 changes: 15 additions & 0 deletions tests/inference/test_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
OptionSelectingByLogProbsInferenceEngine,
RITSInferenceEngine,
TextGenerationInferenceOutput,
VLLMInferenceEngine,
WMLInferenceEngineChat,
WMLInferenceEngineGeneration,
)
Expand Down Expand Up @@ -189,6 +190,20 @@ def test_watsonx_chat_inference(self):

self.assertListEqual(predictions, ["7", "2"])

def test_vllm_chat_inference(self):
model = VLLMInferenceEngine(
model=local_decoder_model,
data_classification_policy=["public"],
temperature=0,
max_tokens=1,
)

dataset = get_text_dataset()

predictions = model(dataset)

self.assertListEqual(list(predictions), ["7", "1"])

def test_watsonx_inference_with_external_client(self):
from ibm_watsonx_ai.client import APIClient, Credentials

Expand Down
Loading