Skip to content

Commit ff07682

Browse files
rmitschsvlandeg
andauthored
Add registry functions to instantiate models by provider (#428)
* Add provider-specific registry functions. * Update model registry handles used in tests. * Update readme and usage examples. * Update spacy_llm/models/rest/openai/registry.py Co-authored-by: Sofie Van Landeghem <[email protected]> * Fix HF registry return type. * Fix GPU test error message regexes. * Fix tests. Bump default OAI model to GPT-4. * Fix external tests. * Format. * Ignore LangChain deprecation warning. Ease sentiment tests. * Use GPT-4 for sharding spancat test case. * Relax EL test. Remove unnecessary warning contexts. * Fix comparison in EL test. * Fix GPU tests. --------- Co-authored-by: Sofie Van Landeghem <[email protected]>
1 parent c87d5a6 commit ff07682

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+363
-97
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ factory = "llm"
119119
labels = ["COMPLIMENT", "INSULT"]
120120

121121
[components.llm.model]
122-
@llm_models = "spacy.GPT-4.v2"
122+
@llm_models = "spacy.OpenAI.v1"
123+
name = "gpt-4"
123124
```
124125

125126
Now run:

pyproject.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ filterwarnings = [
2727
"ignore:^.*The `construct` method is deprecated.*",
2828
"ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*",
2929
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*",
30-
"ignore:^.*was deprecated in langchain-community.*"
30+
"ignore:^.*was deprecated in langchain-community.*",
31+
"ignore:^.*was deprecated in LangChain 0.0.1.*",
32+
"ignore:^.*the load_module() method is deprecated and slated for removal in Python 3.12.*"
3133
]
3234
markers = [
3335
"external: interacts with a (potentially cost-incurring) third-party API",

requirements-dev.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ langchain>=0.1,<0.2; python_version>="3.9"
1313
openai>=0.27,<=0.28.1; python_version>="3.9"
1414

1515
# Necessary for running all local models on GPU.
16-
transformers[sentencepiece]>=4.0.0
16+
# TODO: transformers > 4.38 causes bug in model handling due to unknown factors. To be investigated.
17+
transformers[sentencepiece]>=4.0.0,<=4.38
1718
torch
1819
einops>=0.4
1920

spacy_llm/models/hf/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
from .llama2 import llama2_hf
55
from .mistral import mistral_hf
66
from .openllama import openllama_hf
7+
from .registry import huggingface_v1
78
from .stablelm import stablelm_hf
89

910
__all__ = [
1011
"HuggingFace",
1112
"dolly_hf",
1213
"falcon_hf",
14+
"huggingface_v1",
1315
"llama2_hf",
1416
"mistral_hf",
1517
"openllama_hf",

spacy_llm/models/hf/mistral.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ def mistral_hf(
9999
name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names().
100100
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
101101
config_run (Optional[Dict[str, Any]]): HF config for running the model.
102-
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Mistral instance that can execute a set of prompts and return
103-
the raw responses.
102+
RETURNS (Mistral): Mistral instance that can execute a set of prompts and return the raw responses.
104103
"""
105104
return Mistral(
106105
name=name, config_init=config_init, config_run=config_run, context_length=8000

spacy_llm/models/hf/registry.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import Any, Dict, Optional
2+
3+
from confection import SimpleFrozenDict
4+
5+
from ...registry import registry
6+
from .base import HuggingFace
7+
from .dolly import Dolly
8+
from .falcon import Falcon
9+
from .llama2 import Llama2
10+
from .mistral import Mistral
11+
from .openllama import OpenLLaMA
12+
from .stablelm import StableLM
13+
14+
15+
@registry.llm_models("spacy.HF.v1")
16+
@registry.llm_models("spacy.HuggingFace.v1")
17+
def huggingface_v1(
18+
name: str,
19+
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
20+
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
21+
) -> HuggingFace:
22+
"""Returns HuggingFace model instance.
23+
name (str): Name of model to use.
24+
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
25+
config_run (Optional[Dict[str, Any]]): HF config for running the model.
26+
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Model instance that can execute a set of prompts and return
27+
the raw responses.
28+
"""
29+
model_context_lengths = {
30+
Dolly: 2048,
31+
Falcon: 2048,
32+
Llama2: 4096,
33+
Mistral: 8000,
34+
OpenLLaMA: 2048,
35+
StableLM: 4096,
36+
}
37+
38+
for model_cls, context_length in model_context_lengths.items():
39+
model_names = getattr(model_cls, "MODEL_NAMES")
40+
if model_names and name in model_names.__args__:
41+
return model_cls(
42+
name=name,
43+
config_init=config_init,
44+
config_run=config_run,
45+
context_length=context_length,
46+
)
47+
48+
raise ValueError(
49+
f"Name {name} could not be associated with any of the supported models. Please check "
50+
f"https://spacy.io/api/large-language-models#models-hf to ensure the specified model name is correct."
51+
)

spacy_llm/models/langchain/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def query_langchain(
9999
prompts (Iterable[Iterable[Any]]): Prompts to execute.
100100
RETURNS (Iterable[Iterable[Any]]): LLM responses.
101101
"""
102+
assert callable(model)
102103
return [
103104
[model.invoke(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts
104105
]

spacy_llm/models/rest/anthropic/registry.py

+37
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,43 @@
77
from .model import Anthropic, Endpoints
88

99

10+
@registry.llm_models("spacy.Anthropic.v1")
11+
def anthropic_v1(
12+
name: str,
13+
config: Dict[Any, Any] = SimpleFrozenDict(),
14+
strict: bool = Anthropic.DEFAULT_STRICT,
15+
max_tries: int = Anthropic.DEFAULT_MAX_TRIES,
16+
interval: float = Anthropic.DEFAULT_INTERVAL,
17+
max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME,
18+
context_length: Optional[int] = None,
19+
) -> Anthropic:
20+
"""Returns Anthropic model instance using REST to prompt API.
21+
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
22+
name (str): Name of model to use.
23+
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
24+
or other response object that does not conform to the expectation of how a well-formed response object from
25+
this API should look like). If False, the API error responses are returned by __call__(), but no error will
26+
be raised.
27+
max_tries (int): Max. number of tries for API request.
28+
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
29+
at each retry.
30+
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
31+
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
32+
natively provided by spacy-llm.
33+
RETURNS (Anthropic): Instance of Anthropic model.
34+
"""
35+
return Anthropic(
36+
name=name,
37+
endpoint=Endpoints.COMPLETIONS.value,
38+
config=config,
39+
strict=strict,
40+
max_tries=max_tries,
41+
interval=interval,
42+
max_request_time=max_request_time,
43+
context_length=context_length,
44+
)
45+
46+
1047
@registry.llm_models("spacy.Claude-2.v2")
1148
def anthropic_claude_2_v2(
1249
config: Dict[Any, Any] = SimpleFrozenDict(),

spacy_llm/models/rest/cohere/registry.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,43 @@
77
from .model import Cohere, Endpoints
88

99

10+
@registry.llm_models("spacy.Cohere.v1")
11+
def cohere_v1(
12+
name: str,
13+
config: Dict[Any, Any] = SimpleFrozenDict(),
14+
strict: bool = Cohere.DEFAULT_STRICT,
15+
max_tries: int = Cohere.DEFAULT_MAX_TRIES,
16+
interval: float = Cohere.DEFAULT_INTERVAL,
17+
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
18+
context_length: Optional[int] = None,
19+
) -> Cohere:
20+
"""Returns Cohere model instance using REST to prompt API.
21+
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
22+
name (str): Name of model to use.
23+
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
24+
or other response object that does not conform to the expectation of how a well-formed response object from
25+
this API should look like). If False, the API error responses are returned by __call__(), but no error will
26+
be raised.
27+
max_tries (int): Max. number of tries for API request.
28+
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
29+
at each retry.
30+
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
31+
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
32+
natively provided by spacy-llm.
33+
RETURNS (Cohere): Instance of Cohere model.
34+
"""
35+
return Cohere(
36+
name=name,
37+
endpoint=Endpoints.COMPLETION.value,
38+
config=config,
39+
strict=strict,
40+
max_tries=max_tries,
41+
interval=interval,
42+
max_request_time=max_request_time,
43+
context_length=context_length,
44+
)
45+
46+
1047
@registry.llm_models("spacy.Command.v2")
1148
def cohere_command_v2(
1249
config: Dict[Any, Any] = SimpleFrozenDict(),
@@ -56,7 +93,7 @@ def cohere_command(
5693
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
5794
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
5895
"""Returns Cohere instance for 'command' model using REST to prompt API.
59-
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Model to use.
96+
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Name of model to use.
6097
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
6198
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
6299
or other response object that does not conform to the expectation of how a well-formed response object from

spacy_llm/models/rest/openai/registry.py

+41
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,47 @@
88

99
_DEFAULT_TEMPERATURE = 0.0
1010

11+
12+
@registry.llm_models("spacy.OpenAI.v1")
13+
def openai_v1(
14+
name: str,
15+
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
16+
strict: bool = OpenAI.DEFAULT_STRICT,
17+
max_tries: int = OpenAI.DEFAULT_MAX_TRIES,
18+
interval: float = OpenAI.DEFAULT_INTERVAL,
19+
max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME,
20+
endpoint: Optional[str] = None,
21+
context_length: Optional[int] = None,
22+
) -> OpenAI:
23+
"""Returns OpenAI model instance using REST to prompt API.
24+
25+
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
26+
name (str): Model name to use. Can be any model name supported by the OpenAI API.
27+
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
28+
or other response object that does not conform to the expectation of how a well-formed response object from
29+
this API should look like). If False, the API error responses are returned by __call__(), but no error will
30+
be raised.
31+
max_tries (int): Max. number of tries for API request.
32+
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
33+
at each retry.
34+
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
35+
endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint.
36+
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
37+
natively provided by spacy-llm.
38+
RETURNS (OpenAI): OpenAI model instance.
39+
"""
40+
return OpenAI(
41+
name=name,
42+
endpoint=endpoint or Endpoints.CHAT.value,
43+
config=config,
44+
strict=strict,
45+
max_tries=max_tries,
46+
interval=interval,
47+
max_request_time=max_request_time,
48+
context_length=context_length,
49+
)
50+
51+
1152
"""
1253
Parameter explanations:
1354
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON

spacy_llm/models/rest/palm/registry.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,48 @@
77
from .model import Endpoints, PaLM
88

99

10+
@registry.llm_models("spacy.Google.v1")
11+
def google_v1(
12+
name: str,
13+
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
14+
strict: bool = PaLM.DEFAULT_STRICT,
15+
max_tries: int = PaLM.DEFAULT_MAX_TRIES,
16+
interval: float = PaLM.DEFAULT_INTERVAL,
17+
max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME,
18+
context_length: Optional[int] = None,
19+
endpoint: Optional[str] = None,
20+
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
21+
"""Returns Google model instance using REST to prompt API.
22+
name (str): Name of model to use.
23+
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
24+
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
25+
or other response object that does not conform to the expectation of how a well-formed response object from
26+
this API should look like). If False, the API error responses are returned by __call__(), but no error will
27+
be raised.
28+
max_tries (int): Max. number of tries for API request.
29+
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
30+
at each retry.
31+
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
32+
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
33+
natively provided by spacy-llm.
34+
endpoint (Optional[str]): Endpoint to use. Defaults to standard endpoint.
35+
RETURNS (PaLM): PaLM model instance.
36+
"""
37+
default_endpoint = (
38+
Endpoints.TEXT.value if name in {"text-bison-001"} else Endpoints.MSG.value
39+
)
40+
return PaLM(
41+
name=name,
42+
endpoint=endpoint or default_endpoint,
43+
config=config,
44+
strict=strict,
45+
max_tries=max_tries,
46+
interval=interval,
47+
max_request_time=max_request_time,
48+
context_length=None,
49+
)
50+
51+
1052
@registry.llm_models("spacy.PaLM.v2")
1153
def palm_bison_v2(
1254
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
@@ -18,7 +60,7 @@ def palm_bison_v2(
1860
context_length: Optional[int] = None,
1961
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
2062
"""Returns Google instance for PaLM Bison model using REST to prompt API.
21-
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
63+
name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
2264
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
2365
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
2466
or other response object that does not conform to the expectation of how a well-formed response object from
@@ -57,7 +99,7 @@ def palm_bison(
5799
endpoint: Optional[str] = None,
58100
) -> PaLM:
59101
"""Returns Google instance for PaLM Bison model using REST to prompt API.
60-
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
102+
name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
61103
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
62104
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
63105
or other response object that does not conform to the expectation of how a well-formed response object from

spacy_llm/pipeline/llm.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
logger.addHandler(logging.NullHandler())
2525

2626
DEFAULT_MODEL_CONFIG = {
27-
"@llm_models": "spacy.GPT-3-5.v2",
27+
"@llm_models": "spacy.GPT-3-5.v3",
2828
"strict": True,
2929
}
3030
DEFAULT_CACHE_CONFIG = {
@@ -238,6 +238,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]:
238238
else self._task.generate_prompts(noncached_doc_batch),
239239
n_iters + 1,
240240
)
241+
241242
responses_iters = tee(
242243
self._model(
243244
# Ensure that model receives Iterable[Iterable[Any]]. If task doesn't shard, its prompt is wrapped
@@ -251,7 +252,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]:
251252
)
252253

253254
for prompt_data, response, doc in zip(
254-
prompts_iters[1], responses_iters[0], noncached_doc_batch
255+
prompts_iters[1], list(responses_iters[0]), noncached_doc_batch
255256
):
256257
logger.debug(
257258
"Generated prompt for doc: %s\n%s",
@@ -266,7 +267,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]:
266267
elem[1] if support_sharding else noncached_doc_batch[i]
267268
for i, elem in enumerate(prompts_iters[2])
268269
),
269-
responses_iters[1],
270+
list(responses_iters[1]),
270271
)
271272
)
272273

spacy_llm/tests/models/test_cohere.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_cohere_api_response_when_error():
8484
def test_cohere_error_unsupported_model():
8585
"""Ensure graceful handling of error when model is not supported"""
8686
incorrect_model = "x-gpt-3.5-turbo"
87-
with pytest.raises(ValueError, match="model not found"):
87+
with pytest.raises(ValueError, match="Request to Cohere API failed"):
8888
Cohere(
8989
name=incorrect_model,
9090
config={},

0 commit comments

Comments
 (0)