Skip to content

Commit 947c1b6

Browse files
authored
Support for max_window_layers (#157)
1 parent d4e2fc1 commit 947c1b6

File tree

4 files changed

+87
-3
lines changed

4 files changed

+87
-3
lines changed

fast_llm/layers/transformer/attention.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,20 @@ def _query_key_value_backward(
274274
input_grad.add_(self.key_value.backward(key_grad, context.pop("key_value")))
275275
return input_grad
276276

277+
278+
def _decide_window_size(self) -> int | None:
279+
# NOTE: This is a temporal solution for qwen 2.X
280+
# https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71
281+
# TODO: make universal per layer config
282+
window_size = self._config.window_size
283+
if (
284+
self._config.max_window_layers is not None
285+
and self._layer_index < self._config.max_window_layers
286+
):
287+
window_size = None
288+
289+
return window_size
290+
277291
def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]:
278292
sequence_first = kwargs[TransformerKwargs.sequence_first]
279293
query, key_value = self._query_key_value(input_, sequence_first)
@@ -323,13 +337,16 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[
323337
query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q])
324338
key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k])
325339

340+
341+
window_size = self._decide_window_size()
342+
326343
if self._use_flash_attention:
327344
input_ = flash_attn(
328345
query,
329346
key,
330347
value,
331348
dropout_p=self._config.attention_dropout if self.training else 0.0,
332-
window_size=self._config.window_size,
349+
window_size=window_size,
333350
causal=True,
334351
generator=self._tensor_space.distributed.tp_generator,
335352
softmax_scale=self._softmax_scale,

fast_llm/layers/transformer/config.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,12 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig):
443443
hint=FieldHint.feature,
444444
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
445445
)
446+
max_window_layers: int | None = Field(
447+
default=None,
448+
desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.",
449+
hint=FieldHint.optional,
450+
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
451+
)
446452
# normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto
447453
mlp_recompute_level: MLPRecomputeLevel = Field(
448454
default=MLPRecomputeLevel.none,
@@ -571,4 +577,10 @@ def _validate(self) -> None:
571577
Assert.geq(scale, 0)
572578

573579
def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool:
574-
return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16)
580+
use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16)
581+
582+
# Config parameter `window_size` only can be used with flash attention
583+
if not use_flash_attention:
584+
Assert.is_(self.window_size, None)
585+
586+
return use_flash_attention

tests/test_attention.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import unittest.mock
2+
from fast_llm.layers.transformer.attention import Attention
3+
from fast_llm.layers.transformer.config import TransformerConfig
4+
5+
6+
def test_decide_window_size():
7+
attention = unittest.mock.Mock(spec=Attention)
8+
attention._decide_window_size = Attention._decide_window_size.__get__(attention) # Attach real method
9+
10+
# Arrange - Case 1: window_size is returned (layer_index >= max_window_layers)
11+
attention._config = TransformerConfig(window_size=512, max_window_layers=2)
12+
attention._layer_index = 2
13+
assert attention._decide_window_size() == 512
14+
15+
# Arrange - Case 2: window_size is None (layer_index < max_window_layers)
16+
attention._config = TransformerConfig(window_size=512, max_window_layers=2)
17+
attention._layer_index = 1
18+
assert attention._decide_window_size() is None
19+
20+
# Arrange - Case 3: max_window_layers is None (always return window_size)
21+
attention._config = TransformerConfig(window_size=512, max_window_layers=None)
22+
assert attention._decide_window_size() == 512

tests/test_config.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
import pathlib
2+
import pytest
23
import subprocess
3-
4+
import unittest.mock
45
import yaml
56

7+
8+
from fast_llm.layers.transformer.config import TransformerConfig
9+
from fast_llm.utils import Assert
10+
from fast_llm.engine.distributed.config import DistributedConfig
11+
from fast_llm.engine.config_utils.data_type import DataType
12+
613
from fast_llm.models.auto import trainer_registry
714

815

@@ -51,3 +58,29 @@ def test_validate_example_config():
5158
(pathlib.Path(__file__).parents[1] / "examples" / "mistral.yaml").read_text()
5259
)
5360
trainer_registry["gpt"].from_dict(fast_llm_config_dict)
61+
62+
63+
def test_do_use_flash_attention():
64+
# Create a mock DistributedConfig
65+
mock_distributed_config = unittest.mock.Mock(spec=DistributedConfig)
66+
67+
# Test case 1: use_flash_attention is True and training_dtype is float16
68+
config = TransformerConfig(use_flash_attention=True, window_size=None)
69+
mock_distributed_config.training_dtype = DataType.float16
70+
assert config.do_use_flash_attention(mock_distributed_config) is True
71+
72+
# Test case 2: use_flash_attention is False
73+
config = TransformerConfig(use_flash_attention=False, window_size=None)
74+
mock_distributed_config.training_dtype = DataType.float16
75+
assert config.do_use_flash_attention(mock_distributed_config) is False
76+
77+
# Test case 3: use_flash_attention is True but training_dtype is not float16 or bfloat16
78+
config = TransformerConfig(use_flash_attention=True, window_size=None)
79+
mock_distributed_config.training_dtype = DataType.float32
80+
assert config.do_use_flash_attention(mock_distributed_config) is False
81+
82+
# Test case 4: use_flash_attention is False and window_size is not None
83+
config = TransformerConfig(use_flash_attention=False, window_size=512)
84+
mock_distributed_config.training_dtype = DataType.float32
85+
with pytest.raises(AssertionError):
86+
config.do_use_flash_attention(mock_distributed_config)

0 commit comments

Comments
 (0)