@@ -1181,6 +1181,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
11811181 size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
11821182 size += sizeof (int64_t ) * (1 +op->src [0 ]->ne [2 ]) * op->src [1 ]->ne [2 ];
11831183 return true ;
1184+ case GGML_OP_GET_ROWS:
1185+ size = 0 ; // GET_ROWS (standard and repacked) doesn't need a work buffer
1186+ return true ;
11841187 default :
11851188 // GGML_ABORT("fatal error");
11861189 break ;
@@ -1196,6 +1199,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
11961199 case GGML_OP_MUL_MAT_ID:
11971200 forward_mul_mat_id (params, op);
11981201 return true ;
1202+ case GGML_OP_GET_ROWS:
1203+ forward_get_rows (params, op);
1204+ return true ;
11991205 default :
12001206 // GGML_ABORT("fatal error");
12011207 break ;
@@ -1401,6 +1407,132 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
14011407#undef MMID_MATRIX_ROW
14021408 }
14031409
1410+ void forward_get_rows (const ggml_compute_params * params,
1411+ ggml_tensor * dst) {
1412+ const ggml_tensor * src0 = dst->src [0 ];
1413+
1414+ switch (src0->type ) {
1415+ case GGML_TYPE_Q4_0: {
1416+ ggml_compute_forward_get_rows_q4_0x8 (params, dst);
1417+ } break ;
1418+ default :
1419+ GGML_ABORT (" fatal error" );
1420+ break ;
1421+ }
1422+ }
1423+
1424+ static void ggml_compute_forward_get_rows_q4_0x8 (
1425+ const ggml_compute_params * params,
1426+ ggml_tensor * dst) {
1427+ const ggml_tensor * src0 = dst->src [0 ];
1428+ const ggml_tensor * src1 = dst->src [1 ];
1429+
1430+ GGML_TENSOR_BINARY_OP_LOCALS
1431+
1432+ const int64_t nc = ne00;
1433+ const int64_t nr = ggml_nelements (src1);
1434+
1435+ assert (ne0 == nc);
1436+ assert (ne02 == ne11);
1437+ assert (nb00 == ggml_type_size (src0->type ));
1438+ assert (ggml_nrows (dst) == nr);
1439+
1440+ const int ith = params->ith ;
1441+ const int nth = params->nth ;
1442+
1443+ // rows per thread
1444+ const int dr = (nr + nth - 1 ) / nth;
1445+
1446+ // row range for this thread
1447+ const int ir0 = dr * ith;
1448+ const int ir1 = MIN (ir0 + dr, nr);
1449+
1450+ constexpr int nrows_interleaved = 8 ;
1451+ const size_t sizeof_one_repacked_block = sizeof (block_q4_0x8);
1452+
1453+ const int num_repacked_blocks_per_row_width = nc / QK4_0;
1454+
1455+ const size_t stride_between_actual_row_groups = num_repacked_blocks_per_row_width * sizeof_one_repacked_block;
1456+
1457+ for (int64_t i = ir0; i < ir1; ++i) {
1458+ const int64_t i12 = i / (ne11 * ne10);
1459+ const int64_t i11 = (i - i12 * ne11 * ne10) / ne10;
1460+ const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10);
1461+ const int64_t i01 = *(int32_t *)((char *)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); // original logical row
1462+
1463+ GGML_ASSERT (i01 >= 0 && i01 < ne01);
1464+
1465+ int row_group_idx = i01 / nrows_interleaved;
1466+ const int row_idx_in_group = i01 % nrows_interleaved;
1467+
1468+ const char * base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03;
1469+
1470+ // Pointer to the first block_q4_0x8 of the identified row_group_idx
1471+ const block_q4_0x8 * p_first_repacked_block_of_group_x8 = (const block_q4_0x8 *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);
1472+
1473+ dequantize_row_q4_0x8 (
1474+ p_first_repacked_block_of_group_x8,
1475+ (float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
1476+ }
1477+ }
1478+
1479+ /* *
1480+ * Dequantizes a single logical row from data repacked with quant interleaving.
1481+ *
1482+ * @param p_repacked_group_column_blocks Pointer to the start of 'block_q4_0x8' for the row group.
1483+ * @param y Output buffer for the dequantized float values.
1484+ * @param k Total number of elements (columns) in the logical row.
1485+ * @param row_idx_in_group Index (0-7) of the logical row to dequantize.
1486+ */
1487+ static void dequantize_row_q4_0x8 (
1488+ const block_q4_0x8 * GGML_RESTRICT p_repacked_group_column_blocks,
1489+ float * GGML_RESTRICT y,
1490+ int64_t k,
1491+ int row_idx_in_group) {
1492+ const int GGML_Q4_0_X8_INTERLEAVE_SIZE = 8 ;
1493+ assert (k % QK4_0 == 0 );
1494+ assert (row_idx_in_group >= 0 && row_idx_in_group < GGML_Q4_0_X8_INTERLEAVE_SIZE);
1495+
1496+ const int nb = k / QK4_0;
1497+ const int bytes_for_half_elements = (QK4_0 / 2 ) / 2 ;
1498+
1499+ const int offset_to_second_half_data = bytes_for_half_elements * GGML_Q4_0_X8_INTERLEAVE_SIZE;
1500+ const uint64_t xor_mask = 0x8888888888888888ULL ;
1501+ const int qk4_0_half_elements = QK4_0 / 2 ;
1502+
1503+ for (int i = 0 ; i < nb; ++i) {
1504+ const block_q4_0x8 * current_column_repacked_block = &p_repacked_group_column_blocks[i];
1505+ const float d_val = GGML_FP16_TO_FP32 (current_column_repacked_block->d [row_idx_in_group]);
1506+ float * y_curr = y + i * QK4_0;
1507+
1508+ const int8_t * qs_first_half_repacked_ptr = &(current_column_repacked_block->qs [row_idx_in_group * bytes_for_half_elements]);
1509+
1510+ uint64_t first_half_chunk_u64;
1511+ memcpy (&first_half_chunk_u64, qs_first_half_repacked_ptr, sizeof (uint64_t ));
1512+ first_half_chunk_u64 ^= xor_mask; // Reverse the XOR
1513+ const uint8_t * original_qs_first_half_bytes = (const uint8_t *)&first_half_chunk_u64;
1514+
1515+ const int8_t * qs_second_half_repacked_ptr = &(current_column_repacked_block->qs [offset_to_second_half_data + (row_idx_in_group * bytes_for_half_elements)]);
1516+
1517+ uint64_t second_half_chunk_u64;
1518+ memcpy (&second_half_chunk_u64, qs_second_half_repacked_ptr, sizeof (uint64_t ));
1519+ second_half_chunk_u64 ^= xor_mask; // Reverse the XOR
1520+ const uint8_t * original_qs_second_half_bytes = (const uint8_t *)&second_half_chunk_u64;
1521+
1522+ // dequantizing all QK4_0's for this block.
1523+ for (int j = 0 ; j < bytes_for_half_elements; ++j) {
1524+ const uint8_t quant_byte_first = original_qs_first_half_bytes[j];
1525+ y_curr[j] = ((quant_byte_first & 0x0F ) - 8 ) * d_val;
1526+ y_curr[j + qk4_0_half_elements] = ((quant_byte_first >> 4 ) - 8 ) * d_val;
1527+
1528+ const uint8_t quant_byte_second = original_qs_second_half_bytes[j];
1529+ const int out_idx_base_second_half = j + bytes_for_half_elements; // Offset for the second set of low nibbles
1530+ y_curr[out_idx_base_second_half] = ((quant_byte_second & 0x0F ) - 8 ) * d_val;
1531+ y_curr[out_idx_base_second_half + qk4_0_half_elements] = ((quant_byte_second >> 4 ) - 8 ) * d_val;
1532+ }
1533+ }
1534+ }
1535+
14041536 int repack (struct ggml_tensor * t, const void * data, size_t data_size) override {
14051537 GGML_LOG_DEBUG (" %s: repack tensor %s with %s_%dx%d\n " , __func__, t->name , ggml_type_name (t->type ),
14061538 (int ) NB_COLS, (int ) INTER_SIZE);
@@ -1533,12 +1665,23 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
15331665 // if (op->src[1]->type == GGML_TYPE_Q8_0) {
15341666 // return true;
15351667 // }
1668+ } else if (op->op == GGML_OP_GET_ROWS
1669+ && op->src [0 ]->buffer
1670+ && (ggml_n_dims (op->src [0 ]) == 2 )
1671+ && op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type ()
1672+ && ggml_repack_get_optimal_repack_type (op->src [0 ])) {
1673+ if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
1674+ return false ;
1675+ }
1676+ if (op->src [0 ]->type == GGML_TYPE_Q4_0) {
1677+ return true ;
1678+ }
15361679 }
15371680 return false ;
15381681 }
15391682
15401683 ggml::cpu::tensor_traits * get_tensor_traits (const struct ggml_tensor * op) override {
1541- if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
1684+ if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID || op-> op == GGML_OP_GET_ROWS ) {
15421685 if (op->src [0 ]->buffer && op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type ()) {
15431686 return (ggml::cpu::tensor_traits *) op->src [0 ]->extra ;
15441687 }
0 commit comments