Skip to content

Commit 5e335ea

Browse files
authored
feat(transformers): support also text generation (#1630)
* feat(transformers): support also text generation Signed-off-by: Ettore Di Giacinto <[email protected]> * embedded: set seed -1 --------- Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent d5d82ba commit 5e335ea

File tree

7 files changed

+51
-8
lines changed

7 files changed

+51
-8
lines changed

backend/python/transformers/transformers_server.py

+45-8
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import grpc
1717
import torch
18-
18+
import torch.cuda
1919
from transformers import AutoTokenizer, AutoModel
2020

2121
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -70,14 +70,10 @@ def LoadModel(self, request, context):
7070
try:
7171
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # trust_remote_code is needed to use the encode method with embeddings models like jinai-v2
7272
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
73-
74-
if request.CUDA:
73+
if request.CUDA or torch.cuda.is_available():
7574
try:
76-
# TODO: also tensorflow, make configurable
77-
import torch.cuda
78-
if torch.cuda.is_available():
79-
print("Loading model", model_name, "to CUDA.", file=sys.stderr)
80-
self.model = self.model.to("cuda")
75+
print("Loading model", model_name, "to CUDA.", file=sys.stderr)
76+
self.model = self.model.to("cuda")
8177
except Exception as err:
8278
print("Not using CUDA:", err, file=sys.stderr)
8379
except Exception as err:
@@ -113,6 +109,47 @@ def Embedding(self, request, context):
113109
print("Embeddings:", sentence_embeddings, file=sys.stderr)
114110
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
115111

112+
def Predict(self, request, context):
113+
"""
114+
Generates text based on the given prompt and sampling parameters.
115+
116+
Args:
117+
request: The predict request.
118+
context: The gRPC context.
119+
120+
Returns:
121+
backend_pb2.Reply: The predict result.
122+
"""
123+
if request.TopP == 0:
124+
request.TopP = 0.9
125+
126+
max_tokens = 200
127+
if request.Tokens > 0:
128+
max_tokens = request.Tokens
129+
130+
inputs = self.tokenizer.tokenizer(request.Prompt, return_tensors="pt").input_ids
131+
outputs = self.model.generate(inputs,max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
132+
133+
generated_text = self.tokenizer.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
134+
# Remove prompt from response if present
135+
if request.Prompt in generated_text:
136+
generated_text = generated_text.replace(request.Prompt, "")
137+
138+
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
139+
140+
def PredictStream(self, request, context):
141+
"""
142+
Generates text based on the given prompt and sampling parameters, and streams the results.
143+
144+
Args:
145+
request: The predict stream request.
146+
context: The gRPC context.
147+
148+
Returns:
149+
backend_pb2.Result: The predict stream result.
150+
"""
151+
yield self.Predict(request, context)
152+
116153

117154
def serve(address):
118155
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))

embedded/models/dolphin-2.5-mixtral-8x7b.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ parameters:
55
temperature: 0.2
66
top_k: 40
77
top_p: 0.95
8+
seed: -1
89
template:
910
chat_message: |
1011
<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "user"}}user{{end}}

embedded/models/llava.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ parameters:
1717
temperature: 0.2
1818
top_k: 40
1919
top_p: 0.95
20+
seed: -1
2021

2122
template:
2223
chat: |

embedded/models/mistral-openorca.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ parameters:
55
temperature: 0.2
66
top_k: 40
77
top_p: 0.95
8+
seed: -1
89
template:
910
chat_message: |
1011
<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "user"}}user{{end}}

embedded/models/mixtral-instruct.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ parameters:
44
model: huggingface://TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF/mixtral-8x7b-instruct-v0.1.Q2_K.gguf
55
temperature: 0.2
66
top_k: 40
7+
seed: -1
78
top_p: 0.95
89
template:
910
chat: &chat |

embedded/models/tinyllama-chat.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ parameters:
44
model: huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q8_0.gguf
55
temperature: 0.2
66
top_k: 40
7+
seed: -1
78
top_p: 0.95
89
template:
910
chat_message: |

examples/configurations/phi-2.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ parameters:
1010
temperature: 0.2
1111
top_k: 40
1212
top_p: 0.95
13+
seed: -1
1314
template:
1415
chat: &template |
1516
Instruct: {{.Input}}

0 commit comments

Comments
 (0)