Skip to content

Commit

Permalink
Add special case to avoid quantizing conv in Moonshine
Browse files Browse the repository at this point in the history
- Add a define to prevent quantizing the first conv layers in the
  Moonshine preprocessor
- Add options to enable rotary positional embeddings in the Transformer
  Encoder spec.
  • Loading branch information
njeffrie committed Nov 5, 2024
1 parent 6373848 commit c09c876
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ option(BUILD_TESTS "Compile the tests" OFF)
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF)
option(WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF)
option(MOONSHINE "Compile with moonshine specializations" OFF)

if (MOONSHINE)
add_definitions(-DMOONSHINE)
endif()

if(ENABLE_PROFILING)
message(STATUS "Enable profiling support")
Expand Down
20 changes: 20 additions & 0 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def __init__(
relative_attention_bias: bool = False,
ffn_glu: bool = False,
rms_norm: bool = False,
rotary_dim: Optional[int] = None,
rotary_interleave: bool = True,
rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
rotary_scaling_factor: float = 1,
rotary_base: float = 10000,
multi_query_attention: bool = False,
):
"""Initializes a Transformer encoder specification.
Expand Down Expand Up @@ -66,6 +71,11 @@ def __init__(
relative_attention_bias=relative_attention_bias,
ffn_glu=ffn_glu,
rms_norm=rms_norm,
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=rotary_base,
num_heads_kv=1 if multi_query_attention else None,
)
for _ in range(num_layers)
Expand Down Expand Up @@ -251,6 +261,11 @@ def __init__(
relative_attention_bias=False,
ffn_glu=False,
rms_norm=False,
rotary_dim=None,
rotary_interleave=True,
rotary_scaling_type=None,
rotary_scaling_factor=1,
rotary_base=10000,
num_heads_kv=None,
sliding_window=None,
):
Expand All @@ -259,6 +274,11 @@ def __init__(
relative_position=relative_position,
relative_attention_bias=relative_attention_bias,
rms_norm=rms_norm,
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=rotary_base,
num_heads_kv=num_heads_kv,
sliding_window=sliding_window,
)
Expand Down
3 changes: 3 additions & 0 deletions src/models/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ namespace ctranslate2 {
if (device == Device::CUDA
#ifdef CT2_WITH_DNNL
|| true
#endif
#ifdef MOONSHINE
|| true
#endif
) {
variable_weight_dtype = float_dtype;
Expand Down

0 comments on commit c09c876

Please sign in to comment.