|
15 | 15 |
|
16 | 16 | import grpc
|
17 | 17 | import torch
|
18 |
| - |
| 18 | +import torch.cuda |
19 | 19 | from transformers import AutoTokenizer, AutoModel
|
20 | 20 |
|
21 | 21 | _ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
@@ -70,14 +70,10 @@ def LoadModel(self, request, context):
|
70 | 70 | try:
|
71 | 71 | self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # trust_remote_code is needed to use the encode method with embeddings models like jinai-v2
|
72 | 72 | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
73 |
| - |
74 |
| - if request.CUDA: |
| 73 | + if request.CUDA or torch.cuda.is_available(): |
75 | 74 | try:
|
76 |
| - # TODO: also tensorflow, make configurable |
77 |
| - import torch.cuda |
78 |
| - if torch.cuda.is_available(): |
79 |
| - print("Loading model", model_name, "to CUDA.", file=sys.stderr) |
80 |
| - self.model = self.model.to("cuda") |
| 75 | + print("Loading model", model_name, "to CUDA.", file=sys.stderr) |
| 76 | + self.model = self.model.to("cuda") |
81 | 77 | except Exception as err:
|
82 | 78 | print("Not using CUDA:", err, file=sys.stderr)
|
83 | 79 | except Exception as err:
|
@@ -113,6 +109,47 @@ def Embedding(self, request, context):
|
113 | 109 | print("Embeddings:", sentence_embeddings, file=sys.stderr)
|
114 | 110 | return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
|
115 | 111 |
|
| 112 | + def Predict(self, request, context): |
| 113 | + """ |
| 114 | + Generates text based on the given prompt and sampling parameters. |
| 115 | +
|
| 116 | + Args: |
| 117 | + request: The predict request. |
| 118 | + context: The gRPC context. |
| 119 | +
|
| 120 | + Returns: |
| 121 | + backend_pb2.Reply: The predict result. |
| 122 | + """ |
| 123 | + if request.TopP == 0: |
| 124 | + request.TopP = 0.9 |
| 125 | + |
| 126 | + max_tokens = 200 |
| 127 | + if request.Tokens > 0: |
| 128 | + max_tokens = request.Tokens |
| 129 | + |
| 130 | + inputs = self.tokenizer.tokenizer(request.Prompt, return_tensors="pt").input_ids |
| 131 | + outputs = self.model.generate(inputs,max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP) |
| 132 | + |
| 133 | + generated_text = self.tokenizer.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
| 134 | + # Remove prompt from response if present |
| 135 | + if request.Prompt in generated_text: |
| 136 | + generated_text = generated_text.replace(request.Prompt, "") |
| 137 | + |
| 138 | + return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) |
| 139 | + |
| 140 | + def PredictStream(self, request, context): |
| 141 | + """ |
| 142 | + Generates text based on the given prompt and sampling parameters, and streams the results. |
| 143 | +
|
| 144 | + Args: |
| 145 | + request: The predict stream request. |
| 146 | + context: The gRPC context. |
| 147 | +
|
| 148 | + Returns: |
| 149 | + backend_pb2.Result: The predict stream result. |
| 150 | + """ |
| 151 | + yield self.Predict(request, context) |
| 152 | + |
116 | 153 |
|
117 | 154 | def serve(address):
|
118 | 155 | server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
0 commit comments