Skip to content

Commit

Permalink
feat(server): Support for env value for GPTQ_BITS and GPTQ_GROUPSIZE. (
Browse files Browse the repository at this point in the history
…huggingface#580)

# What does this PR do?

Some models are already converted, and do not have those values in the
file, this enables users to use them with less friction.

Went for pure env based because adding flags would end up (imo) very
tedious to maintain. There's a lot of sanitation to do: those flags
would be errors if not used in conjuction with `--quantize gptq`.
Then the flags need to exist in the launcher and the server passing them
all throughout all function calls.

This PR is intended as an easy escape hatch, not the defacto method to
use gptq in TGI.

Fixes huggingface#500
  • Loading branch information
Narsil authored Jul 12, 2023
1 parent f018143 commit 5bd2ab6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
FastLayerNorm,
get_linear,
)
from safetensors import SafetensorError


def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):

if config.quantize == "gptq":
return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size
Expand Down Expand Up @@ -74,8 +74,17 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)

g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
try:
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os

bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e

weight = (qweight, qzeros, scales, g_idx, bits, groupsize)

Expand All @@ -99,7 +108,6 @@ def _load_multi_mqa_gptq(
def _load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):

if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
Expand Down
15 changes: 12 additions & 3 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path
from typing import List, Dict, Optional
from safetensors import safe_open
from safetensors import safe_open, SafetensorError
import torch


Expand Down Expand Up @@ -120,8 +120,17 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
torch.testing.assert_close(w2, w[0])
g_idx = w[0]

bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os

bits = int(os.getenv("GTPQ_BITS"))
groupsize = int(os.getenv("GTPQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
Expand Down

0 comments on commit 5bd2ab6

Please sign in to comment.