Skip to content

Commit 6e7ed40

Browse files
committed
aws#4725: Change model deployment to JumpStart
1 parent faf8648 commit 6e7ed40

1 file changed

Lines changed: 67 additions & 159 deletions

File tree

introduction_to_amazon_algorithms/jumpstart-foundation-models/question_answering_retrieval_augmented_generation/question_answering_langchain_jumpstart.ipynb

Lines changed: 67 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
"cell_type": "code",
4646
"execution_count": null,
4747
"metadata": {
48-
"collapsed": false,
4948
"jupyter": {
5049
"outputs_hidden": false
5150
},
@@ -56,10 +55,7 @@
5655
},
5756
"outputs": [],
5857
"source": [
59-
"!pip install --upgrade sagemaker --quiet\n",
60-
"!pip install ipywidgets==7.0.0 --quiet\n",
61-
"!pip install langchain==0.0.148 --quiet\n",
62-
"!pip install faiss-cpu --quiet"
58+
"!pip install --upgrade sagemaker --quiet"
6359
]
6460
},
6561
{
@@ -70,52 +66,11 @@
7066
},
7167
"outputs": [],
7268
"source": [
73-
"import time\n",
74-
"import sagemaker, boto3, json\n",
75-
"from sagemaker.session import Session\n",
76-
"from sagemaker.model import Model\n",
77-
"from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n",
78-
"from sagemaker.predictor import Predictor\n",
69+
"from sagemaker import Session\n",
7970
"from sagemaker.utils import name_from_base\n",
80-
"from typing import Any, Dict, List, Optional\n",
81-
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
82-
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
83-
"\n",
84-
"sagemaker_session = Session()\n",
85-
"aws_role = sagemaker_session.get_caller_identity_arn()\n",
86-
"aws_region = boto3.Session().region_name\n",
87-
"sess = sagemaker.Session()\n",
88-
"model_version = \"1.*\""
89-
]
90-
},
91-
{
92-
"cell_type": "code",
93-
"execution_count": null,
94-
"metadata": {
95-
"tags": []
96-
},
97-
"outputs": [],
98-
"source": [
99-
"def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type=\"application/json\"):\n",
100-
" client = boto3.client(\"runtime.sagemaker\")\n",
101-
" response = client.invoke_endpoint(\n",
102-
" EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json\n",
103-
" )\n",
104-
" return response\n",
105-
"\n",
106-
"\n",
107-
"def parse_response_model_flan_t5(query_response):\n",
108-
" model_predictions = json.loads(query_response[\"Body\"].read())\n",
109-
" generated_text = model_predictions[\"generated_texts\"]\n",
110-
" return generated_text\n",
71+
"from sagemaker.jumpstart.model import JumpStartModel\n",
11172
"\n",
112-
"\n",
113-
"def parse_response_multiple_texts_bloomz(query_response):\n",
114-
" generated_text = []\n",
115-
" model_predictions = json.loads(query_response[\"Body\"].read())\n",
116-
" for x in model_predictions[0]:\n",
117-
" generated_text.append(x[\"generated_text\"])\n",
118-
" return generated_text"
73+
"sagemaker_session = Session()"
11974
]
12075
},
12176
{
@@ -135,30 +90,21 @@
13590
"source": [
13691
"_MODEL_CONFIG_ = {\n",
13792
" \"huggingface-text2text-flan-t5-xxl\": {\n",
138-
" \"instance type\": \"ml.g5.12xlarge\",\n",
139-
" \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
140-
" \"parse_function\": parse_response_model_flan_t5,\n",
141-
" \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n",
142-
" },\n",
143-
" \"huggingface-textembedding-gpt-j-6b\": {\n",
144-
" \"instance type\": \"ml.g5.24xlarge\",\n",
145-
" \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
93+
" \"model_version\": \"2.*\",\n",
94+
" \"instance type\": \"ml.g5.12xlarge\"\n",
14695
" },\n",
147-
" # \"huggingface-textgeneration1-bloomz-7b1-fp16\": {\n",
148-
" # \"instance type\": \"ml.g5.12xlarge\",\n",
149-
" # \"env\": {},\n",
150-
" # \"parse_function\": parse_response_multiple_texts_bloomz,\n",
151-
" # \"prompt\": \"\"\"question: \\\"{question}\"\\\\n\\nContext: \\\"{context}\"\\\\n\\nAnswer:\"\"\",\n",
96+
" \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n",
97+
" \"model_version\": \"1.*\",\n",
98+
" \"instance type\": \"ml.g5.24xlarge\"\n",
99+
" }\n",
100+
" # \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n",
101+
" # \"model_version\": \"3.*\",\n",
102+
" # \"instance type\": \"ml.g5.12xlarge\"\n",
152103
" # },\n",
153104
" # \"huggingface-text2text-flan-ul2-bf16\": {\n",
154-
" # \"instance type\": \"ml.g5.24xlarge\",\n",
155-
" # \"env\": {\n",
156-
" # \"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\",\n",
157-
" # \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"\n",
158-
" # },\n",
159-
" # \"parse_function\": parse_response_model_flan_t5,\n",
160-
" # \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n",
161-
" # },\n",
105+
" # \"model_version\": \"2.*\",\n",
106+
" # \"instance type\": \"ml.g5.24xlarge\"\n",
107+
" # }\n",
162108
"}"
163109
]
164110
},
@@ -168,41 +114,32 @@
168114
"metadata": {},
169115
"outputs": [],
170116
"source": [
171-
"newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n",
172-
"\n",
173117
"for model_id in _MODEL_CONFIG_:\n",
174-
" endpoint_name = name_from_base(f\"jumpstart-example-raglc-{model_id}\")\n",
175-
" inference_instance_type = _MODEL_CONFIG_[model_id][\"instance type\"]\n",
176-
"\n",
177-
" # Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.\n",
178-
" deploy_image_uri = image_uris.retrieve(\n",
179-
" region=None,\n",
180-
" framework=None, # automatically inferred from model_id\n",
181-
" image_scope=\"inference\",\n",
118+
" endpoint_name = name_from_base(f'jumpstart-example-raglc-{model_id}')\n",
119+
" inference_instance_type = _MODEL_CONFIG_[model_id]['instance type']\n",
120+
" model_version = _MODEL_CONFIG_[model_id]['model_version']\n",
121+
"\n",
122+
" print(f'Deploying {model_id}...')\n",
123+
"\n",
124+
" model = JumpStartModel(\n",
182125
" model_id=model_id,\n",
183-
" model_version=model_version,\n",
184-
" instance_type=inference_instance_type,\n",
185-
" )\n",
186-
" # Retrieve the model uri.\n",
187-
" model_uri = model_uris.retrieve(\n",
188-
" model_id=model_id, model_version=model_version, model_scope=\"inference\"\n",
189-
" )\n",
190-
" model_inference = Model(\n",
191-
" image_uri=deploy_image_uri,\n",
192-
" model_data=model_uri,\n",
193-
" role=aws_role,\n",
194-
" predictor_cls=Predictor,\n",
195-
" name=endpoint_name,\n",
196-
" env=_MODEL_CONFIG_[model_id][\"env\"],\n",
197-
" )\n",
198-
" model_predictor_inference = model_inference.deploy(\n",
199-
" initial_instance_count=1,\n",
200-
" instance_type=inference_instance_type,\n",
201-
" predictor_cls=Predictor,\n",
202-
" endpoint_name=endpoint_name,\n",
126+
" model_version=model_version\n",
203127
" )\n",
204-
" print(f\"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}\")\n",
205-
" _MODEL_CONFIG_[model_id][\"endpoint_name\"] = endpoint_name"
128+
"\n",
129+
" try:\n",
130+
" predictor = model.deploy(\n",
131+
" initial_instance_count=1,\n",
132+
" instance_type=inference_instance_type,\n",
133+
" endpoint_name=name_from_base(\n",
134+
" f\"jumpstart-example-raglc-{model_id}\"\n",
135+
" )\n",
136+
" )\n",
137+
" print(f\"Deployed endpoint: {predictor.endpoint_name}\")\n",
138+
" _MODEL_CONFIG_[model_id]['predictor'] = predictor\n",
139+
" except Exception as e:\n",
140+
" print(f\"Error deploying {model_id}: {str(e)}\")\n",
141+
"\n",
142+
"print(\"Deployment process completed.\")"
206143
]
207144
},
208145
{
@@ -229,26 +166,16 @@
229166
"metadata": {},
230167
"outputs": [],
231168
"source": [
232-
"payload = {\n",
233-
" \"text_inputs\": question,\n",
234-
" \"max_length\": 100,\n",
235-
" \"num_return_sequences\": 1,\n",
236-
" \"top_k\": 50,\n",
237-
" \"top_p\": 0.95,\n",
238-
" \"do_sample\": True,\n",
239-
"}\n",
240-
"\n",
241169
"list_of_LLMs = list(_MODEL_CONFIG_.keys())\n",
242-
"list_of_LLMs.remove(\"huggingface-textembedding-gpt-j-6b\") # remove the embedding model\n",
243-
"\n",
170+
"list_of_LLMs = [model for model in list_of_LLMs if \"textembedding\" not in model]\n",
244171
"\n",
245172
"for model_id in list_of_LLMs:\n",
246-
" endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n",
247-
" query_response = query_endpoint_with_json_payload(\n",
248-
" json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n",
249-
" )\n",
250-
" generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n",
251-
" print(f\"For model: {model_id}, the generated output is: {generated_texts[0]}\\n\")"
173+
" predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n",
174+
" response = predictor.predict({\n",
175+
" \"inputs\": question\n",
176+
" })\n",
177+
" print(f\"For model: {model_id}, the generated output is:\\n\")\n",
178+
" print(f\"{response[0]['generated_text']}\\n\")"
252179
]
253180
},
254181
{
@@ -283,31 +210,15 @@
283210
"metadata": {},
284211
"outputs": [],
285212
"source": [
286-
"parameters = {\n",
287-
" \"max_length\": 200,\n",
288-
" \"num_return_sequences\": 1,\n",
289-
" \"top_k\": 250,\n",
290-
" \"top_p\": 0.95,\n",
291-
" \"do_sample\": False,\n",
292-
" \"temperature\": 1,\n",
293-
"}\n",
213+
"prompt = f'Answer based on context:\\n\\n{context}\\n\\n{question}'\n",
294214
"\n",
295215
"for model_id in list_of_LLMs:\n",
296-
" endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n",
297-
"\n",
298-
" prompt = _MODEL_CONFIG_[model_id][\"prompt\"]\n",
299-
"\n",
300-
" text_input = prompt.replace(\"{context}\", context)\n",
301-
" text_input = text_input.replace(\"{question}\", question)\n",
302-
" payload = {\"text_inputs\": text_input, **parameters}\n",
303-
"\n",
304-
" query_response = query_endpoint_with_json_payload(\n",
305-
" json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n",
306-
" )\n",
307-
" generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n",
308-
" print(\n",
309-
" f\"{bold}For model: {model_id}, the generated output is: {generated_texts[0]}{unbold}{newline}\"\n",
310-
" )"
216+
" predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n",
217+
" response = predictor.predict({\n",
218+
" \"inputs\": prompt\n",
219+
" })\n",
220+
" print(f\"For model: {model_id}, the generated output is:\\n\")\n",
221+
" print(f\"{response[0]['generated_text']}\\n\")"
311222
]
312223
},
313224
{
@@ -405,9 +316,12 @@
405316
"\n",
406317
"\n",
407318
"content_handler = ContentHandler()\n",
319+
"endpoint_name=_MODEL_CONFIG_[\n",
320+
" \"huggingface-textembedding-all-MiniLM-L6-v2\"\n",
321+
" ][\"predictor\"].endpoint_name\n",
408322
"\n",
409323
"embeddings = SagemakerEndpointEmbeddingsJumpStart(\n",
410-
" endpoint_name=_MODEL_CONFIG_[\"huggingface-textembedding-gpt-j-6b\"][\"endpoint_name\"],\n",
324+
" endpoint_name=endpoint_name,\n",
411325
" region_name=aws_region,\n",
412326
" content_handler=content_handler,\n",
413327
")"
@@ -428,33 +342,27 @@
428342
"source": [
429343
"from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint\n",
430344
"\n",
431-
"parameters = {\n",
432-
" \"max_length\": 200,\n",
433-
" \"num_return_sequences\": 1,\n",
434-
" \"top_k\": 250,\n",
435-
" \"top_p\": 0.95,\n",
436-
" \"do_sample\": False,\n",
437-
" \"temperature\": 1,\n",
438-
"}\n",
439-
"\n",
440-
"\n",
441345
"class ContentHandler(LLMContentHandler):\n",
442346
" content_type = \"application/json\"\n",
443347
" accepts = \"application/json\"\n",
444348
"\n",
445349
" def transform_input(self, prompt: str, model_kwargs={}) -> bytes:\n",
446-
" input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n",
350+
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
447351
" return input_str.encode(\"utf-8\")\n",
448352
"\n",
449353
" def transform_output(self, output: bytes) -> str:\n",
450354
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
451-
" return response_json[\"generated_texts\"][0]\n",
355+
" return response_json[0][\"generated_text\"]\n",
452356
"\n",
453357
"\n",
454358
"content_handler = ContentHandler()\n",
359+
"endpoint_name=_MODEL_CONFIG_[\n",
360+
" \"huggingface-text2text-flan-t5-xxl\"\n",
361+
" ][\"predictor\"].endpoint_name\n",
362+
"\n",
455363
"\n",
456364
"sm_llm = SagemakerEndpoint(\n",
457-
" endpoint_name=_MODEL_CONFIG_[\"huggingface-text2text-flan-t5-xxl\"][\"endpoint_name\"],\n",
365+
" endpoint_name=endpoint_name,\n",
458366
" region_name=aws_region,\n",
459367
" model_kwargs=parameters,\n",
460368
" content_handler=content_handler,\n",
@@ -1384,9 +1292,9 @@
13841292
],
13851293
"instance_type": "ml.t3.medium",
13861294
"kernelspec": {
1387-
"display_name": "Python 3 (Data Science 2.0)",
1295+
"display_name": "Python 3 (ipykernel)",
13881296
"language": "python",
1389-
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-38"
1297+
"name": "python3"
13901298
},
13911299
"language_info": {
13921300
"codemirror_mode": {
@@ -1398,7 +1306,7 @@
13981306
"name": "python",
13991307
"nbconvert_exporter": "python",
14001308
"pygments_lexer": "ipython3",
1401-
"version": "3.8.13"
1309+
"version": "3.11.9"
14021310
}
14031311
},
14041312
"nbformat": 4,

0 commit comments

Comments
 (0)