Skip to content

Commit fb0cc19

Browse files
committed
Revert langchain_initializer changes
1 parent 9dacd9c commit fb0cc19

File tree

4 files changed

+16
-39
lines changed

4 files changed

+16
-39
lines changed

nemoguardrails/llm/models/langchain_initializer.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _init_chat_completion_model(model_name: str, provider_name: str, kwargs: Dic
227227

228228
def _init_text_completion_model(
229229
model_name: str, provider_name: str, kwargs: Dict[str, Any]
230-
) -> BaseLLM | None:
230+
) -> BaseLLM:
231231
"""Initialize a text completion model.
232232
233233
Args:
@@ -241,14 +241,9 @@ def _init_text_completion_model(
241241
Raises:
242242
RuntimeError: If the provider is not found
243243
"""
244-
try:
245-
provider_cls = _get_text_completion_provider(provider_name)
246-
except RuntimeError as e:
247-
return None
248-
244+
provider_cls = _get_text_completion_provider(provider_name)
249245
if provider_cls is None:
250-
return None
251-
246+
raise ValueError()
252247
kwargs = _update_model_kwargs(provider_cls, model_name, kwargs)
253248
# remove stream_usage parameter as it's not supported by text completion APIs
254249
# (e.g., OpenAI's AsyncCompletions.create() doesn't accept this parameter)
@@ -258,7 +253,7 @@ def _init_text_completion_model(
258253

259254
def _init_community_chat_models(
260255
model_name: str, provider_name: str, kwargs: Dict[str, Any]
261-
) -> BaseChatModel | None:
256+
) -> BaseChatModel:
262257
"""Initialize community chat models.
263258
264259
Args:
@@ -275,14 +270,14 @@ def _init_community_chat_models(
275270
"""
276271
provider_cls = _get_chat_completion_provider(provider_name)
277272
if provider_cls is None:
278-
return None
273+
raise ValueError()
279274
kwargs = _update_model_kwargs(provider_cls, model_name, kwargs)
280275
return provider_cls(**kwargs)
281276

282277

283278
def _init_gpt35_turbo_instruct(
284279
model_name: str, provider_name: str, kwargs: Dict[str, Any]
285-
) -> BaseLLM | None:
280+
) -> BaseLLM:
286281
"""Initialize GPT-3.5 Turbo Instruct model.
287282
288283
Currently init_chat_model from langchain infers this as a chat model.

tests/llm_providers/test_langchain_initialization_methods.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,8 @@ def test_init_community_chat_models_no_provider(self):
116116
"nemoguardrails.llm.models.langchain_initializer._get_chat_completion_provider"
117117
) as mock_get_provider:
118118
mock_get_provider.return_value = None
119-
assert (
120-
_init_community_chat_models("community-model", "provider", {}) is None
121-
)
119+
with pytest.raises(ValueError):
120+
_init_community_chat_models("community-model", "provider", {})
122121

123122

124123
class TestTextCompletionInitializer:
@@ -157,7 +156,8 @@ def test_init_text_completion_model_no_provider(self):
157156
"nemoguardrails.llm.models.langchain_initializer._get_text_completion_provider"
158157
) as mock_get_provider:
159158
mock_get_provider.return_value = None
160-
assert _init_text_completion_model("text-model", "provider", {}) is None
159+
with pytest.raises(ValueError):
160+
_init_text_completion_model("text-model", "provider", {})
161161

162162

163163
class TestUpdateModelKwargs:

tests/llm_providers/test_langchain_initializer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,14 @@ def test_text_completion_called_when_previous_fail(mock_initializers):
183183
mock_initializers["text"].assert_called_once()
184184

185185

186-
def test_text_mode_only_calls_text_initializers(mock_initializers):
187-
"""Test that text mode only tries initializers that support text mode."""
186+
def test_text_completion_supports_chat_mode(mock_initializers):
188187
mock_initializers["special"].return_value = None
188+
mock_initializers["chat"].return_value = None
189+
mock_initializers["community"].return_value = None
189190
mock_initializers["text"].return_value = "text_model"
190-
result = init_langchain_model("text-model", "provider", "text", {})
191+
result = init_langchain_model("text-model", "provider", "chat", {})
191192
assert result == "text_model"
192193
mock_initializers["special"].assert_called_once()
194+
mock_initializers["chat"].assert_called_once()
195+
mock_initializers["community"].assert_called_once()
193196
mock_initializers["text"].assert_called_once()
194-
# Since text returns a value, chat and community are not called
195-
mock_initializers["chat"].assert_not_called()
196-
mock_initializers["community"].assert_not_called()

tests/test_actions_llm_utils.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -91,24 +91,6 @@ class MockPatchedNVIDIA(MockNVIDIAOriginal):
9191
__module__ = "nemoguardrails.llm.providers._langchain_nvidia_ai_endpoints_patch"
9292

9393

94-
class MockTRTLLM:
95-
__module__ = "nemoguardrails.llm.providers.trtllm.llm"
96-
97-
98-
class MockAzureLLM:
99-
__module__ = "langchain_openai.chat_models"
100-
101-
102-
class MockLLMWithClient:
103-
__module__ = "langchain_openai.chat_models"
104-
105-
class _MockClient:
106-
base_url = "https://custom.endpoint.com/v1"
107-
108-
def __init__(self):
109-
self.client = self._MockClient()
110-
111-
11294
def test_infer_provider_openai():
11395
llm = MockOpenAILLM()
11496
provider = _infer_provider_from_module(llm)

0 commit comments

Comments
 (0)