Skip to content

Commit 72c447a

Browse files
committed
refactor(rpc): input size validation in graph_compute
Removes detailed, step-by-step size calculations and overflow checks in favor of simpler direct comparisons, assuming 64-bit overflow is unlikely. Signed-off-by: Ville Vesilehto <[email protected]>
1 parent e38c4d7 commit 72c447a

File tree

1 file changed

+10
-64
lines changed

1 file changed

+10
-64
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

+10-64
Original file line numberDiff line numberDiff line change
@@ -1306,81 +1306,28 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
13061306

13071307
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
13081308
// serialization format:
1309-
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t)) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
1310-
1311-
// Perform robust size checks with overflow protection
1312-
const size_t min_header_size = sizeof(uint32_t);
1313-
if (input.size() < min_header_size) {
1314-
GGML_LOG_ERROR("[%s] input message too small for n_nodes header: %zu bytes\n", __func__, input.size());
1315-
response.result = GGML_STATUS_FAILED;
1309+
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
1310+
if (input.size() < sizeof(uint32_t)) {
13161311
return false;
13171312
}
13181313
uint32_t n_nodes;
13191314
memcpy(&n_nodes, input.data(), sizeof(n_nodes));
1320-
1321-
// Calculate required size for nodes array
1322-
size_t nodes_array_bytes = n_nodes * sizeof(uint64_t);
1323-
1324-
// Calculate required size up to n_tensors field safely
1325-
size_t required_size_before_tensors = min_header_size;
1326-
1327-
// Check for overflow before adding nodes_array_bytes
1328-
if (SIZE_MAX - required_size_before_tensors < nodes_array_bytes) {
1329-
GGML_LOG_ERROR("[%s] integer overflow calculating size before tensors step 1: n_nodes=%u\n", __func__, n_nodes);
1330-
response.result = GGML_STATUS_FAILED; // Use correct status
1315+
if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
13311316
return false;
13321317
}
1333-
required_size_before_tensors += nodes_array_bytes;
1334-
1335-
const size_t n_tensors_field_size = sizeof(uint32_t);
1336-
// Check for overflow before adding n_tensors_field_size
1337-
if (SIZE_MAX - required_size_before_tensors < n_tensors_field_size) {
1338-
GGML_LOG_ERROR("[%s] integer overflow calculating size before tensors step 2: n_nodes=%u\n", __func__, n_nodes);
1339-
response.result = GGML_STATUS_FAILED; // Use correct status
1340-
return false;
1341-
}
1342-
required_size_before_tensors += n_tensors_field_size;
1343-
1344-
if (input.size() < required_size_before_tensors) {
1345-
GGML_LOG_ERROR("[%s] input message too small for nodes array or n_tensors header: %zu bytes, needed %zu\n", __func__, input.size(), required_size_before_tensors);
1346-
response.result = GGML_STATUS_FAILED; // Use correct status
1347-
return false;
1348-
}
1349-
1350-
// Read n_tensors
1351-
const uint64_t * nodes_ptr = (const uint64_t *)(input.data() + sizeof(n_nodes));
1318+
const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
13521319
uint32_t n_tensors;
1353-
memcpy(&n_tensors, input.data() + min_header_size + nodes_array_bytes, sizeof(n_tensors));
1354-
1355-
// Calculate required size for tensors array
1356-
size_t tensors_array_bytes = n_tensors * sizeof(rpc_tensor);
1357-
1358-
// Calculate total required size safely
1359-
size_t required_total_size = required_size_before_tensors;
1360-
1361-
// Check for overflow before adding tensors_array_bytes
1362-
if (SIZE_MAX - required_total_size < tensors_array_bytes) {
1363-
GGML_LOG_ERROR("[%s] integer overflow calculating total required size: n_nodes=%u, n_tensors=%u\n", __func__, n_nodes, n_tensors);
1364-
response.result = GGML_STATUS_FAILED;
1320+
memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
1321+
if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
13651322
return false;
13661323
}
1367-
required_total_size += tensors_array_bytes;
1368-
1369-
if (input.size() < required_total_size) {
1370-
GGML_LOG_ERROR("[%s] input message too small for tensors array: %zu bytes, needed %zu\n", __func__, input.size(), required_total_size);
1371-
response.result = GGML_STATUS_FAILED;
1372-
return false;
1373-
}
1374-
1375-
// Pointers are now safe to use based on size checks
1376-
const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + required_size_before_tensors);
1324+
const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
13771325
GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
13781326

1379-
// Estimate buffer size for context
1380-
size_t ctx_buf_size = ggml_tensor_overhead()*((size_t)n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1327+
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
13811328

13821329
struct ggml_init_params params = {
1383-
/*.mem_size =*/ ctx_buf_size,
1330+
/*.mem_size =*/ buf_size,
13841331
/*.mem_buffer =*/ NULL,
13851332
/*.no_alloc =*/ true,
13861333
};
@@ -1396,7 +1343,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
13961343
std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
13971344
for (uint32_t i = 0; i < n_nodes; i++) {
13981345
int64_t id;
1399-
memcpy(&id, &nodes_ptr[i], sizeof(id));
1346+
memcpy(&id, &nodes[i], sizeof(id));
14001347
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
14011348

14021349
// Check if create_node failed for a *non-zero* ID.
@@ -1405,7 +1352,6 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
14051352
if (graph->nodes[i] == nullptr && id != 0) {
14061353
GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
14071354
response.result = GGML_STATUS_FAILED;
1408-
// No need to free ctx, ggml_context_ptr handles it.
14091355
return false;
14101356
}
14111357
}

0 commit comments

Comments
 (0)