Skip to content

Commit

Permalink
feat(server): support new falcon config (huggingface#712)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Jul 27, 2023
1 parent 2efd46e commit ab96b9a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 26 deletions.
11 changes: 3 additions & 8 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,10 @@ def get_model(
trust_remote_code=trust_remote_code,
)

if model_type in ["RefinedWeb", "RefinedWebModel"]:
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
if sharded:
if FLASH_ATTENTION:
if config_dict.get("alibi", False) or (
model_type == "RefinedWebModel"
and config_dict.get("multi_query", True)
):
if config_dict.get("alibi", False):
raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded(
model_id,
Expand All @@ -215,9 +212,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
else:
if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRWSharded(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,19 @@ def __init__(
model_type="RefinedWeb",
vocab_size=250880,
hidden_size=64,
n_layer=2,
n_head=8,
num_hidden_layers=None,
num_attention_heads=None,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
hidden_dropout=0.0,
attention_dropout=0.0,
n_head_kv=None,
num_kv_heads=None,
multi_query=False,
alibi=False,
new_decoder_architecture=None,
bias=False,
parallel_attn=False,
**kwargs,
Expand All @@ -78,8 +79,16 @@ def __init__(
# Backward compatibility with n_embed kwarg
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer
self.n_head = n_head
self.n_layer = (
num_hidden_layers
if num_hidden_layers is not None
else kwargs.pop("n_layer", 2)
)
self.n_head = (
num_attention_heads
if num_attention_heads is not None
else kwargs.pop("n_head", 8)
)
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
Expand All @@ -91,10 +100,21 @@ def __init__(
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

if n_head_kv is not None:
self.n_head_kv = n_head_kv
if num_kv_heads is not None:
self.n_head_kv = num_kv_heads
else:
self.n_head_kv = 1 if multi_query else n_head
old_n_head_kv = kwargs.pop("n_head_kv", None)
if old_n_head_kv is not None:
self.n_head_kv = old_n_head_kv
else:
self.n_head_kv = 1 if multi_query else self.n_head

if new_decoder_architecture is not None:
self.new_decoder_architecture = new_decoder_architecture
elif model_type == "RefinedWeb":
self.new_decoder_architecture = True
else:
self.new_decoder_architecture = False

super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

Expand Down Expand Up @@ -530,26 +550,23 @@ def __init__(self, config, weights):
self.word_embeddings = TensorParallelEmbedding(
prefix="transformer.word_embeddings", weights=weights
)
if config.model_type == "RefinedWebModel":

if config.new_decoder_architecture:
self.h = nn.ModuleList(
[
FlashRWLayer(layer_id, config, weights)
FlashRWLargeLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = self.h[0].self_attention.num_heads_kv
elif config.model_type == "RefinedWeb":
self.cache_size = self.h[0].self_attention.num_groups
else:
self.h = nn.ModuleList(
[
FlashRWLargeLayer(layer_id, config, weights)
FlashRWLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = self.h[0].self_attention.num_groups
else:
raise NotImplementedError(
f"model_type {config.model_type} is not supported."
)
self.cache_size = self.h[0].self_attention.num_heads_kv

self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f",
Expand Down

0 comments on commit ab96b9a

Please sign in to comment.