Skip to content

Commit 15fc6ea

Browse files
committed
chore: add normalize_embeddings to job input
1 parent a26d2b1 commit 15fc6ea

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,6 @@ RUNPOD_AI_API_KEY='**************' RUNPOD_ENDPOINT_ID='*******' python predict.p
137137
```
138138
To run with streaming enabled, use the `--stream` option. To set generation parameters, use the `--params_json` option to pass a JSON string of parameters:
139139
```bash
140-
RUNPOD_AI_API_KEY='**************' RUNPOD_ENDPOINT_ID='*******' python predict.py --params_json '{"sentences": ["Explain The Great Gatsby in 4000 words.", "What is The Great Gatsby about?"]}'
140+
RUNPOD_AI_API_KEY='**************' RUNPOD_ENDPOINT_ID='*******' python predict.py --params_json '{"sentences": ["Explain The Great Gatsby in 4000 words.", "What is The Great Gatsby about?"], normalize_embeddings: true}'
141141
```
142142
You can generate the API key [here](https://www.runpod.io/console/serverless/user/settings) under API Keys.

handler.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ def load_model():
2323
def handler(job):
2424
job_input = job['input']
2525
sentences = job_input.pop("sentences")
26+
normalize_embeddings = job_input.pop("normalize_embeddings", False)
2627
model = load_model()
2728

28-
embeddings = model.encode(sentences)
29+
embeddings = model.encode(sentences, normalize_embeddings=normalize_embeddings)
2930
encoded_embeddings = json.dumps(embeddings, cls=NumpyArrayEncoder)
3031
decoded_embeddings = json.loads(encoded_embeddings)
3132
yield decoded_embeddings

0 commit comments

Comments
 (0)