diff --git a/deserve_client/README.md b/deserve_client/README.md new file mode 100644 index 0000000..ff7a99d --- /dev/null +++ b/deserve_client/README.md @@ -0,0 +1,18 @@ +# DeServe Client + +## How To Run + +For completion: +```bash +python3 -m deserve_client.client complete meta-llama/Meta-Llama-3-8B-Instruct "Here is a text prompt." +``` + +For dumping traces of prefill: +```bash +python3 -m deserve_client.client trace meta-llama/Meta-Llama-3-8B-Instruct "Here is a text prompt." +``` + +For verifying the correctness of the trace: +```bash +python3 -m deserve_client.client verify meta-llama/Meta-Llama-3-8B-Instruct "Here is a text prompt." +``` diff --git a/deserve_client/__init__.py b/deserve_client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deserve_client/client.py b/deserve_client/client.py new file mode 100644 index 0000000..9f9d090 --- /dev/null +++ b/deserve_client/client.py @@ -0,0 +1,100 @@ +import pickle +from typing import Any + +import requests +import safetensors.torch +import torch +import typer +from transformers import AutoTokenizer # type: ignore + +from deserve_client.model import ( + CheckCtx, + Transformer, + VerifyCtx, + llama_3_8b_args, + main_device, +) +from deserve_controller.controller_api import app +from deserve_worker.trace import OpId + +cli = typer.Typer() +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + + +def loads(b: bytes) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + """ + Load tensors and metadata from bytes + """ + + metadata_length = int.from_bytes(b[:4], byteorder="big") + metadata = pickle.loads(b[4 : 4 + metadata_length]) + tensors = safetensors.torch.load(b[4 + metadata_length :]) + return tensors, metadata + + +@cli.command() +def complete(model: str, prompt: str, entry_point: str = "http://localhost:19000"): + response = requests.post( + f"{entry_point}/complete", + json={"model": model, "prompt": prompt}, + stream=True, + ) + if response.status_code != 200: + typer.echo("Error") + return + + for chunk in response.iter_content(): + if chunk: + print(chunk.decode("utf-8"), end="", flush=True) + + +@cli.command() +def trace(model: str, prompt: str, entry_point: str = "http://localhost:19000"): + response = requests.post( + f"{entry_point}/trace", + json={"model": model, "prompt": prompt}, + stream=True, + ) + if response.status_code != 200: + typer.echo("Error") + return + + tensors = {} + for chunk in response.iter_content(chunk_size=None): + if chunk: + temp_tensors, _ = loads(chunk) + tensors.update(temp_tensors) + print(list(tensors.keys())) + + +@cli.command() +def verify(model: str, prompt: str, entry_point: str = "http://localhost:19000"): + response = requests.post( + f"{entry_point}/trace", + json={"model": model, "prompt": prompt}, + stream=True, + ) + if response.status_code != 200: + typer.echo("Error") + return + tensors: dict[str, torch.Tensor] = {} + for chunk in response.iter_content(chunk_size=None): + if chunk: + temp_tensors, _ = loads(chunk) + tensors.update(temp_tensors) + + traces = {OpId.from_str(k): v for k, v in tensors.items()} + transformer = Transformer(llama_3_8b_args) + tokens = tokenizer(prompt, return_tensors="pt")["input_ids"].to(main_device) + result = transformer.forward(tokens, CheckCtx(0.03, traces)) + if isinstance(result, torch.Tensor): + print("No difference found") + else: + if not transformer.verify(tokens, VerifyCtx(result.op_id, 0.03, traces)): + print("Difference found for", result.op_id) + else: + print("Difference found but verification failed") + + +if __name__ == "__main__": + cli() diff --git a/deserve_client/model.py b/deserve_client/model.py new file mode 100644 index 0000000..ca92562 --- /dev/null +++ b/deserve_client/model.py @@ -0,0 +1,658 @@ +import math +import os +from dataclasses import dataclass +from typing import Any, Mapping, Optional + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import AutoTokenizer # type: ignore + +from deserve_worker.trace import ComponentId, LayerId, OpId + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") +torch.set_default_dtype(torch.float16) +main_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 500000 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +llama_3_8b_args = ModelArgs( + n_kv_heads=8, + vocab_size=128256, + multiple_of=1024, + ffn_dim_multiplier=1.3, + norm_eps=1e-5, + rope_theta=500000.0, +) + + +@dataclass +class Diff: + op_id: OpId + diff: float + + +@dataclass +class CheckCtx: + threshold: float + traces: dict[OpId, torch.Tensor] + + def check(self, op_id: OpId, x: torch.Tensor) -> torch.Tensor | Diff: + y = self.traces[op_id].to(main_device) + if torch.allclose(x, y, atol=self.threshold): + return y + else: + return Diff(op_id, torch.max(torch.abs(x - y)).item()) + + +@dataclass +class VerifyCtx: + op_id: OpId + threshold: float + traces: dict[OpId, torch.Tensor] + + def get_trace(self, op_id: OpId) -> torch.Tensor: + return self.traces[op_id].to(main_device) + + def verify(self, x: torch.Tensor) -> bool: + y = self.traces[self.op_id].to(main_device) + return torch.allclose(x, y, atol=self.threshold) + + +class RMSNorm(nn.Module): + def __init__(self, component_id: ComponentId, dim: int, eps: float = 1e-6): + super().__init__() + self.component_id = component_id + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def verify(self, x: torch.Tensor, ctx: VerifyCtx) -> bool: + op = ctx.op_id.op + if op == "output": + return ctx.verify(self._norm(x.float()).type_as(x)) + else: + output = ctx.get_trace(self.component_id.with_op("weighted_output")) + return ctx.verify(output * self.weight) + + def forward(self, x: torch.Tensor, ctx: CheckCtx) -> torch.Tensor | Diff: + output = ctx.check( + self.component_id.with_op("output"), self._norm(x.float()).type_as(x) + ) + if isinstance(output, Diff): + return output + return ctx.check( + self.component_id.with_op("weighted_output"), output * self.weight + ) + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + def __init__(self, component_id: ComponentId, args: ModelArgs): + super().__init__() + self.component_id = component_id + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + self.n_local_heads = args.n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = torch.nn.utils.skip_init( + nn.Linear, + args.dim, + args.n_heads * self.head_dim, + bias=False, + ) + self.wk = torch.nn.utils.skip_init( + nn.Linear, + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wv = torch.nn.utils.skip_init( + nn.Linear, + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wo = torch.nn.utils.skip_init( + nn.Linear, + args.n_heads * self.head_dim, + args.dim, + bias=False, + ) + + def verify( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ctx: VerifyCtx, + ) -> bool: + bsz, seqlen, _ = x.shape + op = ctx.op_id.op + if op == "xq": + xq = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim) + return ctx.verify(xq) + elif op == "xk": + xk = self.wk(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + return ctx.verify(xk) + elif op == "xv": + xv = self.wv(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + return ctx.verify(xv) + elif op == "xq_rotary" or op == "xk_rotary": + xq, xk = self.wq(x), self.wk(x) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + if op == "xq_rotary": + return ctx.verify(xq) + else: + return ctx.verify(xk) + elif op == "scores": + xq = ctx.get_trace(self.component_id.with_op("xq_rotary")) + keys = ctx.get_trace(self.component_id.with_op("xk_rotary")) + keys = repeat_kv( + keys, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose( + 1, 2 + ) # (bs, n_local_heads, cache_len + seqlen, head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + return ctx.verify(scores) + elif op == "output": + scores = ctx.get_trace(self.component_id.with_op("scores")) + values = ctx.get_trace(self.component_id.with_op("xv")) + values = repeat_kv( + values, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = values.transpose( + 1, 2 + ) # (bs, n_local_heads, cache_len + seqlen, head_dim) + output = torch.matmul( + scores, values + ) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return ctx.verify(output) + elif op == "weighted_output": + output = ctx.get_trace(self.component_id.with_op("output")) + return ctx.verify(self.wo(output)) + assert False + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ctx: CheckCtx, + ) -> torch.Tensor | Diff: + bsz, seqlen, _ = x.shape + + xq = ctx.check( + self.component_id.with_op("xq"), + self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim), + ) + if isinstance(xq, Diff): + return xq + + xk = ctx.check( + self.component_id.with_op("xk"), + self.wk(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim), + ) + if isinstance(xk, Diff): + return xk + + xv = ctx.check( + self.component_id.with_op("xv"), + self.wv(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim), + ) + if isinstance(xv, Diff): + return xv + + xq_new, xk_new = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq = ctx.check(self.component_id.with_op("xq_rotary"), xq_new) + if isinstance(xq, Diff): + return xq + + xk = ctx.check(self.component_id.with_op("xk_rotary"), xk_new) + if isinstance(xk, Diff): + return xk + + keys = xk.clone() + values = xv.clone() + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv( + keys, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv( + values, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = values.transpose( + 1, 2 + ) # (bs, n_local_heads, cache_len + seqlen, head_dim) + scores_new = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores_new = ( + scores_new + mask + ) # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores_new = F.softmax(scores_new.float(), dim=-1).type_as(xq) + + # check scores + scores = ctx.check(self.component_id.with_op("scores"), scores_new) + if isinstance(scores, Diff): + return scores + + output_new = torch.matmul( + scores, values + ) # (bs, n_local_heads, seqlen, head_dim) + output_new = output_new.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + output = ctx.check(self.component_id.with_op("output"), output_new) + if isinstance(output, Diff): + return output + + return ctx.check(self.component_id.with_op("weighted_output"), self.wo(output)) + + +class FeedForward(nn.Module): + def __init__( + self, + component_id: ComponentId, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + self.component_id = component_id + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = torch.nn.utils.skip_init( + nn.Linear, + dim, + hidden_dim, + bias=False, + ) + self.w2 = torch.nn.utils.skip_init( + nn.Linear, + hidden_dim, + dim, + bias=False, + ) + self.w3 = torch.nn.utils.skip_init( + nn.Linear, + dim, + hidden_dim, + bias=False, + ) + + def verify(self, x: torch.Tensor, ctx: VerifyCtx) -> bool: + op = ctx.op_id.op + if op == "w1": + return ctx.verify(F.silu(self.w1(x))) + elif op == "w3": + return ctx.verify(self.w3(x)) + elif op == "w2": + w1 = ctx.get_trace(self.component_id.with_op("w1")) + w3 = ctx.get_trace(self.component_id.with_op("w3")) + return ctx.verify(self.w2(w1 * w3)) + assert False + + def forward( + self, + x: torch.Tensor, + ctx: CheckCtx, + ) -> torch.Tensor | Diff: + # check w1, w3, w2 + w1 = ctx.check(self.component_id.with_op("w1"), F.silu(self.w1(x))) + if isinstance(w1, Diff): + return w1 + + w3 = ctx.check(self.component_id.with_op("w3"), self.w3(x)) + if isinstance(w3, Diff): + return w3 + + return ctx.check(self.component_id.with_op("w2"), self.w2(w1 * w3)) + + +class TraceLinear(nn.Module): + def __init__( + self, + component_id: ComponentId, + in_features: int, + out_features: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.component_id = component_id + self.linear = nn.Linear( + in_features, out_features, bias=False, device=device, dtype=dtype + ) + + @torch.inference_mode() + def forward(self, x: torch.Tensor, ctx: CheckCtx) -> torch.Tensor | Diff: + return ctx.check(self.component_id.with_op("output"), self.linear(x)) + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, + ) -> torch.nn.modules.module._IncompatibleKeys: + return self.linear.load_state_dict(state_dict, strict, assign) # type: ignore + + +class TraceEmbedding(nn.Module): + def __init__( + self, + component_id: ComponentId, + num_embeddings: int, + embedding_dim: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.component_id = component_id + self.embedding = nn.Embedding( + num_embeddings, embedding_dim, device=device, dtype=dtype + ) + + @torch.inference_mode() + def forward(self, x: torch.Tensor, ctx: CheckCtx) -> torch.Tensor | Diff: + return ctx.check(self.component_id.with_op("output"), self.embedding(x)) + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, + ) -> torch.nn.modules.module._IncompatibleKeys: + return self.embedding.load_state_dict(state_dict, strict, assign) # type: ignore + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: LayerId, args: ModelArgs): + super().__init__() + self.layer_id = layer_id + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(layer_id.with_component("attention"), args) + self.feed_forward = FeedForward( + layer_id.with_component("feed_forward"), + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.attention_norm = RMSNorm( + layer_id.with_component("attention_norm"), args.dim, eps=args.norm_eps + ) + self.ffn_norm = RMSNorm( + layer_id.with_component("ffn_norm"), args.dim, eps=args.norm_eps + ) + + def verify( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ctx: VerifyCtx, + ) -> bool: + layer = ctx.op_id.layer + component = ctx.op_id.component + op = ctx.op_id.op + if component == "feed_forward": + if op == "res": + return ctx.verify( + ctx.get_trace(OpId(layer, "attention", "res")) + + ctx.get_trace(OpId(layer, "feed_forward", "w2")) + ) + else: + return self.feed_forward.verify( + ctx.get_trace(OpId(layer, "ffn_norm", "weighted_output")), ctx + ) + elif component == "ffn_norm": + return self.ffn_norm.verify( + ctx.get_trace(OpId(layer, "attention", "res")), ctx + ) + elif component == "attention_norm": + return self.attention_norm.verify(x, ctx) + elif component == "attention": + if op == "res": + return ctx.verify( + x + ctx.get_trace(OpId(layer, "attention", "weighted_output")) + ) + else: + return self.attention.verify( + ctx.get_trace(OpId(layer, "attention_norm", "weighted_output")), + freqs_cis, + mask, + ctx, + ) + assert False + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ctx: CheckCtx, + ) -> torch.Tensor | Diff: + attn_norm = self.attention_norm.forward(x, ctx) + if isinstance(attn_norm, Diff): + return attn_norm + + attn = self.attention.forward(attn_norm, freqs_cis, mask, ctx) + if isinstance(attn, Diff): + return attn + + h = ctx.check( + self.layer_id.with_component("attention").with_op("res"), x + attn + ) + if isinstance(h, Diff): + return h + + ffn_norm = self.ffn_norm.forward(h, ctx) + if isinstance(ffn_norm, Diff): + return ffn_norm + + ffn = self.feed_forward.forward(ffn_norm, ctx) + if isinstance(ffn, Diff): + return ffn + + return ctx.check( + self.layer_id.with_component("feed_forward").with_op("res"), h + ffn + ) + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + cache_dir = "~/.cache/fleece-worker/models/llama-3-8b-instruct-slice/" + cache_dir = os.path.expanduser(cache_dir) + + self.tok_embeddings = torch.nn.utils.skip_init( + TraceEmbedding, + ComponentId("tok_embeddings", "main"), + params.vocab_size, + params.dim, + ) + self.tok_embeddings.load_state_dict( + torch.load(cache_dir + "tok_embeddings.pt", map_location="cpu") + ) + self.tok_embeddings.to(main_device) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + layer = TransformerBlock(LayerId.from_str(f"{layer_id:02}"), params) + layer.load_state_dict( + torch.load(cache_dir + f"layers.{layer_id}.pt", map_location="cpu") + ) + layer.to(main_device) + self.layers.append(layer) + + self.norm = RMSNorm( + ComponentId("norm", "main"), params.dim, eps=params.norm_eps + ) + self.norm.load_state_dict(torch.load(cache_dir + "norm.pt", map_location="cpu")) + self.norm.to(main_device) + self.output = torch.nn.utils.skip_init( + TraceLinear, ComponentId("output", "main"), params.dim, params.vocab_size + ) + self.output.load_state_dict( + torch.load(cache_dir + "output.pt", map_location="cpu") + ) + self.output.to(main_device) + + self.freqs_cis = precompute_freqs_cis( + params.dim // params.n_heads, + params.max_seq_len * 2, + params.rope_theta, + ) + + @torch.inference_mode() + def verify(self, tokens: torch.Tensor, ctx: VerifyCtx) -> bool: + _bsz, seqlen = tokens.shape + layer = ctx.op_id.layer + if layer.isdigit(): + mask = None + if seqlen > 1: + mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) + mask = torch.triu(mask, diagonal=1) + mask = torch.hstack( + [torch.zeros((seqlen, 0), device=tokens.device), mask] + ).type_as(tokens) + + layer_int = int(layer) + if layer_int < 0 or layer_int >= self.n_layers: + assert False + + if layer_int == 0: + input = ctx.get_trace(OpId("tok_embeddings", "main", "output")) + else: + input = ctx.get_trace( + OpId(f"{layer_int - 1:02}", "feed_forward", "res") + ) + + return self.layers[layer_int].verify(input, self.freqs_cis, mask, ctx) + elif layer == "tok_embeddings": + return self.tok_embeddings.verify(tokens, ctx) + elif layer == "norm": + num_layers = self.n_layers + return self.norm.verify( + ctx.get_trace(OpId(f"{num_layers - 1:02}", "feed_forward", "res")), ctx + ) + elif layer == "output": + return self.output.verify( + ctx.get_trace(OpId("norm", "main", "output")), ctx + ) + assert False + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, ctx: CheckCtx) -> torch.Tensor | Diff: + _bsz, seqlen = tokens.shape + + h = self.tok_embeddings.forward(tokens, ctx) + if isinstance(h, Diff): + return h + + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[0:seqlen] + + mask = None + if seqlen > 1: + mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) + + mask = torch.triu(mask, diagonal=1) + + # When performing key-value caching, we compute the attention scores + # only for the new sequence. Thus, the matrix of scores is of size + # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for + # j > cache_len + i, since row i corresponds to token cache_len + i. + mask = torch.hstack( + [torch.zeros((seqlen, 0), device=tokens.device), mask] + ).type_as(h) + + for layer in self.layers: + h = layer.forward(h, freqs_cis, mask, ctx) + if isinstance(h, Diff): + return h + h = self.norm.forward(h, ctx) + if isinstance(h, Diff): + return h + + output = self.output.forward(h, ctx) + if isinstance(output, Diff): + return output + return output.float() diff --git a/deserve_client/py.typed b/deserve_client/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/deserve_client/pyproject.toml b/deserve_client/pyproject.toml new file mode 100644 index 0000000..b73ffc6 --- /dev/null +++ b/deserve_client/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "deserve_client" +version = "0.0.1" +authors = [ + { name="Example Author", email="author@example.com" }, +] +description = "Deserve Client" +requires-python = ">=3.11" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] diff --git a/deserve_controller/README.md b/deserve_controller/README.md new file mode 100644 index 0000000..8098686 --- /dev/null +++ b/deserve_controller/README.md @@ -0,0 +1,7 @@ +# DeServe Controller + +## How to run + +```bash +python3 -m deserve_controller.controller_api --port= +``` \ No newline at end of file diff --git a/deserve_controller/__init__.py b/deserve_controller/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deserve_controller/controller_api.py b/deserve_controller/controller_api.py new file mode 100644 index 0000000..643821e --- /dev/null +++ b/deserve_controller/controller_api.py @@ -0,0 +1,275 @@ +import argparse +import logging +import pickle +import queue +import traceback +import uuid +from typing import Any, Generator, Optional + +import requests +import safetensors.torch +import torch +from cachetools import TTLCache +from fastapi import FastAPI, HTTPException, Request, Response +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from transformers import AutoTokenizer # type: ignore + +controller_url: str +app = FastAPI() +logger = logging.getLogger("uvicorn") +workers: TTLCache[str, str] = TTLCache(maxsize=128, ttl=2) +model2layers = { + "meta-llama/Meta-Llama-3-70B-Instruct": 80, + "meta-llama/Meta-Llama-3-8B-Instruct": 32, +} +model2alias = { + "meta-llama/Meta-Llama-3-70B-Instruct": "llama-3-70b-instruct-slice", + "meta-llama/Meta-Llama-3-8B-Instruct": "llama-3-8b-instruct-slice", +} +token_channels: dict[str, queue.Queue[Optional[str]]] = {} +trace_channels: dict[str, queue.Queue[dict[str, torch.Tensor]]] = {} +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + +STOP_TOKEN_IDS = [128001, 128009] + + +def dumps(tensors: dict[str, torch.Tensor], metadata: dict[str, Any]) -> bytes: + """ + Dump tensors and metadata into bytes + """ + + metadata_bytes = pickle.dumps(metadata) + tensors_bytes = safetensors.torch.save(tensors) + return ( + len(metadata_bytes).to_bytes(4, byteorder="big") + + metadata_bytes + + tensors_bytes + ) + + +def loads(b: bytes) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + """ + Load tensors and metadata from bytes + """ + + metadata_length = int.from_bytes(b[:4], byteorder="big") + metadata = pickle.loads(b[4 : 4 + metadata_length]) + tensors = safetensors.torch.load(b[4 + metadata_length :]) + return tensors, metadata + + +class RegisterRequest(BaseModel): + worker_id: str + worker_url: str + + +@app.post("/register") +def register(request: RegisterRequest) -> str: + workers[request.worker_id] = request.worker_url + return "ok" + + +class HeartbeatRequest(BaseModel): + worker_id: str + worker_url: str + + +@app.post("/heartbeat") +def heartbeat(request: HeartbeatRequest) -> str: + workers[request.worker_id] = request.worker_url + return "ok" + + +class CompleteRequest: + pass # discuss about implementation details (how to send, how to retrieve) + + +class PlanStep(BaseModel): + worker_id: str + worker_url: str + layers: list[str] + + +def generate_plan(model: str, worker_ids: list[str]) -> list[PlanStep]: + alias = model2alias[model] + num_layer_total = model2layers[model] + num_layer_worker = num_layer_total // len(worker_ids) + layers = [ + (i * num_layer_worker, (i + 1) * num_layer_worker) + for i in range(len(worker_ids) - 1) + ] + if len(layers) == 0: + layers.append((0, num_layer_total)) + else: + layers.append((layers[-1][1], num_layer_total)) + plans: list[PlanStep] = [] + for worker_id, layer in zip(worker_ids, layers): + plans.append( + PlanStep( + worker_id=worker_id, + worker_url=workers[worker_id], + layers=[f"{alias}/layers.{i}" for i in range(layer[0], layer[1])], + ) + ) + plans[0].layers.insert(0, f"{alias}/tok_embeddings") + plans[-1].layers.append(f"{alias}/norm") + plans[-1].layers.append(f"{alias}/output") + return plans + + +def relay_tokens( + channel: queue.Queue[Optional[str]], +) -> Generator[bytes, None, None]: + while True: + value = channel.get() + if value is None: + break + yield value.encode("utf-8") + + +class OnlineCompleteRequest(BaseModel): + model: str + prompt: str + + +@app.post("/complete") +def complete(request: OnlineCompleteRequest) -> StreamingResponse: + model = request.model + prompt = request.prompt + + if model not in model2layers: + raise HTTPException(status_code=404, detail="Model not found") + + task_id = str(uuid.uuid4()) + + # init channel for relay + token_channel = queue.Queue[Optional[str]]() + token_channels[task_id] = token_channel + + # generate request + tokens = tokenizer(prompt, return_tensors="pt")["input_ids"] + plan = generate_plan(model, list(workers.keys())) + tensors = {"x": tokens} + metadata = { + "task_id": task_id, + "round": 0, + "plan": plan, + "sampling_params": {"temperature": 0.0, "top_p": 1.0, "max_total_len": 2048}, + } + first_worker_url = plan[0].worker_url + response = requests.post( + f"{first_worker_url}/forward", data=dumps(tensors, metadata) + ) + if response.status_code != 200: + raise HTTPException(status_code=500, detail="Worker error") + + return StreamingResponse(relay_tokens(token_channel)) + + +class OfflineCompleteRequest(BaseModel): + model: str + prompts: list[str] + + +@app.post("/offline-complete") +def offline_complete(request: OfflineCompleteRequest) -> None: + pass + + +def relay_traces( + channel: queue.Queue[dict[str, torch.Tensor]], total: int +) -> Generator[bytes, None, None]: + cnt = 0 + while cnt < total: + value = channel.get() + cnt += 1 + if value is None: + break + bytes = dumps(value, {}) + yield bytes + + +class TraceRequest(BaseModel): + model: str + prompt: str + + +@app.post("/trace") +def trace(request: TraceRequest) -> Response: + model = request.model + prompt = request.prompt + + if model not in model2layers: + raise HTTPException(status_code=404, detail="Model not found") + + task_id = str(uuid.uuid4()) + + # init channel for relay, but we don't handle it inside tracing + token_channel = queue.Queue[Optional[str]]() + token_channels[task_id] = token_channel + + # init traces + trace_channel = queue.Queue[dict[str, torch.Tensor]]() + trace_channels[task_id] = trace_channel + + # generate request + tokens = tokenizer(prompt, return_tensors="pt")["input_ids"] + online_workers = list(workers.keys()) + plan = generate_plan(model, online_workers) + tensors = {"x": tokens} + metadata = { + "task_id": task_id, + "round": 0, + "plan": plan, + "sampling_params": {"temperature": 0.0, "top_p": 1.0, "max_total_len": 2048}, + } + first_worker_url = plan[0].worker_url + response = requests.post(f"{first_worker_url}/trace", data=dumps(tensors, metadata)) + if response.status_code != 200: + raise HTTPException(status_code=500, detail="Worker error") + return StreamingResponse(relay_traces(trace_channel, len(online_workers))) + + +class UpdateTaskRequest(BaseModel): + task_id: str + output_tokens: list[list[int]] # [bsz, seqlen], in normal case, bsz=1 and seqlen=1 + + +@app.post("/update_tasks") +def update_tasks(requests: list[UpdateTaskRequest]) -> None: + for request in requests: + task_id = request.task_id + for token_ids in request.output_tokens: + token_id = token_ids[0] + if token_id in STOP_TOKEN_IDS: + token_channels[task_id].put(None) + else: + token = tokenizer.decode(token_id) + if task_id in token_channels: + token_channels[task_id].put(token) + else: + logger.warning(f"Task {task_id} not found") + + +@app.post("/update_traces") +async def update_traces(requests: Request) -> None: + body = await requests.body() + tensors, metadata = loads(body) + task_id = metadata["task_id"] + if task_id in trace_channels: + trace_channels[task_id].put(tensors) + else: + logger.warning(f"Task {task_id} not found") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=19000) + args = parser.parse_args() + + controller_url = f"http://localhost:{args.port}" + + import uvicorn + + uvicorn.run(app, host="127.0.0.1", port=args.port) diff --git a/deserve_controller/py.typed b/deserve_controller/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/deserve_controller/pyproject.toml b/deserve_controller/pyproject.toml new file mode 100644 index 0000000..c80b66b --- /dev/null +++ b/deserve_controller/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "deserve_controller" +version = "0.0.1" +authors = [ + { name="Example Author", email="author@example.com" }, +] +description = "Deserve Controller" +requires-python = ">=3.11" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] diff --git a/deserve_worker/README.md b/deserve_worker/README.md new file mode 100644 index 0000000..4778647 --- /dev/null +++ b/deserve_worker/README.md @@ -0,0 +1,67 @@ +# DeServe Worker + +## How to run + +```bash +python3 -m deserve_worker.worker_api +``` + +For example, + +```bash +python3 -m deserve_worker.worker_api 8080 worker0 +``` + +## API + +### Inference + +To inference, you need to pass a plan and other metadata in the request body. You have to send it to the first worker. The plan is a list of workers with their layers. The first worker will send the request to the next worker in the plan. The last worker will return the token to the controller. Here is an example: + +```python +plan = [ + { + "worker_id": worker_id0, + "worker_url": "http://localhost:8080", + "layers": [ + "llama-3-8b-instruct-slice/tok_embeddings", + *[f"llama-3-8b-instruct-slice/layers.{i}" for i in range(0, 16)], + ], + }, + { + "worker_id": worker_id1, + "worker_url": "http://localhost:8081", + "layers": [ + *[f"llama-3-8b-instruct-slice/layers.{i}" for i in range(16, 32)], + "llama-3-8b-instruct-slice/norm", + "llama-3-8b-instruct-slice/output", + ], + }, +] + +metadata = { + "task_id": task_id, + "round": 0, + "plan": plan, + "sampling_params": { + "temperature": 0.0, + "top_p": 1.0, + "max_total_len": 2048, + }, +} + +tensors = {"x": tokens} + +requests.post( + "http://localhost:8080/forward", data=dumps(tensors, metadata) +) +``` + +### Trace + +To trace, the plan is also required. It is worth noting that trace use different kernel for computation and dumping. + + +### Cancel + +You should not cancel a task. It's used for freeing resources like KV caches. \ No newline at end of file diff --git a/deserve_worker/__init__.py b/deserve_worker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deserve_worker/command.py b/deserve_worker/command.py new file mode 100644 index 0000000..b64892d --- /dev/null +++ b/deserve_worker/command.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass + +import torch + +from deserve_worker.kvcache.kvcache import KVCacheManager +from deserve_worker.layer_storage import LayerStorage +from deserve_worker.task import TaskData +from deserve_worker.trace import OpId + + +@dataclass +class BatchForward: + xs: torch.Tensor + layer_storage: LayerStorage + task_datas: list[TaskData] + kvcache_manager: KVCacheManager + need_sample: bool # to be eliminated in the future, because we can infer this from LayerStorage + + +@dataclass +class SingleTrace: + x: torch.Tensor + layer_storage: LayerStorage + task_data: TaskData + kvcache_manager: KVCacheManager + need_sample: bool + + +@dataclass +class BatchResult: + xs: torch.Tensor + task_ids: list[str] + + +@dataclass +class BatchUpdate: + tokens: list[torch.Tensor] + task_ids: list[str] + cancel_ids: list[str] + + +@dataclass +class TraceResult: + x: torch.Tensor + task_id: str + trace: dict[OpId, torch.Tensor] diff --git a/deserve_worker/kvcache/__init__.py b/deserve_worker/kvcache/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deserve_worker/kvcache/block_pool.py b/deserve_worker/kvcache/block_pool.py new file mode 100644 index 0000000..7700d11 --- /dev/null +++ b/deserve_worker/kvcache/block_pool.py @@ -0,0 +1,56 @@ +from typing import Optional + +import torch + + +class BlockPool: + def __init__( + self, + num_blocks: int, + block_size: int, + main_device: torch.device, + main_dtype: torch.dtype, + ): + self.num_blocks = num_blocks + self.block_size = block_size + self.main_device = main_device + self.main_dtype = main_dtype + self.fetch_size = 1024 + + self.block_ks = torch.randn( + num_blocks, block_size, 8, 128, device=main_device, dtype=main_dtype + ) + self.block_vs = torch.randn( + num_blocks, block_size, 8, 128, device=main_device, dtype=main_dtype + ) + self.block_bitmap = torch.ones( + (num_blocks,), device=main_device, dtype=torch.bool + ) + self.block_buffer = torch.empty(0, device=main_device, dtype=torch.int32) + + def alloc(self, size: int) -> Optional[torch.Tensor]: + if size > self.block_buffer.shape[0]: + fetch_size = max(self.fetch_size, size - self.block_buffer.shape[0]) + block_avails = torch.nonzero(self.block_bitmap)[:fetch_size] + self.block_bitmap[block_avails] = False + self.block_buffer = torch.cat([self.block_buffer, block_avails]) + if size > self.block_buffer.shape[0]: + return None + result = self.block_buffer[:size] + self.block_buffer = self.block_buffer[size:] + return result + + def alloc_consecutive(self, size: int) -> Optional[torch.Tensor]: + output, invert_indices, counts = torch.unique_consecutive( + self.block_bitmap, return_counts=True, return_inverse=True + ) + avail_bitmap: torch.Tensor = (counts >= size) & output + avail_indices = avail_bitmap.nonzero().flatten() + if avail_indices.shape[0] == 0: + return None + else: + index = avail_indices[0] + return (invert_indices == index).nonzero().flatten() + + def recycle(self, blocks: torch.Tensor) -> None: + self.block_bitmap[blocks] = True diff --git a/deserve_worker/kvcache/kvcache.py b/deserve_worker/kvcache/kvcache.py new file mode 100644 index 0000000..8e0e4c6 --- /dev/null +++ b/deserve_worker/kvcache/kvcache.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +KV_CACHE_BLOCK_SIZE = 256 + +main_dtype = torch.float16 +main_device = torch.device("cuda") +torch.set_default_dtype(main_dtype) # type: ignore + + +class KVCache(ABC): + @abstractmethod + def renew(self, bsz: int, seqlen: int, start_pos: int) -> bool: + pass + + @abstractmethod + def clear(self) -> None: + pass + + +class KVCacheManager(ABC): + @abstractmethod + def alloc(self, bsz: int, seqlen: int) -> Optional[KVCache]: + pass + + @abstractmethod + def recycle(self, kvcache: KVCache) -> None: + pass + + @abstractmethod + def renew(self, kvcache: KVCache, bsz: int, seqlen: int, start_pos: int) -> bool: + pass diff --git a/deserve_worker/kvcache/packed_kvcache.py b/deserve_worker/kvcache/packed_kvcache.py new file mode 100644 index 0000000..a717aa0 --- /dev/null +++ b/deserve_worker/kvcache/packed_kvcache.py @@ -0,0 +1,98 @@ +from typing import Optional, cast + +import torch + +from deserve_worker.kvcache.block_pool import BlockPool +from deserve_worker.kvcache.kvcache import KVCache, KVCacheManager, main_device + + +def del_tensor(t: torch.Tensor) -> None: + t.detach() + t.grad = None + t.untyped_storage().resize_(0) + + +class PackedKVCacheManager(KVCacheManager): + def __init__(self, block_pool: BlockPool): + self.block_pool = block_pool + self.block_size = block_pool.block_size + + def get_kv_cache_length(self, cur: int, seqlen: int) -> int: + while cur < seqlen: + cur += self.block_size + return cur + + def alloc(self, bsz: int, seqlen: int) -> Optional[KVCache]: + len_token = self.get_kv_cache_length(0, seqlen) + len_block = len_token // self.block_size + total_block = len_block * bsz + blocks = self.block_pool.alloc(total_block) + # the consecutive block table is in shape of [bsz, len_block], which corresponds to [bsz, len_block * block_size, 8, 128] in memory + if blocks is None: + return None + else: + return PackedKVCache(blocks.reshape(bsz, -1), self) + + def recycle(self, kvcache: KVCache) -> None: + kvcache = cast(PackedKVCache, kvcache) + self.block_pool.recycle(kvcache.csct_block_table.flatten()) + kvcache.csct_block_table = torch.empty( + (0, 0), device=main_device, dtype=torch.int32 + ) + + def renew(self, kvcache: KVCache, bsz: int, seqlen: int, start_pos: int) -> bool: + kvcache = cast(PackedKVCache, kvcache) + if ( + start_pos + seqlen + > kvcache.csct_block_table.shape[1] * self.block_pool.block_size + ): + len_token = self.get_kv_cache_length( + kvcache.csct_block_table.shape[1] * self.block_size, start_pos + seqlen + ) + len_block = len_token // self.block_size + total_block = len_block * bsz + blocks = self.block_pool.alloc(total_block) + if blocks is None: + return False + else: + # the original blocks are viewed as [bsz, old_len_block * block_size, 8, 128] + # the new blocks are viewed as [bsz, len_block * block_size, 8, 128] + # we need to copy the old blocks to the new blocks + old_len_block = kvcache.csct_block_table.shape[1] + old_blocks = kvcache.csct_block_table.flatten() + old_block_ks = self.block_pool.block_ks[ + old_blocks[0] : old_blocks[-1] + 1 + ].view(bsz, old_len_block * self.block_size, 8, 128) + new_block_ks = self.block_pool.block_ks[ + blocks[0] : blocks[-1] + 1 + ].view(bsz, len_block * self.block_size, 8, 128) + new_block_ks[:, :start_pos, :, :] = old_block_ks[:, :start_pos, :, :] + + old_block_vs = self.block_pool.block_vs[ + old_blocks[0] : old_blocks[-1] + 1 + ].view(bsz, old_len_block * self.block_size, 8, 128) + new_block_vs = self.block_pool.block_vs[ + blocks[0] : blocks[-1] + 1 + ].view(bsz, len_block * self.block_size, 8, 128) + new_block_vs[:, :start_pos, :, :] = old_block_vs[:, :start_pos, :, :] + + self.block_pool.recycle(old_blocks) + kvcache.csct_block_table = blocks.reshape(bsz, -1) + + return True + + +class PackedKVCache(KVCache): + def __init__( + self, + csct_block_table: torch.Tensor, + manager: PackedKVCacheManager, + ): + self.csct_block_table = csct_block_table # consecutive block table + self.manager = manager + + def renew(self, bsz: int, seqlen: int, start_pos: int) -> bool: + return self.manager.renew(self, bsz, seqlen, start_pos) + + def clear(self) -> None: + return self.manager.recycle(self) diff --git a/deserve_worker/kvcache/paged_kvcache.py b/deserve_worker/kvcache/paged_kvcache.py new file mode 100644 index 0000000..d140ca0 --- /dev/null +++ b/deserve_worker/kvcache/paged_kvcache.py @@ -0,0 +1,95 @@ +import queue +from typing import Optional, cast + +import torch + +from deserve_worker.kvcache.block_pool import BlockPool + +from .kvcache import KVCache, KVCacheManager, main_device, main_dtype + + +class PagedKVCacheManager(KVCacheManager): + def __init__( + self, + block_pool: BlockPool, + ): + self.block_pool = block_pool + + def get_kv_cache_length(self, cur: int, seqlen: int) -> int: + while cur < seqlen: + cur += self.block_pool.block_size + return cur + + def alloc(self, bsz: int, seqlen: int) -> Optional["PagedKVCache"]: + len_token = self.get_kv_cache_length(0, seqlen) + len_block = len_token // self.block_pool.block_size + total_block = len_block * bsz + blocks = self.block_pool.alloc(total_block) + if blocks is None: + return None + else: + return PagedKVCache(blocks.reshape(bsz, -1), self) + + def recycle(self, kvcache: KVCache) -> None: + kvcache = cast(PagedKVCache, kvcache) + self.block_pool.recycle(kvcache.block_table.flatten()) + kvcache.block_table = torch.empty((0, 0), device=main_device, dtype=torch.int32) + + def renew(self, kvcache: KVCache, bsz: int, seqlen: int, start_pos: int) -> bool: + kvcache = cast(PagedKVCache, kvcache) + if ( + start_pos + seqlen + > kvcache.block_table.shape[1] * self.block_pool.block_size + ): + len_block = ( + self.get_kv_cache_length( + kvcache.block_table.shape[1] * self.block_pool.block_size, + start_pos + seqlen, + ) + // self.block_pool.block_size + ) + total_block = (len_block - kvcache.block_table.shape[1]) * bsz + blocks = self.block_pool.alloc(total_block) + if blocks is None: + return False + else: + new_block_table = torch.zeros( + ( + bsz, + len_block, + ), + device=main_device, + dtype=torch.int32, + ) + new_block_table[:, : kvcache.block_table.shape[1]] = ( + kvcache.block_table[:, :] + ) + new_block_table[:, kvcache.block_table.shape[1] :] = blocks.reshape( + bsz, -1 + ) + kvcache.block_table = new_block_table + return True + + +class PagedKVCache(KVCache): + def __init__( + self, + block_table: torch.Tensor, + manager: PagedKVCacheManager, + ): + self.block_table = block_table + self.manager = manager + + def renew( + self, + bsz: int, + seqlen: int, + start_pos: int, + ) -> bool: + return self.manager.renew(self, bsz, seqlen, start_pos) + + def clear(self) -> None: + self.manager.recycle(self) + + def shape(self) -> torch.Size: + return self.block_table.shape diff --git a/deserve_worker/layer_storage.py b/deserve_worker/layer_storage.py new file mode 100644 index 0000000..df0eb5c --- /dev/null +++ b/deserve_worker/layer_storage.py @@ -0,0 +1,272 @@ +import os +import threading +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Optional + +import requests +import torch + +from deserve_worker.task import TaskData +from deserve_worker.trace import ComponentId, LayerId, OpId + +from .kvcache.kvcache import KVCache, KVCacheManager +from .model.llama import ( + ModelArgs, + RMSNorm, + TraceEmbedding, + TraceLinear, + TransformerBlock, +) + +EOS_TOKEN_ID = 128001 # for llama 3 only +STOP_TOKEN_IDS = [128001, 128009] + +llama_2_7b_args = { + "dim": 4096, + "multiple_of": 256, + "n_heads": 32, + "n_layers": 32, + "norm_eps": 1e-06, + "vocab_size": 32000, +} + +llama_2_13b_args = { + "dim": 5120, + "multiple_of": 256, + "n_heads": 40, + "n_layers": 40, + "norm_eps": 1e-05, + "vocab_size": 32000, +} + +llama_2_70b_args = { + "dim": 8192, + "multiple_of": 4096, + "ffn_dim_multiplier": 1.3, + "n_heads": 64, + "n_kv_heads": 8, + "n_layers": 80, + "norm_eps": 1e-05, + "vocab_size": 32000, +} + +llama_3_8b_args = { + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": 128256, + "multiple_of": 1024, + "ffn_dim_multiplier": 1.3, + "norm_eps": 1e-05, + "rope_theta": 500000.0, +} + +llama_3_70b_args = { + "dim": 8192, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "n_heads": 64, + "n_kv_heads": 8, + "n_layers": 80, + "norm_eps": 1e-05, + "vocab_size": 128256, + "rope_theta": 500000.0, +} + + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + +class LayerManager: + def __init__(self, main_device: torch.device): + self.main_device = main_device + self.network_executor = ThreadPoolExecutor(max_workers=80) + self.cache_dir = os.path.expanduser("~/.cache/fleece-worker/models") + self.layer_storages: dict[frozenset[str], LayerStorage] = {} + self.layers: dict[str, torch.nn.Module] = {} + self.mutex = threading.Lock() + + def get_layer_storage(self, layer_names: list[str]) -> "LayerStorage": + frozen_layer_names = frozenset(layer_names) + if frozen_layer_names not in self.layer_storages: + with self.mutex: + self.layer_storages[frozen_layer_names] = LayerStorage( + self.preload_layers(layer_names), self.main_device + ) + return self.layer_storages[frozen_layer_names] + + def fetch_layer(self, full_layer_name: str) -> str: + model_name, layer_name = full_layer_name.split("/") + path = os.path.join(self.cache_dir, model_name, f"{layer_name}.pt") + if not os.path.exists(path): # TODO lock + os.makedirs(os.path.join(self.cache_dir, model_name), exist_ok=True) + with requests.get( + f"https://huggingface.co/colearn/{model_name}/resolve/main/{layer_name}.pt", + stream=True, + ) as r: + r.raise_for_status() + with open(path, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + return path + + def preload_layers(self, full_layer_names: list[str]) -> dict[str, torch.nn.Module]: + threads: list[tuple[str, Future[str]]] = [] + result = {} + for full_layer_name in full_layer_names: + thread = self.network_executor.submit(self.fetch_layer, full_layer_name) + threads.append((full_layer_name, thread)) + for full_layer_name, thread in threads: + path = thread.result() + model_name, layer_name = full_layer_name.split("/") + if model_name.startswith("llama-2-7b"): + model_args = ModelArgs(**llama_2_7b_args) # type: ignore + elif model_name.startswith("llama-2-13b"): + model_args = ModelArgs(**llama_2_13b_args) # type: ignore + elif model_name.startswith("llama-2-70b"): + model_args = ModelArgs(**llama_2_70b_args) # type: ignore + elif model_name.startswith("llama-3-8b"): + model_args = ModelArgs(**llama_3_8b_args) # type: ignore + elif model_name.startswith("llama-3-70b"): + model_args = ModelArgs(**llama_3_70b_args) # type: ignore + else: + raise NotImplementedError("Unknown model") + if layer_name == "tok_embeddings": + l = torch.nn.utils.skip_init( # type: ignore + # torch.nn.Embedding, + TraceEmbedding, + ComponentId("tok_embeddings", "main"), + model_args.vocab_size, + model_args.dim, + ) + elif layer_name.startswith("layer"): + l = TransformerBlock(LayerId(f"{int(layer_name[7:]):02}"), model_args) + elif layer_name == "norm": + l = RMSNorm( + ComponentId("norm", "main"), model_args.dim, eps=model_args.norm_eps + ) + elif layer_name == "output": + l = torch.nn.utils.skip_init( # type: ignore + # torch.nn.Linear, + TraceLinear, + ComponentId("output", "main"), + model_args.dim, + model_args.vocab_size, + ) + else: + raise NotImplementedError("Unknown layers") + l.load_state_dict(torch.load(path, map_location="cpu")) + l.to(self.main_device) + print("Loaded", full_layer_name) + self.layers[full_layer_name] = l + for full_layer_name in full_layer_names: + result[full_layer_name] = self.layers[full_layer_name] + return result + + +class LayerStorage: + def __init__(self, layers: dict[str, torch.nn.Module], main_device: torch.device): + self.main_device = main_device + self.layers = layers + + def clear(self) -> None: + self.layers.clear() + + @torch.inference_mode() + def forward( + self, + h: torch.Tensor, + bsz_list: list[int], + start_pos_list: list[int], + global_freqs_cis: torch.Tensor, + kvcache_list: list[dict[int, KVCache]], + kvcache_manager: KVCacheManager, + traces: Optional[dict[OpId, torch.Tensor]], + ) -> torch.Tensor: + _, seqlen = h.shape[:2] + for full_layer_name in self.layers: + _, layer_name = full_layer_name.split("/") + if layer_name == "tok_embeddings": + h = self.layers[full_layer_name](h, traces) + # h = self.layers[full_layer_name](h) + elif layer_name.startswith("layers."): + layer_id = int(layer_name.split(".")[1]) + cur_kvcache_list = [] + for i, kv_cache in enumerate(kvcache_list): + kv_cache[layer_id].renew(1, seqlen, start_pos_list[i]) + cur_kvcache_list.append(kv_cache[layer_id]) + h = self.layers[full_layer_name]( + h, + bsz_list, + start_pos_list, + global_freqs_cis, + cur_kvcache_list, + kvcache_manager, + traces, + ) + elif layer_name == "norm": + h = self.layers[full_layer_name](h, traces) + elif layer_name == "output": + h = self.layers[full_layer_name](h, traces) + # h = self.layers[full_layer_name](h) + else: + raise NotImplementedError("Unknown layers") + return h + + @torch.inference_mode() + def sample( + self, merged_h: torch.Tensor, task_datas: list[TaskData] + ) -> tuple[list[torch.Tensor], list[str], list[torch.Tensor], list[str], list[str]]: + ongoing_tokens = [] + ongoing_ids = [] + all_tokens = [] + all_ids = [] + done_ids = [] + for ptr, task_data in enumerate(task_datas): + h = merged_h[ptr : ptr + 1] + _, seqlen = h.shape[:2] + task_data.start_pos += seqlen + task_data.round += 1 + sampling_params = task_data.sampling_params + if task_data.start_pos >= sampling_params.max_total_len: + next_token = torch.tensor([[EOS_TOKEN_ID]]) + elif sampling_params.temperature > 0: + probs = torch.softmax(h[:, -1] / sampling_params.temperature, dim=-1) + next_token = sample_top_p(probs, sampling_params.top_p) + next_token = next_token.reshape(1, -1) + else: + next_token = torch.argmax(h[:, -1], dim=-1) + next_token = next_token.reshape(1, -1) + next_token = next_token.to("cpu") + all_ids.append(task_data.task_id) + all_tokens.append(next_token) + if next_token[0][0] in STOP_TOKEN_IDS: + done_ids.append(task_data.task_id) + else: + ongoing_ids.append(task_data.task_id) + ongoing_tokens.append(next_token) + return ongoing_tokens, ongoing_ids, all_tokens, all_ids, done_ids diff --git a/deserve_worker/llm_engine.py b/deserve_worker/llm_engine.py new file mode 100644 index 0000000..333e0fc --- /dev/null +++ b/deserve_worker/llm_engine.py @@ -0,0 +1,210 @@ +import queue +from typing import Optional + +import torch + +from deserve_worker.layer_storage import LayerStorage +from deserve_worker.task import TaskData +from deserve_worker.trace import OpId + +from .command import BatchForward, BatchResult, BatchUpdate, SingleTrace, TraceResult +from .kvcache.kvcache import KVCacheManager, main_device + +EOS_TOKEN_ID = 128001 # for llama 3 only +STOP_TOKEN_IDS = [128001, 128009] + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, enable_flash_attn: bool = False +) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + if enable_flash_attn: + freqs_cis = torch.stack([freqs.cos(), freqs.sin()]) # flash_attn + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +global_freqs_cis = precompute_freqs_cis(128, 8192, 500000.0, False).to(main_device) +flash_global_freqs_cis = precompute_freqs_cis(128, 8192, 500000.0, True).to(main_device) + + +class LLMEngine: + def __init__( + self, + max_total_bsz: int, + sender: queue.Queue[BatchResult | BatchUpdate | TraceResult], + ): + self.max_total_bsz = max_total_bsz + self.sender = sender + self.receiver = queue.Queue[BatchForward | SingleTrace]() + + def run(self) -> None: + q = self.receiver + while True: + commands: list[BatchForward | SingleTrace] = [q.get()] + while True: + try: + new = q.get(block=False) + commands.append(new) + except queue.Empty: + break + traces = [ + command for command in commands if isinstance(command, SingleTrace) + ] + forwards = [ + command for command in commands if isinstance(command, BatchForward) + ] + self.handle_trace(traces) + self.handle_forward(forwards) + + def handle_forward(self, forwards: list[BatchForward]) -> None: + prefill_tasks = [task for task in forwards if task.xs.shape[1] > 1] + decode_tasks = [task for task in forwards if task.xs.shape[1] == 1] + + for task in prefill_tasks: + h = self.step_forward( + task.xs, + task.layer_storage, + task.task_datas, + task.kvcache_manager, + flash_global_freqs_cis, + None, + ) + self.post_forward(h, task) + + print( + f"prefill_tasks: {len(prefill_tasks)}, decode_tasks: {sum(task.xs.shape[0] for task in decode_tasks)}" + ) + + decode_tasks.sort(key=lambda task: task.xs.shape[0], reverse=False) + while len(decode_tasks) > 0: + total_bsz = 0 + todo_tasks = [] + for i in reversed(range(len(decode_tasks))): + cur_bsz = decode_tasks[i].xs.shape[0] + if total_bsz + cur_bsz > self.max_total_bsz: + continue + total_bsz += cur_bsz + todo_tasks.append(decode_tasks.pop(i)) + new_task_datas = [] + for task in todo_tasks: + new_task_datas.extend(task.task_datas) + new_xs = torch.cat([task.xs for task in todo_tasks]) + # TODO: check if all tasks share same information + new_task = BatchForward( + xs=new_xs, + layer_storage=todo_tasks[0].layer_storage, + task_datas=new_task_datas, + need_sample=todo_tasks[0].need_sample, + kvcache_manager=todo_tasks[0].kvcache_manager, + ) + h = self.step_forward( + new_task.xs, + new_task.layer_storage, + new_task.task_datas, + new_task.kvcache_manager, + flash_global_freqs_cis, + None, + ) + self.post_forward(h, new_task) + + def handle_trace(self, tasks: list[SingleTrace]) -> None: + print(f"trace_tasks: {len(tasks)}") + for task in tasks: + traces: dict[OpId, torch.Tensor] = {} + h = self.step_forward( + task.x, + task.layer_storage, + [task.task_data], + task.kvcache_manager, + global_freqs_cis, + traces, + ) + self.post_trace(h, traces, task) + + def step_forward( + self, + h: torch.Tensor, + layer_storage: LayerStorage, + task_datas: list[TaskData], + kvcache_manager: KVCacheManager, + global_freqs_cis: torch.Tensor, + traces: Optional[dict[OpId, torch.Tensor]], + ) -> torch.Tensor: + # we need to check that all tasks share the same layer storage + with torch.inference_mode(): + bsz_list = [1 for _ in range(len(task_datas))] + start_pos_list = [task.start_pos for task in task_datas] + kvcache_list = [task.kvcaches for task in task_datas] + result = layer_storage.forward( + h, + bsz_list, + start_pos_list, + global_freqs_cis, + kvcache_list, + kvcache_manager, + traces, + ) + return result + + def post_forward(self, merged_h: torch.Tensor, tasks: BatchForward) -> None: + if tasks.need_sample: + layer_storage = tasks.layer_storage + ongoing_tokens, ongoing_ids, all_tokens, all_ids, done_ids = ( + layer_storage.sample(merged_h, tasks.task_datas) + ) + if len(ongoing_tokens) > 0: + self.sender.put(BatchResult(torch.cat(ongoing_tokens), ongoing_ids)) + self.sender.put(BatchUpdate(all_tokens, all_ids, done_ids)) + else: + seqlen = tasks.xs.shape[1] + for task in tasks.task_datas: + task.start_pos += seqlen + self.sender.put( + BatchResult(merged_h, [task.task_id for task in tasks.task_datas]) + ) + + def post_trace( + self, h: torch.Tensor, traces: dict[OpId, torch.Tensor], task: SingleTrace + ) -> None: + task_data = task.task_data + if task.need_sample: + layer_storage = task.layer_storage + ongoing_tokens, ongoing_ids, all_tokens, all_ids, done_ids = ( + layer_storage.sample(h, [task_data]) + ) + if len(ongoing_tokens) > 0: + # at most have one + self.sender.put( + TraceResult(torch.cat(ongoing_tokens), ongoing_ids[0], traces) + ) + self.sender.put(BatchUpdate(all_tokens, all_ids, done_ids)) + else: + seqlen = task.x.shape[1] + task_data.start_pos += seqlen + self.sender.put(TraceResult(h, task_data.task_id, traces)) + + def add_batch_forward(self, forwards: BatchForward) -> None: + self.receiver.put(forwards) + + def add_trace(self, trace: SingleTrace) -> None: + self.receiver.put(trace) diff --git a/deserve_worker/model/__init__.py b/deserve_worker/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deserve_worker/model/llama.py b/deserve_worker/model/llama.py new file mode 100644 index 0000000..61f6132 --- /dev/null +++ b/deserve_worker/model/llama.py @@ -0,0 +1,608 @@ +# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement. + +import math +import pickle +from dataclasses import dataclass +from typing import Any, List, Mapping, Optional, Tuple, cast + +import safetensors.torch +import torch +import torch.nn.functional as F +from flash_attn import flash_attn_with_kvcache # type: ignore +from torch import nn + +from deserve_worker.kvcache.paged_kvcache import PagedKVCache, PagedKVCacheManager +from deserve_worker.trace import ComponentId, LayerId, OpId + +from ..kvcache.kvcache import KVCache, KVCacheManager +from ..kvcache.packed_kvcache import PackedKVCache, PackedKVCacheManager + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 500000 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +def trace_op( + traces: Optional[dict[OpId, torch.Tensor]], op_id: OpId, op_value: torch.Tensor +) -> None: + if traces is not None: + traces[op_id] = op_value + + +class RMSNorm(torch.nn.Module): + def __init__(self, component_id: ComponentId, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.component_id = component_id + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + @torch.inference_mode() + def forward( + self, + x: torch.Tensor, + traces: Optional[dict[OpId, torch.Tensor]], + ) -> torch.Tensor: + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + trace_op(traces, self.component_id.with_op("output"), output) + result = output * self.weight + trace_op(traces, self.component_id.with_op("weighted_output"), result) + return result + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + + """ + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """Multi-head attention module.""" + + def __init__(self, component_id: ComponentId, args: ModelArgs): + """ + Initialize the Attention module. + + Args: + args (ModelArgs): Model configuration parameters. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_local_heads (int): Number of local query heads. + n_local_kv_heads (int): Number of local key and value heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (ColumnParallelLinear): Linear transformation for queries. + wk (ColumnParallelLinear): Linear transformation for keys. + wv (ColumnParallelLinear): Linear transformation for values. + wo (RowParallelLinear): Linear transformation for output. + cache_k (torch.Tensor): Cached keys for attention. + cache_v (torch.Tensor): Cached values for attention. + + """ + super().__init__() + self.component_id = component_id + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + # model_parallel_size = fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads # // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads # // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = torch.nn.utils.skip_init( # type: ignore + nn.Linear, + args.dim, + args.n_heads * self.head_dim, + bias=False, + ) + self.wk = torch.nn.utils.skip_init( # type: ignore + nn.Linear, + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wv = torch.nn.utils.skip_init( # type: ignore + nn.Linear, + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wo = torch.nn.utils.skip_init( # type: ignore + nn.Linear, + args.n_heads * self.head_dim, + args.dim, + bias=False, + ) + + @torch.inference_mode() + def forward( + self, + x: torch.Tensor, + bsz_list: List[int], + start_pos_list: List[int], + global_freqs_cis: torch.Tensor, + kvcache_list: list[KVCache], + kvcache_manager: KVCacheManager, + traces: Optional[dict[OpId, torch.Tensor]], + ) -> torch.Tensor: + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position for caching. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + mask (torch.Tensor, optional): Attention mask tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + _, seqlen, _ = x.shape + xq_, xk_, xv_ = self.wq(x), self.wk(x), self.wv(x) + + if isinstance(kvcache_manager, PagedKVCacheManager): + cache_seqlens = [] + for i, bsz in enumerate(bsz_list): + cache_seqlens += [start_pos_list[i]] * bsz + cache_seqlens_tch = torch.tensor( + cache_seqlens, dtype=torch.int32, device=x.device + ) + bsz = cache_seqlens_tch.shape[0] + paged_kv_cache_list = cast(list[PagedKVCache], kvcache_list) + + max_len = max([kvcache.shape()[1] for kvcache in paged_kv_cache_list]) + block_table = torch.zeros( + (bsz, max_len), dtype=torch.int32, device=x.device + ) + start = 0 + for i, bsz in enumerate(bsz_list): + block_table[ + start : start + bsz, : paged_kv_cache_list[i].shape()[1] + ] = paged_kv_cache_list[i].block_table + start += bsz + + bsz = cache_seqlens_tch.shape[0] + xq = xq_.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk_.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv_.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + cos = global_freqs_cis[0].type_as(xq) + sin = global_freqs_cis[1].type_as(xq) + output = flash_attn_with_kvcache( + xq, + kvcache_manager.block_pool.block_ks, + kvcache_manager.block_pool.block_vs, + xk, + xv, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens_tch, + block_table=block_table, + causal=True, + rotary_interleaved=True, + ) + output = output.view(bsz, seqlen, -1) + return self.wo(output) # type: ignore + else: + kvcache_manager = cast(PackedKVCacheManager, kvcache_manager) + start = 0 + output_list = [] + for i, bsz in enumerate(bsz_list): + xq = xq_[start : start + bsz].view( + bsz, seqlen, self.n_local_heads, self.head_dim + ) + xk = xk_[start : start + bsz].view( + bsz, seqlen, self.n_local_kv_heads, self.head_dim + ) + xv = xv_[start : start + bsz].view( + bsz, seqlen, self.n_local_kv_heads, self.head_dim + ) + trace_op(traces, self.component_id.with_op("xq"), xq) + trace_op(traces, self.component_id.with_op("xk"), xk) + trace_op(traces, self.component_id.with_op("xv"), xv) + start += bsz + + start_pos = start_pos_list[i] + # remember consecutive block table [bsz, len] corresponds to memory [bsz, len * block_size, 8, 128] + kv_cache: PackedKVCache = cast(PackedKVCache, kvcache_list[i]) + csct_block_table = kv_cache.csct_block_table.flatten() + block_bsz, block_len = kv_cache.csct_block_table.shape[:2] + cache_k = kvcache_manager.block_pool.block_ks[ + csct_block_table[0] : csct_block_table[-1] + 1 + ].view(block_bsz, block_len * kvcache_manager.block_size, 8, 128) + cache_v = kvcache_manager.block_pool.block_vs[ + csct_block_table[0] : csct_block_table[-1] + 1 + ].view(block_bsz, block_len * kvcache_manager.block_size, 8, 128) + + freqs_cis = global_freqs_cis[start_pos : start_pos + seqlen] + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=x.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(x) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + trace_op(traces, self.component_id.with_op("xq_rotary"), xq) + trace_op(traces, self.component_id.with_op("xk_rotary"), xk) + + cache_k[:bsz, start_pos : start_pos + seqlen] = xk + cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = cache_k[:bsz, : start_pos + seqlen] + values = cache_v[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv( + keys, self.n_rep + ) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv( + values, self.n_rep + ) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + if mask is not None: + scores = ( + scores + mask + ) # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + trace_op(traces, self.component_id.with_op("scores"), scores) + output = torch.matmul( + scores, values + ) # (bs, n_local_heads, seqlen, head_dim) + + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + trace_op(traces, self.component_id.with_op("output"), output) + output_list.append(output) + output = torch.cat([x for x in output_list]) + result = self.wo(output) + trace_op(traces, self.component_id.with_op("weighted_output"), result) + return result # type: ignore + + +class FeedForward(nn.Module): + def __init__( + self, + component_id: ComponentId, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + self.component_id = component_id + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = torch.nn.utils.skip_init( # type: ignore + nn.Linear, + dim, + hidden_dim, + bias=False, + ) + self.w2 = torch.nn.utils.skip_init( # type: ignore + nn.Linear, + hidden_dim, + dim, + bias=False, + ) + self.w3 = torch.nn.utils.skip_init( # type: ignore + nn.Linear, + dim, + hidden_dim, + bias=False, + ) + + @torch.inference_mode() + def forward( + self, x: torch.Tensor, traces: Optional[dict[OpId, torch.Tensor]] + ) -> torch.Tensor: + w1 = F.silu(self.w1(x)) + w3 = self.w3(x) + w2 = self.w2(w1 * w3) + trace_op(traces, self.component_id.with_op("w1"), w1) + trace_op(traces, self.component_id.with_op("w3"), w3) + trace_op(traces, self.component_id.with_op("w2"), w2) + + return w2 # type: ignore + + +class TraceLinear(nn.Module): + def __init__( + self, + component_id: ComponentId, + in_features: int, + out_features: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.component_id = component_id + self.linear = nn.Linear( + in_features, out_features, bias=False, device=device, dtype=dtype + ) + + @torch.inference_mode() + def forward( + self, x: torch.Tensor, traces: Optional[dict[OpId, torch.Tensor]] + ) -> torch.Tensor: + out = self.linear(x) + trace_op(traces, self.component_id.with_op("output"), out) + return out # type: ignore + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, + ) -> torch.nn.modules.module._IncompatibleKeys: + return self.linear.load_state_dict(state_dict, strict, assign) # type: ignore + + +class TraceEmbedding(nn.Module): + def __init__( + self, + component_id: ComponentId, + num_embeddings: int, + embedding_dim: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.component_id = component_id + self.embedding = nn.Embedding( + num_embeddings, embedding_dim, device=device, dtype=dtype + ) + + @torch.inference_mode() + def forward( + self, x: torch.Tensor, traces: Optional[dict[OpId, torch.Tensor]] + ) -> torch.Tensor: + out = self.embedding(x) + trace_op(traces, self.component_id.with_op("output"), out) + return out # type: ignore + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, + ) -> torch.nn.modules.module._IncompatibleKeys: + return self.embedding.load_state_dict(state_dict, strict, assign) # type: ignore + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: LayerId, args: ModelArgs): + """ + Initialize a TransformerBlock. + + Args: + layer_id (int): Identifier for the layer. + args (ModelArgs): Model configuration parameters. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + super().__init__() + self.layer_id = layer_id + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(layer_id.with_component("attention"), args) + self.feed_forward = FeedForward( + layer_id.with_component("feed_forward"), + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.attention_norm = RMSNorm( + layer_id.with_component("attention_norm"), args.dim, eps=args.norm_eps + ) + self.ffn_norm = RMSNorm( + layer_id.with_component("ffn_norm"), args.dim, eps=args.norm_eps + ) + + @torch.inference_mode() + def forward( + self, + x: torch.Tensor, + bsz_list: List[int], + start_pos_list: List[int], + global_freqs_cis: torch.Tensor, + kvcache_list: list[KVCache], + kvcache_manager: KVCacheManager, + traces: Optional[dict[OpId, torch.Tensor]], + ) -> torch.Tensor: + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position for attention caching. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention.forward( + self.attention_norm(x, traces), + bsz_list, + start_pos_list, + global_freqs_cis, + kvcache_list, + kvcache_manager, + traces, + ) + trace_op(traces, self.layer_id.with_component("attention").with_op("res"), h) + out = h + self.feed_forward.forward(self.ffn_norm(h, traces), traces) + trace_op( + traces, self.layer_id.with_component("feed_forward").with_op("res"), out + ) + return out + + +def dumps(tensors: dict[str, torch.Tensor], metadata: dict[str, Any]) -> bytes: + """ + Dump tensors and metadata into bytes + """ + + metadata_bytes = pickle.dumps(metadata) + tensors_bytes = safetensors.torch.save(tensors) + return ( + len(metadata_bytes).to_bytes(4, byteorder="big") + + metadata_bytes + + tensors_bytes + ) + + +def loads(b: bytes) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + """ + Load tensors and metadata from bytes + """ + + metadata_length = int.from_bytes(b[:4], byteorder="big") + metadata = pickle.loads(b[4 : 4 + metadata_length]) + tensors = safetensors.torch.load(b[4 + metadata_length :]) + return tensors, metadata diff --git a/deserve_worker/py.typed b/deserve_worker/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/deserve_worker/pyproject.toml b/deserve_worker/pyproject.toml new file mode 100644 index 0000000..b9621c9 --- /dev/null +++ b/deserve_worker/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "deserve_worker" +version = "0.0.1" +authors = [ + { name="Example Author", email="author@example.com" }, +] +description = "Deserve Worker" +requires-python = ">=3.11" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] diff --git a/deserve_worker/task.py b/deserve_worker/task.py new file mode 100644 index 0000000..718da07 --- /dev/null +++ b/deserve_worker/task.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass + +from pydantic import BaseModel + +from .kvcache.kvcache import KVCache + + +class PlanStep(BaseModel): + worker_id: str + worker_url: str + layers: list[str] + + +class SamplingParams(BaseModel): + temperature: float + top_p: float + max_total_len: int + + +class TaskInfo(BaseModel): + task_id: str + plan: list[PlanStep] + round: int + sampling_params: SamplingParams + + +@dataclass +class TaskData: + task_id: str + start_pos: int + plan: list[PlanStep] + round: int + sampling_params: SamplingParams + kvcaches: dict[int, KVCache] + """ + When flash attention is enabled, we use paged attention, otherwise the standard attention is adopted. + """ diff --git a/deserve_worker/trace.py b/deserve_worker/trace.py new file mode 100644 index 0000000..2f0c039 --- /dev/null +++ b/deserve_worker/trace.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass + + +@dataclass +class LayerId: + layer: str + + def with_component(self, component: str) -> "ComponentId": + return ComponentId(self.layer, component) + + def __str__(self) -> str: + return self.layer + + @staticmethod + def from_str(s: str) -> "LayerId": + return LayerId(s) + + +@dataclass +class ComponentId: + layer: str + component: str + + def with_op(self, op: str) -> "OpId": + return OpId(self.layer, self.component, op) + + def __str__(self) -> str: + return f"{self.layer}.{self.component}" + + @staticmethod + def from_str(s: str) -> "ComponentId": + layer, component = s.split(".") + return ComponentId(layer, component) + + +@dataclass +class OpId: + layer: str + component: str + op: str + + def __hash__(self) -> int: + return hash((self.layer, self.component, self.op)) + + def __str__(self) -> str: + return f"{self.layer}.{self.component}.{self.op}" + + @staticmethod + def from_str(s: str) -> "OpId": + layer, component, op = s.split(".") + return OpId(layer, component, op) diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py new file mode 100644 index 0000000..6cb1914 --- /dev/null +++ b/deserve_worker/worker.py @@ -0,0 +1,320 @@ +import queue +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, cast + +import requests +import torch +from transformers import AutoTokenizer # type: ignore + +from deserve_worker.kvcache.block_pool import BlockPool +from deserve_worker.kvcache.packed_kvcache import PackedKVCacheManager # type: ignore + +from .command import BatchForward, BatchResult, BatchUpdate, SingleTrace, TraceResult +from .kvcache.kvcache import KVCache, main_device, main_dtype +from .kvcache.paged_kvcache import PagedKVCacheManager +from .layer_storage import LayerManager +from .llm_engine import LLMEngine +from .model.llama import dumps +from .task import PlanStep, SamplingParams, TaskData, TaskInfo + +EOS_TOKEN_ID = 128001 # for llama 3 only + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + + +class Worker: + def __init__( + self, worker_id: str, worker_url: str, max_total_bsz: int, controller_url: str + ): + self.worker_id = worker_id + self.worker_url = worker_url + self.controller_url = controller_url + self.task_datas: dict[str, TaskData] = {} + self.relay_queue = queue.Queue[BatchResult | BatchUpdate | TraceResult]() + self.llm_engine = LLMEngine(max_total_bsz, self.relay_queue) + self.layer_manager = LayerManager(main_device) + self.block_pool = BlockPool(11600, 256, main_device, main_dtype) + # TODO: in future, different cache manager could allocate on same memory + self.paged_kvcache_manager = PagedKVCacheManager(self.block_pool) + self.packed_kvcache_manager = PackedKVCacheManager(self.block_pool) + self.network_executor = ThreadPoolExecutor(max_workers=max_total_bsz) + + threading.Thread(target=self.llm_engine.run, daemon=True).start() + threading.Thread(target=self.relay, daemon=True).start() + threading.Thread(target=self.heartbeat, daemon=True).start() + + def locate_in_plan(self, plan: list[PlanStep]) -> Optional[int]: + return next( + (i for i, worker in enumerate(plan) if worker.worker_id == self.worker_id), + None, + ) + + def init_forward_task_data( + self, x: torch.Tensor, index: int, task_info: TaskInfo + ) -> TaskData: + if task_info.round == 0: + kvcaches = {} + for full_layer_name in task_info.plan[index].layers: + _, layer_name = full_layer_name.split("/") + if layer_name.startswith("layers."): + layer_id = int(layer_name.split(".")[1]) + kvcaches[layer_id] = self.paged_kvcache_manager.alloc( + x.shape[0], x.shape[1] + ) + + # TODO: need double check whether request is repeated + task_data = TaskData( + task_id=task_info.task_id, + start_pos=0, + plan=task_info.plan, + round=0, + sampling_params=task_info.sampling_params, + kvcaches=cast(dict[int, KVCache], kvcaches), + ) + self.task_datas[task_info.task_id] = task_data + else: + task_data = self.task_datas[task_info.task_id] + task_data.round = task_info.round + + return task_data + + def init_trace_task_data( + self, x: torch.Tensor, index: int, task_info: TaskInfo + ) -> TaskData: + if task_info.round == 0: + kvcaches = {} + for full_layer_name in task_info.plan[index].layers: + _, layer_name = full_layer_name.split("/") + if layer_name.startswith("layers."): + layer_id = int(layer_name.split(".")[1]) + kvcaches[layer_id] = self.packed_kvcache_manager.alloc( + x.shape[0], x.shape[1] + ) + + task_data = TaskData( + task_id=task_info.task_id, + start_pos=0, + plan=task_info.plan, + round=0, + sampling_params=task_info.sampling_params, + kvcaches=cast(dict[int, KVCache], kvcaches), + ) + self.task_datas[task_info.task_id] = task_data + else: + task_data = self.task_datas[task_info.task_id] + task_data.round = task_info.round + + return task_data + + def batch_forward( + self, + xs: torch.Tensor, + task_infos: list[TaskInfo], + ) -> None: + plan = task_infos[0].plan + index = self.locate_in_plan(plan) + assert index is not None + layer_storage = self.layer_manager.get_layer_storage( + task_infos[0].plan[index].layers + ) + task_datas = [ + self.init_forward_task_data(xs, index, task_info) + for task_info in task_infos + ] + self.llm_engine.add_batch_forward( + BatchForward( + xs=xs.to(main_device), + layer_storage=layer_storage, + task_datas=task_datas, + need_sample=(index == len(plan) - 1), + kvcache_manager=self.paged_kvcache_manager, + ) + ) + + def forward( + self, + x: torch.Tensor, + task_id: str, + round: int, + plan: list[PlanStep], + sampling_params: SamplingParams, + ) -> None: + index = self.locate_in_plan(plan) + if index is None: + return None + + layer_storage = self.layer_manager.get_layer_storage(plan[index].layers) + forward = BatchForward( + xs=x.to(main_device), + layer_storage=layer_storage, + task_datas=[ + self.init_forward_task_data( + x, + index, + TaskInfo( + task_id=task_id, + plan=plan, + round=round, + sampling_params=sampling_params, + ), + ) + ], + need_sample=(index == len(plan) - 1), + kvcache_manager=self.paged_kvcache_manager, + ) + self.llm_engine.add_batch_forward(forward) + + def trace( + self, + x: torch.Tensor, + task_id: str, + round: int, + plan: list[PlanStep], + sampling_params: SamplingParams, + ) -> None: + index = self.locate_in_plan(plan) + if index is None: + return None + + layer_storage = self.layer_manager.get_layer_storage(plan[index].layers) + trace = SingleTrace( + x=x.to(main_device), + layer_storage=layer_storage, + task_data=self.init_trace_task_data( + x, + index, + TaskInfo( + task_id=task_id, + plan=plan, + round=round, + sampling_params=sampling_params, + ), + ), + kvcache_manager=self.packed_kvcache_manager, + need_sample=(index == len(plan) - 1), + ) + self.llm_engine.add_trace(trace) + + def relay(self) -> None: + q = self.relay_queue + while True: + result = q.get() + if isinstance(result, BatchResult): + task_id = result.task_ids[0] + task_info = self.task_datas[task_id] + plan = task_info.plan + index = self.locate_in_plan(plan) + assert index is not None + next_index = (index + 1) % len(plan) + next_worker_url = plan[next_index].worker_url + data = dumps( + {"x": result.xs}, + { + "task_infos": [ + { + "task_id": task_id, + "round": self.task_datas[task_id].round, + "plan": plan, + "sampling_params": self.task_datas[ + task_id + ].sampling_params, + } + for task_id in result.task_ids + ], + }, + ) + self.network_executor.submit( + requests.post, + f"{next_worker_url}/batch_forward", + data=data, + ) + elif isinstance(result, BatchUpdate): + updated_tasks = [] + for tokens, task_id in zip(result.tokens, result.task_ids): + updated_tasks.append( + { + "task_id": task_id, + "output_tokens": tokens.tolist(), + } + ) + self.network_executor.submit( + requests.post, + f"{self.controller_url}/update_tasks", + json=updated_tasks, + ) + for task_id in result.cancel_ids: + self.cancel(task_id, None, self.task_datas[task_id].plan) + elif isinstance(result, TraceResult): + task_id = result.task_id + task_info = self.task_datas[task_id] + plan = task_info.plan + index = self.locate_in_plan(plan) + assert index is not None + next_index = (index + 1) % len(plan) + if next_index != 0: + next_worker_url = plan[next_index].worker_url + data = dumps( + {"x": result.x}, + { + "task_id": task_id, + "round": self.task_datas[task_id].round, + "plan": plan, + "sampling_params": self.task_datas[task_id].sampling_params, + }, + ) + self.network_executor.submit( + requests.post, + f"{next_worker_url}/trace", + data=data, + ) + data = dumps( + {str(key): value for key, value in result.trace.items()}, + { + "task_id": task_id, + }, + ) + self.network_executor.submit( + requests.post, + f"{self.controller_url}/update_traces", + data=data, + ) + + def cancel( + self, task_id: str, start_index: Optional[int], plan: list[PlanStep] + ) -> None: + index = next( + (i for i, x in enumerate(plan) if x.worker_id == self.worker_id), None + ) + if index is None: + return + if start_index is None: + start_index = index + + task_info = self.task_datas.pop(task_id, None) + if task_info is not None: + for kvcache in task_info.kvcaches.values(): + kvcache.clear() + next_index = (index + 1) % len(plan) + if next_index != start_index: + requests.post( + f"{plan[next_index].worker_url}/cancel", + json={ + "task_id": task_id, + "start_index": index, + "plan": [step.model_dump() for step in plan], + }, + ) + + def heartbeat(self): + while True: + self.network_executor.submit( + requests.post, + f"{self.controller_url}/heartbeat", + json={ + "worker_id": self.worker_id, + "worker_url": self.worker_url, + }, + ) + time.sleep(1) diff --git a/deserve_worker/worker_api.py b/deserve_worker/worker_api.py new file mode 100644 index 0000000..19aa6eb --- /dev/null +++ b/deserve_worker/worker_api.py @@ -0,0 +1,91 @@ +import argparse +import traceback +from concurrent.futures import ThreadPoolExecutor + +import uvicorn +from fastapi import FastAPI, Request +from pydantic import BaseModel + +from .model.llama import loads +from .task import PlanStep, SamplingParams, TaskInfo +from .worker import Worker + +app = FastAPI() +worker: Worker +runtime_executor = ThreadPoolExecutor(max_workers=96) + + +@app.post("/batch_forward") +async def batch_forward(request: Request) -> str: + body = await request.body() + tensors, metadata = loads(body) + runtime_executor.submit( + worker.batch_forward, + tensors["x"], + [TaskInfo.model_validate(task_info) for task_info in metadata["task_infos"]], + ) + return "ok" + + +@app.post("/forward") +async def forward(request: Request) -> str: + try: + body = await request.body() + tensors, metadata = loads(body) + runtime_executor.submit( + worker.forward, + tensors["x"], + metadata["task_id"], + metadata["round"], + [PlanStep.model_validate(step) for step in metadata["plan"]], + SamplingParams.model_validate(metadata["sampling_params"]), + ) + except Exception as e: + traceback.print_exc() + return "ok" + + +@app.post("/trace") +async def trace(request: Request) -> str: + try: + body = await request.body() + tensors, metadata = loads(body) + runtime_executor.submit( + worker.trace, + tensors["x"], + metadata["task_id"], + metadata["round"], + [PlanStep.model_validate(step) for step in metadata["plan"]], + SamplingParams.model_validate(metadata["sampling_params"]), + ) + except Exception as e: + traceback.print_exc() + return "ok" + + +class CancelRequest(BaseModel): + task_id: str + start_index: int + plan: list[PlanStep] + + +@app.post("/cancel") +async def cancel(request: CancelRequest) -> str: + runtime_executor.submit( + worker.cancel, request.task_id, request.start_index, request.plan + ) + return "ok" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("id", type=str) + parser.add_argument("--batch-size", type=int) + parser.add_argument("--port", type=int) + parser.add_argument("--controller-url", type=str) + args = parser.parse_args() + + worker = Worker( + args.id, f"http://localhost:{args.port}", args.batch_size, args.controller_url + ) + uvicorn.run(app, host="127.0.0.1", port=args.port)