Skip to content

Commit

Permalink
fix(server): fix quantization python requirements (huggingface#708)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Jul 27, 2023
1 parent e64a658 commit 8bd0adb
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 21 deletions.
14 changes: 13 additions & 1 deletion server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ tokenizers = "0.13.3"
huggingface-hub = "^0.14.1"
transformers = "4.29.2"
einops = "^0.6.1"
texttable = "^1.6.7"

[tool.poetry.extras]
accelerate = ["accelerate"]
Expand Down
1 change: 1 addition & 0 deletions server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0"
texttable==1.6.7 ; python_version >= "3.9" and python_version < "4.0"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
Expand Down
1 change: 0 additions & 1 deletion server/text_generation_server/models/flash_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(
if config.quantize == "gptq":
weights._set_gptq_params(model_id)


model = FlashRWForCausalLM(config, weights)

torch.distributed.barrier(group=self.process_group)
Expand Down
26 changes: 13 additions & 13 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,21 @@ async def Decode(self, request, context):


def serve(
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
):
async def serve_inner(
model_id: str,
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
model_id: str,
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
):
unix_socket_template = "unix://{}-{}"
if sharded:
Expand Down
8 changes: 2 additions & 6 deletions server/text_generation_server/utils/gptq/quantize.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import math
import json
import os
import torch
import transformers

from texttable import Texttable
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import transformers
from huggingface_hub import HfApi
import numpy as np
import torch
from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files
Expand Down

0 comments on commit 8bd0adb

Please sign in to comment.