diff --git a/variation-commons-mongodb/src/main/java/uk/ac/ebi/eva/commons/mongodb/writers/VariantSourceMongoWriter.java b/variation-commons-mongodb/src/main/java/uk/ac/ebi/eva/commons/mongodb/writers/VariantSourceMongoWriter.java index 6b602128..2de859e1 100644 --- a/variation-commons-mongodb/src/main/java/uk/ac/ebi/eva/commons/mongodb/writers/VariantSourceMongoWriter.java +++ b/variation-commons-mongodb/src/main/java/uk/ac/ebi/eva/commons/mongodb/writers/VariantSourceMongoWriter.java @@ -17,6 +17,9 @@ package uk.ac.ebi.eva.commons.mongodb.writers; import com.mongodb.client.model.IndexOptions; +import com.mongodb.client.model.UpdateOneModel; +import com.mongodb.client.model.UpdateOptions; +import com.mongodb.client.model.WriteModel; import org.bson.Document; import org.springframework.batch.item.data.MongoItemWriter; import org.springframework.data.mongodb.core.MongoOperations; @@ -25,6 +28,7 @@ import uk.ac.ebi.eva.commons.core.models.IVariantSource; import uk.ac.ebi.eva.commons.mongodb.entities.VariantSourceMongo; +import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -60,9 +64,27 @@ private void createIndexes() { @Override public void write(List items) throws Exception { - List convertedList = items.stream() + List variantSourceMongoList = items.stream() .map(VariantSourceMongo::new) .collect(Collectors.toList()); - super.write(convertedList); + + List> writes = new ArrayList<>(); + for (VariantSourceMongo variantSourceMongo : variantSourceMongoList) { + // include only shard keys as part of query + Document query = new Document() + .append(VariantSourceMongo.STUDYID_FIELD, variantSourceMongo.getStudyId()) + .append(VariantSourceMongo.FILEID_FIELD, variantSourceMongo.getFileId()) + .append(VariantSourceMongo.FILENAME_FIELD, variantSourceMongo.getFileName()); + Document update = new Document("$set", convertToMongo(variantSourceMongo)); + writes.add(new UpdateOneModel<>(query, update, new UpdateOptions().upsert(true))); + } + + if (!writes.isEmpty()) { + mongoOperations.getCollection(collection).bulkWrite(writes); + } + } + + private Document convertToMongo(VariantSourceMongo variantSourceMongo) { + return (Document) mongoOperations.getConverter().convertToMongoType(variantSourceMongo); } } diff --git a/variation-commons-mongodb/src/test/java/uk/ac/ebi/eva/commons/mongodb/writers/VariantSourceMongoWriterTest.java b/variation-commons-mongodb/src/test/java/uk/ac/ebi/eva/commons/mongodb/writers/VariantSourceMongoWriterTest.java index 228fe896..9974c29f 100644 --- a/variation-commons-mongodb/src/test/java/uk/ac/ebi/eva/commons/mongodb/writers/VariantSourceMongoWriterTest.java +++ b/variation-commons-mongodb/src/test/java/uk/ac/ebi/eva/commons/mongodb/writers/VariantSourceMongoWriterTest.java @@ -45,9 +45,7 @@ import java.util.stream.Collectors; import java.util.stream.StreamSupport; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.*; /** @@ -73,6 +71,8 @@ public class VariantSourceMongoWriterTest { private static final String STUDY_ID = "1"; + private static final String FILE_NAME = "CHICKEN_SNPS_LAYER"; + private static final String STUDY_NAME = "small"; private static final StudyType STUDY_TYPE = StudyType.COLLECTION; @@ -98,7 +98,7 @@ public void tearDown() throws Exception { @Test public void shouldWriteAllFieldsIntoMongoDb() throws Exception { - MongoCollection fileCollection = mongoOperations.getCollection (COLLECTION_FILES_NAME); + MongoCollection fileCollection = mongoOperations.getCollection(COLLECTION_FILES_NAME); VariantSourceMongoWriter filesWriter = new VariantSourceMongoWriter( mongoOperations, COLLECTION_FILES_NAME); @@ -108,7 +108,7 @@ public void shouldWriteAllFieldsIntoMongoDb() throws Exception { FindIterable cursor = fileCollection.find(); int count = 0; - for (Document next: cursor) { + for (Document next : cursor) { count++; assertNotNull(next.get(VariantSourceMongo.FILEID_FIELD)); assertNotNull(next.get(VariantSourceMongo.FILENAME_FIELD)); @@ -129,6 +129,61 @@ public void shouldWriteAllFieldsIntoMongoDb() throws Exception { assertEquals(1, count); } + @Test + public void shouldDoUpdateInCaseOfExistingDocument() throws Exception { + MongoCollection fileCollection = mongoOperations.getCollection(COLLECTION_FILES_NAME); + VariantSourceMongoWriter filesWriter = new VariantSourceMongoWriter( + mongoOperations, COLLECTION_FILES_NAME); + + // make an entry into the database + VariantSourceMongo variantSource = new VariantSourceMongo(FILE_ID, FILE_NAME, STUDY_ID, STUDY_NAME, + StudyType.AGGREGATE, Aggregation.BASIC, null, null, null); + filesWriter.write(Collections.singletonList(variantSource)); + FindIterable cursor = fileCollection.find(); + int count = 0; + for (Document next : cursor) { + count++; + assertTrue(next.get(VariantSourceMongo.FILEID_FIELD).equals(FILE_ID)); + assertTrue(next.get(VariantSourceMongo.STUDYID_FIELD).equals(STUDY_ID)); + assertTrue(next.get(VariantSourceMongo.FILENAME_FIELD).equals(FILE_NAME)); + assertTrue(next.get(VariantSourceMongo.STUDYNAME_FIELD).equals(STUDY_NAME)); + assertTrue(next.get(VariantSourceMongo.STUDYTYPE_FIELD).equals(StudyType.AGGREGATE.name())); + assertTrue(next.get(VariantSourceMongo.AGGREGATION_FIELD).equals(Aggregation.BASIC.name())); + assertNotNull(next.get(VariantSourceMongo.DATE_FIELD)); + assertTrue(next.get(VariantSourceMongo.SAMPLES_FIELD).equals(new Document())); + assertTrue(next.get(VariantSourceMongo.METADATA_FIELD).equals(new Document())); + } + assertEquals(1, count); + + // insert another document with same fileId, studyId and fileName + variantSource = getVariantSource(); + filesWriter.write(Collections.singletonList(variantSource)); + cursor = fileCollection.find(); + count = 0; + for (Document next : cursor) { + count++; + + assertTrue(next.get(VariantSourceMongo.FILEID_FIELD).equals(FILE_ID)); + assertTrue(next.get(VariantSourceMongo.STUDYID_FIELD).equals(STUDY_ID)); + assertTrue(next.get(VariantSourceMongo.FILENAME_FIELD).equals(FILE_NAME)); + assertTrue(next.get(VariantSourceMongo.STUDYNAME_FIELD).equals(STUDY_NAME)); + + // existing document should be updated with new values from the document + assertTrue(next.get(VariantSourceMongo.STUDYTYPE_FIELD).equals(StudyType.COLLECTION.name())); + assertTrue(next.get(VariantSourceMongo.AGGREGATION_FIELD).equals(Aggregation.NONE.name())); + assertNotNull(next.get(VariantSourceMongo.SAMPLES_FIELD)); + assertNotNull(next.get(VariantSourceMongo.DATE_FIELD)); + + Document meta = (Document) next.get(VariantSourceMongo.METADATA_FIELD); + assertNotNull(meta); + assertNotNull(meta.get("ALT")); + assertNotNull(meta.get("FILTER")); + assertNotNull(meta.get("INFO")); + assertNotNull(meta.get("FORMAT")); + } + assertEquals(1, count); + } + @Test public void shouldWriteSamplesWithDotsInName() throws Exception { MongoCollection fileCollection = mongoOperations.getCollection(COLLECTION_FILES_NAME); @@ -147,7 +202,7 @@ public void shouldWriteSamplesWithDotsInName() throws Exception { FindIterable cursor = fileCollection.find(); - for (Document next: cursor) { + for (Document next : cursor) { Document samples = (Document) next.get(VariantSourceMongo.SAMPLES_FIELD); Set keySet = samples.keySet(); @@ -158,8 +213,8 @@ public void shouldWriteSamplesWithDotsInName() throws Exception { @Test public void shouldCreateUniqueFileIndex() throws Exception { - MongoCollection fileCollection = mongoOperations.getCollection (COLLECTION_FILES_NAME); - VariantSourceMongoWriter filesWriter = new VariantSourceMongoWriter( mongoOperations, COLLECTION_FILES_NAME); + MongoCollection fileCollection = mongoOperations.getCollection(COLLECTION_FILES_NAME); + VariantSourceMongoWriter filesWriter = new VariantSourceMongoWriter(mongoOperations, COLLECTION_FILES_NAME); VariantSourceMongo variantSource = getVariantSource(); filesWriter.write(Collections.singletonList(variantSource)); @@ -167,13 +222,13 @@ public void shouldCreateUniqueFileIndex() throws Exception { ListIndexesIterable indexesInfo = fileCollection.listIndexes(); Set createdIndexes = StreamSupport.stream( - indexesInfo.map(index -> index.get("name").toString()).spliterator(), false) - .collect(Collectors.toSet()); + indexesInfo.map(index -> index.get("name").toString()).spliterator(), false) + .collect(Collectors.toSet()); Set expectedIndexes = new HashSet<>(); expectedIndexes.addAll(Arrays.asList("sid_1_fid_1_fname_1", "_id_")); assertEquals(expectedIndexes, createdIndexes); - for(Document indexInfo: indexesInfo) { + for (Document indexInfo : indexesInfo) { if ("sid_1_fid_1_fname_1".equals(indexInfo.get("name").toString())) { assertNotNull(indexInfo); assertEquals("true", indexInfo.get(UNIQUE_INDEX).toString()); @@ -196,9 +251,9 @@ private VariantSourceMongo getVariantSource() throws Exception { metadata.put("FILTER", "All filters passed"); metadata.put("INFO", "INFO field"); metadata.put("FORMAT", "FORMAT field"); - return new VariantSourceMongo(FILE_ID, "CHICKEN_SNPS_LAYER", STUDY_ID, STUDY_NAME, STUDY_TYPE, + return new VariantSourceMongo(FILE_ID, FILE_NAME, STUDY_ID, STUDY_NAME, STUDY_TYPE, AGGREGATION, samplesPosition, metadata, - new VariantGlobalStatsMongo(0,0,0, 0, + new VariantGlobalStatsMongo(0, 0, 0, 0, 0, 0, 0, 0, 0)); } }