7
7
from .model import Endpoints , PaLM
8
8
9
9
10
+ @registry .llm_models ("spacy.Google.v1" )
11
+ def google_v1 (
12
+ name : str ,
13
+ config : Dict [Any , Any ] = SimpleFrozenDict (temperature = 0 ),
14
+ strict : bool = PaLM .DEFAULT_STRICT ,
15
+ max_tries : int = PaLM .DEFAULT_MAX_TRIES ,
16
+ interval : float = PaLM .DEFAULT_INTERVAL ,
17
+ max_request_time : float = PaLM .DEFAULT_MAX_REQUEST_TIME ,
18
+ context_length : Optional [int ] = None ,
19
+ endpoint : Optional [str ] = None ,
20
+ ) -> Callable [[Iterable [Iterable [str ]]], Iterable [Iterable [str ]]]:
21
+ """Returns Google model instance using REST to prompt API.
22
+ name (str): Name of model to use.
23
+ config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
24
+ strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
25
+ or other response object that does not conform to the expectation of how a well-formed response object from
26
+ this API should look like). If False, the API error responses are returned by __call__(), but no error will
27
+ be raised.
28
+ max_tries (int): Max. number of tries for API request.
29
+ interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
30
+ at each retry.
31
+ max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
32
+ context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
33
+ natively provided by spacy-llm.
34
+ endpoint (Optional[str]): Endpoint to use. Defaults to standard endpoint.
35
+ RETURNS (PaLM): PaLM model instance.
36
+ """
37
+ default_endpoint = (
38
+ Endpoints .TEXT .value if name in {"text-bison-001" } else Endpoints .MSG .value
39
+ )
40
+ return PaLM (
41
+ name = name ,
42
+ endpoint = endpoint or default_endpoint ,
43
+ config = config ,
44
+ strict = strict ,
45
+ max_tries = max_tries ,
46
+ interval = interval ,
47
+ max_request_time = max_request_time ,
48
+ context_length = None ,
49
+ )
50
+
51
+
10
52
@registry .llm_models ("spacy.PaLM.v2" )
11
53
def palm_bison_v2 (
12
54
config : Dict [Any , Any ] = SimpleFrozenDict (temperature = 0 ),
@@ -18,7 +60,7 @@ def palm_bison_v2(
18
60
context_length : Optional [int ] = None ,
19
61
) -> Callable [[Iterable [Iterable [str ]]], Iterable [Iterable [str ]]]:
20
62
"""Returns Google instance for PaLM Bison model using REST to prompt API.
21
- name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
63
+ name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
22
64
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
23
65
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
24
66
or other response object that does not conform to the expectation of how a well-formed response object from
@@ -57,7 +99,7 @@ def palm_bison(
57
99
endpoint : Optional [str ] = None ,
58
100
) -> PaLM :
59
101
"""Returns Google instance for PaLM Bison model using REST to prompt API.
60
- name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
102
+ name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
61
103
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
62
104
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
63
105
or other response object that does not conform to the expectation of how a well-formed response object from
0 commit comments