diff --git a/projects/Baichuan/README.md b/projects/Baichuan/README.md
new file mode 100644
index 000000000..f10d06988
--- /dev/null
+++ b/projects/Baichuan/README.md
@@ -0,0 +1,51 @@
+### Baichuan
+#### 推理
+
+- cuda PASS
+```bash
+python projects/Baichuan/pipeline.py --mode=huggingface --model_path=/root/models/Baichuan2-7B-Chat
+```
+
+- xpu PASS
+```bash
+python projects/Baichuan/pipeline.py --mode=huggingface --device=xpu --model_path=/root/models/Baichuan2-7B-Chat
+```
+
+#### 训练
+- cuda PASS
+```bash
+export NUM_GPUS=8
+python3 -m oneflow.distributed.launch \
+ --nproc_per_node ${NUM_GPUS} \
+ --nnodes 1 \
+ --node_rank 0 \
+ --master_addr 127.0.0.1 \
+ --master_port 12345 \
+ tools/train_net.py --config-file=projects/Baichuan/configs/baichuan_sft.py \
+ graph.enabled=True \
+ train.input_placement_device="cuda" \
+ train.dist.device_type="cuda" \
+ train.dist.pipeline_parallel_size=${NUM_GPUS}
+```
+
+```
+[09/19 14:39:40 lb.utils.events]: eta: 22:07:15 iteration: 87/18660 consumed_samples: 704 total_loss: 10.36 time: 4.2893 s/iter data_time: 0.0105 s/iter total_throughput: 1.87 samples/s lr: 6.99e-07
+[09/19 14:39:44 lb.utils.events]: eta: 22:07:07 iteration: 88/18660 consumed_samples: 712 total_loss: nan time: 4.2889 s/iter data_time: 0.0104 s/iter total_throughput: 1.87 samples/s lr: 7.07e-07
+NaN or Inf found in input tensor.
+```
+
+- xpu OOM after 7 iterations
+```bash
+export NUM_GPUS=1
+python3 -m oneflow.distributed.launch \
+ --nproc_per_node ${NUM_GPUS} \
+ --nnodes 1 \
+ --node_rank 0 \
+ --master_addr 127.0.0.1 \
+ --master_port 12345 \
+ tools/train_net.py --config-file=projects/Baichuan/configs/baichuan_sft.py \
+ graph.enabled=False \
+ train.input_placement_device="xpu" \
+ train.dist.device_type="xpu" \
+ train.dist.pipeline_parallel_size=${NUM_GPUS}
+```
diff --git a/projects/Baichuan/baichuan.py b/projects/Baichuan/baichuan.py
new file mode 100644
index 000000000..ad4f3e457
--- /dev/null
+++ b/projects/Baichuan/baichuan.py
@@ -0,0 +1,653 @@
+# coding=utf-8
+# Copyright 2021 The OneFlow Authors. All rights reserved.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Tuple
+
+import oneflow as flow
+from oneflow import nn
+
+from libai.config import configurable
+from libai.inference.generator.generation_utils import Generator
+from libai.layers import Linear, RMSLayerNorm, VocabEmbedding
+from libai.layers.attention import AttnMaskType
+from libai.models.utils import init_method_normal, scaled_init_method_normal
+from libai.utils import distributed as dist
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return flow.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ cos = cos[position_ids].unsqueeze(1)
+ sin = sin[position_ids].unsqueeze(1)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+
+ def forward(self, x, seq_len=None, cos_cached=None, sin_cached=None):
+ if seq_len > self.max_position_embeddings:
+ raise ValueError(
+ f"The maximum supported length is {self.max_position_embeddings}, "
+ f"and the current length is{seq_len}."
+ )
+
+ return (
+ cos_cached[:seq_len].to_global(placement=x.placement),
+ sin_cached[:seq_len].to_global(placement=x.placement),
+ )
+
+
+class MLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ intermediate_size,
+ init_method=nn.init.xavier_normal_,
+ output_layer_init_method=None,
+ *,
+ layer_idx=0,
+ ):
+ super().__init__()
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ self.gate_proj = Linear(
+ hidden_size,
+ intermediate_size,
+ bias=False,
+ parallel="col",
+ init_method=init_method,
+ layer_idx=layer_idx,
+ )
+
+ self.up_proj = Linear(
+ hidden_size,
+ intermediate_size,
+ bias=False,
+ parallel="col",
+ init_method=init_method,
+ layer_idx=layer_idx,
+ )
+
+ self.down_proj = Linear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ parallel="row",
+ init_method=output_layer_init_method,
+ layer_idx=layer_idx,
+ )
+
+ self.activation_func = nn.SiLU()
+
+ def forward(self, hidden_states):
+ gate_out = self.activation_func(self.gate_proj(hidden_states))
+ up_out = self.up_proj(hidden_states)
+ output = self.down_proj(gate_out * up_out)
+ return output
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_attention_heads,
+ max_position_embeddings,
+ init_method=nn.init.xavier_normal_,
+ output_layer_init_method=None,
+ scale_mask_softmax_fusion=False,
+ attn_mask_type=AttnMaskType.padding,
+ *,
+ layer_idx=0,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ self.num_heads = num_attention_heads
+ self.head_size = hidden_size // num_attention_heads
+ self.attn_mask_type = attn_mask_type
+
+ self.norm_factor = 1.0 / math.sqrt(float(self.head_size))
+
+ self.scale_mask_softmax_fusion = scale_mask_softmax_fusion
+
+ self.query_key_value = Linear(
+ self.hidden_size,
+ self.hidden_size * 3,
+ bias=False,
+ parallel="col",
+ init_method=init_method,
+ layer_idx=layer_idx,
+ )
+
+ self.o_proj = Linear(
+ self.hidden_size,
+ self.hidden_size,
+ bias=False,
+ parallel="row",
+ init_method=output_layer_init_method,
+ layer_idx=layer_idx,
+ )
+
+ self.coeff = None
+
+ rotary_dim = self.head_size
+ self.rotary_embed = RotaryEmbedding(
+ dim=rotary_dim,
+ max_position_embeddings=max_position_embeddings,
+ )
+
+ def forward(
+ self,
+ hidden_states: flow.Tensor,
+ encoder_states: flow.Tensor = None,
+ attention_mask: flow.Tensor = None,
+ position_ids=None,
+ past_key_value: Tuple[flow.Tensor, flow.Tensor] = None,
+ cos_cached: flow.Tensor = None,
+ sin_cached: flow.Tensor = None,
+ use_cache: bool = False,
+ ):
+ if encoder_states is not None:
+ encoder_states = encoder_states.to_global(placement=hidden_states.placement)
+
+ if attention_mask is not None:
+ attention_mask = attention_mask.to_global(placement=hidden_states.placement)
+
+ bsz, tgt_len = hidden_states.size()[:2]
+
+ query_key_value = self.query_key_value(hidden_states)
+ query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size)
+ query_key_value = query_key_value.permute(
+ 0, 2, 1, 3
+ ) # [bsz, num_heads, src_len, 3 * head_size]
+ query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1)
+
+ kv_seq_len = key.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_embed(
+ value, seq_len=kv_seq_len, cos_cached=cos_cached, sin_cached=sin_cached
+ )
+ query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ past_key, past_value = past_key_value
+ key = flow.cat((past_key.type_as(key), key), dim=2)
+ value = flow.cat((past_value.type_as(value), value), dim=2)
+
+ # query, key, value: [S(0), S(1)], shape: [bsz, num_heads, seq_length, head_size]
+ if use_cache:
+ past_key_value = (key, value)
+
+ # [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)]
+ attention_scores = flow.matmul(query, key, transpose_b=True, alpha=self.norm_factor)
+ attention_weights = attention_scores + attention_mask
+
+ attention_weights = flow.softmax(attention_weights, dim=-1)
+ # Context shape: [bsz, num_heads, tgt_len, head_size] with [S(0), S(1)]
+ context = flow.matmul(attention_weights, value)
+
+ # Change shape: [bsz, num_heads, tgt_len, head_size] -> [bsz, tgt_len, num_heads, head_size]
+ context = context.transpose(1, 2)
+ output = self.o_proj(context.flatten(2))
+
+ if use_cache:
+ output = (output, past_key_value)
+
+ return output
+
+
+class CasualMask(nn.Module):
+ def __init__(self, max_positions=1024, dtype=flow.float16, *, layer_idx=0):
+ super().__init__()
+ self.dtype = dtype
+ self.mask = flow.full(
+ (max_positions, max_positions),
+ flow.finfo(dtype).min,
+ placement=dist.get_layer_placement(layer_idx),
+ sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
+ )
+ mask_cond = flow.arange(
+ self.mask.size(-1),
+ placement=dist.get_layer_placement(layer_idx),
+ sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
+ )
+ self.mask.masked_fill_(mask_cond < (mask_cond + 1).view(self.mask.size(-1), 1), 0)
+ self.mask = self.mask.to(dtype)
+
+ def forward(self, input_ids, past_length=0, attention_mask=None, input_dtype=None):
+ bsz, tgt_len = input_ids.size()
+ casual_mask = self.mask[:tgt_len, :tgt_len]
+ if past_length > 0:
+ # in case past_key_values are used, we need to add a prefix ones mask to casual mask
+ casual_mask = flow.cat(
+ [flow.zeros(tgt_len, past_length, dtype=self.dtype), casual_mask], dim=-1
+ )
+ casual_mask = (
+ casual_mask.unsqueeze(0).unsqueeze(1).expand(bsz, 1, tgt_len, tgt_len + past_length)
+ )
+ casual_mask = casual_mask.to_global(sbp=input_ids.sbp)
+ if attention_mask is not None:
+ bsz, src_len = attention_mask.size()
+ attention_mask = (
+ attention_mask[:, None, None, :]
+ .expand(bsz, 1, tgt_len, src_len)
+ .to(casual_mask.dtype)
+ )
+ inverted_attention_mask = 1.0 - attention_mask
+ inverted_attention_mask.masked_fill(
+ inverted_attention_mask.to(flow.bool), flow.finfo(casual_mask.dtype).min
+ )
+ inverted_attention_mask = inverted_attention_mask.to_global(
+ placement=casual_mask.placement
+ )
+ casual_mask = casual_mask + inverted_attention_mask
+ if input_dtype is not None:
+ casual_mask = casual_mask.to(input_dtype)
+ return casual_mask
+
+
+class DecoderLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ intermediate_size,
+ num_attention_heads,
+ is_decoder=False,
+ rms_norm_eps=1e-5,
+ max_position_embeddings=None,
+ init_method=nn.init.xavier_normal_,
+ output_layer_init_method=None,
+ scale_mask_softmax_fusion=False,
+ attn_mask_type=AttnMaskType.padding,
+ *,
+ layer_idx=0,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.rms_norm_eps = rms_norm_eps
+ self.max_position_embeddings = max_position_embeddings
+ self.attn_mask_type = attn_mask_type
+
+ self.layer_idx = layer_idx
+ self.is_decoder = is_decoder
+
+ self.scale_mask_softmax_fusion = scale_mask_softmax_fusion
+
+ self.init_method = init_method
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+ self.output_layer_init_method = output_layer_init_method
+
+ self.input_layernorm = RMSLayerNorm(
+ self.hidden_size, eps=self.rms_norm_eps, layer_idx=self.layer_idx
+ )
+
+ self.self_attn = self.build_attention()
+ self.post_attention_layernorm = RMSLayerNorm(
+ self.hidden_size, eps=self.rms_norm_eps, layer_idx=self.layer_idx
+ )
+
+ self.mlp = MLP(
+ self.hidden_size,
+ self.intermediate_size,
+ self.init_method,
+ output_layer_init_method=self.output_layer_init_method,
+ layer_idx=self.layer_idx,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ past_key_value=None,
+ cos_cached=None,
+ sin_cached=None,
+ use_cache=False,
+ ):
+ hidden_states = hidden_states.to_global(placement=dist.get_layer_placement(self.layer_idx))
+
+ # hidden_states shape: (batch_size, seq_length, hidden_size)
+ if attention_mask is not None:
+ attention_mask = attention_mask.to_global(
+ placement=dist.get_layer_placement(self.layer_idx)
+ )
+
+ if past_key_value is not None:
+ if self.is_decoder:
+ assert len(past_key_value) == 4
+ self_attn_past_key_value = past_key_value[:2]
+ else:
+ self_attn_past_key_value = past_key_value
+ else:
+ self_attn_past_key_value = None
+
+ layernorm_output = self.input_layernorm(hidden_states)
+ attention_output = self.self_attn(
+ layernorm_output,
+ attention_mask=attention_mask,
+ past_key_value=self_attn_past_key_value,
+ cos_cached=cos_cached,
+ sin_cached=sin_cached,
+ use_cache=use_cache,
+ )
+
+ if use_cache:
+ attention_output, presents = attention_output
+
+ hidden_states = hidden_states + attention_output
+
+ layernorm_output = self.post_attention_layernorm(hidden_states)
+
+ mlp_output = self.mlp(layernorm_output)
+
+ output = hidden_states + mlp_output
+
+ if use_cache:
+ output = (output, presents)
+ return output
+
+ def build_attention(self):
+ return MultiheadAttention(
+ self.hidden_size,
+ self.num_attention_heads,
+ self.max_position_embeddings,
+ init_method=self.init_method,
+ output_layer_init_method=self.output_layer_init_method,
+ scale_mask_softmax_fusion=self.scale_mask_softmax_fusion,
+ attn_mask_type=self.attn_mask_type,
+ layer_idx=self.layer_idx,
+ )
+
+
+class BaichuanModel(nn.Module):
+ def __init__(
+ self,
+ hidden_layers,
+ vocab_size,
+ hidden_size,
+ intermediate_size,
+ num_attention_heads,
+ max_position_embeddings=1024,
+ rms_norm_eps=1e-5,
+ initializer_range=0.02,
+ use_scaled_init_for_output_weights=True,
+ scale_mask_softmax_fusion=False,
+ amp_enabled=False,
+ ):
+ super().__init__()
+ init_method = init_method_normal(sigma=initializer_range)
+ if use_scaled_init_for_output_weights:
+ output_layer_init_method = scaled_init_method_normal(initializer_range, hidden_layers)
+ else:
+ output_layer_init_method = init_method
+
+ self.embed_tokens = VocabEmbedding(
+ vocab_size, hidden_size, init_method=init_method, amp_enabled=amp_enabled
+ )
+ self.layers = nn.ModuleList(
+ [
+ DecoderLayer(
+ hidden_size,
+ intermediate_size,
+ num_attention_heads,
+ rms_norm_eps=rms_norm_eps,
+ max_position_embeddings=max_position_embeddings,
+ init_method=init_method,
+ output_layer_init_method=output_layer_init_method,
+ scale_mask_softmax_fusion=scale_mask_softmax_fusion,
+ attn_mask_type=AttnMaskType.causal,
+ layer_idx=i,
+ )
+ for i in range(hidden_layers)
+ ]
+ )
+ self.norm = RMSLayerNorm(hidden_size, eps=rms_norm_eps, layer_idx=-1)
+
+ self._set_cos_sin_cache(
+ rotary_dim=hidden_size // num_attention_heads,
+ seq_len=max_position_embeddings,
+ dtype=flow.float32,
+ layer_idx=0,
+ )
+
+ def _set_cos_sin_cache(self, rotary_dim, seq_len, base=10000, dtype=None, layer_idx=0):
+ position = flow.arange(
+ 0,
+ rotary_dim,
+ 2,
+ dtype=dtype,
+ sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
+ placement=dist.get_layer_placement(layer_idx),
+ )
+ inv_freq = 1.0 / (base ** (position / rotary_dim))
+
+ t = flow.arange(
+ seq_len,
+ dtype=inv_freq.dtype,
+ sbp=inv_freq.sbp,
+ placement=inv_freq.placement,
+ )
+
+ freqs = flow.einsum("i,j->ij", t, inv_freq)
+ emb = flow.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype))
+ self.register_buffer("sin_cached", emb.sin().to(dtype))
+
+ def forward(
+ self,
+ input_ids,
+ attention_mask=None,
+ past_key_values=None,
+ use_cache=False,
+ set_cache=None,
+ ):
+ if use_cache:
+ presents = []
+ input_ids = input_ids.to_global(placement=dist.get_layer_placement(0))
+ hidden_states = self.embed_tokens(input_ids)
+
+ for layer, past_key_value in zip(self.layers, past_key_values):
+ hidden_states = layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ past_key_value=past_key_value,
+ cos_cached=self.cos_cached,
+ sin_cached=self.sin_cached,
+ use_cache=False,
+ )
+ if use_cache:
+ hidden_states, present = hidden_states
+ presents.append(present)
+
+ hidden_states = self.norm(hidden_states)
+
+ if use_cache:
+ set_cache(presents)
+
+ return hidden_states
+
+
+class CrossEntropyLoss(nn.Module):
+ def forward(self, logits: flow.Tensor, target: flow.Tensor):
+ assert logits.ndim == 3
+ assert target.ndim == 2
+ assert logits.shape[0:2] == target.shape
+
+ target = target.to_global(placement=logits.placement)
+ target = target * (target >= 0)
+
+ lm_loss = flow._C.cross_entropy(
+ logits.view(-1, logits.shape[-1]), target.view(-1), ignore_index=0
+ )
+ return lm_loss
+
+
+class SFTLoss(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.lm_loss = CrossEntropyLoss()
+
+ def forward(self, logits, lm_labels):
+ lm_loss = self.lm_loss(logits, lm_labels)
+ lm_loss = lm_loss.mean()
+ return {"lm_loss": lm_loss}
+
+
+class BaichuanForCausalLM(nn.Module, Generator):
+ @configurable
+ def __init__(
+ self,
+ hidden_layers,
+ vocab_size,
+ hidden_size,
+ intermediate_size,
+ num_attention_heads,
+ max_position_embeddings=1024,
+ rms_norm_eps=1e-5,
+ initializer_range=0.02,
+ use_scaled_init_for_output_weights=True,
+ scale_mask_softmax_fusion=False,
+ amp_enabled=False,
+ cfg=None,
+ ):
+ super().__init__()
+ self.cfg = cfg
+ self.model = BaichuanModel(
+ hidden_layers=hidden_layers,
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ max_position_embeddings=max_position_embeddings,
+ rms_norm_eps=rms_norm_eps,
+ initializer_range=initializer_range,
+ use_scaled_init_for_output_weights=use_scaled_init_for_output_weights,
+ scale_mask_softmax_fusion=scale_mask_softmax_fusion,
+ amp_enabled=amp_enabled,
+ )
+ self.casual_mask = CasualMask(max_position_embeddings, layer_idx=0)
+ self.lm_head = Linear(hidden_size, vocab_size, bias=False, layer_idx=-1)
+ self.loss_func = SFTLoss()
+
+ self.past_key_values = [None] * hidden_layers
+ self.past_length = 0
+
+ def forward(self, input_ids, attention_mask=None, labels=None, use_cache=False):
+ input_ids = input_ids.to_global(placement=dist.get_layer_placement(0))
+ attention_mask = (
+ attention_mask.to_global(placement=dist.get_layer_placement(0))
+ if attention_mask is not None
+ else attention_mask
+ )
+ labels = (
+ labels.to_global(placement=dist.get_layer_placement(0))
+ if labels is not None
+ else labels
+ )
+
+ if use_cache and self.past_key_values[0] is not None:
+ self.past_length = self.past_key_values[0][0].size(-2)
+ else:
+ self.past_length = 0
+
+ mask = self.casual_mask(
+ input_ids,
+ past_length=self.past_length,
+ attention_mask=attention_mask,
+ input_dtype=self.lm_head.weight.dtype,
+ )
+
+ output = self.model(
+ input_ids,
+ attention_mask=mask,
+ past_key_values=self.past_key_values,
+ use_cache=use_cache,
+ set_cache=self.set_cache,
+ )
+
+ logits = self.lm_head(output)
+
+ if labels is not None:
+ lm_loss = self.loss_func(logits, labels)
+ return lm_loss
+ else:
+ return {"logits": logits}
+
+ def set_cache(self, past_key_values):
+ self.past_length = 0 if past_key_values is None else past_key_values[0][0].shape[2]
+
+ if past_key_values is None:
+ past_key_values = [None] * self.cfg.hidden_layers
+
+ assert len(past_key_values) == self.cfg.hidden_layers, (
+ f"past_key_values's length {len(past_key_values)} doesn't match "
+ f"num_layers:' {self.cfg.hidden_layers}"
+ )
+
+ def prepare_inputs_for_generation(self, input_ids: flow.Tensor, **kwargs):
+ if "attention_mask" in kwargs:
+ attention_mask = kwargs.pop("attention_mask").float()
+ attention_mask = attention_mask - 1
+ attention_mask.masked_fill_(attention_mask == -1, flow.finfo(flow.float32).min)
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+ @classmethod
+ def from_config(cls, cfg):
+ return {
+ "hidden_layers": cfg.hidden_layers,
+ "vocab_size": cfg.vocab_size,
+ "hidden_size": cfg.hidden_size,
+ "intermediate_size": cfg.intermediate_size,
+ "num_attention_heads": cfg.num_attention_heads,
+ "max_position_embeddings": cfg.max_position_embeddings,
+ "rms_norm_eps": cfg.rms_norm_eps,
+ "initializer_range": cfg.initializer_range,
+ "use_scaled_init_for_output_weights": cfg.use_scaled_init_for_output_weights,
+ "scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion,
+ "amp_enabled": cfg.amp_enabled,
+ "cfg": cfg,
+ }
+
+ @staticmethod
+ def set_activation_checkpoint(model):
+ for module_block in model.modules():
+ # Old API in OneFlow 0.8
+ if hasattr(module_block, "origin"):
+ if isinstance(module_block.origin, DecoderLayer):
+ module_block.config.activation_checkpointing = True
+ else:
+ if isinstance(module_block.to(nn.Module), DecoderLayer):
+ module_block.to(nn.graph.GraphModule).activation_checkpointing = True
diff --git a/projects/Baichuan/baichuan_dataset.py b/projects/Baichuan/baichuan_dataset.py
new file mode 100644
index 000000000..d78efe9fe
--- /dev/null
+++ b/projects/Baichuan/baichuan_dataset.py
@@ -0,0 +1,19 @@
+import oneflow as flow
+from oneflow.utils.data import Dataset
+
+from libai.data.structures import DistTensorData, Instance
+
+
+class AlpacaDataset(Dataset):
+ def __init__(self, path, tokenizer):
+ self.data = flow.load(path)
+ self.tokenizer = tokenizer
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, index):
+ return Instance(
+ input_ids=DistTensorData(self.data[index]["input_ids"]),
+ labels=DistTensorData(self.data[index]["labels"]),
+ )
diff --git a/projects/Baichuan/configs/baichuan_config.py b/projects/Baichuan/configs/baichuan_config.py
new file mode 100644
index 000000000..d5b8a1401
--- /dev/null
+++ b/projects/Baichuan/configs/baichuan_config.py
@@ -0,0 +1,62 @@
+from omegaconf import DictConfig, OmegaConf
+
+from libai.config import LazyCall
+from projects.Baichuan.baichuan import BaichuanForCausalLM
+from projects.Baichuan.tokenizer import BaichuanTokenizer
+from configs.common.train import train
+
+
+cfg = dict(
+ # Model
+ model_type="baichuan",
+ hidden_act="silu",
+ hidden_size=4096,
+ initializer_range=0.02,
+ intermediate_size=11008,
+ max_position_embeddings=2048,
+ num_attention_heads=32,
+ hidden_layers=32,
+ pretraining_tp=1,
+ rms_norm_eps=1e-05,
+ rope_scaling=None,
+ tie_word_embeddings=False,
+ vocab_size=32000,
+ use_scaled_init_for_output_weights=False,
+ scale_mask_softmax_fusion=False,
+ amp_enabled=True,
+ # Inference
+ is_encoder_decoder=False,
+ max_length=256,
+ min_length=0,
+ do_sample=False,
+ early_stopping=False,
+ num_beams=1,
+ num_beam_groups=1,
+ diversity_penalty=0.0,
+ temperature=0.9,
+ top_k=50,
+ top_p=0.6,
+ typical_p=1.0,
+ repetition_penalty=1.0,
+ length_penalty=1.0,
+ no_repeat_ngram_size=0,
+ encoder_no_repeat_ngram_size=0,
+ num_return_sequences=1,
+ chunk_size_feed_forward=0,
+ output_scores=False,
+ use_cache=True,
+ bos_token_id=1,
+ eos_token_id=2,
+ pad_token_id=0,
+ # train
+ pretrained_model_path="/root/models/Baichuan2-7B-Chat",
+)
+
+cfg = DictConfig(cfg)
+
+model = LazyCall(BaichuanForCausalLM)(cfg=cfg)
+tokenization = OmegaConf.create()
+tokenization.make_vocab_size_divisible_by = 1
+tokenization.tokenizer = LazyCall(BaichuanTokenizer)(
+ # pretrained_model_path=cfg.pretrained_model_path + "/tokenizer.model"
+)
diff --git a/projects/Baichuan/configs/baichuan_sft.py b/projects/Baichuan/configs/baichuan_sft.py
new file mode 100644
index 000000000..738d5d199
--- /dev/null
+++ b/projects/Baichuan/configs/baichuan_sft.py
@@ -0,0 +1,100 @@
+import os
+from omegaconf import OmegaConf
+
+from libai.config import LazyCall
+from libai.evaluation import PPLEvaluator
+from libai.scheduler import WarmupExponentialLR
+from libai.data.build import build_nlp_test_loader, build_nlp_train_loader
+
+from configs.common.train import train
+from configs.common.models.graph import graph
+from configs.common.optim import optim
+
+from projects.Baichuan.configs.baichuan_config import cfg
+from projects.Baichuan.baichuan_dataset import AlpacaDataset
+from projects.Baichuan.tokenizer import BaichuanTokenizer
+from projects.Baichuan.baichuan import BaichuanForCausalLM
+
+
+# Hyperparameters
+weight_decay = 0.1
+learning_rate = 5e-5
+dataset_path = "./alpaca_data"
+pretrained_model_path = "/root/models/Llama-2-7b-chat-hf"
+
+# graph & optim
+graph["enabled"] = False
+optim.update(
+ dict(
+ lr=learning_rate,
+ weight_decay=weight_decay,
+ )
+)
+
+# tokenize
+tokenization = OmegaConf.create()
+tokenization.make_vocab_size_divisible_by = 1
+tokenization.tokenizer = LazyCall(BaichuanTokenizer)(
+ pretrained_model_path=os.path.join(pretrained_model_path, "tokenizer.model")
+)
+
+# model
+model = LazyCall(BaichuanForCausalLM)(cfg=cfg)
+
+# datasets
+dataloader = OmegaConf.create()
+dataloader.train = LazyCall(build_nlp_train_loader)(
+ dataset=[
+ LazyCall(AlpacaDataset)(
+ path=os.path.join(dataset_path, "train"), tokenizer=tokenization.tokenizer
+ )
+ ],
+)
+dataloader.test = [
+ LazyCall(build_nlp_test_loader)(
+ dataset=LazyCall(AlpacaDataset)(
+ path=os.path.join(dataset_path, "test"), tokenizer=tokenization.tokenizer
+ ),
+ ),
+]
+
+
+train.update(
+ dict(
+ output_dir="./sft_result",
+ train_micro_batch_size=1,
+ test_micro_batch_size=1,
+ train_epoch=3,
+ train_iter=1,
+ log_period=1,
+ warmup_ratio=1 / 3,
+ num_accumulation_steps=8,
+ rdma_enabled=True,
+ amp=dict(enabled=False),
+ train_with_fp16=True,
+ activation_checkpoint=dict(enabled=True),
+ input_placement_device="cuda",
+ checkpointer=dict(
+ period=5000,
+ max_to_keep=20,
+ ),
+ dist=dict(
+ data_parallel_size=1,
+ tensor_parallel_size=1,
+ pipeline_parallel_size=1,
+ pipeline_num_layers=cfg.hidden_layers,
+ device_type="cuda",
+ ),
+ evaluation=dict(
+ enabled=True,
+ evaluator=LazyCall(PPLEvaluator)(),
+ eval_period=1000,
+ eval_iter=1e5,
+ ),
+ scheduler=LazyCall(WarmupExponentialLR)(
+ warmup_factor=0.0,
+ gamma=1.0,
+ warmup_method="linear",
+ ),
+ )
+)
diff --git a/projects/Baichuan/pipeline.py b/projects/Baichuan/pipeline.py
new file mode 100644
index 000000000..00d22ea11
--- /dev/null
+++ b/projects/Baichuan/pipeline.py
@@ -0,0 +1,128 @@
+# coding=utf-8
+# Copyright 2021 The OneFlow Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import click
+
+from libai.inference.basic import BasePipeline
+from libai.utils import distributed as dist
+
+
+class TextGenerationPipeline(BasePipeline):
+ def load_pretrain_weight(self, libai_cfg_model, model_path, mode="huggingface"):
+ """load pretrained model.
+
+ Args:
+ libai_cfg_model (libai.models): Lazy config Model in Libai, you can import it
+ by `from libai.config.configs.common.models.bert
+ import pretrain_model as libai_cfg_model`
+ model_path (str): The directory path of pretrained model,
+ """
+ if mode == "huggingface":
+ from projects.Baichuan.utils.baichuan_loader import BaichuanLoaderHuggerFace
+
+ model_loader = BaichuanLoaderHuggerFace(
+ libai_cfg_model,
+ libai_cfg_model.cfg,
+ model_path,
+ )
+ model = model_loader.load()
+ model.eval()
+ return model
+
+ elif mode == "libai":
+ from projects.Baichuan.utils.baichuan_loader import BaichuanLoaderLiBai
+
+ model_loader = BaichuanLoaderLiBai(
+ libai_cfg_model,
+ libai_cfg_model.cfg,
+ model_path,
+ )
+ model = model_loader.load()
+ model.eval()
+ return model
+
+ elif mode == "random":
+ from libai.engine import DefaultTrainer
+
+ return DefaultTrainer.build_model(self.cfg)
+ else:
+ raise NotImplementedError
+
+ def _parse_parameters(self, **pipeline_parameters):
+ preprocess_params = {}
+ forward_params = {**pipeline_parameters}
+ postprocess_params = {}
+
+ return preprocess_params, forward_params, postprocess_params
+
+ def preprocess(self, inputs, **kwargs) -> dict:
+ # tokenizer encoderW
+ inputs = self.tokenizer.tokenize(inputs, add_bos=True, padding=True, device=self.device)
+ inputs = {
+ "input_ids": inputs,
+ }
+
+ return inputs
+
+ def forward(self, inputs, **kwargs) -> dict:
+ outputs = self.model.generate(inputs["input_ids"], max_length=50, **kwargs)
+ return {"return_ids": outputs}
+
+ def postprocess(self, model_output_dict, **kwargs) -> dict:
+ return_ids = model_output_dict["return_ids"]
+ records = [
+ {"generated_text": self.tokenizer.decode(return_ids[i])}
+ for i in range(return_ids.size(0))
+ ]
+ return records
+
+
+@click.command()
+@click.option(
+ "--config_file",
+ default="projects/Baichuan/configs/baichuan_config.py",
+ help="Path to the configuration file.",
+)
+@click.option("--model_path", default=None, help="Path to the model checkpoint.")
+@click.option(
+ "--mode",
+ default="libai",
+ help="Mode for the dataloader pipeline, e.g., 'libai' or 'huggingface'.",
+)
+@click.option(
+ "--device", default="cuda", help="Device to run the model on, e.g., 'cuda', 'xpu', 'npu'."
+)
+def main(config_file, model_path, mode, device):
+ pipeline = TextGenerationPipeline(
+ config_file,
+ data_parallel=1,
+ tensor_parallel=1,
+ pipeline_parallel=1,
+ pipeline_num_layers=32,
+ model_path=model_path,
+ mode=mode,
+ device=device,
+ )
+
+ text = [
+ "Give three tips for staying healthy.",
+ ]
+ output = pipeline(inputs=text)
+ if dist.is_main_process():
+ print(output)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/projects/Baichuan/tokenizer.py b/projects/Baichuan/tokenizer.py
new file mode 100644
index 000000000..61d52ac8a
--- /dev/null
+++ b/projects/Baichuan/tokenizer.py
@@ -0,0 +1,95 @@
+# coding=utf-8
+# Copyright 2021 The OneFlow Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import oneflow as flow
+import sentencepiece as spm
+
+import libai.utils.distributed as dist
+
+
+class BaichuanTokenizer:
+ def __init__(
+ self,
+ pretrained_model_path,
+ bos_token="",
+ eos_token="",
+ pad_token="",
+ bos_token_id=None,
+ eos_token_id=None,
+ ):
+ self.sp_model = spm.SentencePieceProcessor()
+ self.sp_model.Load(pretrained_model_path)
+
+ self.bos_token = bos_token
+ self.eos_token = eos_token
+ self.pad_token = pad_token
+ self.bos_token_id = self.sp_model.bos_id() if self.sp_model.bos_id() else bos_token_id
+ self.eos_token_id = self.sp_model.eos_id() if self.sp_model.eos_id() else eos_token_id
+ self.pad_token_id = 0
+
+ @property
+ def vocab_size(self):
+ return self.sp_model.get_piece_size()
+
+ def get_vocab(self):
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ return vocab
+
+ def encode(self, text):
+ tokens = self.sp_model.encode(text)
+ return tokens
+
+ def tokenize(
+ self,
+ text,
+ add_bos=False,
+ add_eos=False,
+ padding=False,
+ device="cuda",
+ max_length=4096,
+ **kwargs
+ ):
+ if isinstance(text, str):
+ tokens = [self.sp_model.encode(text)[:max_length]]
+
+ if isinstance(text, list):
+ tokens = [self.sp_model.encode(s)[:max_length] for s in text]
+ if padding:
+ max_length = max([len(i) for i in tokens])
+ tokens = [t + (max_length - len(t)) * [self.pad_token_id] for t in tokens]
+
+ if add_bos:
+ tokens = [[self.bos_token_id] + token for token in tokens]
+ if add_eos:
+ tokens = [token + [self.eos_token_id] for token in tokens]
+
+ if device:
+ sbp = kwargs.get("sbp", dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]))
+ placement = kwargs.get("placement", flow.placement(device, [0]))
+ return_token_ids = flow.tensor(tokens, sbp=sbp, placement=placement, dtype=flow.long)
+ else:
+ return_token_ids = flow.tensor(tokens, dtype=flow.long)
+ return return_token_ids
+
+ def decode(self, tokens):
+ if isinstance(tokens, flow.Tensor):
+ tokens = tokens.tolist()
+ return self.sp_model.decode(tokens)
+
+ def convert_token_to_id(self, token):
+ return self.sp_model.piece_to_id(token)
+
+ def convert_id_to_token(self, index):
+ return self.sp_model.IdToPiece(index)
diff --git a/projects/Baichuan/utils/baichuan_loader.py b/projects/Baichuan/utils/baichuan_loader.py
new file mode 100644
index 000000000..096247373
--- /dev/null
+++ b/projects/Baichuan/utils/baichuan_loader.py
@@ -0,0 +1,98 @@
+# coding=utf-8
+# Copyright 2021 The OneFlow Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from libai.models.utils.model_loader.base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai
+
+
+class BaichuanLoaderHuggerFace(ModelLoaderHuggerFace):
+ def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
+ super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
+
+ self.base_model_prefix_1 = "model"
+ self.base_model_prefix_2 = "model"
+
+ def _convert_state_dict(self, flow_state_dict, cfg):
+ """Convert state_dict's keys to match model.
+
+ Args:
+ flow_state_dict (OrderedDict): model state dict.
+ cfg (dict): model's default config dict in LiBai.
+
+ Returns:
+ OrderedDict: flow state dict.
+ """
+ # The converted checkpoint.
+ oneflow_state_dict = flow_state_dict.copy()
+ old_keys = list(oneflow_state_dict.keys())
+
+ # Get configs
+ num_attention_heads = cfg.get("num_attention_heads")
+ hidden_size = cfg.get("hidden_size")
+ head_size = int(hidden_size // num_attention_heads)
+
+ new_key_qkv = "model.layers.{}.self_attn.query_key_value.weight"
+ old_key_qkv = "model.layers.{}.self_attn.{}.weight"
+ for layer_idx in range(cfg.get("hidden_layers")):
+ w_pack = old_key_qkv.format(layer_idx, "W_pack")
+
+ qkv = oneflow_state_dict[w_pack]
+ qkv = self._fix_qkv_ordering(qkv, head_size, num_attention_heads, hidden_size)
+ oneflow_state_dict[new_key_qkv.format(layer_idx)] = qkv
+ oneflow_state_dict.pop(w_pack)
+
+ for k in old_keys:
+ if "inv_freq" in k:
+ oneflow_state_dict.pop(k)
+
+ return oneflow_state_dict
+
+ def _load_config_from_json(self, config_file):
+ """load config from `config.json`, and update default config.
+
+ Args:
+ config_file (str): Path of config file.
+ """
+ with open(config_file, mode="r", encoding="utf-8") as f:
+ cfg_dict = json.load(f)
+
+ # update libai_cfg by config.json
+ self._update_cfg("hidden_layers", cfg_dict["num_hidden_layers"])
+ self._update_cfg("hidden_size", cfg_dict["hidden_size"])
+ self._update_cfg("num_attention_heads", cfg_dict["num_attention_heads"])
+ self._update_cfg("max_position_embeddings", cfg_dict["max_position_embeddings"])
+ self._update_cfg("intermediate_size", cfg_dict["intermediate_size"])
+ self._update_cfg("rms_norm_eps", cfg_dict["rms_norm_eps"])
+ self._update_cfg("vocab_size", cfg_dict["vocab_size"])
+ self._update_cfg("initializer_range", cfg_dict["initializer_range"])
+ self._update_cfg(
+ "ffn_hidden_size",
+ cfg_dict.get("n_inner")
+ if cfg_dict.get("n_inner") is not None
+ else 4 * self.libai_cfg["hidden_size"],
+ )
+
+ # update libai_cfg by kwargs
+ for k, v in self.kwargs.items():
+ self._update_cfg(k, v)
+
+ self._update_cfg_log()
+
+
+class BaichuanLoaderLiBai(ModelLoaderLiBai):
+ def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
+ super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
+ self.base_model_prefix_2 = "model"
diff --git a/projects/Baichuan/utils/data_prepare.py b/projects/Baichuan/utils/data_prepare.py
new file mode 100644
index 000000000..378c73467
--- /dev/null
+++ b/projects/Baichuan/utils/data_prepare.py
@@ -0,0 +1,160 @@
+import copy
+import json
+import math
+import os
+from pathlib import Path
+from typing import Optional
+
+import oneflow as flow
+import requests
+from oneflow.utils.data import random_split
+from tqdm import tqdm
+
+from libai.config import instantiate
+from libai.utils.logger import setup_logger
+from projects.Baichuan.configs.baichuan_config import tokenization
+
+logger = setup_logger()
+
+
+def prepare(
+ destination_path: Path = Path("./data/libai_xpu_alpaca"),
+ checkpoint_dir: Path = Path("/root/models/Baichuan2-7B-Chat"),
+ test_split_fraction: float = 0.03865, # to get exactly 2000 test samples,
+ seed: int = 42,
+ mask_inputs: bool = False, # as in alpaca-lora
+ data_file_name: str = "alpaca_data_cleaned_archive.json",
+ data_file_url: str = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json", # noqa
+ ignore_index: int = -100,
+ max_seq_length: Optional[int] = 512,
+) -> None:
+ """Prepare the Alpaca dataset for instruction tuning.
+ The output is a training and test dataset saved as `train.pt` and `test.pt`,
+ which stores the preprocessed and tokenized prompts and labels.
+ """
+ if max_seq_length is None:
+ with open(os.path.join(checkpoint_dir, "config.json"), "r", encoding="utf-8") as file:
+ config = json.load(file)
+ max_seq_length = config["max_position_embeddings"]
+
+ destination_path.mkdir(parents=True, exist_ok=True)
+ data_file_path = destination_path / data_file_name
+ logger.info("Loading data file...")
+ download_if_missing(data_file_path, data_file_url)
+ with open(data_file_path, "r", encoding="utf-8") as file:
+ data = json.load(file)
+
+ logger.info("Loading tokenizer...")
+ tokenizer = instantiate(tokenization.tokenizer)
+
+ # Partition the dataset into train and test
+ num_of_test_samples = math.floor(test_split_fraction * len(data))
+ num_of_train_samples = len(data) - num_of_test_samples
+ train_set, test_set = random_split(
+ data,
+ [num_of_train_samples, num_of_test_samples],
+ generator=flow.Generator().manual_seed(seed),
+ )
+ train_set, test_set = list(train_set), list(test_set)
+
+ logger.info(f"train has {len(train_set):,} samples")
+ logger.info(f"test has {len(test_set):,} samples")
+
+ logger.info("Processing train split ...")
+ train_set = [
+ prepare_sample(
+ example=sample,
+ tokenizer=tokenizer,
+ max_length=max_seq_length,
+ )
+ for sample in tqdm(train_set)
+ ]
+ flow.save(train_set, destination_path / "train")
+
+ logger.info("Processing test split ...")
+ test_set = [
+ prepare_sample(
+ example=sample,
+ tokenizer=tokenizer,
+ max_length=max_seq_length,
+ )
+ for sample in tqdm(test_set)
+ ]
+ flow.save(test_set, destination_path / "test")
+
+ max_length = max([i["input_ids"].shape[0] for i in train_set])
+ logger.info("Max length of training dataset: {}".format(max_length))
+
+
+def download_if_missing(file_path: Path, file_url: str) -> None:
+ """Downloads the raw json data file and saves it in the given destination."""
+ if file_path.exists() and file_path.stat().st_size > 0:
+ return
+ with open(file_path, "w", encoding="utf-8") as f:
+ f.write(requests.get(file_url).text)
+
+
+def prepare_sample(example: dict, tokenizer, max_length: int) -> dict:
+ """Processes a single sample.
+ Each sample in the dataset consists of:
+ - instruction: A string describing the task
+ - input: A string holding a special input value for the instruction.
+ This only applies to some samples, and in others this is empty.
+ - output: The response string
+ This function processes this data to produce a prompt text and a label for
+ supervised training. The prompt text is formed as a single message including both
+ the instruction and the input. The label/target is the same message but with the
+ response attached.
+ Finally, both the prompt and the label get tokenized. If desired, all tokens
+ in the label that correspond to the original input prompt get masked out (default).
+ """
+ full_prompt = generate_prompt(example)
+ full_prompt_and_response = full_prompt + example["output"]
+
+ prompt = tokenizer.tokenize(full_prompt, add_bos=True, add_eos=False, device="cpu")[0]
+ example = tokenizer.tokenize(
+ full_prompt_and_response, add_bos=True, add_eos=True, device="cpu"
+ )[0]
+
+ padding = max_length - example.shape[0]
+ if padding > 0:
+ example = flow.cat((example, flow.zeros(padding, dtype=flow.long) - 1))
+ elif padding < 0:
+ example = example[:max_length]
+ labels = copy.deepcopy(example)
+ labels[: len(prompt)] = -1
+ example_mask = example.ge(0)
+ label_mask = labels.ge(0)
+ example[~example_mask] = 0
+ labels[~label_mask] = -1
+ example = example[:-1]
+ labels = labels[1:]
+ example_mask = flow.where(
+ example_mask, flow.tensor(0, dtype=flow.float), flow.tensor(-float("inf"))
+ )
+ example_mask = example_mask[:-1]
+ return {
+ "input_ids": example,
+ "labels": labels,
+ }
+
+
+def generate_prompt(example: dict) -> str:
+ """Generates a standardized message to prompt the model with an instruction, optional input and a
+ 'response' field."""
+
+ if example["input"]:
+ return (
+ "Below is an instruction that describes a task, paired with an input that provides further context. " # noqa
+ "Write a response that appropriately completes the request.\n\n"
+ f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" # noqa
+ )
+ return (
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request.\n\n"
+ f"### Instruction:\n{example['instruction']}\n\n### Response:"
+ )
+
+
+if __name__ == "__main__":
+ prepare()