@@ -11,15 +11,15 @@ struct llama_ubatch {
1111 bool equal_seqs;
1212 // TODO: whole_seqs for embeddings?
1313
14- uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
14+ uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
1515 uint32_t n_seq_tokens; // tokens per sequence
1616 uint32_t n_seqs;
1717
1818 llama_token * token; // [n_tokens]
1919 float * embd; // [n_embd, n_tokens]
2020 llama_pos * pos; // [n_tokens]
21- int32_t * n_seq_id; // [n_seqs]
22- llama_seq_id ** seq_id; // [n_seqs]
21+ int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
22+ llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
2323 int8_t * output; // [n_tokens]
2424};
2525
@@ -49,13 +49,18 @@ struct llama_sbatch {
4949
5050 const llama_batch * batch = nullptr ;
5151
52- // buffers for the ubatch
53- std::vector<llama_token> ubatch_token;
54- std::vector<float > ubatch_embd;
55- std::vector<llama_pos> ubatch_pos;
56- std::vector<int32_t > ubatch_n_seq_id;
57- std::vector<llama_seq_id *> ubatch_seq_id;
58- std::vector<int8_t > ubatch_output;
52+ // buffers for the ubatches
53+ // TODO: very hacky, this needs a complete rework
54+ struct ubatch_data {
55+ std::vector<llama_token> token;
56+ std::vector<float > embd;
57+ std::vector<llama_pos> pos;
58+ std::vector<int32_t > n_seq_id;
59+ std::vector<llama_seq_id *> seq_id;
60+ std::vector<int8_t > output;
61+ };
62+
63+ std::vector<ubatch_data> udatas;
5964
6065 llama_ubatch reserve_ubatch (size_t n_ubatch, bool has_embd = false );
6166
0 commit comments