|
45 | 45 | "cell_type": "code", |
46 | 46 | "execution_count": null, |
47 | 47 | "metadata": { |
48 | | - "collapsed": false, |
49 | 48 | "jupyter": { |
50 | 49 | "outputs_hidden": false |
51 | 50 | }, |
|
56 | 55 | }, |
57 | 56 | "outputs": [], |
58 | 57 | "source": [ |
59 | | - "!pip install --upgrade sagemaker --quiet\n", |
60 | | - "!pip install ipywidgets==7.0.0 --quiet\n", |
61 | | - "!pip install langchain==0.0.148 --quiet\n", |
62 | | - "!pip install faiss-cpu --quiet" |
| 58 | + "!pip install --upgrade sagemaker --quiet" |
63 | 59 | ] |
64 | 60 | }, |
65 | 61 | { |
|
70 | 66 | }, |
71 | 67 | "outputs": [], |
72 | 68 | "source": [ |
73 | | - "import time\n", |
74 | | - "import sagemaker, boto3, json\n", |
75 | | - "from sagemaker.session import Session\n", |
76 | | - "from sagemaker.model import Model\n", |
77 | | - "from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n", |
78 | | - "from sagemaker.predictor import Predictor\n", |
| 69 | + "from sagemaker import Session\n", |
79 | 70 | "from sagemaker.utils import name_from_base\n", |
80 | | - "from typing import Any, Dict, List, Optional\n", |
81 | | - "from langchain.embeddings import SagemakerEndpointEmbeddings\n", |
82 | | - "from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n", |
83 | | - "\n", |
84 | | - "sagemaker_session = Session()\n", |
85 | | - "aws_role = sagemaker_session.get_caller_identity_arn()\n", |
86 | | - "aws_region = boto3.Session().region_name\n", |
87 | | - "sess = sagemaker.Session()\n", |
88 | | - "model_version = \"1.*\"" |
89 | | - ] |
90 | | - }, |
91 | | - { |
92 | | - "cell_type": "code", |
93 | | - "execution_count": null, |
94 | | - "metadata": { |
95 | | - "tags": [] |
96 | | - }, |
97 | | - "outputs": [], |
98 | | - "source": [ |
99 | | - "def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type=\"application/json\"):\n", |
100 | | - " client = boto3.client(\"runtime.sagemaker\")\n", |
101 | | - " response = client.invoke_endpoint(\n", |
102 | | - " EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json\n", |
103 | | - " )\n", |
104 | | - " return response\n", |
105 | | - "\n", |
106 | | - "\n", |
107 | | - "def parse_response_model_flan_t5(query_response):\n", |
108 | | - " model_predictions = json.loads(query_response[\"Body\"].read())\n", |
109 | | - " generated_text = model_predictions[\"generated_texts\"]\n", |
110 | | - " return generated_text\n", |
| 71 | + "from sagemaker.jumpstart.model import JumpStartModel\n", |
111 | 72 | "\n", |
112 | | - "\n", |
113 | | - "def parse_response_multiple_texts_bloomz(query_response):\n", |
114 | | - " generated_text = []\n", |
115 | | - " model_predictions = json.loads(query_response[\"Body\"].read())\n", |
116 | | - " for x in model_predictions[0]:\n", |
117 | | - " generated_text.append(x[\"generated_text\"])\n", |
118 | | - " return generated_text" |
| 73 | + "sagemaker_session = Session()" |
119 | 74 | ] |
120 | 75 | }, |
121 | 76 | { |
|
135 | 90 | "source": [ |
136 | 91 | "_MODEL_CONFIG_ = {\n", |
137 | 92 | " \"huggingface-text2text-flan-t5-xxl\": {\n", |
138 | | - " \"instance type\": \"ml.g5.12xlarge\",\n", |
139 | | - " \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n", |
140 | | - " \"parse_function\": parse_response_model_flan_t5,\n", |
141 | | - " \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n", |
142 | | - " },\n", |
143 | | - " \"huggingface-textembedding-gpt-j-6b\": {\n", |
144 | | - " \"instance type\": \"ml.g5.24xlarge\",\n", |
145 | | - " \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n", |
| 93 | + " \"model_version\": \"2.*\",\n", |
| 94 | + " \"instance type\": \"ml.g5.12xlarge\"\n", |
146 | 95 | " },\n", |
147 | | - " # \"huggingface-textgeneration1-bloomz-7b1-fp16\": {\n", |
148 | | - " # \"instance type\": \"ml.g5.12xlarge\",\n", |
149 | | - " # \"env\": {},\n", |
150 | | - " # \"parse_function\": parse_response_multiple_texts_bloomz,\n", |
151 | | - " # \"prompt\": \"\"\"question: \\\"{question}\"\\\\n\\nContext: \\\"{context}\"\\\\n\\nAnswer:\"\"\",\n", |
| 96 | + " \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n", |
| 97 | + " \"model_version\": \"1.*\",\n", |
| 98 | + " \"instance type\": \"ml.g5.24xlarge\"\n", |
| 99 | + " }\n", |
| 100 | + " # \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n", |
| 101 | + " # \"model_version\": \"3.*\",\n", |
| 102 | + " # \"instance type\": \"ml.g5.12xlarge\"\n", |
152 | 103 | " # },\n", |
153 | 104 | " # \"huggingface-text2text-flan-ul2-bf16\": {\n", |
154 | | - " # \"instance type\": \"ml.g5.24xlarge\",\n", |
155 | | - " # \"env\": {\n", |
156 | | - " # \"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\",\n", |
157 | | - " # \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"\n", |
158 | | - " # },\n", |
159 | | - " # \"parse_function\": parse_response_model_flan_t5,\n", |
160 | | - " # \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n", |
161 | | - " # },\n", |
| 105 | + " # \"model_version\": \"2.*\",\n", |
| 106 | + " # \"instance type\": \"ml.g5.24xlarge\"\n", |
| 107 | + " # }\n", |
162 | 108 | "}" |
163 | 109 | ] |
164 | 110 | }, |
|
168 | 114 | "metadata": {}, |
169 | 115 | "outputs": [], |
170 | 116 | "source": [ |
171 | | - "newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n", |
172 | | - "\n", |
173 | 117 | "for model_id in _MODEL_CONFIG_:\n", |
174 | | - " endpoint_name = name_from_base(f\"jumpstart-example-raglc-{model_id}\")\n", |
175 | | - " inference_instance_type = _MODEL_CONFIG_[model_id][\"instance type\"]\n", |
176 | | - "\n", |
177 | | - " # Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.\n", |
178 | | - " deploy_image_uri = image_uris.retrieve(\n", |
179 | | - " region=None,\n", |
180 | | - " framework=None, # automatically inferred from model_id\n", |
181 | | - " image_scope=\"inference\",\n", |
| 118 | + " endpoint_name = name_from_base(f'jumpstart-example-raglc-{model_id}')\n", |
| 119 | + " inference_instance_type = _MODEL_CONFIG_[model_id]['instance type']\n", |
| 120 | + " model_version = _MODEL_CONFIG_[model_id]['model_version']\n", |
| 121 | + "\n", |
| 122 | + " print(f'Deploying {model_id}...')\n", |
| 123 | + "\n", |
| 124 | + " model = JumpStartModel(\n", |
182 | 125 | " model_id=model_id,\n", |
183 | | - " model_version=model_version,\n", |
184 | | - " instance_type=inference_instance_type,\n", |
185 | | - " )\n", |
186 | | - " # Retrieve the model uri.\n", |
187 | | - " model_uri = model_uris.retrieve(\n", |
188 | | - " model_id=model_id, model_version=model_version, model_scope=\"inference\"\n", |
189 | | - " )\n", |
190 | | - " model_inference = Model(\n", |
191 | | - " image_uri=deploy_image_uri,\n", |
192 | | - " model_data=model_uri,\n", |
193 | | - " role=aws_role,\n", |
194 | | - " predictor_cls=Predictor,\n", |
195 | | - " name=endpoint_name,\n", |
196 | | - " env=_MODEL_CONFIG_[model_id][\"env\"],\n", |
197 | | - " )\n", |
198 | | - " model_predictor_inference = model_inference.deploy(\n", |
199 | | - " initial_instance_count=1,\n", |
200 | | - " instance_type=inference_instance_type,\n", |
201 | | - " predictor_cls=Predictor,\n", |
202 | | - " endpoint_name=endpoint_name,\n", |
| 126 | + " model_version=model_version\n", |
203 | 127 | " )\n", |
204 | | - " print(f\"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}\")\n", |
205 | | - " _MODEL_CONFIG_[model_id][\"endpoint_name\"] = endpoint_name" |
| 128 | + "\n", |
| 129 | + " try:\n", |
| 130 | + " predictor = model.deploy(\n", |
| 131 | + " initial_instance_count=1,\n", |
| 132 | + " instance_type=inference_instance_type,\n", |
| 133 | + " endpoint_name=name_from_base(\n", |
| 134 | + " f\"jumpstart-example-raglc-{model_id}\"\n", |
| 135 | + " )\n", |
| 136 | + " )\n", |
| 137 | + " print(f\"Deployed endpoint: {predictor.endpoint_name}\")\n", |
| 138 | + " _MODEL_CONFIG_[model_id]['predictor'] = predictor\n", |
| 139 | + " except Exception as e:\n", |
| 140 | + " print(f\"Error deploying {model_id}: {str(e)}\")\n", |
| 141 | + "\n", |
| 142 | + "print(\"Deployment process completed.\")" |
206 | 143 | ] |
207 | 144 | }, |
208 | 145 | { |
|
229 | 166 | "metadata": {}, |
230 | 167 | "outputs": [], |
231 | 168 | "source": [ |
232 | | - "payload = {\n", |
233 | | - " \"text_inputs\": question,\n", |
234 | | - " \"max_length\": 100,\n", |
235 | | - " \"num_return_sequences\": 1,\n", |
236 | | - " \"top_k\": 50,\n", |
237 | | - " \"top_p\": 0.95,\n", |
238 | | - " \"do_sample\": True,\n", |
239 | | - "}\n", |
240 | | - "\n", |
241 | 169 | "list_of_LLMs = list(_MODEL_CONFIG_.keys())\n", |
242 | | - "list_of_LLMs.remove(\"huggingface-textembedding-gpt-j-6b\") # remove the embedding model\n", |
243 | | - "\n", |
| 170 | + "list_of_LLMs = [model for model in list_of_LLMs if \"textembedding\" not in model]\n", |
244 | 171 | "\n", |
245 | 172 | "for model_id in list_of_LLMs:\n", |
246 | | - " endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n", |
247 | | - " query_response = query_endpoint_with_json_payload(\n", |
248 | | - " json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n", |
249 | | - " )\n", |
250 | | - " generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n", |
251 | | - " print(f\"For model: {model_id}, the generated output is: {generated_texts[0]}\\n\")" |
| 173 | + " predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n", |
| 174 | + " response = predictor.predict({\n", |
| 175 | + " \"inputs\": question\n", |
| 176 | + " })\n", |
| 177 | + " print(f\"For model: {model_id}, the generated output is:\\n\")\n", |
| 178 | + " print(f\"{response[0]['generated_text']}\\n\")" |
252 | 179 | ] |
253 | 180 | }, |
254 | 181 | { |
|
283 | 210 | "metadata": {}, |
284 | 211 | "outputs": [], |
285 | 212 | "source": [ |
286 | | - "parameters = {\n", |
287 | | - " \"max_length\": 200,\n", |
288 | | - " \"num_return_sequences\": 1,\n", |
289 | | - " \"top_k\": 250,\n", |
290 | | - " \"top_p\": 0.95,\n", |
291 | | - " \"do_sample\": False,\n", |
292 | | - " \"temperature\": 1,\n", |
293 | | - "}\n", |
| 213 | + "prompt = f'Answer based on context:\\n\\n{context}\\n\\n{question}'\n", |
294 | 214 | "\n", |
295 | 215 | "for model_id in list_of_LLMs:\n", |
296 | | - " endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n", |
297 | | - "\n", |
298 | | - " prompt = _MODEL_CONFIG_[model_id][\"prompt\"]\n", |
299 | | - "\n", |
300 | | - " text_input = prompt.replace(\"{context}\", context)\n", |
301 | | - " text_input = text_input.replace(\"{question}\", question)\n", |
302 | | - " payload = {\"text_inputs\": text_input, **parameters}\n", |
303 | | - "\n", |
304 | | - " query_response = query_endpoint_with_json_payload(\n", |
305 | | - " json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n", |
306 | | - " )\n", |
307 | | - " generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n", |
308 | | - " print(\n", |
309 | | - " f\"{bold}For model: {model_id}, the generated output is: {generated_texts[0]}{unbold}{newline}\"\n", |
310 | | - " )" |
| 216 | + " predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n", |
| 217 | + " response = predictor.predict({\n", |
| 218 | + " \"inputs\": prompt\n", |
| 219 | + " })\n", |
| 220 | + " print(f\"For model: {model_id}, the generated output is:\\n\")\n", |
| 221 | + " print(f\"{response[0]['generated_text']}\\n\")" |
311 | 222 | ] |
312 | 223 | }, |
313 | 224 | { |
|
405 | 316 | "\n", |
406 | 317 | "\n", |
407 | 318 | "content_handler = ContentHandler()\n", |
| 319 | + "endpoint_name=_MODEL_CONFIG_[\n", |
| 320 | + " \"huggingface-textembedding-all-MiniLM-L6-v2\"\n", |
| 321 | + " ][\"predictor\"].endpoint_name\n", |
408 | 322 | "\n", |
409 | 323 | "embeddings = SagemakerEndpointEmbeddingsJumpStart(\n", |
410 | | - " endpoint_name=_MODEL_CONFIG_[\"huggingface-textembedding-gpt-j-6b\"][\"endpoint_name\"],\n", |
| 324 | + " endpoint_name=endpoint_name,\n", |
411 | 325 | " region_name=aws_region,\n", |
412 | 326 | " content_handler=content_handler,\n", |
413 | 327 | ")" |
|
428 | 342 | "source": [ |
429 | 343 | "from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint\n", |
430 | 344 | "\n", |
431 | | - "parameters = {\n", |
432 | | - " \"max_length\": 200,\n", |
433 | | - " \"num_return_sequences\": 1,\n", |
434 | | - " \"top_k\": 250,\n", |
435 | | - " \"top_p\": 0.95,\n", |
436 | | - " \"do_sample\": False,\n", |
437 | | - " \"temperature\": 1,\n", |
438 | | - "}\n", |
439 | | - "\n", |
440 | | - "\n", |
441 | 345 | "class ContentHandler(LLMContentHandler):\n", |
442 | 346 | " content_type = \"application/json\"\n", |
443 | 347 | " accepts = \"application/json\"\n", |
444 | 348 | "\n", |
445 | 349 | " def transform_input(self, prompt: str, model_kwargs={}) -> bytes:\n", |
446 | | - " input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n", |
| 350 | + " input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n", |
447 | 351 | " return input_str.encode(\"utf-8\")\n", |
448 | 352 | "\n", |
449 | 353 | " def transform_output(self, output: bytes) -> str:\n", |
450 | 354 | " response_json = json.loads(output.read().decode(\"utf-8\"))\n", |
451 | | - " return response_json[\"generated_texts\"][0]\n", |
| 355 | + " return response_json[0][\"generated_text\"]\n", |
452 | 356 | "\n", |
453 | 357 | "\n", |
454 | 358 | "content_handler = ContentHandler()\n", |
| 359 | + "endpoint_name=_MODEL_CONFIG_[\n", |
| 360 | + " \"huggingface-text2text-flan-t5-xxl\"\n", |
| 361 | + " ][\"predictor\"].endpoint_name\n", |
| 362 | + "\n", |
455 | 363 | "\n", |
456 | 364 | "sm_llm = SagemakerEndpoint(\n", |
457 | | - " endpoint_name=_MODEL_CONFIG_[\"huggingface-text2text-flan-t5-xxl\"][\"endpoint_name\"],\n", |
| 365 | + " endpoint_name=endpoint_name,\n", |
458 | 366 | " region_name=aws_region,\n", |
459 | 367 | " model_kwargs=parameters,\n", |
460 | 368 | " content_handler=content_handler,\n", |
|
1384 | 1292 | ], |
1385 | 1293 | "instance_type": "ml.t3.medium", |
1386 | 1294 | "kernelspec": { |
1387 | | - "display_name": "Python 3 (Data Science 2.0)", |
| 1295 | + "display_name": "Python 3 (ipykernel)", |
1388 | 1296 | "language": "python", |
1389 | | - "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-38" |
| 1297 | + "name": "python3" |
1390 | 1298 | }, |
1391 | 1299 | "language_info": { |
1392 | 1300 | "codemirror_mode": { |
|
1398 | 1306 | "name": "python", |
1399 | 1307 | "nbconvert_exporter": "python", |
1400 | 1308 | "pygments_lexer": "ipython3", |
1401 | | - "version": "3.8.13" |
| 1309 | + "version": "3.11.9" |
1402 | 1310 | } |
1403 | 1311 | }, |
1404 | 1312 | "nbformat": 4, |
|
0 commit comments