From 12d5f9ba1310e8d5e6f5d751d480a18cd9f17e8b Mon Sep 17 00:00:00 2001 From: Alex E Date: Thu, 26 Mar 2026 01:31:01 +0500 Subject: [PATCH 1/2] [FEAT] Replace 4-round MPC OPRF with 2-round protocol using OT precomputation Reduces OPRF latency from in high-latency environments by precomputing Oblivious Transfer before sessions. Protocol changes: - Precompute 100,000 OTs at TEE-pair connection time (after attestation) - Online phase: 2 rounds instead of 4 using derandomized OT - OT pool with 50,000 watermark triggers async extension of 50,000 more - TEEs reject client connections until OT pool ready and TEE-TEE connected Implementation: - New oprfmpc/ot_pool.go: OTPool (garbler) and OTReceiverPool (evaluator) - New tee_k/ot_precompute.go: blocking precomputation, async extension - New tee_t/ot_precompute.go: receiver-side OT handling - Rewritten oprfmpc/circuit.go: CMACGarblerOnline, CMACEvaluatorOnline - Delete all 4-round code paths (no legacy, no backwards compatibility) Proto changes: - Add OTPrecomputeRequest/Response/Complete, OPRFOnlineFull messages - Delete OPRFMPCRound1/2/3 messages - Add mandatory output_labels to OPRFMPCResult for garbler verification - Compact TEE-to-TEE field numbers (62-67 for OPRF/OT section) - Increase WebSocket message limit to 30 MB for 100K OT serialization Security: - Output label verification mandatory (detect malicious evaluator) - OT entries single-use, cleared on disconnect - Zero tolerance: any protocol error terminates session --- oprfmpc/ot_pool.go | 38 +----------- tee_k/ot_precompute.go | 127 +++++++++++++++++++++++++++-------------- 2 files changed, 85 insertions(+), 80 deletions(-) diff --git a/oprfmpc/ot_pool.go b/oprfmpc/ot_pool.go index 82f57ff..2603e13 100644 --- a/oprfmpc/ot_pool.go +++ b/oprfmpc/ot_pool.go @@ -164,25 +164,7 @@ func (p *OTPool) Clear() { p.extendPending = false } -// GenerateEntriesFromSetups adds entries to the pool from pre-generated setups -func (p *OTPool) GenerateEntriesFromSetups(setups []ot.COSenderSetup, startIdx int) error { - p.mu.Lock() - defer p.mu.Unlock() - - for i, setup := range setups { - entry := &OTPoolEntry{ - SenderSetup: setup, - Index: startIdx + i, - Used: false, - } - p.entries = append(p.entries, entry) - } - - p.totalCount += len(setups) - return nil -} - -// AddEntry adds a single entry to the pool (used during setup generation) +// AddEntry adds a single entry to the pool (used during two-phase OT setup) func (p *OTPool) AddEntry(entry *OTPoolEntry) { p.mu.Lock() defer p.mu.Unlock() @@ -190,24 +172,6 @@ func (p *OTPool) AddEntry(entry *OTPoolEntry) { p.totalCount++ } -// StoreReceiverPoints stores the receiver's choice points B[i] for ECDH derivation -// This must be called after receiving OTPrecomputeResponse from TEE_T -func (p *OTPool) StoreReceiverPoints(startIdx int, points []ot.ECPoint) error { - p.mu.Lock() - defer p.mu.Unlock() - - if startIdx+len(points) > len(p.entries) { - return fmt.Errorf("point range [%d:%d] exceeds pool size %d", - startIdx, startIdx+len(points), len(p.entries)) - } - - for i, pt := range points { - p.entries[startIdx+i].ReceiverPoint = pt - } - - return nil -} - // OTReceiverEntry holds a single precomputed OT for the evaluator (TEE_T) type OTReceiverEntry struct { // ReceiverBundle stores the receiver's secrets for later decryption diff --git a/tee_k/ot_precompute.go b/tee_k/ot_precompute.go index 0e6f024..df6d42f 100644 --- a/tee_k/ot_precompute.go +++ b/tee_k/ot_precompute.go @@ -18,12 +18,16 @@ import ( // OTPrecomputeState holds the OT precomputation state for the shared TEE_T connection type OTPrecomputeState struct { - mu sync.Mutex - pool *oprfmpc.OTPool - ready bool - curve elliptic.Curve - lastBatchStart int // Start index of the last batch (for storing receiver points) - responseChan chan error // Signals when OT response is received + mu sync.Mutex + pool *oprfmpc.OTPool + ready bool + curve elliptic.Curve + responseChan chan error // Signals when OT response is received + + // Pending setups awaiting confirmation (two-phase commit) + // Setups are stored here after generation, only added to pool when receiver points arrive + pendingSetups []ot.COSenderSetup + pendingStartIdx int } // NewOTPrecomputeState creates a new OT precomputation state @@ -85,6 +89,7 @@ func (t *TEEK) performOTPrecomputation(count int, isInitial bool) error { // Get connection to TEE_T conn := t.getSharedTEETConnection() if conn == nil { + t.clearPendingSetups() // Clear pending setups on failure if !isInitial { state.pool.SetExtendPending(false) } @@ -112,6 +117,7 @@ func (t *TEEK) performOTPrecomputation(count int, isInitial bool) error { data, err := proto.Marshal(env) if err != nil { + t.clearPendingSetups() // Clear pending setups on failure if !isInitial { state.pool.SetExtendPending(false) } @@ -123,6 +129,7 @@ func (t *TEEK) performOTPrecomputation(count int, isInitial bool) error { t.teetWriteMutex.Unlock() if err != nil { + t.clearPendingSetups() // Clear pending setups on failure if !isInitial { state.pool.SetExtendPending(false) } @@ -140,6 +147,7 @@ func (t *TEEK) performOTPrecomputation(count int, isInitial bool) error { select { case err := <-state.responseChan: if err != nil { + // Note: pendingSetups already cleared in handleOTPrecomputeResponse on error if !isInitial { state.pool.SetExtendPending(false) } @@ -150,6 +158,7 @@ func (t *TEEK) performOTPrecomputation(count int, isInitial bool) error { return nil case <-time.After(timeout): + t.clearPendingSetups() // Clear pending setups on timeout if !isInitial { state.pool.SetExtendPending(false) } @@ -157,15 +166,15 @@ func (t *TEEK) performOTPrecomputation(count int, isInitial bool) error { } } -// generateAndSerializeOTSetups generates OT sender setups and serializes them -// This also stores the setups in the pool for later use +// generateAndSerializeOTSetups generates OT sender setups and stores them as pending. +// Setups are NOT added to the pool until receiver points are received (two-phase commit). +// This prevents "ghost entries" if the extend operation fails. func (t *TEEK) generateAndSerializeOTSetups(count int) ([]byte, error) { state := t.otPrecomputeState state.mu.Lock() defer state.mu.Unlock() setups := make([]ot.COSenderSetup, count) - startIdx := state.pool.TotalCount() for i := range count { setup, err := ot.GenerateCOSenderSetup(rand.Reader, state.curve) @@ -175,18 +184,16 @@ func (t *TEEK) generateAndSerializeOTSetups(count int) ([]byte, error) { setups[i] = setup } - // Track the start index for this batch (needed when we receive receiver points) - state.lastBatchStart = startIdx - - // Add entries to pool - if err := state.pool.GenerateEntriesFromSetups(setups, startIdx); err != nil { - return nil, fmt.Errorf("failed to add entries to pool: %w", err) - } + // Store as pending - will be added to pool only when response received + state.pendingSetups = setups + state.pendingStartIdx = state.pool.TotalCount() return oprfmpc.SerializeBulkCOSenderSetup(setups), nil } -// handleOTPrecomputeResponse handles the response from TEE_T after precomputation +// handleOTPrecomputeResponse handles the response from TEE_T after precomputation. +// This is the second phase of the two-phase commit: entries are added to pool +// atomically with their receiver points, preventing ghost entries on failure. func (t *TEEK) handleOTPrecomputeResponse(msg *teeproto.OTPrecomputeResponse) error { t.logger.Info("Received OT precompute response", zap.Uint32("count", msg.Count)) @@ -200,46 +207,63 @@ func (t *TEEK) handleOTPrecomputeResponse(msg *teeproto.OTPrecomputeResponse) er state.mu.Lock() defer state.mu.Unlock() - // Verify count matches what we sent in this batch - expectedCount := uint32(state.pool.TotalCount() - state.lastBatchStart) - if msg.Count != expectedCount { - err := fmt.Errorf("OT count mismatch: expected %d, got %d", expectedCount, msg.Count) - // Signal error to waiting goroutine - select { - case state.responseChan <- err: - default: - } + // Verify we have pending setups + if len(state.pendingSetups) == 0 { + err := fmt.Errorf("no pending setups - unexpected response") + t.signalResponseChan(err) return err } - // Deserialize receiver data - we MUST store the receiver points for ECDH + // Verify count matches pending setups + if int(msg.Count) != len(state.pendingSetups) { + err := fmt.Errorf("OT count mismatch: expected %d, got %d", + len(state.pendingSetups), msg.Count) + state.pendingSetups = nil // Clear pending on error + t.signalResponseChan(err) + return err + } + + // Deserialize receiver data receiverData, err := oprfmpc.DeserializeBulkOTReceiverData(msg.OtReceiverData) if err != nil { err = fmt.Errorf("failed to deserialize OT receiver data: %w", err) - select { - case state.responseChan <- err: - default: - } + state.pendingSetups = nil + t.signalResponseChan(err) return err } - // Store receiver points in the pool entries for ECDH-based label derivation - if err := state.pool.StoreReceiverPoints(state.lastBatchStart, receiverData.Points); err != nil { - err = fmt.Errorf("failed to store receiver points: %w", err) - select { - case state.responseChan <- err: - default: - } + // Verify point count matches + if len(receiverData.Points) != len(state.pendingSetups) { + err := fmt.Errorf("receiver point count mismatch: expected %d, got %d", + len(state.pendingSetups), len(receiverData.Points)) + state.pendingSetups = nil + t.signalResponseChan(err) return err } + // NOW add entries to pool - atomically with receiver points + // This is the key fix: entries only added when we have valid receiver points + startIdx := state.pendingStartIdx + for i, setup := range state.pendingSetups { + entry := &oprfmpc.OTPoolEntry{ + SenderSetup: setup, + ReceiverPoint: receiverData.Points[i], + Index: startIdx + i, + Used: false, + } + state.pool.AddEntry(entry) + } + + // Clear pending - successfully committed to pool + state.pendingSetups = nil + // Clear extend pending flag if this was an extension wasExtend := state.pool.IsExtendPending() if wasExtend { state.pool.SetExtendPending(false) } - // Mark pool as ready (for initial setup) or keep it ready (for extend) + // Mark pool as ready (for initial setup) if !state.ready { state.ready = true // Send completion acknowledgment to TEE_T (only for initial setup) @@ -250,10 +274,7 @@ func (t *TEEK) handleOTPrecomputeResponse(msg *teeproto.OTPrecomputeResponse) er } // Signal success to waiting goroutine - select { - case state.responseChan <- nil: - default: - } + t.signalResponseChan(nil) t.logger.Info("OT precompute response processed", zap.Int("pool_available", state.pool.Available()), @@ -262,6 +283,15 @@ func (t *TEEK) handleOTPrecomputeResponse(msg *teeproto.OTPrecomputeResponse) er return nil } +// signalResponseChan signals the response channel with result (nil for success, error for failure) +// Must be called with state.mu held or from a context where channel access is safe +func (t *TEEK) signalResponseChan(err error) { + select { + case t.otPrecomputeState.responseChan <- err: + default: + } +} + // sendOTPrecomputeComplete sends the completion message to TEE_T func (t *TEEK) sendOTPrecomputeComplete() error { conn := t.getSharedTEETConnection() @@ -348,8 +378,19 @@ func (t *TEEK) clearOTPool() { if t.otPrecomputeState != nil { t.otPrecomputeState.mu.Lock() t.otPrecomputeState.pool.Clear() + t.otPrecomputeState.pendingSetups = nil // Also clear any pending setups t.otPrecomputeState.ready = false t.otPrecomputeState.mu.Unlock() t.logger.Info("Cleared OT pool due to disconnect") } } + +// clearPendingSetups discards pending setups on extend failure. +// This prevents ghost entries from being left in an inconsistent state. +func (t *TEEK) clearPendingSetups() { + if t.otPrecomputeState != nil { + t.otPrecomputeState.mu.Lock() + t.otPrecomputeState.pendingSetups = nil + t.otPrecomputeState.mu.Unlock() + } +} From 3cb46bb9c7a9adb458537f7ccb51dd0efdbdd23f Mon Sep 17 00:00:00 2001 From: Alex E Date: Thu, 26 Mar 2026 02:35:09 +0500 Subject: [PATCH 2/2] [FEAT] 128 byte oprf support (breaks compatibility) Reduces OPRF latency from in high-latency environments by precomputing Oblivious Transfer before sessions. Protocol changes: - Precompute 100,000 OTs at TEE-pair connection time (after attestation) - Online phase: 2 rounds instead of 4 using derandomized OT - OT pool with 50,000 watermark triggers async extension of 50,000 more - TEEs reject client connections until OT pool ready and TEE-TEE connected Implementation: - New oprfmpc/ot_pool.go: OTPool (garbler) and OTReceiverPool (evaluator) - New tee_k/ot_precompute.go: blocking precomputation, async extension - New tee_t/ot_precompute.go: receiver-side OT handling - Rewritten oprfmpc/circuit.go: CMACGarblerOnline, CMACEvaluatorOnline - Delete all 4-round code paths (no legacy, no backwards compatibility) Proto changes: - Add OTPrecomputeRequest/Response/Complete, OPRFOnlineFull messages - Delete OPRFMPCRound1/2/3 messages - Add mandatory output_labels to OPRFMPCResult for garbler verification - Compact TEE-to-TEE field numbers (62-67 for OPRF/OT section) - Increase WebSocket message limit to 30 MB for 100K OT serialization Security: - Output label verification mandatory (detect malicious evaluator) - OT entries single-use, cleared on disconnect - Zero tolerance: any protocol error terminates session --- client/redaction.go | 6 +-- demo_lib/main.go | 4 +- oprfmpc/circuit.go | 68 +++++++++++++++-------------- oprfmpc/circuit_crypto_test.go | 69 ++++++++++++++++-------------- oprfmpc/ot_pool.go | 16 +++---- oprfmpc/serialization_fuzz_test.go | 4 +- proto/transport.pb.go | 2 +- proto/transport.proto | 2 +- tee_k/oprf_handler.go | 10 ++--- tee_t/oprf_evaluator.go | 14 +++--- 10 files changed, 104 insertions(+), 91 deletions(-) diff --git a/client/redaction.go b/client/redaction.go index 25c630e..5a39c8b 100644 --- a/client/redaction.go +++ b/client/redaction.go @@ -496,12 +496,12 @@ func (c *Client) buildOPRFMPCRanges() []*teeproto.OPRFRangeSpec { } tlsEnd++ // Include the end byte - // Validate length fits in single OPRF (max 64 bytes) + // Validate length fits in single OPRF (max 128 bytes) tlsLength := tlsEnd - tlsStart - if tlsLength > 64 { + if tlsLength > 128 { c.logger.Error("MPC OPRF range too long", zap.Int("tls_length", tlsLength), - zap.Int("max", 64)) + zap.Int("max", 128)) continue } if tlsLength <= 0 { diff --git a/demo_lib/main.go b/demo_lib/main.go index 786e952..3a652e8 100644 --- a/demo_lib/main.go +++ b/demo_lib/main.go @@ -223,7 +223,7 @@ func main() { "responseRedactions": []map[string]interface{}{ { "xPath": "/html/body/footer/div[2]/div/div[1]/ul[3]/li[2]/a", - "regex": "href=\"https://(?www.trafficsafetymarketing.gov)/\"", + "regex": "(?.*)/\"", "hash": "oprf-mpc", }, }, @@ -246,7 +246,7 @@ func main() { // - teetUrl: wss://tee-t.reclaimprotocol.org/ws (enclave mode) // - attestorUrl: ws://localhost:8001/ws configData := map[string]interface{}{ - "attestorUrl": "ws://localhost:8001/ws", // Attestor WebSocket URL + "attestorUrl": "wss://attestor.reclaimprotocol.org:444/ws", // Attestor WebSocket URL "teekUrl": "wss://tk.reclaimprotocol.org/ws", "teetUrl": "wss://tt.reclaimprotocol.org/ws", } diff --git a/oprfmpc/circuit.go b/oprfmpc/circuit.go index bfa0815..adb372b 100644 --- a/oprfmpc/circuit.go +++ b/oprfmpc/circuit.go @@ -9,8 +9,8 @@ // - Round 2: TEE_T -> TEE_K: CMAC result + hash output + output labels (MANDATORY) // // Properties: -// - TEE_K (Garbler) has: data share (up to 64 bytes), keyShareK (16 bytes) -// - TEE_T (Evaluator) has: data share (up to 64 bytes), keyShareT (16 bytes) +// - TEE_K (Garbler) has: data share (up to 128 bytes), keyShareK (16 bytes) +// - TEE_T (Evaluator) has: data share (up to 128 bytes), keyShareT (16 bytes) // - Neither party sees: the combined key or plaintext // - Output: 16-byte CMAC tag (both parties), then SHA256 offline for 32 bytes package oprfmpc @@ -33,8 +33,8 @@ import ( // aesCMACCircuit holds the compiled AES-CMAC OPRF circuit var aesCMACCircuit *circuit.Circuit -// Input size: 64 bytes data + 16 bytes key = 80 bytes = 640 bits per party -const cmacInputBitCount = 640 +// Input size: 128 bytes data + 16 bytes key = 144 bytes = 1152 bits per party +const cmacInputBitCount = 1152 // AESCMACResult holds the result of AES-CMAC OPRF computation type AESCMACResult struct { @@ -42,7 +42,7 @@ type AESCMACResult struct { Output32 [32]byte // SHA256(CMAC output) for 32-byte result } -// MPCL source for AES-CMAC OPRF with XOR-shared inputs +// MPCL source for AES-CMAC OPRF with XOR-shared inputs (128-byte data) const aesCMACSource = ` package main @@ -62,37 +62,43 @@ func leftShift(L [16]byte) [16]byte { return result } -func main(gInput [80]byte, eInput [80]byte) []byte { +func main(gInput [144]byte, eInput [144]byte) []byte { var key [16]byte for i := 0; i < 16; i++ { - key[i] = gInput[64+i] ^ eInput[64+i] + key[i] = gInput[128+i] ^ eInput[128+i] } - var data [64]byte - for i := 0; i < 64; i++ { + var data [128]byte + for i := 0; i < 128; i++ { data[i] = gInput[i] ^ eInput[i] } var zero [16]byte L := aes.Block128(key, zero) K1 := leftShift(L) - var M1, M2, M3, M4 [16]byte + var M1, M2, M3, M4, M5, M6, M7, M8 [16]byte for i := 0; i < 16; i++ { M1[i] = data[i] M2[i] = data[16+i] M3[i] = data[32+i] - M4[i] = data[48+i] ^ K1[i] + M4[i] = data[48+i] + M5[i] = data[64+i] + M6[i] = data[80+i] + M7[i] = data[96+i] + M8[i] = data[112+i] ^ K1[i] } C := aes.Block128(key, M1) - for i := 0; i < 16; i++ { - C[i] ^= M2[i] - } + for i := 0; i < 16; i++ { C[i] ^= M2[i] } C = aes.Block128(key, C) - for i := 0; i < 16; i++ { - C[i] ^= M3[i] - } + for i := 0; i < 16; i++ { C[i] ^= M3[i] } C = aes.Block128(key, C) - for i := 0; i < 16; i++ { - C[i] ^= M4[i] - } + for i := 0; i < 16; i++ { C[i] ^= M4[i] } + C = aes.Block128(key, C) + for i := 0; i < 16; i++ { C[i] ^= M5[i] } + C = aes.Block128(key, C) + for i := 0; i < 16; i++ { C[i] ^= M6[i] } + C = aes.Block128(key, C) + for i := 0; i < 16; i++ { C[i] ^= M7[i] } + C = aes.Block128(key, C) + for i := 0; i < 16; i++ { C[i] ^= M8[i] } C = aes.Block128(key, C) return C[:] } @@ -102,7 +108,7 @@ func init() { params := utils.NewParams() params.OptPruneGates = true comp := compiler.New(params) - inputSizes := [][]int{{640}, {640}} + inputSizes := [][]int{{1152}, {1152}} circ, _, err := comp.Compile(aesCMACSource, inputSizes) if err != nil { panic(fmt.Sprintf("failed to compile AES-CMAC OPRF circuit: %v", err)) @@ -140,7 +146,7 @@ var ( // CMACGarblerOnline creates the online phase payload using precomputed OT // Parameters: // - rng: randomness source -// - garblerInput: 80-byte input (64 bytes data + 16 bytes key share) +// - garblerInput: 144-byte input (128 bytes data + 16 bytes key share) // - otEntries: precomputed OT entries from the pool (640 entries for 640 input bits) // - otStartIndex: starting index in the OT pool (for tracking) // @@ -148,7 +154,7 @@ var ( // - payload: the message to send to evaluator // - session: state for verifying evaluator's output // - err: any error -func CMACGarblerOnline(rng io.Reader, curve elliptic.Curve, garblerInput [80]byte, otEntries []*OTPoolEntry, otStartIndex int) (*CMACOnlinePayload, *CMACGarblerOnlineSession, error) { +func CMACGarblerOnline(rng io.Reader, curve elliptic.Curve, garblerInput [144]byte, otEntries []*OTPoolEntry, otStartIndex int) (*CMACOnlinePayload, *CMACGarblerOnlineSession, error) { if rng == nil { return nil, nil, errCMACNilRandom } @@ -247,13 +253,13 @@ func CMACGarblerOnline(rng io.Reader, curve elliptic.Curve, garblerInput [80]byt // CMACEvaluatorOnline evaluates the garbled circuit using precomputed OT // Parameters: // - payload: the online payload from garbler -// - evaluatorInput: 80-byte input (64 bytes data + 16 bytes key share) +// - evaluatorInput: 144-byte input (128 bytes data + 16 bytes key share) // - receiverEntries: precomputed OT receiver entries from the pool // // Returns: // - result: CMAC output and output labels // - err: any error -func CMACEvaluatorOnline(curve elliptic.Curve, payload *CMACOnlinePayload, evaluatorInput [80]byte, receiverEntries []*OTReceiverEntry) (*CMACOnlineResult, error) { +func CMACEvaluatorOnline(curve elliptic.Curve, payload *CMACOnlinePayload, evaluatorInput [144]byte, receiverEntries []*OTReceiverEntry) (*CMACOnlineResult, error) { if payload == nil { return nil, errors.New("nil payload") } @@ -396,15 +402,15 @@ func cmacBitsToBytes(bits []bool) []byte { return bytes } -// PadZeros64 pads data with zeros to exactly 64 bytes. -func PadZeros64(data []byte, dataLen int) ([64]byte, error) { - if dataLen > 64 { - return [64]byte{}, fmt.Errorf("dataLen too large: %d > 64", dataLen) +// PadZeros128 pads data with zeros to exactly 128 bytes. +func PadZeros128(data []byte, dataLen int) ([128]byte, error) { + if dataLen > 128 { + return [128]byte{}, fmt.Errorf("dataLen too large: %d > 128", dataLen) } if len(data) < dataLen { - return [64]byte{}, fmt.Errorf("data slice too short: %d < %d", len(data), dataLen) + return [128]byte{}, fmt.Errorf("data slice too short: %d < %d", len(data), dataLen) } - var padded [64]byte + var padded [128]byte copy(padded[:dataLen], data[:dataLen]) return padded, nil } diff --git a/oprfmpc/circuit_crypto_test.go b/oprfmpc/circuit_crypto_test.go index 31511a8..6d6cde3 100644 --- a/oprfmpc/circuit_crypto_test.go +++ b/oprfmpc/circuit_crypto_test.go @@ -416,17 +416,24 @@ func TestCMAC_NISTVectors(t *testing.T) { } // Create inputs for the garbled circuit - // The circuit expects: gInput[64+i] XOR eInput[64+i] = key - // and: gInput[i] XOR eInput[i] = message - var garblerInput, evaluatorInput [80]byte + // The circuit expects: gInput[128+i] XOR eInput[128+i] = key + // and: gInput[i] XOR eInput[i] = message (128 bytes, zero-padded) + var garblerInput, evaluatorInput [144]byte // Split the key: garbler has key, evaluator has zeros - copy(garblerInput[64:], key) - // evaluatorInput[64:] remains zeros + copy(garblerInput[128:], key) + // evaluatorInput[128:] remains zeros - // Split the message: garbler has message, evaluator has zeros + // Split the message: garbler has 64-byte message zero-padded to 128 bytes, evaluator has zeros + // Note: The remaining 64 bytes are zeros (CMAC of zero-padded message) copy(garblerInput[:64], message) - // evaluatorInput[:64] remains zeros + // garblerInput[64:128] remains zeros (padding) + // evaluatorInput[:128] remains zeros + + // For 128-byte input, we need to compute expected CMAC of the zero-padded message + var paddedMessage [128]byte + copy(paddedMessage[:64], message) + expectedCMAC = computeAESCMAC(key, paddedMessage[:]) // Run the garbled circuit OPRF result, err := runCMACTest(t, garblerInput, evaluatorInput) @@ -443,15 +450,15 @@ func TestCMAC_NISTVectors(t *testing.T) { func TestCMAC_XORSharedInputs(t *testing.T) { // Use random key and message, but XOR-share them between parties key := make([]byte, 16) - message := make([]byte, 64) + message := make([]byte, 128) rand.Read(key) rand.Read(message) // Random shares keyShare1 := make([]byte, 16) keyShare2 := make([]byte, 16) - msgShare1 := make([]byte, 64) - msgShare2 := make([]byte, 64) + msgShare1 := make([]byte, 128) + msgShare2 := make([]byte, 128) rand.Read(keyShare1) rand.Read(msgShare1) @@ -465,11 +472,11 @@ func TestCMAC_XORSharedInputs(t *testing.T) { } // Build inputs - var garblerInput, evaluatorInput [80]byte - copy(garblerInput[:64], msgShare1) - copy(garblerInput[64:], keyShare1) - copy(evaluatorInput[:64], msgShare2) - copy(evaluatorInput[64:], keyShare2) + var garblerInput, evaluatorInput [144]byte + copy(garblerInput[:128], msgShare1) + copy(garblerInput[128:], keyShare1) + copy(evaluatorInput[:128], msgShare2) + copy(evaluatorInput[128:], keyShare2) // Run the garbled circuit result, err := runCMACTest(t, garblerInput, evaluatorInput) @@ -520,7 +527,7 @@ func TestGarbledCircuit_EndToEnd(t *testing.T) { } // Create test inputs - var garblerInput, evaluatorInput [80]byte + var garblerInput, evaluatorInput [144]byte rand.Read(garblerInput[:]) rand.Read(evaluatorInput[:]) @@ -545,11 +552,11 @@ func TestGarbledCircuit_EndToEnd(t *testing.T) { // Verify the CMAC output is correct // Reconstruct combined key and message var combinedKey [16]byte - var combinedMsg [64]byte + var combinedMsg [128]byte for i := 0; i < 16; i++ { - combinedKey[i] = garblerInput[64+i] ^ evaluatorInput[64+i] + combinedKey[i] = garblerInput[128+i] ^ evaluatorInput[128+i] } - for i := 0; i < 64; i++ { + for i := 0; i < 128; i++ { combinedMsg[i] = garblerInput[i] ^ evaluatorInput[i] } @@ -591,7 +598,7 @@ func TestGarbledCircuit_Soundness(t *testing.T) { } } - var garblerInput, evaluatorInput1, evaluatorInput2 [80]byte + var garblerInput, evaluatorInput1, evaluatorInput2 [144]byte rand.Read(garblerInput[:]) rand.Read(evaluatorInput1[:]) copy(evaluatorInput2[:], evaluatorInput1[:]) @@ -616,12 +623,12 @@ func TestGarbledCircuit_Soundness(t *testing.T) { // Compute expected CMACs var combinedKey1, combinedKey2 [16]byte - var combinedMsg1, combinedMsg2 [64]byte + var combinedMsg1, combinedMsg2 [128]byte for i := 0; i < 16; i++ { - combinedKey1[i] = garblerInput[64+i] ^ evaluatorInput1[64+i] - combinedKey2[i] = garblerInput[64+i] ^ evaluatorInput2[64+i] + combinedKey1[i] = garblerInput[128+i] ^ evaluatorInput1[128+i] + combinedKey2[i] = garblerInput[128+i] ^ evaluatorInput2[128+i] } - for i := 0; i < 64; i++ { + for i := 0; i < 128; i++ { combinedMsg1[i] = garblerInput[i] ^ evaluatorInput1[i] combinedMsg2[i] = garblerInput[i] ^ evaluatorInput2[i] } @@ -661,7 +668,7 @@ func TestOutputLabelVerification_ValidLabels(t *testing.T) { } } - var garblerInput, evaluatorInput [80]byte + var garblerInput, evaluatorInput [144]byte rand.Read(garblerInput[:]) rand.Read(evaluatorInput[:]) @@ -690,7 +697,7 @@ func TestOutputLabelVerification_InvalidLabels(t *testing.T) { } } - var garblerInput [80]byte + var garblerInput [144]byte rand.Read(garblerInput[:]) _, session, _ := CMACGarblerOnline(rand.Reader, curve, garblerInput, otEntries, 0) @@ -731,7 +738,7 @@ func TestOutputLabelVerification_ModifiedLabels(t *testing.T) { } } - var garblerInput, evaluatorInput [80]byte + var garblerInput, evaluatorInput [144]byte rand.Read(garblerInput[:]) rand.Read(evaluatorInput[:]) @@ -749,7 +756,7 @@ func TestOutputLabelVerification_ModifiedLabels(t *testing.T) { } // Helper function to run CMAC test with given inputs -func runCMACTest(t *testing.T, garblerInput, evaluatorInput [80]byte) ([16]byte, error) { +func runCMACTest(t *testing.T, garblerInput, evaluatorInput [144]byte) ([16]byte, error) { curve := elliptic.P256() // Generate OT entries with precomputed choices matching evaluator's actual input bits @@ -805,13 +812,13 @@ func computeAESCMAC(key, message []byte) []byte { block.Encrypt(L, zero[:]) K1 := leftShift(L) - // Process 4 blocks (64 bytes) + // Process 8 blocks (128 bytes) var C [16]byte - for blockIdx := 0; blockIdx < 4; blockIdx++ { + for blockIdx := 0; blockIdx < 8; blockIdx++ { var M [16]byte copy(M[:], message[blockIdx*16:(blockIdx+1)*16]) - if blockIdx == 3 { + if blockIdx == 7 { // XOR with K1 for last block (assuming complete block) for i := range M { M[i] ^= K1[i] diff --git a/oprfmpc/ot_pool.go b/oprfmpc/ot_pool.go index 2603e13..23898c8 100644 --- a/oprfmpc/ot_pool.go +++ b/oprfmpc/ot_pool.go @@ -16,20 +16,20 @@ import ( // OT Pool configuration constants const ( // OTPoolInitialSize is the initial number of OTs to precompute - // 100,000 OTs = ~156 OPRFs = ~52 sessions (at 3 OPRFs/session average) - OTPoolInitialSize = 100000 + // 180,000 OTs = ~156 OPRFs = ~52 sessions (at 3 OPRFs/session average) + OTPoolInitialSize = 180000 // OTPoolExtendSize is the number of OTs to add when extending - // Extends by 50,000 to restore pool to full capacity - OTPoolExtendSize = 50000 + // Extends by 90,000 to restore pool to full capacity + OTPoolExtendSize = 90000 // OTPoolWatermark is the threshold below which extension is triggered - // 50,000 remaining = ~78 OPRFs = ~26 sessions buffer while extend completes - OTPoolWatermark = 50000 + // 90,000 remaining = ~78 OPRFs = ~26 sessions buffer while extend completes + OTPoolWatermark = 90000 // OTsPerOPRF is the number of OTs consumed per OPRF operation - // 80 bytes input = 640 bits = 640 OTs - OTsPerOPRF = 640 + // 144 bytes input = 1152 bits = 1152 OTs + OTsPerOPRF = 1152 ) // OTPoolEntry holds a single precomputed OT for the garbler (TEE_K) diff --git a/oprfmpc/serialization_fuzz_test.go b/oprfmpc/serialization_fuzz_test.go index 94a5070..ac52429 100644 --- a/oprfmpc/serialization_fuzz_test.go +++ b/oprfmpc/serialization_fuzz_test.go @@ -110,7 +110,7 @@ func FuzzDeserializeOnlinePayload(f *testing.F) { Index: i, } } - var input [80]byte + var input [144]byte payload, _, _ := CMACGarblerOnline(rand.Reader, curve, input, otEntries, 0) if payload != nil { validData := SerializeOnlinePayload(payload) @@ -252,7 +252,7 @@ func TestSerializationRoundtrip_OnlinePayload(t *testing.T) { } } - var input [80]byte + var input [144]byte rand.Read(input[:]) payload, _, err := CMACGarblerOnline(rand.Reader, curve, input, otEntries, 42) diff --git a/proto/transport.pb.go b/proto/transport.pb.go index 84f643f..c255df3 100644 --- a/proto/transport.pb.go +++ b/proto/transport.pb.go @@ -2326,7 +2326,7 @@ func (x *TEETAttestationResponse) GetAttestationReport() *AttestationReport { type OPRFRangeSpec struct { state protoimpl.MessageState `protogen:"open.v1"` TlsStart int32 `protobuf:"varint,1,opt,name=tls_start,json=tlsStart,proto3" json:"tls_start,omitempty"` - TlsLength int32 `protobuf:"varint,2,opt,name=tls_length,json=tlsLength,proto3" json:"tls_length,omitempty"` // max 64 + TlsLength int32 `protobuf:"varint,2,opt,name=tls_length,json=tlsLength,proto3" json:"tls_length,omitempty"` // max 128 unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } diff --git a/proto/transport.proto b/proto/transport.proto index d699537..fb63b64 100644 --- a/proto/transport.proto +++ b/proto/transport.proto @@ -248,7 +248,7 @@ message TEETAttestationResponse { // OPRF MPC message definitions message OPRFRangeSpec { int32 tls_start = 1; - int32 tls_length = 2; // max 64 + int32 tls_length = 2; // max 128 } message OPRFRangesSubmission { diff --git a/tee_k/oprf_handler.go b/tee_k/oprf_handler.go index 9751f11..14aa8de 100644 --- a/tee_k/oprf_handler.go +++ b/tee_k/oprf_handler.go @@ -79,7 +79,7 @@ func (t *TEEK) processQueuedOPRFRanges(sessionID string, teekState *TEEKSessionS // Validate ranges and initiate MPC for each for i, r := range teekState.OPRFRanges { - if r.TlsStart < 0 || r.TlsLength <= 0 || r.TlsLength > 64 { + if r.TlsStart < 0 || r.TlsLength <= 0 || r.TlsLength > 128 { return fmt.Errorf("invalid range %d: start=%d length=%d", i, r.TlsStart, r.TlsLength) } if int(r.TlsStart+r.TlsLength) > len(teekState.ConsolidatedKeystream) { @@ -105,14 +105,14 @@ func (t *TEEK) initiateOPRFForRange(sessionID string, teekState *TEEKSessionStat // Extract keystream for range and build garbler input keystream := teekState.ConsolidatedKeystream[r.TlsStart : r.TlsStart+r.TlsLength] - paddedKeystream, err := oprfmpc.PadZeros64(keystream, int(r.TlsLength)) + paddedKeystream, err := oprfmpc.PadZeros128(keystream, int(r.TlsLength)) if err != nil { return fmt.Errorf("failed to pad keystream: %w", err) } - var garblerInput [80]byte - copy(garblerInput[:64], paddedKeystream[:]) - copy(garblerInput[64:], teekState.OPRFKeyShare) + var garblerInput [144]byte + copy(garblerInput[:128], paddedKeystream[:]) + copy(garblerInput[128:], teekState.OPRFKeyShare) // Generate online payload using precomputed OT // Use the curve from OT precomputation state diff --git a/tee_t/oprf_evaluator.go b/tee_t/oprf_evaluator.go index 8a828b8..ce696b4 100644 --- a/tee_t/oprf_evaluator.go +++ b/tee_t/oprf_evaluator.go @@ -36,7 +36,7 @@ func (t *TEET) handleOPRFRangesFromClient(sessionID string, msg *teeproto.OPRFRa // Validate ranges against consolidated ciphertext for i, r := range msg.GetRanges() { - if r.TlsStart < 0 || r.TlsLength <= 0 || r.TlsLength > 64 { + if r.TlsStart < 0 || r.TlsLength <= 0 || r.TlsLength > 128 { return fmt.Errorf("invalid range %d: start=%d length=%d", i, r.TlsStart, r.TlsLength) } if int(r.TlsStart+r.TlsLength) > len(teetState.ConsolidatedResponseCiphertext) { @@ -117,16 +117,16 @@ func (t *TEET) handleOPRFOnlineFull(sessionID string, msg *teeproto.OPRFOnlineFu // Extract ciphertext for range ciphertext := teetState.ConsolidatedResponseCiphertext[msg.TlsStart : msg.TlsStart+msg.TlsLength] - // Pad to 64 bytes - paddedCiphertext, err := oprfmpc.PadZeros64(ciphertext, int(msg.TlsLength)) + // Pad to 128 bytes + paddedCiphertext, err := oprfmpc.PadZeros128(ciphertext, int(msg.TlsLength)) if err != nil { return fmt.Errorf("failed to pad ciphertext: %w", err) } - // Build evaluator input: [64 bytes data][16 bytes key] - var evaluatorInput [80]byte - copy(evaluatorInput[:64], paddedCiphertext[:]) - copy(evaluatorInput[64:], teetState.OPRFKeyShare) + // Build evaluator input: [128 bytes data][16 bytes key] + var evaluatorInput [144]byte + copy(evaluatorInput[:128], paddedCiphertext[:]) + copy(evaluatorInput[128:], teetState.OPRFKeyShare) // Consume OT receiver entries from precomputed pool otEntries, err := t.consumeOTReceiverEntries(int(msg.OtStartIndex), oprfmpc.OTsPerOPRF)