diff --git a/ga4gh/backend.py b/ga4gh/backend.py index 9fd762505..919117bcf 100644 --- a/ga4gh/backend.py +++ b/ga4gh/backend.py @@ -274,23 +274,30 @@ def variantsGenerator(self, request): Returns a generator over the (variant, nextPageToken) pairs defined by the specified request. """ - variantSetIds = request.variantSetIds - startVariantSetIndex = 0 + # TODO this method should also use the interval search semantics + # in the readsGenerator above. + if len(request.variantSetIds) != 1: + raise exceptions.NotImplementedException( + "VariantSearch search over multiple variantSets not supported") + variantSetId = request.variantSetIds[0] + try: + variantSet = self._variantSetIdMap[request.variantSetIds[0]] + except KeyError: + raise exceptions.VariantSetNotFoundException(variantSetId) startPosition = request.start if request.pageToken is not None: - startVariantSetIndex, startPosition = self.parsePageToken( - request.pageToken, 2) - for variantSetIndex in range(startVariantSetIndex, len(variantSetIds)): - variantSetId = variantSetIds[variantSetIndex] - if variantSetId in self._variantSetIdMap: - variantSet = self._variantSetIdMap[variantSetId] - iterator = variantSet.getVariants( - request.referenceName, startPosition, request.end, - request.variantName, request.callSetIds) - for variant in iterator: - nextPageToken = "{0}:{1}".format( - variantSetIndex, variant.start + 1) - yield variant, nextPageToken + startPosition, = self.parsePageToken(request.pageToken, 1) + iterator = variantSet.getVariants( + request.referenceName, startPosition, request.end, + request.variantName, request.callSetIds) + variant = next(iterator, None) + while variant is not None: + nextVariant = next(iterator, None) + nextPageToken = None + if nextVariant is not None: + nextPageToken = "{}".format(nextVariant.start) + yield variant, nextPageToken + variant = nextVariant def callSetsGenerator(self, request): """ diff --git a/ga4gh/cli.py b/ga4gh/cli.py index 7799cf72b..1653b7512 100644 --- a/ga4gh/cli.py +++ b/ga4gh/cli.py @@ -362,12 +362,19 @@ def __init__(self, args): self._setRequest(request, args) def run(self): - if self._minimalOutput: - self._run(self._httpClient.searchVariants, 'id') - else: - results = self._httpClient.searchVariants(self._request) - for result in results: - self.printVariant(result) + # TODO this is a hack until we make a nicer interface to deal with + # multiple requests. The server does not support multiple values + # so we send of sequential requests instead. + request = self._request + variantSetIds = request.variantSetIds + for variantSetId in variantSetIds: + request.variantSetIds = [variantSetId] + if self._minimalOutput: + self._run(self._httpClient.searchVariants, 'id') + else: + results = self._httpClient.searchVariants(self._request) + for result in results: + self.printVariant(result) def printVariant(self, variant): """ diff --git a/tests/unit/test_backends.py b/tests/unit/test_backends.py index ab0a83af5..dffbc3eb5 100644 --- a/tests/unit/test_backends.py +++ b/tests/unit/test_backends.py @@ -117,7 +117,10 @@ def testSearchVariantSets(self): isinstance(response, protocol.GASearchVariantSetsResponse)) def testSearchVariants(self): + variantSetIds = [ + variantSet.id for variantSet in self.getVariantSets(pageSize=1)] request = protocol.GASearchVariantsRequest() + request.variantSetIds = variantSetIds[:1] responseStr = self._backend.searchVariants(request.toJsonString()) response = protocol.GASearchVariantsResponse.fromJsonString( responseStr) diff --git a/tests/unit/test_views.py b/tests/unit/test_views.py index 2129c53d4..eaed71fe0 100644 --- a/tests/unit/test_views.py +++ b/tests/unit/test_views.py @@ -57,13 +57,15 @@ def sendRequest(self, path, request): versionedPath, headers=headers, data=request.toJsonString()) - def sendVariantsSearch( - self, variantSetIds=[""], referenceName="", start=0, end=0): + def sendVariantsSearch(self): + response = self.sendVariantSetsSearch() + variantSets = protocol.GASearchVariantSetsResponse().fromJsonString( + response.data).variantSets request = protocol.GASearchVariantsRequest() - request.variantSetIds = variantSetIds - request.referenceName = referenceName - request.start = start - request.end = end + request.variantSetIds = [variantSets[0].id] + request.referenceName = "1" + request.start = 0 + request.end = 1 return self.sendRequest('/variants/search', request) def sendVariantSetsSearch(self, datasetIds=[""]): @@ -165,7 +167,7 @@ def testVariantsSearch(self): self.assertEqual(200, response.status_code) responseData = protocol.GASearchVariantsResponse.fromJsonString( response.data) - self.assertEqual(responseData.variants, []) + self.assertEqual(len(responseData.variants), 1) def testVariantSetsSearch(self): response = self.sendVariantSetsSearch()