Skip to content

Commit

Permalink
Fix the bug in the llv2 that inflates the reach when maximum_frequenc…
Browse files Browse the repository at this point in the history
…y = 1. (#1257)

Fix the bug in the llv2 setup phase that causes the aggregator to call DestroyKeysAndCounts on combinedRegisterVector when max frequency is one. Add unit test for llv2 when the max frequency is one.

(cherry picked from commit c1ea552)
  • Loading branch information
ple13 committed Oct 4, 2023
1 parent c32fe55 commit 867222c
Show file tree
Hide file tree
Showing 12 changed files with 201 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ absl::Status AddAllFrequencyNoise(
// the destroyed register flag and all counts with random values.
absl::StatusOr<std::string> DestroyKeysAndCounts(
const CompleteSetupPhaseRequest& request) {
std::string source = request.combined_register_vector();
std::string source = request.requisition_register_vector();
std::string dest;

if (source.empty()) {
Expand Down Expand Up @@ -764,15 +764,18 @@ absl::StatusOr<CompleteInitializationPhaseResponse> CompleteInitializationPhase(
absl::StatusOr<CompleteSetupPhaseResponse> CompleteSetupPhase(
const CompleteSetupPhaseRequest& request) {
StartedThreadCpuTimer timer;

CompleteSetupPhaseResponse response;
std::string* response_crv = response.mutable_combined_register_vector();

// When maximum frequency is 1, the keys from the requisition register vector
// (received from the EDPs) will be replaced with the destroy register flag,
// and the counts will be replaced with a random value.
if (request.maximum_frequency() == 1) {
*response_crv = *DestroyKeysAndCounts(request);
response_crv->append(*DestroyKeysAndCounts(request));
} else {
*response_crv = request.combined_register_vector();
response_crv->append(request.requisition_register_vector());
}
response_crv->append(request.combined_register_vector());

if (request.has_noise_parameters()) {
const RegisterNoiseGenerationParameters& noise_parameters =
Expand Down Expand Up @@ -800,7 +803,8 @@ absl::StatusOr<CompleteSetupPhaseResponse> CompleteSetupPhase(

// Resize the space to hold all output data.
size_t pos = response_crv->size();
response_crv->resize(request.combined_register_vector().size() +
response_crv->resize(request.requisition_register_vector().size() +
request.combined_register_vector().size() +
total_noise_registers_count * kBytesPerCipherRegister);

RETURN_IF_ERROR(ValidateSetupNoiseParameters(noise_parameters));
Expand Down Expand Up @@ -1139,12 +1143,11 @@ CompleteExecutionPhaseTwoAtAggregator(
// non_empty_register_count could be negative if there is too few registers in
// the sketch and the number of noise registers is smaller than the baseline.
non_empty_register_count = std::max(non_empty_register_count, 0L);
ASSIGN_OR_RETURN(
int64_t reach,
EstimateReach(request.liquid_legions_parameters().decay_rate(),
request.liquid_legions_parameters().size(),
non_empty_register_count,
request.vid_sampling_interval_width()));
ASSIGN_OR_RETURN(int64_t reach,
EstimateReach(request.sketch_parameters().decay_rate(),
request.sketch_parameters().size(),
non_empty_register_count,
request.vid_sampling_interval_width()));
response.set_reach(reach);

response.set_elapsed_cpu_time_millis(timer.ElapsedMillis());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -655,12 +655,11 @@ CompleteReachOnlyExecutionPhaseAtAggregator(
}

// Estimate the reach
ASSIGN_OR_RETURN(
int64_t reach,
EstimateReach(request.liquid_legions_parameters().decay_rate(),
request.liquid_legions_parameters().size(),
non_empty_register_count,
request.vid_sampling_interval_width()));
ASSIGN_OR_RETURN(int64_t reach,
EstimateReach(request.sketch_parameters().decay_rate(),
request.sketch_parameters().size(),
non_empty_register_count,
request.vid_sampling_interval_width()));

response.set_reach(reach);
response.set_elapsed_cpu_time_millis(timer.ElapsedMillis());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,17 +489,18 @@ class LiquidLegionsV2Mill(
val llv2Details = token.computationDetails.liquidLegionsV2
require(AGGREGATOR == llv2Details.role) { "invalid role for this function." }
val inputBlobCount = token.participantCount - 1
val requisition = dataClients.readAllRequisitionBlobs(token, duchyId)
val combinedRegisterVector = readAndCombineAllInputBlobs(token, inputBlobCount)
val (bytes, nextToken) =
existingOutputOr(token) {
val request =
dataClients
.readAllRequisitionBlobs(token, duchyId)
.concat(readAndCombineAllInputBlobs(token, inputBlobCount))
.toCompleteSetupPhaseRequest(
llv2Details,
token.requisitionsCount,
token.participantCount
)
toCompleteSetupPhaseRequest(
requisition,
combinedRegisterVector,
llv2Details,
token.requisitionsCount,
token.participantCount
)
val cryptoResult: CompleteSetupPhaseResponse = cryptoWorker.completeSetupPhase(request)
logStageDurationMetric(
token,
Expand Down Expand Up @@ -530,16 +531,17 @@ class LiquidLegionsV2Mill(
private suspend fun completeSetupPhaseAtNonAggregator(token: ComputationToken): ComputationToken {
val llv2Details = token.computationDetails.liquidLegionsV2
require(NON_AGGREGATOR == llv2Details.role) { "invalid role for this function." }
val requisition = dataClients.readAllRequisitionBlobs(token, duchyId)
val (bytes, nextToken) =
existingOutputOr(token) {
val request =
dataClients
.readAllRequisitionBlobs(token, duchyId)
.toCompleteSetupPhaseRequest(
llv2Details,
token.requisitionsCount,
token.participantCount
)
toCompleteSetupPhaseRequest(
requisition,
ByteString.EMPTY,
llv2Details,
token.requisitionsCount,
token.participantCount
)
val cryptoResult: CompleteSetupPhaseResponse = cryptoWorker.completeSetupPhase(request)
logStageDurationMetric(
token,
Expand Down Expand Up @@ -688,7 +690,7 @@ class LiquidLegionsV2Mill(
curveId = llv2Parameters.ellipticCurveId.toLong()
flagCountTuples = readAndCombineAllInputBlobs(token, 1)
maximumFrequency = maximumRequestedFrequency
liquidLegionsParameters = liquidLegionsSketchParameters {
sketchParameters = liquidLegionsSketchParameters {
decayRate = llv2Parameters.sketchParameters.decayRate
size = llv2Parameters.sketchParameters.size
}
Expand Down Expand Up @@ -954,14 +956,17 @@ class LiquidLegionsV2Mill(
)
}

private fun ByteString.toCompleteSetupPhaseRequest(
private fun toCompleteSetupPhaseRequest(
requisition: ByteString,
combinedRegisterVector: ByteString,
llv2Details: LiquidLegionsSketchAggregationV2.ComputationDetails,
totalRequisitionsCount: Int,
participantCount: Int,
): CompleteSetupPhaseRequest {
val noiseConfig = llv2Details.parameters.noise
return completeSetupPhaseRequest {
combinedRegisterVector = this@toCompleteSetupPhaseRequest
this.requisitionRegisterVector = requisition
this.combinedRegisterVector = combinedRegisterVector
maximumFrequency = llv2Details.parameters.maximumFrequency.coerceAtLeast(1)
if (noiseConfig.hasReachNoiseConfig()) {
noiseParameters = registerNoiseGenerationParameters {
Expand Down Expand Up @@ -1012,6 +1017,7 @@ class LiquidLegionsV2Mill(
directoryPath = Paths.get("any_sketch_java/src/main/java/org/wfanet/anysketch/crypto")
)
}

private val logger: Logger = Logger.getLogger(this::class.java.name)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ class ReachOnlyLiquidLegionsV2Mill(
globalReachDpNoise = rollv2Parameters.noise.reachNoiseConfig.globalReachDpNoise
}
}
liquidLegionsParameters = liquidLegionsSketchParameters {
sketchParameters = liquidLegionsSketchParameters {
decayRate = rollv2Parameters.sketchParameters.decayRate
size = rollv2Parameters.sketchParameters.size
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,27 @@ message CompleteInitializationPhaseResponse {

// The request to complete work in the setup phase.
message CompleteSetupPhaseRequest {
// The input combined register vector (CRV)
// The input register vector (RV) received from the EDPs.
// Each register contains a 3-tuple of (index, key, count), each of which is
// a 66 bytes ElGamal ciphertext. In other words, the CRV size should be
// divisible by 66*3.
// The CRV is only needed so the noise can be interleaved and hidden in the
// CRV. The registers in the CRV are unchanged, except for their orders.
bytes combined_register_vector = 1;
bytes requisition_register_vector = 1;
// The input combined register vector (CRV) received from other duchies.
// Do not set this field for non-aggregator. In the setup phase, only the
// aggregator receives data from non-aggregators.
bytes combined_register_vector = 2;
// The parameters required for generating noise registers.
// if unset, the worker only shuffles the register without adding any noise.
RegisterNoiseGenerationParameters noise_parameters = 2;
RegisterNoiseGenerationParameters noise_parameters = 3;
// The maximum frequency that should be computed.
// If set to 1, then no frequency histogram is returned.
int32 maximum_frequency = 3;
int32 maximum_frequency = 4;
// The mechanism used to generate noise.
LiquidLegionsV2NoiseConfig.NoiseMechanism noise_mechanism = 4;
LiquidLegionsV2NoiseConfig.NoiseMechanism noise_mechanism = 5;
// The maximum number of threads used by crypto actions.
int32 parallelism = 5;
int32 parallelism = 6;
}

// Response of the CompleteSetupPhase method.
Expand Down Expand Up @@ -251,7 +255,7 @@ message CompleteExecutionPhaseTwoAtAggregatorRequest {
// (maximum_frequency-1) will be the row dimension of the 2-D SKA matrix.
int32 maximum_frequency = 5;
// LiquidLegions parameters used for reach estimation.
LiquidLegionsSketchParameters liquid_legions_parameters = 6;
LiquidLegionsSketchParameters sketch_parameters = 6;
// Parameters for computing the noise baseline of the global reach DP noise
// registers added in the setup phase.
// The baseline is subtracted before reach is estimated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ message CompleteReachOnlyExecutionPhaseAtAggregatorRequest {
// The baseline is subtracted before reach is estimated.
GlobalReachDpNoiseBaseline reach_dp_noise_baseline = 5;
// LiquidLegions parameters used for reach estimation.
LiquidLegionsSketchParameters liquid_legions_parameters = 6;
LiquidLegionsSketchParameters sketch_parameters = 6;
// The sampling rate to be used by the LiquidLegionsV2 protocol.
// This is taken from the VidSamplingInterval.width parameter in the
// MeasurementSpec.
Expand Down
Loading

0 comments on commit 867222c

Please sign in to comment.