Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 55 additions & 121 deletions 00-llm-sagemaker-jumpstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"source": [
"## Step 1. Deploy a LLM in SageMaker JumpStart\n",
"\n",
"To better illustrate the idea, let's first deploy the model that is required to perform the demo. You will need to (1) install the required python packages, (2) authenticate the use of AWS services by using an AWS role, (3) define the parse response functions, (4) select a model, (5) deploy the model\n",
"To better illustrate the idea, let's first deploy the model that is required to perform the demo. You will need to (1) install the required python packages, (2) authenticate the use of AWS services by using an AWS role, (3) select a model, (4) deploy the model\n",
"\n",
"When you deploy a model from JumpStart, SageMaker hosts the model and deploys an endpoint that you can use for inference. In this notebook, we focus on the deployment of Flan T5 and demo with the Flan T5 SageMaker endpoint. \n",
"\n",
Expand Down Expand Up @@ -110,6 +110,7 @@
"from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n",
"from sagemaker.predictor import Predictor\n",
"from sagemaker.utils import name_from_base\n",
"from sagemaker.jumpstart.model import JumpStartModel\n",
"\n",
"sagemaker_session = Session()\n",
"aws_role = sagemaker_session.get_caller_identity_arn()\n",
Expand All @@ -125,51 +126,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Define the Response Parser\n",
"\n",
"The below cell includes three functions that outline how you are going to interact with the model’s endpoint, and how you want the output to be formatted. \n",
"\n",
"The **query_endpoint_with_json_payload** function specifies that the input to the model's endpoint is any string of text formatted as json and encoded in utf-8 format. \n",
"\n",
"Both the **parse_response_model_flan_t5** and **parse_response_multiple_texts_bloomz** functions are specific to their respective models, and they ensure that the output from the endpoint is formatted as json and includes the generated text.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"#Defining the parse functions for all of the available models\n",
"\n",
"def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type=\"application/json\"):\n",
" client = boto3.client(\"runtime.sagemaker\")\n",
" response = client.invoke_endpoint(\n",
" EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json\n",
" )\n",
" return response\n",
"\n",
"def parse_response_model_flan_t5(query_response):\n",
" model_predictions = json.loads(query_response[\"Body\"].read())\n",
" generated_text = model_predictions[\"generated_texts\"]\n",
" return generated_text\n",
"\n",
"\n",
"def parse_response_multiple_texts_bloomz(query_response):\n",
" generated_text = []\n",
" model_predictions = json.loads(query_response[\"Body\"].read())\n",
" for x in model_predictions[0]:\n",
" generated_text.append(x[\"generated_text\"])\n",
" return generated_text"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Select an LLM to Deploy\n",
"### 3. Select an LLM to Deploy\n",
"\n",
"As mentioned previously, Amazon SageMaker Jumpstart provides access to hundreds of built-in algorithms with pretrained models from popular model hubs. You can check [the available models](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html) on Amazon SageMaker Jumpstart to get the full available model list.\n",
"\n",
Expand All @@ -189,22 +146,19 @@
"outputs": [],
"source": [
"_MODEL_CONFIG_ = {\n",
" \"huggingface-text2text-flan-t5-small\": {\n",
" \"huggingface-text2text-flan-t5-small\": { \n",
" \"model_predictor\": \"predic-flan-t5\",\n",
" \"instance type\": \"ml.g5.xlarge\",\n",
" \"env\": {\"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
" \"parse_function\": parse_response_model_flan_t5,\n",
" \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n",
" },\n",
" # \"huggingface-textgeneration1-bloomz-7b1-fp16\": {\n",
" # \"instance type\": \"ml.g5.12xlarge\",\n",
" # \"env\": {},\n",
" # \"parse_function\": parse_response_multiple_texts_bloomz,\n",
" # \"prompt\": \"\"\"question: \\\"{question}\"\\\\n\\nContext: \\\"{context}\"\\\\n\\nAnswer:\"\"\",\n",
" # },\n",
" \"huggingface-textgeneration1-bloomz-7b1-fp16\": {\n",
" \"model_predictor\": \"predic-bloomz-7b\",\n",
" \"instance type\": \"ml.g5.xlarge\",\n",
" \"prompt\": \"\"\"question: \\\"{question}\"\\\\n\\nContext: \\\"{context}\"\\\\n\\nAnswer:\"\"\",\n",
" },\n",
" # \"huggingface-text2text-flan-ul2-bf16\": {\n",
" # \"model_predictor\": \"predic-flan-ul2\",\n",
" # \"instance type\": \"ml.g5.24xlarge\",\n",
" # \"env\": {\"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
" # \"parse_function\": parse_response_model_flan_t5,\n",
" # \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n",
" # }\n",
"}"
Expand All @@ -227,6 +181,8 @@
"**model_inference**: the object containing all of the model's attributes \n",
"**model_predictor_inference**: the object that will be used to deploy the model \n",
"\n",
"It's also possible to deploy a Sagemaker endpoint by using a low-code deployment with the JumpStartModel class. Using the model ID to define your model as a JumpStart model, and the deploy method to automatically deploy your model for inference. this is the method you will use. \n",
"\n",
"The following cell can take around 5-10 minutes to process as we are deploying the model endpoint here.\n",
"\n",
"Please note if you decide to deploy multiple models, this cell will take longer to execute.\n"
Expand All @@ -245,36 +201,22 @@
"for model_id in _MODEL_CONFIG_:\n",
" endpoint_name = name_from_base(f\"jumpstart-example-{model_id}\")\n",
" inference_instance_type = _MODEL_CONFIG_[model_id][\"instance type\"]\n",
"\n",
" # Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.\n",
" deploy_image_uri = image_uris.retrieve(\n",
" region=None,\n",
" framework=None, # automatically inferred from model_id\n",
" image_scope=\"inference\",\n",
" model_id=model_id,\n",
" model_version=model_version,\n",
" instance_type=inference_instance_type,\n",
" )\n",
" # Retrieve the model uri.\n",
" model_uri = model_uris.retrieve(\n",
" model_id=model_id, model_version=model_version, model_scope=\"inference\"\n",
" )\n",
" model_inference = Model(\n",
" image_uri=deploy_image_uri,\n",
" model_data=model_uri,\n",
" role=aws_role,\n",
" predictor_cls=Predictor,\n",
" name=endpoint_name,\n",
" env=_MODEL_CONFIG_[model_id][\"env\"],\n",
" \n",
" model_inference = JumpStartModel(\n",
" model_id=model_id, \n",
" model_version=model_version\n",
" )\n",
" model_predictor_inference = model_inference.deploy(\n",
" \n",
" _MODEL_CONFIG_[model_id][\"model_predictor\"] = model_inference.deploy(\n",
" initial_instance_count=1,\n",
" instance_type=inference_instance_type,\n",
" predictor_cls=Predictor,\n",
" endpoint_name=endpoint_name,\n",
" endpoint_name=endpoint_name\n",
" )\n",
" \n",
" _MODEL_CONFIG_[model_id][\"endpoint_name\"] = endpoint_name\n",
" \n",
" print(f\"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}\")\n",
" _MODEL_CONFIG_[model_id][\"endpoint_name\"] = endpoint_name"
" "
]
},
{
Expand Down Expand Up @@ -373,22 +315,20 @@
"outputs": [],
"source": [
"payload = {\n",
" \"text_inputs\": question,\n",
" \"max_length\": 100,\n",
" \"num_return_sequences\": 1,\n",
" \"top_k\": 10,\n",
" \"top_p\": 0.95, #0.95,\n",
" \"do_sample\": True,\n",
" \"inputs\": question,\n",
" \"parameters\": {\n",
" \"max_length\": 100, \n",
" \"num_return_sequences\": 1, \n",
" \"top_k\": 10,\n",
" \"top_p\": 0.95, \n",
" \"do_sample\": True\n",
" },\n",
"}\n",
"\n",
"\n",
"for model_id in _MODEL_CONFIG_:\n",
" endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n",
" query_response = query_endpoint_with_json_payload(\n",
" json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n",
" )\n",
" generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n",
" print(f\"For model: {model_id}, the generated output is: {generated_texts[0]}\\n\")"
" query_response = _MODEL_CONFIG_[model_id][\"model_predictor\"].predict(payload)\n",
" print(f\"For model: {model_id}, the generated output is: {query_response[0]['generated_text']}\\n\")"
]
},
{
Expand Down Expand Up @@ -439,22 +379,19 @@
"outputs": [],
"source": [
"payload = {\n",
" \"text_inputs\": question2,\n",
" \"max_length\": 100,\n",
" \"num_return_sequences\": 10,\n",
" \"top_k\": 3,\n",
" \"top_p\": 0.95, #0.95,\n",
" \"do_sample\": True,\n",
" \"inputs\": question2,\n",
" \"parameters\": {\n",
" \"max_length\": 100, \n",
" \"num_return_sequences\": 1, \n",
" \"top_k\": 10,\n",
" \"top_p\": 0.95, \n",
" \"do_sample\": True\n",
" },\n",
"}\n",
"\n",
"\n",
"for model_id in _MODEL_CONFIG_:\n",
" endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n",
" query_response = query_endpoint_with_json_payload(\n",
" json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n",
" )\n",
" generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n",
" print(f\"For model: {model_id}, the generated output is: {generated_texts[0]}\\n\")"
" query_response = _MODEL_CONFIG_[model_id][\"model_predictor\"].predict(payload)\n",
" print(f\"For model: {model_id}, the generated output is: {query_response[0]['generated_text']}\\n\")"
]
},
{
Expand Down Expand Up @@ -503,30 +440,25 @@
"outputs": [],
"source": [
"parameters = {\n",
" \"max_length\": 200,\n",
" \"num_return_sequences\": 10,\n",
" \"max_length\": 200, \n",
" \"num_return_sequences\": 10, \n",
" \"top_k\": 250,\n",
" \"top_p\": 0.95,\n",
" \"top_p\": 0.95, \n",
" \"do_sample\": False,\n",
" #\"temperature\": 1,\n",
"}\n",
"\n",
"\n",
"for model_id in _MODEL_CONFIG_:\n",
" endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n",
"\n",
" prompt = _MODEL_CONFIG_[model_id][\"prompt\"]\n",
"\n",
" text_input = prompt.replace(\"{context}\", context)\n",
" text_input = text_input.replace(\"{question}\", question2)\n",
" payload = {\"text_inputs\": text_input, **parameters}\n",
"\n",
" query_response = query_endpoint_with_json_payload(\n",
" json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n",
" )\n",
" generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n",
" print(\n",
" f\"{bold}For model: {model_id}, the generated output is: {generated_texts[0]}{unbold}{newline}\"\n",
" )"
" payload = {\"inputs\": text_input, \"parameters\":parameters }\n",
" \n",
" query_response = _MODEL_CONFIG_[model_id][\"model_predictor\"].predict(payload)\n",
" print(f\"For model: {model_id}, the generated output is: {query_response[0]['generated_text']}\\n\")"
]
},
{
Expand Down Expand Up @@ -554,7 +486,9 @@
"outputs": [],
"source": [
"sagemaker_client = boto3.client('sagemaker')\n",
"sagemaker_client.delete_endpoint(EndpointName=endpoint_name)"
"\n",
"for model_id in _MODEL_CONFIG_:\n",
" sagemaker_client.delete_endpoint(EndpointName=_MODEL_CONFIG_[model_id][\"endpoint_name\"])"
]
}
],
Expand Down