diff --git a/projects/Baichuan/README.md b/projects/Baichuan/README.md new file mode 100644 index 000000000..f10d06988 --- /dev/null +++ b/projects/Baichuan/README.md @@ -0,0 +1,51 @@ +### Baichuan +#### 推理 + +- cuda PASS +```bash +python projects/Baichuan/pipeline.py --mode=huggingface --model_path=/root/models/Baichuan2-7B-Chat +``` + +- xpu PASS +```bash +python projects/Baichuan/pipeline.py --mode=huggingface --device=xpu --model_path=/root/models/Baichuan2-7B-Chat +``` + +#### 训练 +- cuda PASS +```bash +export NUM_GPUS=8 +python3 -m oneflow.distributed.launch \ + --nproc_per_node ${NUM_GPUS} \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 127.0.0.1 \ + --master_port 12345 \ + tools/train_net.py --config-file=projects/Baichuan/configs/baichuan_sft.py \ + graph.enabled=True \ + train.input_placement_device="cuda" \ + train.dist.device_type="cuda" \ + train.dist.pipeline_parallel_size=${NUM_GPUS} +``` + +``` +[09/19 14:39:40 lb.utils.events]: eta: 22:07:15 iteration: 87/18660 consumed_samples: 704 total_loss: 10.36 time: 4.2893 s/iter data_time: 0.0105 s/iter total_throughput: 1.87 samples/s lr: 6.99e-07 +[09/19 14:39:44 lb.utils.events]: eta: 22:07:07 iteration: 88/18660 consumed_samples: 712 total_loss: nan time: 4.2889 s/iter data_time: 0.0104 s/iter total_throughput: 1.87 samples/s lr: 7.07e-07 +NaN or Inf found in input tensor. +``` + +- xpu OOM after 7 iterations +```bash +export NUM_GPUS=1 +python3 -m oneflow.distributed.launch \ + --nproc_per_node ${NUM_GPUS} \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 127.0.0.1 \ + --master_port 12345 \ + tools/train_net.py --config-file=projects/Baichuan/configs/baichuan_sft.py \ + graph.enabled=False \ + train.input_placement_device="xpu" \ + train.dist.device_type="xpu" \ + train.dist.pipeline_parallel_size=${NUM_GPUS} +``` diff --git a/projects/Baichuan/baichuan.py b/projects/Baichuan/baichuan.py new file mode 100644 index 000000000..ad4f3e457 --- /dev/null +++ b/projects/Baichuan/baichuan.py @@ -0,0 +1,653 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple + +import oneflow as flow +from oneflow import nn + +from libai.config import configurable +from libai.inference.generator.generation_utils import Generator +from libai.layers import Linear, RMSLayerNorm, VocabEmbedding +from libai.layers.attention import AttnMaskType +from libai.models.utils import init_method_normal, scaled_init_method_normal +from libai.utils import distributed as dist + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return flow.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + + def forward(self, x, seq_len=None, cos_cached=None, sin_cached=None): + if seq_len > self.max_position_embeddings: + raise ValueError( + f"The maximum supported length is {self.max_position_embeddings}, " + f"and the current length is{seq_len}." + ) + + return ( + cos_cached[:seq_len].to_global(placement=x.placement), + sin_cached[:seq_len].to_global(placement=x.placement), + ) + + +class MLP(nn.Module): + def __init__( + self, + hidden_size, + intermediate_size, + init_method=nn.init.xavier_normal_, + output_layer_init_method=None, + *, + layer_idx=0, + ): + super().__init__() + + if output_layer_init_method is None: + output_layer_init_method = init_method + + self.gate_proj = Linear( + hidden_size, + intermediate_size, + bias=False, + parallel="col", + init_method=init_method, + layer_idx=layer_idx, + ) + + self.up_proj = Linear( + hidden_size, + intermediate_size, + bias=False, + parallel="col", + init_method=init_method, + layer_idx=layer_idx, + ) + + self.down_proj = Linear( + intermediate_size, + hidden_size, + bias=False, + parallel="row", + init_method=output_layer_init_method, + layer_idx=layer_idx, + ) + + self.activation_func = nn.SiLU() + + def forward(self, hidden_states): + gate_out = self.activation_func(self.gate_proj(hidden_states)) + up_out = self.up_proj(hidden_states) + output = self.down_proj(gate_out * up_out) + return output + + +class MultiheadAttention(nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + max_position_embeddings, + init_method=nn.init.xavier_normal_, + output_layer_init_method=None, + scale_mask_softmax_fusion=False, + attn_mask_type=AttnMaskType.padding, + *, + layer_idx=0, + ): + super().__init__() + self.hidden_size = hidden_size + if output_layer_init_method is None: + output_layer_init_method = init_method + + self.num_heads = num_attention_heads + self.head_size = hidden_size // num_attention_heads + self.attn_mask_type = attn_mask_type + + self.norm_factor = 1.0 / math.sqrt(float(self.head_size)) + + self.scale_mask_softmax_fusion = scale_mask_softmax_fusion + + self.query_key_value = Linear( + self.hidden_size, + self.hidden_size * 3, + bias=False, + parallel="col", + init_method=init_method, + layer_idx=layer_idx, + ) + + self.o_proj = Linear( + self.hidden_size, + self.hidden_size, + bias=False, + parallel="row", + init_method=output_layer_init_method, + layer_idx=layer_idx, + ) + + self.coeff = None + + rotary_dim = self.head_size + self.rotary_embed = RotaryEmbedding( + dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + ) + + def forward( + self, + hidden_states: flow.Tensor, + encoder_states: flow.Tensor = None, + attention_mask: flow.Tensor = None, + position_ids=None, + past_key_value: Tuple[flow.Tensor, flow.Tensor] = None, + cos_cached: flow.Tensor = None, + sin_cached: flow.Tensor = None, + use_cache: bool = False, + ): + if encoder_states is not None: + encoder_states = encoder_states.to_global(placement=hidden_states.placement) + + if attention_mask is not None: + attention_mask = attention_mask.to_global(placement=hidden_states.placement) + + bsz, tgt_len = hidden_states.size()[:2] + + query_key_value = self.query_key_value(hidden_states) + query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size) + query_key_value = query_key_value.permute( + 0, 2, 1, 3 + ) # [bsz, num_heads, src_len, 3 * head_size] + query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1) + + kv_seq_len = key.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_embed( + value, seq_len=kv_seq_len, cos_cached=cos_cached, sin_cached=sin_cached + ) + query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids) + + if past_key_value is not None: + past_key, past_value = past_key_value + key = flow.cat((past_key.type_as(key), key), dim=2) + value = flow.cat((past_value.type_as(value), value), dim=2) + + # query, key, value: [S(0), S(1)], shape: [bsz, num_heads, seq_length, head_size] + if use_cache: + past_key_value = (key, value) + + # [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)] + attention_scores = flow.matmul(query, key, transpose_b=True, alpha=self.norm_factor) + attention_weights = attention_scores + attention_mask + + attention_weights = flow.softmax(attention_weights, dim=-1) + # Context shape: [bsz, num_heads, tgt_len, head_size] with [S(0), S(1)] + context = flow.matmul(attention_weights, value) + + # Change shape: [bsz, num_heads, tgt_len, head_size] -> [bsz, tgt_len, num_heads, head_size] + context = context.transpose(1, 2) + output = self.o_proj(context.flatten(2)) + + if use_cache: + output = (output, past_key_value) + + return output + + +class CasualMask(nn.Module): + def __init__(self, max_positions=1024, dtype=flow.float16, *, layer_idx=0): + super().__init__() + self.dtype = dtype + self.mask = flow.full( + (max_positions, max_positions), + flow.finfo(dtype).min, + placement=dist.get_layer_placement(layer_idx), + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + ) + mask_cond = flow.arange( + self.mask.size(-1), + placement=dist.get_layer_placement(layer_idx), + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + ) + self.mask.masked_fill_(mask_cond < (mask_cond + 1).view(self.mask.size(-1), 1), 0) + self.mask = self.mask.to(dtype) + + def forward(self, input_ids, past_length=0, attention_mask=None, input_dtype=None): + bsz, tgt_len = input_ids.size() + casual_mask = self.mask[:tgt_len, :tgt_len] + if past_length > 0: + # in case past_key_values are used, we need to add a prefix ones mask to casual mask + casual_mask = flow.cat( + [flow.zeros(tgt_len, past_length, dtype=self.dtype), casual_mask], dim=-1 + ) + casual_mask = ( + casual_mask.unsqueeze(0).unsqueeze(1).expand(bsz, 1, tgt_len, tgt_len + past_length) + ) + casual_mask = casual_mask.to_global(sbp=input_ids.sbp) + if attention_mask is not None: + bsz, src_len = attention_mask.size() + attention_mask = ( + attention_mask[:, None, None, :] + .expand(bsz, 1, tgt_len, src_len) + .to(casual_mask.dtype) + ) + inverted_attention_mask = 1.0 - attention_mask + inverted_attention_mask.masked_fill( + inverted_attention_mask.to(flow.bool), flow.finfo(casual_mask.dtype).min + ) + inverted_attention_mask = inverted_attention_mask.to_global( + placement=casual_mask.placement + ) + casual_mask = casual_mask + inverted_attention_mask + if input_dtype is not None: + casual_mask = casual_mask.to(input_dtype) + return casual_mask + + +class DecoderLayer(nn.Module): + def __init__( + self, + hidden_size, + intermediate_size, + num_attention_heads, + is_decoder=False, + rms_norm_eps=1e-5, + max_position_embeddings=None, + init_method=nn.init.xavier_normal_, + output_layer_init_method=None, + scale_mask_softmax_fusion=False, + attn_mask_type=AttnMaskType.padding, + *, + layer_idx=0, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.rms_norm_eps = rms_norm_eps + self.max_position_embeddings = max_position_embeddings + self.attn_mask_type = attn_mask_type + + self.layer_idx = layer_idx + self.is_decoder = is_decoder + + self.scale_mask_softmax_fusion = scale_mask_softmax_fusion + + self.init_method = init_method + if output_layer_init_method is None: + output_layer_init_method = init_method + self.output_layer_init_method = output_layer_init_method + + self.input_layernorm = RMSLayerNorm( + self.hidden_size, eps=self.rms_norm_eps, layer_idx=self.layer_idx + ) + + self.self_attn = self.build_attention() + self.post_attention_layernorm = RMSLayerNorm( + self.hidden_size, eps=self.rms_norm_eps, layer_idx=self.layer_idx + ) + + self.mlp = MLP( + self.hidden_size, + self.intermediate_size, + self.init_method, + output_layer_init_method=self.output_layer_init_method, + layer_idx=self.layer_idx, + ) + + def forward( + self, + hidden_states, + attention_mask=None, + past_key_value=None, + cos_cached=None, + sin_cached=None, + use_cache=False, + ): + hidden_states = hidden_states.to_global(placement=dist.get_layer_placement(self.layer_idx)) + + # hidden_states shape: (batch_size, seq_length, hidden_size) + if attention_mask is not None: + attention_mask = attention_mask.to_global( + placement=dist.get_layer_placement(self.layer_idx) + ) + + if past_key_value is not None: + if self.is_decoder: + assert len(past_key_value) == 4 + self_attn_past_key_value = past_key_value[:2] + else: + self_attn_past_key_value = past_key_value + else: + self_attn_past_key_value = None + + layernorm_output = self.input_layernorm(hidden_states) + attention_output = self.self_attn( + layernorm_output, + attention_mask=attention_mask, + past_key_value=self_attn_past_key_value, + cos_cached=cos_cached, + sin_cached=sin_cached, + use_cache=use_cache, + ) + + if use_cache: + attention_output, presents = attention_output + + hidden_states = hidden_states + attention_output + + layernorm_output = self.post_attention_layernorm(hidden_states) + + mlp_output = self.mlp(layernorm_output) + + output = hidden_states + mlp_output + + if use_cache: + output = (output, presents) + return output + + def build_attention(self): + return MultiheadAttention( + self.hidden_size, + self.num_attention_heads, + self.max_position_embeddings, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + scale_mask_softmax_fusion=self.scale_mask_softmax_fusion, + attn_mask_type=self.attn_mask_type, + layer_idx=self.layer_idx, + ) + + +class BaichuanModel(nn.Module): + def __init__( + self, + hidden_layers, + vocab_size, + hidden_size, + intermediate_size, + num_attention_heads, + max_position_embeddings=1024, + rms_norm_eps=1e-5, + initializer_range=0.02, + use_scaled_init_for_output_weights=True, + scale_mask_softmax_fusion=False, + amp_enabled=False, + ): + super().__init__() + init_method = init_method_normal(sigma=initializer_range) + if use_scaled_init_for_output_weights: + output_layer_init_method = scaled_init_method_normal(initializer_range, hidden_layers) + else: + output_layer_init_method = init_method + + self.embed_tokens = VocabEmbedding( + vocab_size, hidden_size, init_method=init_method, amp_enabled=amp_enabled + ) + self.layers = nn.ModuleList( + [ + DecoderLayer( + hidden_size, + intermediate_size, + num_attention_heads, + rms_norm_eps=rms_norm_eps, + max_position_embeddings=max_position_embeddings, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, + attn_mask_type=AttnMaskType.causal, + layer_idx=i, + ) + for i in range(hidden_layers) + ] + ) + self.norm = RMSLayerNorm(hidden_size, eps=rms_norm_eps, layer_idx=-1) + + self._set_cos_sin_cache( + rotary_dim=hidden_size // num_attention_heads, + seq_len=max_position_embeddings, + dtype=flow.float32, + layer_idx=0, + ) + + def _set_cos_sin_cache(self, rotary_dim, seq_len, base=10000, dtype=None, layer_idx=0): + position = flow.arange( + 0, + rotary_dim, + 2, + dtype=dtype, + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + placement=dist.get_layer_placement(layer_idx), + ) + inv_freq = 1.0 / (base ** (position / rotary_dim)) + + t = flow.arange( + seq_len, + dtype=inv_freq.dtype, + sbp=inv_freq.sbp, + placement=inv_freq.placement, + ) + + freqs = flow.einsum("i,j->ij", t, inv_freq) + emb = flow.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype)) + self.register_buffer("sin_cached", emb.sin().to(dtype)) + + def forward( + self, + input_ids, + attention_mask=None, + past_key_values=None, + use_cache=False, + set_cache=None, + ): + if use_cache: + presents = [] + input_ids = input_ids.to_global(placement=dist.get_layer_placement(0)) + hidden_states = self.embed_tokens(input_ids) + + for layer, past_key_value in zip(self.layers, past_key_values): + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + cos_cached=self.cos_cached, + sin_cached=self.sin_cached, + use_cache=False, + ) + if use_cache: + hidden_states, present = hidden_states + presents.append(present) + + hidden_states = self.norm(hidden_states) + + if use_cache: + set_cache(presents) + + return hidden_states + + +class CrossEntropyLoss(nn.Module): + def forward(self, logits: flow.Tensor, target: flow.Tensor): + assert logits.ndim == 3 + assert target.ndim == 2 + assert logits.shape[0:2] == target.shape + + target = target.to_global(placement=logits.placement) + target = target * (target >= 0) + + lm_loss = flow._C.cross_entropy( + logits.view(-1, logits.shape[-1]), target.view(-1), ignore_index=0 + ) + return lm_loss + + +class SFTLoss(nn.Module): + def __init__(self) -> None: + super().__init__() + self.lm_loss = CrossEntropyLoss() + + def forward(self, logits, lm_labels): + lm_loss = self.lm_loss(logits, lm_labels) + lm_loss = lm_loss.mean() + return {"lm_loss": lm_loss} + + +class BaichuanForCausalLM(nn.Module, Generator): + @configurable + def __init__( + self, + hidden_layers, + vocab_size, + hidden_size, + intermediate_size, + num_attention_heads, + max_position_embeddings=1024, + rms_norm_eps=1e-5, + initializer_range=0.02, + use_scaled_init_for_output_weights=True, + scale_mask_softmax_fusion=False, + amp_enabled=False, + cfg=None, + ): + super().__init__() + self.cfg = cfg + self.model = BaichuanModel( + hidden_layers=hidden_layers, + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=rms_norm_eps, + initializer_range=initializer_range, + use_scaled_init_for_output_weights=use_scaled_init_for_output_weights, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, + amp_enabled=amp_enabled, + ) + self.casual_mask = CasualMask(max_position_embeddings, layer_idx=0) + self.lm_head = Linear(hidden_size, vocab_size, bias=False, layer_idx=-1) + self.loss_func = SFTLoss() + + self.past_key_values = [None] * hidden_layers + self.past_length = 0 + + def forward(self, input_ids, attention_mask=None, labels=None, use_cache=False): + input_ids = input_ids.to_global(placement=dist.get_layer_placement(0)) + attention_mask = ( + attention_mask.to_global(placement=dist.get_layer_placement(0)) + if attention_mask is not None + else attention_mask + ) + labels = ( + labels.to_global(placement=dist.get_layer_placement(0)) + if labels is not None + else labels + ) + + if use_cache and self.past_key_values[0] is not None: + self.past_length = self.past_key_values[0][0].size(-2) + else: + self.past_length = 0 + + mask = self.casual_mask( + input_ids, + past_length=self.past_length, + attention_mask=attention_mask, + input_dtype=self.lm_head.weight.dtype, + ) + + output = self.model( + input_ids, + attention_mask=mask, + past_key_values=self.past_key_values, + use_cache=use_cache, + set_cache=self.set_cache, + ) + + logits = self.lm_head(output) + + if labels is not None: + lm_loss = self.loss_func(logits, labels) + return lm_loss + else: + return {"logits": logits} + + def set_cache(self, past_key_values): + self.past_length = 0 if past_key_values is None else past_key_values[0][0].shape[2] + + if past_key_values is None: + past_key_values = [None] * self.cfg.hidden_layers + + assert len(past_key_values) == self.cfg.hidden_layers, ( + f"past_key_values's length {len(past_key_values)} doesn't match " + f"num_layers:' {self.cfg.hidden_layers}" + ) + + def prepare_inputs_for_generation(self, input_ids: flow.Tensor, **kwargs): + if "attention_mask" in kwargs: + attention_mask = kwargs.pop("attention_mask").float() + attention_mask = attention_mask - 1 + attention_mask.masked_fill_(attention_mask == -1, flow.finfo(flow.float32).min) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + @classmethod + def from_config(cls, cfg): + return { + "hidden_layers": cfg.hidden_layers, + "vocab_size": cfg.vocab_size, + "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, + "num_attention_heads": cfg.num_attention_heads, + "max_position_embeddings": cfg.max_position_embeddings, + "rms_norm_eps": cfg.rms_norm_eps, + "initializer_range": cfg.initializer_range, + "use_scaled_init_for_output_weights": cfg.use_scaled_init_for_output_weights, + "scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion, + "amp_enabled": cfg.amp_enabled, + "cfg": cfg, + } + + @staticmethod + def set_activation_checkpoint(model): + for module_block in model.modules(): + # Old API in OneFlow 0.8 + if hasattr(module_block, "origin"): + if isinstance(module_block.origin, DecoderLayer): + module_block.config.activation_checkpointing = True + else: + if isinstance(module_block.to(nn.Module), DecoderLayer): + module_block.to(nn.graph.GraphModule).activation_checkpointing = True diff --git a/projects/Baichuan/baichuan_dataset.py b/projects/Baichuan/baichuan_dataset.py new file mode 100644 index 000000000..d78efe9fe --- /dev/null +++ b/projects/Baichuan/baichuan_dataset.py @@ -0,0 +1,19 @@ +import oneflow as flow +from oneflow.utils.data import Dataset + +from libai.data.structures import DistTensorData, Instance + + +class AlpacaDataset(Dataset): + def __init__(self, path, tokenizer): + self.data = flow.load(path) + self.tokenizer = tokenizer + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return Instance( + input_ids=DistTensorData(self.data[index]["input_ids"]), + labels=DistTensorData(self.data[index]["labels"]), + ) diff --git a/projects/Baichuan/configs/baichuan_config.py b/projects/Baichuan/configs/baichuan_config.py new file mode 100644 index 000000000..d5b8a1401 --- /dev/null +++ b/projects/Baichuan/configs/baichuan_config.py @@ -0,0 +1,62 @@ +from omegaconf import DictConfig, OmegaConf + +from libai.config import LazyCall +from projects.Baichuan.baichuan import BaichuanForCausalLM +from projects.Baichuan.tokenizer import BaichuanTokenizer +from configs.common.train import train + + +cfg = dict( + # Model + model_type="baichuan", + hidden_act="silu", + hidden_size=4096, + initializer_range=0.02, + intermediate_size=11008, + max_position_embeddings=2048, + num_attention_heads=32, + hidden_layers=32, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + tie_word_embeddings=False, + vocab_size=32000, + use_scaled_init_for_output_weights=False, + scale_mask_softmax_fusion=False, + amp_enabled=True, + # Inference + is_encoder_decoder=False, + max_length=256, + min_length=0, + do_sample=False, + early_stopping=False, + num_beams=1, + num_beam_groups=1, + diversity_penalty=0.0, + temperature=0.9, + top_k=50, + top_p=0.6, + typical_p=1.0, + repetition_penalty=1.0, + length_penalty=1.0, + no_repeat_ngram_size=0, + encoder_no_repeat_ngram_size=0, + num_return_sequences=1, + chunk_size_feed_forward=0, + output_scores=False, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + pad_token_id=0, + # train + pretrained_model_path="/root/models/Baichuan2-7B-Chat", +) + +cfg = DictConfig(cfg) + +model = LazyCall(BaichuanForCausalLM)(cfg=cfg) +tokenization = OmegaConf.create() +tokenization.make_vocab_size_divisible_by = 1 +tokenization.tokenizer = LazyCall(BaichuanTokenizer)( + # pretrained_model_path=cfg.pretrained_model_path + "/tokenizer.model" +) diff --git a/projects/Baichuan/configs/baichuan_sft.py b/projects/Baichuan/configs/baichuan_sft.py new file mode 100644 index 000000000..738d5d199 --- /dev/null +++ b/projects/Baichuan/configs/baichuan_sft.py @@ -0,0 +1,100 @@ +import os +from omegaconf import OmegaConf + +from libai.config import LazyCall +from libai.evaluation import PPLEvaluator +from libai.scheduler import WarmupExponentialLR +from libai.data.build import build_nlp_test_loader, build_nlp_train_loader + +from configs.common.train import train +from configs.common.models.graph import graph +from configs.common.optim import optim + +from projects.Baichuan.configs.baichuan_config import cfg +from projects.Baichuan.baichuan_dataset import AlpacaDataset +from projects.Baichuan.tokenizer import BaichuanTokenizer +from projects.Baichuan.baichuan import BaichuanForCausalLM + + +# Hyperparameters +weight_decay = 0.1 +learning_rate = 5e-5 +dataset_path = "./alpaca_data" +pretrained_model_path = "/root/models/Llama-2-7b-chat-hf" + +# graph & optim +graph["enabled"] = False +optim.update( + dict( + lr=learning_rate, + weight_decay=weight_decay, + ) +) + +# tokenize +tokenization = OmegaConf.create() +tokenization.make_vocab_size_divisible_by = 1 +tokenization.tokenizer = LazyCall(BaichuanTokenizer)( + pretrained_model_path=os.path.join(pretrained_model_path, "tokenizer.model") +) + +# model +model = LazyCall(BaichuanForCausalLM)(cfg=cfg) + +# datasets +dataloader = OmegaConf.create() +dataloader.train = LazyCall(build_nlp_train_loader)( + dataset=[ + LazyCall(AlpacaDataset)( + path=os.path.join(dataset_path, "train"), tokenizer=tokenization.tokenizer + ) + ], +) +dataloader.test = [ + LazyCall(build_nlp_test_loader)( + dataset=LazyCall(AlpacaDataset)( + path=os.path.join(dataset_path, "test"), tokenizer=tokenization.tokenizer + ), + ), +] + + +train.update( + dict( + output_dir="./sft_result", + train_micro_batch_size=1, + test_micro_batch_size=1, + train_epoch=3, + train_iter=1, + log_period=1, + warmup_ratio=1 / 3, + num_accumulation_steps=8, + rdma_enabled=True, + amp=dict(enabled=False), + train_with_fp16=True, + activation_checkpoint=dict(enabled=True), + input_placement_device="cuda", + checkpointer=dict( + period=5000, + max_to_keep=20, + ), + dist=dict( + data_parallel_size=1, + tensor_parallel_size=1, + pipeline_parallel_size=1, + pipeline_num_layers=cfg.hidden_layers, + device_type="cuda", + ), + evaluation=dict( + enabled=True, + evaluator=LazyCall(PPLEvaluator)(), + eval_period=1000, + eval_iter=1e5, + ), + scheduler=LazyCall(WarmupExponentialLR)( + warmup_factor=0.0, + gamma=1.0, + warmup_method="linear", + ), + ) +) diff --git a/projects/Baichuan/pipeline.py b/projects/Baichuan/pipeline.py new file mode 100644 index 000000000..00d22ea11 --- /dev/null +++ b/projects/Baichuan/pipeline.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import click + +from libai.inference.basic import BasePipeline +from libai.utils import distributed as dist + + +class TextGenerationPipeline(BasePipeline): + def load_pretrain_weight(self, libai_cfg_model, model_path, mode="huggingface"): + """load pretrained model. + + Args: + libai_cfg_model (libai.models): Lazy config Model in Libai, you can import it + by `from libai.config.configs.common.models.bert + import pretrain_model as libai_cfg_model` + model_path (str): The directory path of pretrained model, + """ + if mode == "huggingface": + from projects.Baichuan.utils.baichuan_loader import BaichuanLoaderHuggerFace + + model_loader = BaichuanLoaderHuggerFace( + libai_cfg_model, + libai_cfg_model.cfg, + model_path, + ) + model = model_loader.load() + model.eval() + return model + + elif mode == "libai": + from projects.Baichuan.utils.baichuan_loader import BaichuanLoaderLiBai + + model_loader = BaichuanLoaderLiBai( + libai_cfg_model, + libai_cfg_model.cfg, + model_path, + ) + model = model_loader.load() + model.eval() + return model + + elif mode == "random": + from libai.engine import DefaultTrainer + + return DefaultTrainer.build_model(self.cfg) + else: + raise NotImplementedError + + def _parse_parameters(self, **pipeline_parameters): + preprocess_params = {} + forward_params = {**pipeline_parameters} + postprocess_params = {} + + return preprocess_params, forward_params, postprocess_params + + def preprocess(self, inputs, **kwargs) -> dict: + # tokenizer encoderW + inputs = self.tokenizer.tokenize(inputs, add_bos=True, padding=True, device=self.device) + inputs = { + "input_ids": inputs, + } + + return inputs + + def forward(self, inputs, **kwargs) -> dict: + outputs = self.model.generate(inputs["input_ids"], max_length=50, **kwargs) + return {"return_ids": outputs} + + def postprocess(self, model_output_dict, **kwargs) -> dict: + return_ids = model_output_dict["return_ids"] + records = [ + {"generated_text": self.tokenizer.decode(return_ids[i])} + for i in range(return_ids.size(0)) + ] + return records + + +@click.command() +@click.option( + "--config_file", + default="projects/Baichuan/configs/baichuan_config.py", + help="Path to the configuration file.", +) +@click.option("--model_path", default=None, help="Path to the model checkpoint.") +@click.option( + "--mode", + default="libai", + help="Mode for the dataloader pipeline, e.g., 'libai' or 'huggingface'.", +) +@click.option( + "--device", default="cuda", help="Device to run the model on, e.g., 'cuda', 'xpu', 'npu'." +) +def main(config_file, model_path, mode, device): + pipeline = TextGenerationPipeline( + config_file, + data_parallel=1, + tensor_parallel=1, + pipeline_parallel=1, + pipeline_num_layers=32, + model_path=model_path, + mode=mode, + device=device, + ) + + text = [ + "Give three tips for staying healthy.", + ] + output = pipeline(inputs=text) + if dist.is_main_process(): + print(output) + + +if __name__ == "__main__": + main() diff --git a/projects/Baichuan/tokenizer.py b/projects/Baichuan/tokenizer.py new file mode 100644 index 000000000..61d52ac8a --- /dev/null +++ b/projects/Baichuan/tokenizer.py @@ -0,0 +1,95 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import oneflow as flow +import sentencepiece as spm + +import libai.utils.distributed as dist + + +class BaichuanTokenizer: + def __init__( + self, + pretrained_model_path, + bos_token="", + eos_token="", + pad_token="", + bos_token_id=None, + eos_token_id=None, + ): + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(pretrained_model_path) + + self.bos_token = bos_token + self.eos_token = eos_token + self.pad_token = pad_token + self.bos_token_id = self.sp_model.bos_id() if self.sp_model.bos_id() else bos_token_id + self.eos_token_id = self.sp_model.eos_id() if self.sp_model.eos_id() else eos_token_id + self.pad_token_id = 0 + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + return vocab + + def encode(self, text): + tokens = self.sp_model.encode(text) + return tokens + + def tokenize( + self, + text, + add_bos=False, + add_eos=False, + padding=False, + device="cuda", + max_length=4096, + **kwargs + ): + if isinstance(text, str): + tokens = [self.sp_model.encode(text)[:max_length]] + + if isinstance(text, list): + tokens = [self.sp_model.encode(s)[:max_length] for s in text] + if padding: + max_length = max([len(i) for i in tokens]) + tokens = [t + (max_length - len(t)) * [self.pad_token_id] for t in tokens] + + if add_bos: + tokens = [[self.bos_token_id] + token for token in tokens] + if add_eos: + tokens = [token + [self.eos_token_id] for token in tokens] + + if device: + sbp = kwargs.get("sbp", dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])) + placement = kwargs.get("placement", flow.placement(device, [0])) + return_token_ids = flow.tensor(tokens, sbp=sbp, placement=placement, dtype=flow.long) + else: + return_token_ids = flow.tensor(tokens, dtype=flow.long) + return return_token_ids + + def decode(self, tokens): + if isinstance(tokens, flow.Tensor): + tokens = tokens.tolist() + return self.sp_model.decode(tokens) + + def convert_token_to_id(self, token): + return self.sp_model.piece_to_id(token) + + def convert_id_to_token(self, index): + return self.sp_model.IdToPiece(index) diff --git a/projects/Baichuan/utils/baichuan_loader.py b/projects/Baichuan/utils/baichuan_loader.py new file mode 100644 index 000000000..096247373 --- /dev/null +++ b/projects/Baichuan/utils/baichuan_loader.py @@ -0,0 +1,98 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from libai.models.utils.model_loader.base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai + + +class BaichuanLoaderHuggerFace(ModelLoaderHuggerFace): + def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs): + super().__init__(model, libai_cfg, pretrained_model_path, **kwargs) + + self.base_model_prefix_1 = "model" + self.base_model_prefix_2 = "model" + + def _convert_state_dict(self, flow_state_dict, cfg): + """Convert state_dict's keys to match model. + + Args: + flow_state_dict (OrderedDict): model state dict. + cfg (dict): model's default config dict in LiBai. + + Returns: + OrderedDict: flow state dict. + """ + # The converted checkpoint. + oneflow_state_dict = flow_state_dict.copy() + old_keys = list(oneflow_state_dict.keys()) + + # Get configs + num_attention_heads = cfg.get("num_attention_heads") + hidden_size = cfg.get("hidden_size") + head_size = int(hidden_size // num_attention_heads) + + new_key_qkv = "model.layers.{}.self_attn.query_key_value.weight" + old_key_qkv = "model.layers.{}.self_attn.{}.weight" + for layer_idx in range(cfg.get("hidden_layers")): + w_pack = old_key_qkv.format(layer_idx, "W_pack") + + qkv = oneflow_state_dict[w_pack] + qkv = self._fix_qkv_ordering(qkv, head_size, num_attention_heads, hidden_size) + oneflow_state_dict[new_key_qkv.format(layer_idx)] = qkv + oneflow_state_dict.pop(w_pack) + + for k in old_keys: + if "inv_freq" in k: + oneflow_state_dict.pop(k) + + return oneflow_state_dict + + def _load_config_from_json(self, config_file): + """load config from `config.json`, and update default config. + + Args: + config_file (str): Path of config file. + """ + with open(config_file, mode="r", encoding="utf-8") as f: + cfg_dict = json.load(f) + + # update libai_cfg by config.json + self._update_cfg("hidden_layers", cfg_dict["num_hidden_layers"]) + self._update_cfg("hidden_size", cfg_dict["hidden_size"]) + self._update_cfg("num_attention_heads", cfg_dict["num_attention_heads"]) + self._update_cfg("max_position_embeddings", cfg_dict["max_position_embeddings"]) + self._update_cfg("intermediate_size", cfg_dict["intermediate_size"]) + self._update_cfg("rms_norm_eps", cfg_dict["rms_norm_eps"]) + self._update_cfg("vocab_size", cfg_dict["vocab_size"]) + self._update_cfg("initializer_range", cfg_dict["initializer_range"]) + self._update_cfg( + "ffn_hidden_size", + cfg_dict.get("n_inner") + if cfg_dict.get("n_inner") is not None + else 4 * self.libai_cfg["hidden_size"], + ) + + # update libai_cfg by kwargs + for k, v in self.kwargs.items(): + self._update_cfg(k, v) + + self._update_cfg_log() + + +class BaichuanLoaderLiBai(ModelLoaderLiBai): + def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs): + super().__init__(model, libai_cfg, pretrained_model_path, **kwargs) + self.base_model_prefix_2 = "model" diff --git a/projects/Baichuan/utils/data_prepare.py b/projects/Baichuan/utils/data_prepare.py new file mode 100644 index 000000000..378c73467 --- /dev/null +++ b/projects/Baichuan/utils/data_prepare.py @@ -0,0 +1,160 @@ +import copy +import json +import math +import os +from pathlib import Path +from typing import Optional + +import oneflow as flow +import requests +from oneflow.utils.data import random_split +from tqdm import tqdm + +from libai.config import instantiate +from libai.utils.logger import setup_logger +from projects.Baichuan.configs.baichuan_config import tokenization + +logger = setup_logger() + + +def prepare( + destination_path: Path = Path("./data/libai_xpu_alpaca"), + checkpoint_dir: Path = Path("/root/models/Baichuan2-7B-Chat"), + test_split_fraction: float = 0.03865, # to get exactly 2000 test samples, + seed: int = 42, + mask_inputs: bool = False, # as in alpaca-lora + data_file_name: str = "alpaca_data_cleaned_archive.json", + data_file_url: str = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json", # noqa + ignore_index: int = -100, + max_seq_length: Optional[int] = 512, +) -> None: + """Prepare the Alpaca dataset for instruction tuning. + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + if max_seq_length is None: + with open(os.path.join(checkpoint_dir, "config.json"), "r", encoding="utf-8") as file: + config = json.load(file) + max_seq_length = config["max_position_embeddings"] + + destination_path.mkdir(parents=True, exist_ok=True) + data_file_path = destination_path / data_file_name + logger.info("Loading data file...") + download_if_missing(data_file_path, data_file_url) + with open(data_file_path, "r", encoding="utf-8") as file: + data = json.load(file) + + logger.info("Loading tokenizer...") + tokenizer = instantiate(tokenization.tokenizer) + + # Partition the dataset into train and test + num_of_test_samples = math.floor(test_split_fraction * len(data)) + num_of_train_samples = len(data) - num_of_test_samples + train_set, test_set = random_split( + data, + [num_of_train_samples, num_of_test_samples], + generator=flow.Generator().manual_seed(seed), + ) + train_set, test_set = list(train_set), list(test_set) + + logger.info(f"train has {len(train_set):,} samples") + logger.info(f"test has {len(test_set):,} samples") + + logger.info("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + ) + for sample in tqdm(train_set) + ] + flow.save(train_set, destination_path / "train") + + logger.info("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + ) + for sample in tqdm(test_set) + ] + flow.save(test_set, destination_path / "test") + + max_length = max([i["input_ids"].shape[0] for i in train_set]) + logger.info("Max length of training dataset: {}".format(max_length)) + + +def download_if_missing(file_path: Path, file_url: str) -> None: + """Downloads the raw json data file and saves it in the given destination.""" + if file_path.exists() and file_path.stat().st_size > 0: + return + with open(file_path, "w", encoding="utf-8") as f: + f.write(requests.get(file_url).text) + + +def prepare_sample(example: dict, tokenizer, max_length: int) -> dict: + """Processes a single sample. + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + + prompt = tokenizer.tokenize(full_prompt, add_bos=True, add_eos=False, device="cpu")[0] + example = tokenizer.tokenize( + full_prompt_and_response, add_bos=True, add_eos=True, device="cpu" + )[0] + + padding = max_length - example.shape[0] + if padding > 0: + example = flow.cat((example, flow.zeros(padding, dtype=flow.long) - 1)) + elif padding < 0: + example = example[:max_length] + labels = copy.deepcopy(example) + labels[: len(prompt)] = -1 + example_mask = example.ge(0) + label_mask = labels.ge(0) + example[~example_mask] = 0 + labels[~label_mask] = -1 + example = example[:-1] + labels = labels[1:] + example_mask = flow.where( + example_mask, flow.tensor(0, dtype=flow.float), flow.tensor(-float("inf")) + ) + example_mask = example_mask[:-1] + return { + "input_ids": example, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + + if example["input"]: + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " # noqa + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" # noqa + ) + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Response:" + ) + + +if __name__ == "__main__": + prepare()