Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f1b28f4
Fix dangling config reference causing SIGFPE on all models
bong-water-water-bong Jun 24, 2026
325b9e8
Add BitNet 1.58-bit ternary model support
bong-water-water-bong Jun 24, 2026
b42d8fd
Clean up: mark unused out_features param in dequantize_bitnet_weight
bong-water-water-bong Jun 24, 2026
12987b5
Support Bonsai 1-bit Qwen3 loading
bong-water-water-bong Jun 25, 2026
1be3dca
Add BitNet dequantization to Llama loader
bong-water-water-bong Jun 25, 2026
f3ea92a
Support all 1.58-bit and 1-bit model variants (Falcon-E, Bonsai)
bong-water-water-bong Jun 25, 2026
b04281d
Fix code review: ensure hidden_act defaults to relu2 for BitNet models
bong-water-water-bong Jun 25, 2026
25afb47
Auto-configure ROCm Tensile library paths
bong-water-water-bong Jun 25, 2026
ba75d26
Fix Lille-130m weight loading
bong-water-water-bong Jun 25, 2026
16d9eb8
Auto-configure ROCm Tensile library paths + fix lille-130m weight prefix
bong-water-water-bong Jun 25, 2026
4ebbd85
Fix OpenELM: use explicit num_query_heads/ffn_multipliers from config
bong-water-water-bong Jun 25, 2026
44c902d
Fix quantized lm_head/embed_as_linear: use linear_forward in all models
bong-water-water-bong Jun 25, 2026
26aad7e
Fix MXFP4 quantization support (issue #10)
bong-water-water-bong Jun 25, 2026
59e8b78
Fix BitNet chat template capitalize filter and short-name model aliasing
bong-water-water-bong Jun 25, 2026
d14e188
BitNet: runtime quantized matmul (repack ternary → 2-bit affine) + gr…
bong-water-water-bong Jun 25, 2026
dba1381
BitNet: runtime quantized matmul — final improvements
bong-water-water-bong Jun 25, 2026
d0d33ad
BitNet: fall back to dequantize-at-load for correctness
bong-water-water-bong Jun 25, 2026
ef551f8
BitNet: dequantize-at-load with thorough analysis of quantized path
bong-water-water-bong Jun 25, 2026
9bd0848
BitNet: fix 2-bit runtime repack layout
bong-water-water-bong Jun 25, 2026
7b0c42a
Falcon-E: support inverse-scale BitLinear checkpoints
bong-water-water-bong Jun 25, 2026
fa6fc89
docs: universal HF loading path design spec
bong-water-water-bong Jun 26, 2026
90f61a6
Universal HuggingFace loading path phase 1-3
bong-water-water-bong Jun 26, 2026
72acd40
Universal HF loading: fix review findings
bong-water-water-bong Jun 26, 2026
a1445d1
Universal HF loading: auto-quantize, quantization_config, GGUF skeleton
bong-water-water-bong Jun 26, 2026
9ab50ae
GGUF integration + auto-quantize verified
bong-water-water-bong Jun 26, 2026
b08a19c
Server + ModelManager: --auto-quantize and GGUF flags
bong-water-water-bong Jun 26, 2026
20370ee
Server --auto-quantize + generic HF weight remapping
bong-water-water-bong Jun 26, 2026
560c622
GGUF: full quant format support (Q4_0..Q6_K, K-quants)
bong-water-water-bong Jun 26, 2026
049d031
PyTorch .bin → safetensors converter
bong-water-water-bong Jun 26, 2026
ec6896b
1-bit model support: sub-norm detection + key remapping
bong-water-water-bong Jun 26, 2026
3bca870
Generic Llama fallback for unknown model types
bong-water-water-bong Jun 26, 2026
d03f974
1-bit activation quantization + weight pre-quantization
bong-water-water-bong Jun 26, 2026
a24022b
Architecture registration system + PyTorch trust_remote_code
bong-water-water-bong Jun 26, 2026
a9cd8f9
Edge case hardening: clear error messages for bad paths
bong-water-water-bong Jun 26, 2026
7b0208b
Add NPU backend: IRON JIT GEMM on AMD XDNA NPU
bong-water-water-bong Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ add_library(mlx-lm-llm
src/llm/models/lfm2.cpp
src/llm/models/nemotron_h.cpp
src/llm/models/granite_moe_hybrid.cpp
src/llm/models/bitnet.cpp
)
target_link_libraries(mlx-lm-llm PUBLIC mlx-lm-common)

Expand Down
156 changes: 156 additions & 0 deletions include/mlx-lm/llm/models/bitnet.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// BitNet 1.58-bit model — Llama variant with ternary weights and relu² activation.
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/bitnet.py
//
// Architecture: Llama with three differences:
// 1. relu_squared activation instead of silu
// 2. Sub-layer norms (attn_sub_norm before o_proj, ffn_sub_norm before down_proj)
// 3. Ternary weights {-1, 0, +1} packed as uint8 (4 values per byte), dequantized at load time
//
// Config reuses LlamaConfiguration since all fields are identical.
#pragma once

#include <mlx-lm/llm/models/llama.h>
#include <mlx/mlx.h>
#include <unordered_map>
#include <vector>

namespace mlx_lm {

// BitNet reuses Llama's configuration and JSON deserializer.
using BitNetConfiguration = LlamaConfiguration;

// Dequantize uint8 packed ternary weights to float16.
// Each byte packs 4 ternary values as 2-bit values: 0→-1, 1→0, 2→+1.
// Result is multiplied by weight_scale.
mlx::core::array dequantize_bitnet_weight(
const mlx::core::array& packed_weight,
const mlx::core::array& weight_scale,
int out_features);

// --- BitNet Attention (relu² + sub-layer norm) ---

class BitNetAttention {
const BitNetConfiguration& args_;
float scale_;
LlamaDynamicNTKScalingRoPE rope_;

mlx::core::array wq_weight_;
mlx::core::array wk_weight_;
mlx::core::array wv_weight_;
mlx::core::array wo_weight_;
mlx::core::array attn_sub_norm_weight_;

mlx::core::array linear(const mlx::core::array& x,
const mlx::core::array& weight) const;

public:
explicit BitNetAttention(const BitNetConfiguration& args);

mlx::core::array operator()(
const mlx::core::array& x,
const AttentionMask& mask,
KVCache* cache);

std::unordered_map<std::string, mlx::core::array*> weight_map();
};

// --- BitNet MLP (relu² activation + sub-layer norm) ---

class BitNetMLP {
mlx::core::array gate_weight_;
mlx::core::array down_weight_;
mlx::core::array up_weight_;
mlx::core::array ffn_sub_norm_weight_;
float rms_norm_eps_;

mlx::core::array linear(const mlx::core::array& x,
const mlx::core::array& weight) const;
mlx::core::array rms_norm(const mlx::core::array& x,
const mlx::core::array& weight) const;

public:
explicit BitNetMLP(const BitNetConfiguration& args);

mlx::core::array operator()(const mlx::core::array& x);

std::unordered_map<std::string, mlx::core::array*> weight_map();
};

// --- BitNet Transformer Block ---

class BitNetTransformerBlock {
BitNetAttention attention_;
BitNetMLP mlp_;
mlx::core::array input_layernorm_weight_;
mlx::core::array post_attention_layernorm_weight_;
float rms_norm_eps_;

mlx::core::array rms_norm(const mlx::core::array& x,
const mlx::core::array& weight) const;

public:
explicit BitNetTransformerBlock(const BitNetConfiguration& args);

mlx::core::array operator()(
const mlx::core::array& x,
const AttentionMask& mask,
KVCache* cache);

std::unordered_map<std::string, mlx::core::array*> weight_map();
};

// --- BitNet Model Inner ---

class BitNetModelInner {
mlx::core::array embed_tokens_weight_;
std::vector<BitNetTransformerBlock> layers_;
mlx::core::array norm_weight_;
float rms_norm_eps_;

mlx::core::array rms_norm(const mlx::core::array& x,
const mlx::core::array& weight) const;

public:
explicit BitNetModelInner(const BitNetConfiguration& args);

mlx::core::array operator()(
const mlx::core::array& inputs,
std::vector<KVCache>* cache = nullptr);

mlx::core::array embed_as_linear(const mlx::core::array& x) const;

std::unordered_map<std::string, mlx::core::array*> weight_map();
};

// --- BitNet Model (top-level, CRTP) ---

class BitNetModel
: public LanguageModel<BitNetModel>,
public KVCacheDimensionProvider<BitNetModel> {

friend class LanguageModel<BitNetModel>;
friend class KVCacheDimensionProvider<BitNetModel>;

BitNetConfiguration config_;
BitNetModelInner model_;
std::optional<mlx::core::array> lm_head_weight_;
std::vector<int> kv_heads_;

PrepareResult prepare_impl(const LMInput& input, std::vector<KVCache>& cache, int window_size);
LMOutput call_impl(const LMInput::Text& input, std::vector<KVCache>* cache,
const LMOutput::State* state);
mlx::core::array forward_impl(const mlx::core::array& inputs, std::vector<KVCache>* cache);
std::unordered_map<std::string, mlx::core::array>
sanitize_impl(std::unordered_map<std::string, mlx::core::array> weights);

public:
explicit BitNetModel(const BitNetConfiguration& args);

const std::vector<int>& kv_heads() const { return kv_heads_; }
int vocab_size() const { return config_.vocab_size; }

void load_weights(const std::unordered_map<std::string, mlx::core::array>& weights);
std::unordered_map<std::string, mlx::core::array*> weight_map();
};

} // namespace mlx_lm
3 changes: 3 additions & 0 deletions src/llm/llm_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <mlx-lm/llm/models/lfm2_moe.h>
#include <mlx-lm/llm/models/baichuan_m1.h>
#include <mlx-lm/llm/models/falcon_h1.h>
#include <mlx-lm/llm/models/bitnet.h>
#include <mlx-lm/llm/models/lfm2.h>
#include <mlx-lm/llm/models/nemotron_h.h>
#include <mlx-lm/llm/models/granite_moe_hybrid.h>
Expand Down Expand Up @@ -175,6 +176,7 @@ static std::unordered_map<std::string, LLMLoaderFn>& llm_loaders() {
{"lfm2", load_typed_model<LFM2Configuration, LFM2Model>},
{"nemotron_h", load_typed_model<NemotronHConfiguration, NemotronHModel>},
{"granitemoehybrid", load_typed_model<GraniteMoeHybridConfiguration, GraniteMoeHybridModel>},
{"bitnet", load_typed_model<BitNetConfiguration, BitNetModel>},
};
return loaders;
}
Expand Down Expand Up @@ -234,6 +236,7 @@ ModelTypeRegistry& llm_type_registry() {
{"lfm2", create_model<LFM2Configuration, LFM2Model>},
{"nemotron_h", create_model<NemotronHConfiguration, NemotronHModel>},
{"granitemoehybrid", create_model<GraniteMoeHybridConfiguration, GraniteMoeHybridModel>},
{"bitnet", create_model<BitNetConfiguration, BitNetModel>},
});
return registry;
}
Expand Down
2 changes: 1 addition & 1 deletion src/llm/models/afmoe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ std::unordered_map<std::string, mx::array*> AfMoEModelInner::weight_map() {
// --- AfMoEModel ---

AfMoEModel::AfMoEModel(const AfMoEConfiguration& config)
: config_(config), model_(config)
: config_(config), model_(config_)
{
kv_heads_.resize(config.num_hidden_layers, config.num_key_value_heads);
for (const auto& lt : config.layer_types) {
Expand Down
2 changes: 1 addition & 1 deletion src/llm/models/apertus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ std::unordered_map<std::string, mx::array*> ApertusModelInner::weight_map() {
// --- ApertusModel ---

ApertusModel::ApertusModel(const ApertusConfiguration& config)
: config_(config), model_(config)
: config_(config), model_(config_)
{
kv_heads_.resize(config.num_hidden_layers, config.num_key_value_heads);
if (!config.tie_word_embeddings) {
Expand Down
2 changes: 1 addition & 1 deletion src/llm/models/baichuan_m1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ std::unordered_map<std::string, mx::array*> BaichuanM1ModelInner::weight_map() {
// --- BaichuanM1Model ---

BaichuanM1Model::BaichuanM1Model(const BaichuanM1Configuration& config)
: config_(config), model_(config)
: config_(config), model_(config_)
{
if (!config.tie_word_embeddings) {
lm_head_weight_ = mx::zeros({config.vocab_size, config.hidden_size});
Expand Down
2 changes: 1 addition & 1 deletion src/llm/models/bailing_moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ std::unordered_map<std::string, mx::array*> BailingMoeModelInner::weight_map() {
// --- BailingMoeModel ---

BailingMoeModel::BailingMoeModel(const BailingMoeConfiguration& config)
: config_(config), model_(config)
: config_(config), model_(config_)
{
kv_heads_.resize(config.num_hidden_layers, config.num_key_value_heads);
if (!config.tie_word_embeddings) {
Expand Down
Loading
Loading