diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries/StreamMeasurements.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries/StreamMeasurements.kt index 374131f1f33..de90392c047 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries/StreamMeasurements.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries/StreamMeasurements.kt @@ -22,6 +22,8 @@ import org.wfanet.measurement.gcloud.spanner.appendClause import org.wfanet.measurement.gcloud.spanner.bind import org.wfanet.measurement.internal.kingdom.Measurement import org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequest +import org.wfanet.measurement.kingdom.deploy.common.DuchyIds +import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.DuchyNotFoundException import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.readers.MeasurementReader class StreamMeasurements( @@ -77,15 +79,34 @@ class StreamMeasurements( } if (filter.hasUpdatedBefore()) { - conjuncts.add("(UpdateTime < @$UPDATED_BEFORE)") + conjuncts.add("UpdateTime < @$UPDATED_BEFORE") bind(UPDATED_BEFORE to filter.updatedBefore.toGcloudTimestamp()) } if (filter.hasCreatedBefore()) { - conjuncts.add("(CreateTime < @$CREATED_BEFORE)") + conjuncts.add("CreateTime < @$CREATED_BEFORE") bind(CREATED_BEFORE to filter.createdBefore.toGcloudTimestamp()) } + if (filter.externalDuchyId.isNotEmpty()) { + val duchyId: Long = + DuchyIds.getInternalId(filter.externalDuchyId) + ?: throw DuchyNotFoundException(filter.externalDuchyId) + conjuncts.add( + """ + @$DUCHY_ID_PARAM IN ( + SELECT DuchyId + FROM ComputationParticipants + WHERE + ComputationParticipants.MeasurementConsumerId = Measurements.MeasurementConsumerId + AND ComputationParticipants.MeasurementId = Measurements.MeasurementId + ) + """ + .trimIndent() + ) + bind(DUCHY_ID_PARAM).to(duchyId) + } + if (filter.hasAfter()) { @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf case fields cannot be null. when (filter.after.keyCase) { @@ -150,16 +171,7 @@ class StreamMeasurements( } override fun Flow.transform(): Flow { - // TODO(@tristanvuong): determine how to do this in the SQL query instead - if (requestFilter.externalDuchyId.isBlank()) { - return this - } - - return filter { value: MeasurementReader.Result -> - value.measurement.computationParticipantsList - .map { it.externalDuchyId } - .contains(requestFilter.externalDuchyId) - } + return this } companion object { @@ -168,14 +180,16 @@ class StreamMeasurements( const val EXTERNAL_MEASUREMENT_CONSUMER_CERTIFICATE_ID_PARAM = "externalMeasurementConsumerCertificateId" const val UPDATED_AFTER = "updatedAfter" + const val UPDATED_BEFORE = "updatedBefore" + const val CREATED_BEFORE = "createdBefore" const val STATES_PARAM = "states" + const val DUCHY_ID_PARAM = "duchyId" + object AfterParams { const val UPDATE_TIME = "after_updateTime" const val EXTERNAL_MEASUREMENT_CONSUMER_ID = "after_externalMeasurementConsumerId" const val EXTERNAL_MEASUREMENT_ID = "after_externalMeasurementId" const val EXTERNAL_COMPUTATION_ID = "after_externalComputationId" } - const val UPDATED_BEFORE = "updatedBefore" - const val CREATED_BEFORE = "createdBefore" } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt index 4362d96cda6..feb458017ae 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt @@ -1288,25 +1288,25 @@ abstract class MeasurementsServiceTest { } } ) - - val measurement2 = - measurementsService.createMeasurement( - createMeasurementRequest { - measurement = - MEASUREMENT.copy { - externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId - externalMeasurementConsumerCertificateId = - measurementConsumer.certificate.externalCertificateId - details = - details.copy { - protocolConfig = protocolConfig { - direct = ProtocolConfig.Direct.getDefaultInstance() - } - clearDuchyProtocolConfig() + measurementsService.createMeasurement( + createMeasurementRequest { + measurement = + MEASUREMENT.copy { + externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId + externalMeasurementConsumerCertificateId = + measurementConsumer.certificate.externalCertificateId + details = + details.copy { + protocolConfig = protocolConfig { + direct = ProtocolConfig.Direct.getDefaultInstance() } - } - } - ) + clearDuchyProtocolConfig() + } + } + } + ) + val measurement3 = + measurementsService.createMeasurement(createMeasurementRequest { measurement = measurement1 }) val streamMeasurementsRequest = streamMeasurementsRequest { limit = 2 @@ -1317,13 +1317,24 @@ abstract class MeasurementsServiceTest { measurementView = Measurement.View.COMPUTATION } - val measurements: List = + val responses: List = measurementsService.streamMeasurements(streamMeasurementsRequest).toList() - assertThat(measurements).hasSize(1) - assertThat(measurements[0].externalMeasurementId).isEqualTo(measurement1.externalMeasurementId) - assertThat(measurements[0].externalMeasurementId) - .isNotEqualTo(measurement2.externalMeasurementId) + val computationMeasurement1 = + measurementsService.getMeasurementByComputationId( + getMeasurementByComputationIdRequest { + externalComputationId = measurement1.externalComputationId + } + ) + val computationMeasurement3 = + measurementsService.getMeasurementByComputationId( + getMeasurementByComputationIdRequest { + externalComputationId = measurement3.externalComputationId + } + ) + assertThat(responses) + .containsExactly(computationMeasurement1, computationMeasurement3) + .inOrder() } @Test diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationsService.kt index 08979920555..6b1cd364f45 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationsService.kt @@ -16,7 +16,6 @@ package org.wfanet.measurement.kingdom.service.system.v1alpha import io.grpc.Status import io.grpc.StatusException -import java.util.logging.Logger import kotlin.time.Duration import kotlin.time.Duration.Companion.minutes import kotlin.time.Duration.Companion.seconds @@ -60,7 +59,8 @@ class ComputationsService( private val measurementsClient: MeasurementsCoroutineStub, private val duchyIdentityProvider: () -> DuchyIdentity = ::duchyIdentityFromContext, private val streamingTimeout: Duration = 10.minutes, - private val streamingThrottle: Duration = 1.seconds + private val streamingThrottle: Duration = 1.seconds, + private val streamingLimit: Int = DEFAULT_STREAMING_LIMIT, ) : ComputationsCoroutineImplBase() { override suspend fun getComputation(request: GetComputationRequest): Computation { val computationKey = @@ -189,6 +189,7 @@ class ComputationsService( externalDuchyId = duchyIdentityProvider().id } measurementView = Measurement.View.COMPUTATION + limit = streamingLimit } try { return measurementsClient.streamMeasurements(request) @@ -204,20 +205,24 @@ class ComputationsService( } companion object { - private val logger: Logger = Logger.getLogger(this::class.java.name) + /** + * Default limit for [org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequest] + * messages. + */ + const val DEFAULT_STREAMING_LIMIT = 50 + + private val STATES_SUBSCRIBED = + listOf( + Measurement.State.PENDING_REQUISITION_PARAMS, + Measurement.State.PENDING_PARTICIPANT_CONFIRMATION, + Measurement.State.PENDING_COMPUTATION, + Measurement.State.FAILED, + Measurement.State.CANCELLED, + Measurement.State.SUCCEEDED + ) } } -private val STATES_SUBSCRIBED = - listOf( - Measurement.State.PENDING_REQUISITION_PARAMS, - Measurement.State.PENDING_PARTICIPANT_CONFIRMATION, - Measurement.State.PENDING_COMPUTATION, - Measurement.State.FAILED, - Measurement.State.CANCELLED, - Measurement.State.SUCCEEDED - ) - private object ContinuationTokenConverter { fun encode(token: StreamActiveComputationsContinuationToken): String = token.toByteArray().base64UrlEncode() diff --git a/src/test/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationsServiceTest.kt index 2afffda8c14..3cce42bcf6c 100644 --- a/src/test/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationsServiceTest.kt @@ -542,6 +542,7 @@ class ComputationsServiceTest { } } } + limit = ComputationsService.DEFAULT_STREAMING_LIMIT } inOrder(internalMeasurementsServiceMock) {