Skip to content

Conversation

@zeroecco
Copy link
Contributor

@zeroecco zeroecco commented Sep 5, 2025

🏎️ speed up multi-gpu impls by moving lift to join at the critical segment proving section.
🧹 clean up some clippy warnings

example on single gpu:
before:

task times:
 task_type | completed_count |      avg_seconds       | min_seconds | max_seconds
-----------+-----------------+------------------------+-------------+-------------
 Executor  |               1 |     3.8048860000000000 |    3.804886 |    3.804886
 Prove     |             281 |     2.3673917153024911 |    1.232411 |    2.540479
 Keccak    |              54 |     1.3369058888888889 |    1.314760 |    1.370198
 Union     |              53 | 0.27935541509433962264 |    0.270431 |    0.283290
 Join      |             280 | 0.27879625000000000000 |    0.268699 |    0.283521
 Resolve   |               1 | 0.27236400000000000000 |    0.272364 |    0.272364
 Finalize  |               1 | 0.02594100000000000000 |    0.025941 |    0.025941
(7 rows)

task times (totals):
 task_type | completed_count | total_secs
-----------+-----------------+------------
 Prove     |             281 | 665.237072
 Join      |             280 |  78.062950
 Keccak    |              54 |  72.192918
 Union     |              53 |  14.805837
 Executor  |               1 |   3.804886
 Resolve   |               1 |   0.272364
 Finalize  |               1 |   0.025941
(7 rows)

Effective Hz:
         hz          | total_cycles | elapsed_sec
---------------------+--------------+-------------
 353406.605524556257 | 294125568    |  832.258264
(1 row)

After:

 jobs_count
------------
        671
(1 row)

 task_type | remaining_count
-----------+-----------------
(0 rows)

task times:
 task_type | completed_count |      avg_seconds       | min_seconds | max_seconds
-----------+-----------------+------------------------+-------------+-------------
 Executor  |               1 |     3.7993170000000000 |    3.799317 |    3.799317
 Prove     |             281 |     1.9717433096085409 |    1.031125 |    2.367374
 Join      |             280 |     1.6322991214285714 |    0.250003 |    3.312511
 Keccak    |              54 |     1.3520483333333333 |    1.066685 |    1.449422
 Union     |              53 | 0.37663630188679245283 |    0.246680 |    0.460196
 Resolve   |               1 | 0.25521400000000000000 |    0.255214 |    0.255214
 Finalize  |               1 | 0.02820700000000000000 |    0.028207 |    0.028207
(7 rows)

task times (totals):
 task_type | completed_count | total_secs
-----------+-----------------+------------
 Prove     |             281 | 554.059870
 Join      |             280 | 457.043754
 Keccak    |              54 |  73.010610
 Union     |              53 |  19.961724
 Executor  |               1 |   3.799317
 Resolve   |               1 |   0.255214
 Finalize  |               1 |   0.028207
(7 rows)

Effective Hz:
         hz          | total_cycles | elapsed_sec
---------------------+--------------+-------------
 452561.847155696635 | 294125568    |  649.912426

@zeroecco zeroecco marked this pull request as ready for review September 5, 2025 00:02
@github-actions github-actions bot changed the title move lift to join BM-1546: move lift to join Sep 5, 2025
@github-actions github-actions bot changed the title BM-1546: move lift to join BM-1547: move lift to join Sep 5, 2025
let right_receipt: SuccinctReceipt<ReceiptClaim> =
deserialize_obj(&right_receipt).context("Failed to deserialize right receipt")?;
// Handle each receipt independently - they could be mixed types
let prover = agent.prover.as_ref().context("Missing prover from resolve task")?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not resolve task, so might be confusing?

Comment on lines +37 to +70
async {
match deserialize_obj::<SegmentReceipt>(&left_receipt_data) {
Ok(segment) => {
let receipt = prover
.lift(&segment)
.with_context(|| format!("Failed to lift segment {left_idx}"))?;
tracing::debug!("lifting complete {job_id} - {left_idx}");
Ok::<_, anyhow::Error>(receipt)
}
Err(_) => {
let receipt: SuccinctReceipt<ReceiptClaim> =
deserialize_obj(&left_receipt_data)
.context("Failed to deserialize left receipt")?;
Ok(receipt)
}
}
},
async {
match deserialize_obj::<SegmentReceipt>(&right_receipt_data) {
Ok(segment) => {
let receipt = prover
.lift(&segment)
.with_context(|| format!("Failed to lift segment {right_idx}"))?;
tracing::debug!("lifting complete {job_id} - {right_idx}");
Ok::<_, anyhow::Error>(receipt)
}
Err(_) => {
let receipt: SuccinctReceipt<ReceiptClaim> =
deserialize_obj(&right_receipt_data)
.context("Failed to deserialize right receipt")?;
Ok(receipt)
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: duplication here and in join_povw could be removed if you want:

diff --git a/bento/crates/workflow/src/tasks/join.rs b/bento/crates/workflow/src/tasks/join.rs
index 2758570c..240518b9 100644
--- a/bento/crates/workflow/src/tasks/join.rs
+++ b/bento/crates/workflow/src/tasks/join.rs
@@ -3,16 +3,40 @@
 // Use of this source code is governed by the Business Source License
 // as found in the LICENSE-BSL file.
 
+use std::rc::Rc;
+
 use crate::{
     Agent,
     redis::{self, AsyncCommands},
     tasks::{RECUR_RECEIPT_PATH, deserialize_obj, serialize_obj},
 };
 use anyhow::{Context, Result};
-use risc0_zkvm::{ReceiptClaim, SegmentReceipt, SuccinctReceipt};
+use risc0_zkvm::{ProverServer, ReceiptClaim, SegmentReceipt, SuccinctReceipt};
 use uuid::Uuid;
 use workflow_common::JoinReq;
 
+/// Lifts a receipt if it's a segment receipt, otherwise returns it as-is
+async fn lift_or_deserialize_receipt(
+    prover: &Rc<dyn ProverServer>,
+    receipt_data: &[u8],
+    idx: usize,
+    job_id: &Uuid,
+) -> Result<SuccinctReceipt<ReceiptClaim>> {
+    match deserialize_obj::<SegmentReceipt>(receipt_data) {
+        Ok(segment) => {
+            let receipt =
+                prover.lift(&segment).with_context(|| format!("Failed to lift segment {idx}"))?;
+            tracing::debug!("lifting complete {job_id} - {idx}");
+            Ok(receipt)
+        }
+        Err(_) => {
+            let receipt: SuccinctReceipt<ReceiptClaim> = deserialize_obj(receipt_data)
+                .with_context(|| format!("Failed to deserialize receipt {idx}"))?;
+            Ok(receipt)
+        }
+    }
+}
+
 /// Run the join operation
 pub async fn join(agent: &Agent, job_id: &Uuid, request: &JoinReq) -> Result<()> {
     let mut conn = agent.redis_pool.get().await?;
@@ -29,45 +53,11 @@ pub async fn join(agent: &Agent, job_id: &Uuid, request: &JoinReq) -> Result<()>
         )?;
 
     // Handle each receipt independently - they could be mixed types
-    let prover = agent.prover.as_ref().context("Missing prover from resolve task")?;
-    let left_idx = request.left;
-    let right_idx = request.right;
+    let prover = agent.prover.as_ref().context("Missing prover from join task")?;
 
     let (left_receipt, right_receipt) = tokio::try_join!(
-        async {
-            match deserialize_obj::<SegmentReceipt>(&left_receipt_data) {
-                Ok(segment) => {
-                    let receipt = prover
-                        .lift(&segment)
-                        .with_context(|| format!("Failed to lift segment {left_idx}"))?;
-                    tracing::debug!("lifting complete {job_id} - {left_idx}");
-                    Ok::<_, anyhow::Error>(receipt)
-                }
-                Err(_) => {
-                    let receipt: SuccinctReceipt<ReceiptClaim> =
-                        deserialize_obj(&left_receipt_data)
-                            .context("Failed to deserialize left receipt")?;
-                    Ok(receipt)
-                }
-            }
-        },
-        async {
-            match deserialize_obj::<SegmentReceipt>(&right_receipt_data) {
-                Ok(segment) => {
-                    let receipt = prover
-                        .lift(&segment)
-                        .with_context(|| format!("Failed to lift segment {right_idx}"))?;
-                    tracing::debug!("lifting complete {job_id} - {right_idx}");
-                    Ok::<_, anyhow::Error>(receipt)
-                }
-                Err(_) => {
-                    let receipt: SuccinctReceipt<ReceiptClaim> =
-                        deserialize_obj(&right_receipt_data)
-                            .context("Failed to deserialize right receipt")?;
-                    Ok(receipt)
-                }
-            }
-        }
+        lift_or_deserialize_receipt(prover, &left_receipt_data, request.left, job_id),
+        lift_or_deserialize_receipt(prover, &right_receipt_data, request.right, job_id)
     )?;
 
     tracing::trace!("Joining {job_id} - {} + {} -> {}", request.left, request.right, request.idx);
diff --git a/bento/crates/workflow/src/tasks/join_povw.rs b/bento/crates/workflow/src/tasks/join_povw.rs
index ea5e4b1b..0408c848 100644
--- a/bento/crates/workflow/src/tasks/join_povw.rs
+++ b/bento/crates/workflow/src/tasks/join_povw.rs
@@ -3,16 +3,42 @@
 // Use of this source code is governed by the Business Source License
 // as found in the LICENSE-BSL file.
 
+use std::rc::Rc;
+
 use crate::{
     Agent,
     redis::{self, AsyncCommands},
     tasks::{RECUR_RECEIPT_PATH, deserialize_obj, serialize_obj},
 };
 use anyhow::{Context, Result};
-use risc0_zkvm::{ReceiptClaim, SegmentReceipt, SuccinctReceipt, WorkClaim};
+use risc0_zkvm::{ProverServer, ReceiptClaim, SegmentReceipt, SuccinctReceipt, WorkClaim};
 use uuid::Uuid;
 use workflow_common::JoinReq;
 
+/// Lifts a receipt to POVW if it's a segment receipt, otherwise returns it as-is
+async fn lift_or_deserialize_povw_receipt(
+    prover: &Rc<dyn ProverServer>,
+    receipt_data: &[u8],
+    idx: usize,
+    job_id: &Uuid,
+) -> Result<SuccinctReceipt<WorkClaim<ReceiptClaim>>> {
+    match deserialize_obj::<SegmentReceipt>(receipt_data) {
+        Ok(segment_receipt) => {
+            let povw_receipt = prover
+                .lift_povw(&segment_receipt)
+                .with_context(|| format!("Failed to lift segment {idx} to POVW"))?;
+            tracing::debug!("POVW lifting complete {job_id} - {idx}");
+            Ok(povw_receipt)
+        }
+        Err(_) => {
+            let povw_receipt: SuccinctReceipt<WorkClaim<ReceiptClaim>> =
+                deserialize_obj(receipt_data)
+                    .with_context(|| format!("Failed to deserialize POVW receipt {idx}"))?;
+            Ok(povw_receipt)
+        }
+    }
+}
+
 /// Run a POVW join request
 pub async fn join_povw(agent: &Agent, job_id: &Uuid, request: &JoinReq) -> Result<()> {
     let mut conn = agent.redis_pool.get().await?;
@@ -32,46 +58,9 @@ pub async fn join_povw(agent: &Agent, job_id: &Uuid, request: &JoinReq) -> Resul
     // Handle each receipt independently - they could be mixed types
     let prover = agent.prover.as_ref().context("Missing prover from POVW join task")?;
 
-    let (left_receipt, right_receipt): (
-        SuccinctReceipt<WorkClaim<ReceiptClaim>>,
-        SuccinctReceipt<WorkClaim<ReceiptClaim>>,
-    ) = tokio::try_join!(
-        async {
-            match deserialize_obj::<SegmentReceipt>(&left_receipt_bytes) {
-                Ok(segment_receipt) => {
-                    // Successfully deserialized as segment receipt, now lift to POVW
-                    let povw_receipt = prover
-                        .lift_povw(&segment_receipt)
-                        .context("Failed to lift left segment to POVW")?;
-                    Ok::<_, anyhow::Error>(povw_receipt)
-                }
-                Err(_) => {
-                    // Failed to deserialize as segment, try as already-lifted POVW receipt
-                    let povw_receipt: SuccinctReceipt<WorkClaim<ReceiptClaim>> =
-                        deserialize_obj(&left_receipt_bytes)
-                            .context("Failed to deserialize left POVW receipt")?;
-                    Ok(povw_receipt)
-                }
-            }
-        },
-        async {
-            match deserialize_obj::<SegmentReceipt>(&right_receipt_bytes) {
-                Ok(segment_receipt) => {
-                    // Successfully deserialized as segment receipt, now lift to POVW
-                    let povw_receipt = prover
-                        .lift_povw(&segment_receipt)
-                        .context("Failed to lift right segment to POVW")?;
-                    Ok::<_, anyhow::Error>(povw_receipt)
-                }
-                Err(_) => {
-                    // Failed to deserialize as segment, try as already-lifted POVW receipt
-                    let povw_receipt: SuccinctReceipt<WorkClaim<ReceiptClaim>> =
-                        deserialize_obj(&right_receipt_bytes)
-                            .context("Failed to deserialize right POVW receipt")?;
-                    Ok(povw_receipt)
-                }
-            }
-        }
+    let (left_receipt, right_receipt) = tokio::try_join!(
+        lift_or_deserialize_povw_receipt(prover, &left_receipt_bytes, request.left, job_id),
+        lift_or_deserialize_povw_receipt(prover, &right_receipt_bytes, request.right, job_id)
     )?;
 
     tracing::debug!("Starting POVW join of receipts {} and {}", request.left, request.right);

@willpote willpote requested a review from ec2 as a code owner November 5, 2025 00:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants