|
5 | 5 | LICENSE file in the root directory of this source tree.
|
6 | 6 | """
|
7 | 7 |
|
| 8 | +import inspect |
8 | 9 | import json
|
9 | 10 | import logging
|
10 | 11 | from abc import ABC
|
@@ -225,6 +226,13 @@ def __init_subclass__(cls, **kwargs):
|
225 | 226 | "The DEPRECATED 'process' method must not be implemented "
|
226 | 227 | "alongside 'process_input' or 'process_response'."
|
227 | 228 | )
|
| 229 | + if is_process_overridden and inspect.iscoroutinefunction(cls.process): |
| 230 | + # we don't want to add async capabilities to the deprecated function |
| 231 | + raise TypeError( |
| 232 | + f"Cannot create concrete class {cls.__name__}. " |
| 233 | + "The DEPRECATED 'process' method does not support async. " |
| 234 | + "Implement 'process_input' and/or 'process_response' instead." |
| 235 | + ) |
228 | 236 |
|
229 | 237 | return
|
230 | 238 |
|
@@ -875,15 +883,18 @@ async def _parse_and_process(self, request: Request) -> Response:
|
875 | 883 | prompt_hash, response_hash = (None, None)
|
876 | 884 | if input_direction:
|
877 | 885 | prompt_hash = prompt.hash()
|
878 |
| - result: Result | Reject = self.process_input( |
| 886 | + result = await self._handle_process_function( |
| 887 | + self.process_input, |
879 | 888 | metadata=metadata,
|
880 | 889 | parameters=parameters,
|
881 | 890 | prompt=prompt,
|
882 | 891 | request=request,
|
883 | 892 | )
|
| 893 | + |
884 | 894 | else:
|
885 | 895 | response_hash = response.hash()
|
886 |
| - result: Result | Reject = self.process_response( |
| 896 | + result = await self._handle_process_function( |
| 897 | + self.process_response, |
887 | 898 | metadata=metadata,
|
888 | 899 | parameters=parameters,
|
889 | 900 | prompt=prompt,
|
@@ -1014,7 +1025,16 @@ def _is_method_overridden(self, method_name: str) -> bool:
|
1014 | 1025 | # the method object directly from the Processor class, then it has been overridden.
|
1015 | 1026 | return instance_class_method_obj is not base_class_method_obj
|
1016 | 1027 |
|
1017 |
| - def process_input( |
| 1028 | + async def _process_fallback(self, **kwargs) -> Result | Reject: |
| 1029 | + warnings.warn( |
| 1030 | + f"{type(self).__name__} uses the deprecated 'process' method. " |
| 1031 | + "Implement 'process_input' and/or 'process_response' instead.", |
| 1032 | + DeprecationWarning, |
| 1033 | + stacklevel=2, |
| 1034 | + ) |
| 1035 | + return await self._handle_process_function(self.process, **kwargs) |
| 1036 | + |
| 1037 | + async def process_input( |
1018 | 1038 | self,
|
1019 | 1039 | prompt: PROMPT,
|
1020 | 1040 | metadata: Metadata,
|
@@ -1043,26 +1063,20 @@ def process_input(self, prompt, response, metadata, parameters, request):
|
1043 | 1063 |
|
1044 | 1064 | return Result(processor_result=result)
|
1045 | 1065 | """
|
1046 |
| - if self._is_method_overridden("process"): |
1047 |
| - warnings.warn( |
1048 |
| - f"{type(self).__name__} uses the deprecated 'process' method for input. " |
1049 |
| - "Implement 'process_input' instead.", |
1050 |
| - DeprecationWarning, |
1051 |
| - stacklevel=2, # Points the warning to the caller of process_input |
| 1066 | + if not self._is_method_overridden("process"): |
| 1067 | + raise NotImplementedError( |
| 1068 | + f"{type(self).__name__} must implement 'process_input' or the " |
| 1069 | + "deprecated 'process' method to handle input." |
1052 | 1070 | )
|
1053 |
| - return self.process( |
1054 |
| - prompt=prompt, |
1055 |
| - response=None, |
1056 |
| - metadata=metadata, |
1057 |
| - parameters=parameters, |
1058 |
| - request=request, |
1059 |
| - ) |
1060 |
| - raise NotImplementedError( |
1061 |
| - f"{type(self).__name__} must implement 'process_input' or the " |
1062 |
| - "deprecated 'process' method to handle input." |
| 1071 | + return await self._process_fallback( |
| 1072 | + prompt=prompt, |
| 1073 | + response=None, |
| 1074 | + metadata=metadata, |
| 1075 | + parameters=parameters, |
| 1076 | + request=request, |
1063 | 1077 | )
|
1064 | 1078 |
|
1065 |
| - def process_response( |
| 1079 | + async def process_response( |
1066 | 1080 | self,
|
1067 | 1081 | prompt: PROMPT | None,
|
1068 | 1082 | response: RESPONSE,
|
@@ -1096,23 +1110,17 @@ def process_response(self, prompt, response, metadata, parameters, request):
|
1096 | 1110 | return Result(processor_result=result)
|
1097 | 1111 | """
|
1098 | 1112 |
|
1099 |
| - if self._is_method_overridden("process"): |
1100 |
| - warnings.warn( |
1101 |
| - f"{type(self).__name__} uses the deprecated 'process' method for response. " |
1102 |
| - "Implement 'process_response' instead.", |
1103 |
| - DeprecationWarning, |
1104 |
| - stacklevel=2, # Points the warning to the caller of process_input |
| 1113 | + if not self._is_method_overridden("process"): |
| 1114 | + raise NotImplementedError( |
| 1115 | + f"{type(self).__name__} must implement 'process_response' or the " |
| 1116 | + "deprecated 'process' method to handle input." |
1105 | 1117 | )
|
1106 |
| - return self.process( |
1107 |
| - prompt=prompt, |
1108 |
| - response=response, |
1109 |
| - metadata=metadata, |
1110 |
| - parameters=parameters, |
1111 |
| - request=request, |
1112 |
| - ) |
1113 |
| - raise NotImplementedError( |
1114 |
| - f"{type(self).__name__} must implement 'process_response' or the " |
1115 |
| - "deprecated 'process' method to handle input." |
| 1118 | + return await self._process_fallback( |
| 1119 | + prompt=prompt, |
| 1120 | + response=response, |
| 1121 | + metadata=metadata, |
| 1122 | + parameters=parameters, |
| 1123 | + request=request, |
1116 | 1124 | )
|
1117 | 1125 |
|
1118 | 1126 | def process(
|
@@ -1159,6 +1167,13 @@ def process(self, prompt, response, metadata, parameters, request):
|
1159 | 1167 | "'process_input'/'process_response'."
|
1160 | 1168 | )
|
1161 | 1169 |
|
| 1170 | + async def _handle_process_function(self, func, **kwargs) -> Result | Reject: |
| 1171 | + if inspect.iscoroutinefunction(func): |
| 1172 | + result = await func(**kwargs) |
| 1173 | + else: |
| 1174 | + result = func(**kwargs) |
| 1175 | + return result |
| 1176 | + |
1162 | 1177 |
|
1163 | 1178 | def _validation_error_as_messages(err: ValidationError) -> list[str]:
|
1164 | 1179 | return [_error_details_to_str(e) for e in err.errors()]
|
|
0 commit comments