Skip to content

Commit

Permalink
[Spark] Add option to use a bed Array (#648)
Browse files Browse the repository at this point in the history
  • Loading branch information
DimitrisStaratzis authored May 8, 2024
1 parent 7cc19c1 commit a5aaf7a
Show file tree
Hide file tree
Showing 60 changed files with 519 additions and 25 deletions.
1 change: 1 addition & 0 deletions apis/spark/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ repositories {
dependencies {
compileOnly 'org.apache.spark:spark-sql_2.12:2.4.3'
compileOnly 'org.apache.spark:spark-core_2.12:2.4.3'
implementation 'io.tiledb:tiledb-java:0.24.0'
compile group: 'io.tiledb', name: 'tiledb-vcf-java', version: version
compile 'com.amazonaws:aws-java-sdk:1.11.650'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ public Optional<URI> getBedURI() {
return Optional.empty();
}

/** @return Optional uri of BED array */
public Optional<URI> getBedArrayURI() {
if (options.containsKey("bed_array")) {
return Optional.of(URI.create(options.get("bed_array")));
}
return Optional.empty();
}

/** @return Optional uri of SampleFile file */
public Optional<URI> getSampleURI() {
if (options.containsKey("samplefile")) {
Expand Down
127 changes: 116 additions & 11 deletions apis/spark/src/main/java/io/tiledb/vcf/VCFDataSourceReader.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package io.tiledb.vcf;

import io.tiledb.java.api.*;
import io.tiledb.libvcfnative.VCFBedFile;
import io.tiledb.libvcfnative.VCFReader;
import io.tiledb.util.CredentialProviderUtils;
import java.net.URI;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -215,7 +217,19 @@ public List<InputPartition<ColumnarBatch>> planBatchInputPartitions() {
// Create Spark input partitions
List<List<String>> regions = null;
if (options.getNewPartitionMethod().orElse(false)) {
regions = computeRegionPartitionsFromBedFile(numRangePartitions);

// Compute regions from bed array or bed file
Optional<URI> bedArrayURI = options.getBedArrayURI();
Optional<URI> bedURI = options.getBedURI();
if (bedArrayURI.isPresent()) {
regions = computeRegionPartitionsFromBedArray(numRangePartitions, bedArrayURI.get());
} else if (bedURI.isPresent()) {
regions = computeRegionPartitionsFromBedFile(numRangePartitions);
} else {
throw new RuntimeException(
"Can't use new_partition_method without setting bed_file or bed_array");
}

numRangePartitions = regions.size();
ranges_end = regions.size();
log.info("New partition method has yielded " + numRangePartitions + " range partitions");
Expand Down Expand Up @@ -247,10 +261,99 @@ public List<InputPartition<ColumnarBatch>> planBatchInputPartitions() {
return inputPartitions;
}

List<List<String>> computeRegionPartitionsFromBedArray(
int desiredNumRangePartitions, URI arrayURI) {

// Read bed array

try {
Map<String, List<String>> mapOfRegions = new HashMap<>();
int counter = 0;

Context ctx = new Context();
Array bedArray = new Array(ctx, arrayURI.toString(), QueryType.TILEDB_READ);

String CONTIG = "alias contig";
String START = "alias start";
String END = "alias end";

String[] keys = new String[] {CONTIG, START, END};

NativeArray contigAliasNA = bedArray.getMetadata(CONTIG, Datatype.TILEDB_STRING_ASCII);
NativeArray startAliasNA = bedArray.getMetadata(START, Datatype.TILEDB_STRING_ASCII);
NativeArray endAliasNA = bedArray.getMetadata(END, Datatype.TILEDB_STRING_ASCII);

String contigAlias = new String((byte[]) contigAliasNA.toJavaArray());
String startAlias = new String((byte[]) startAliasNA.toJavaArray());
String endAlias = new String((byte[]) endAliasNA.toJavaArray());

Query query = new Query(bedArray);
query.setLayout(Layout.TILEDB_UNORDERED);

Pair<Long, Long> estSize = query.getEstResultSizeVar(ctx, contigAlias);

// todo unsafe casting needs to be addressed in the java api
// Prepare buffers
query.setDataBuffer(
contigAlias,
new NativeArray(ctx, estSize.getSecond().intValue(), Datatype.TILEDB_STRING_ASCII));
query.setOffsetsBuffer(
contigAlias, new NativeArray(ctx, estSize.getFirst().intValue(), Datatype.TILEDB_UINT64));
query.setDataBuffer(
startAlias, new NativeArray(ctx, estSize.getFirst().intValue(), Datatype.TILEDB_UINT64));
query.setDataBuffer(
endAlias, new NativeArray(ctx, estSize.getFirst().intValue(), Datatype.TILEDB_UINT64));

do {
query.submit();
// get buffers
long[] contigOffsets = (long[]) query.getVarBuffer(contigAlias);
byte[] contigData = (byte[]) query.getBuffer(contigAlias);

String[] contigs = io.tiledb.java.api.Util.bytesToStrings(contigOffsets, contigData);
long[] start = (long[]) query.getBuffer(startAlias);
long[] end = (long[]) query.getBuffer(endAlias);

if (!(contigs.length == start.length && start.length == end.length)) {
throw new RuntimeException("There was an error reading the bed array");
}

// Put regions in map
for (int i = 0; i < contigs.length; i++) {
String shortContig = contigs[i].replace("chr", "");
String region = shortContig + ":" + start[i] + "-" + end[i] + ":" + counter;
counter++;

// Check if the key exists in the map
if (mapOfRegions.containsKey(shortContig)) {
// If the key exists, append the region string to the existing list
mapOfRegions.get(shortContig).add(region);
} else {
// If the key doesn't exist, create a new list with the region string
List<String> newList = new ArrayList<>();
newList.add(region);
mapOfRegions.put(shortContig, newList);
}
}
} while (query.getQueryStatus() == QueryStatus.TILEDB_INCOMPLETE);

List<List<String>> res = new LinkedList<>(mapOfRegions.values());

sortRegions(res, desiredNumRangePartitions);

return res;

} catch (TileDBError err) {
throw new RuntimeException(err);
}
}

List<List<String>> computeRegionPartitionsFromBedFile(int desiredNumRangePartitions) {
Optional<URI> bedURI = options.getBedURI();
if (!bedURI.isPresent()) {
throw new RuntimeException("Can't use new_partition_method without setting bed_file");
Optional<URI> bedArrayURI = options.getBedArrayURI();
if (!bedURI.isPresent() && !bedArrayURI.isPresent()) {
throw new RuntimeException(
"Can't use new_partition_method without setting bed_file or bed_array");
}

log.info("Init VCFReader for partition calculation");
Expand All @@ -273,15 +376,22 @@ List<List<String>> computeRegionPartitionsFromBedFile(int desiredNumRangePartiti
Map<String, List<String>> mapOfRegions = bedFile.getContigRegionStrings();
List<List<String>> res = new LinkedList<>(mapOfRegions.values());

sortRegions(res, desiredNumRangePartitions);

bedFile.close();
vcfReader.close();

return res;
}

private void sortRegions(List<List<String>> res, int desiredNumRangePartitions) {
// Sort the region list by size of regions in contig, largest first
res.sort(Comparator.comparingInt(List<String>::size).reversed());

// Keep splitting the larges region lists until we have the desired minimum number of range
// Keep splitting the largest region lists until we have the desired minimum number of range
// Partitions, we stop if the large region has a size of 10 or less
while (res.size() < desiredNumRangePartitions && res.get(0).size() >= 10) {

List<String> top = res.remove(0);

List<String> first = new LinkedList<>(top.subList(0, top.size() / 2));
List<String> second = new LinkedList<>(top.subList(top.size() / 2, top.size()));
res.add(first);
Expand All @@ -291,10 +401,5 @@ List<List<String>> computeRegionPartitionsFromBedFile(int desiredNumRangePartiti
res.sort(Comparator.comparingInt(List::size));
Collections.reverse(res);
}

bedFile.close();
vcfReader.close();

return res;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ private void initVCFReader() {
if (bedURI.isPresent()) {
vcfReader.setBedFile(bedURI.get().toString());
}

// Set BED array
Optional<URI> bedArrayURI = options.getBedArrayURI();
if (bedArrayURI.isPresent()) {
vcfReader.setBedFile(bedArrayURI.get().toString());
}
} else {
if (rangePartitionInfo.getRegions().isEmpty()) {
throw new RuntimeException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ public void testBedURIOptionMissing() {
Assert.assertFalse(options.getBedURI().isPresent());
}

@Test
public void testBedArrayURIOptionMissing() {
VCFDataSourceOptions options = new VCFDataSourceOptions(new DataSourceOptions(new HashMap<>()));
Assert.assertFalse(options.getBedArrayURI().isPresent());
}

@Test
public void testBedURIOption() {
URI expectedURI = URI.create("s3://foo/bar");
Expand All @@ -100,12 +106,28 @@ public void testBedURIOption() {
Assert.assertEquals(expectedURI, options.getBedURI().get());
}

@Test
public void testBedArrayURIOption() {
URI expectedURI = URI.create("s3://foo/bar");
HashMap<String, String> optionMap = new HashMap<>();
optionMap.put("bed_array", expectedURI.toString());
VCFDataSourceOptions options = new VCFDataSourceOptions(new DataSourceOptions(optionMap));
Assert.assertTrue(options.getBedArrayURI().isPresent());
Assert.assertEquals(expectedURI, options.getBedArrayURI().get());
}

@Test
public void testSampleURIOptionMissing() {
VCFDataSourceOptions options = new VCFDataSourceOptions(new DataSourceOptions(new HashMap<>()));
Assert.assertFalse(options.getBedURI().isPresent());
}

@Test
public void testArraySampleURIOptionMissing() {
VCFDataSourceOptions options = new VCFDataSourceOptions(new DataSourceOptions(new HashMap<>()));
Assert.assertFalse(options.getBedArrayURI().isPresent());
}

@Test
public void testSampleURIOption() {
URI expectedURI = URI.create("s3://foo/bar");
Expand Down
113 changes: 112 additions & 1 deletion apis/spark/src/test/java/io/tiledb/vcf/VCFDatasourceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
public class VCFDatasourceTest extends SharedJavaSparkSession {

private String testSampleGroupURI(String sampleGroupName) {
Path arraysPath = Paths.get("src", "test", "resources", "arrays", "v3", sampleGroupName);
Path arraysPath = Paths.get("src", "test", "resources", "arrays", "v4", sampleGroupName);
return "file://".concat(arraysPath.toAbsolutePath().toString());
}

Expand All @@ -39,6 +39,16 @@ private String testSimpleBEDFile() {
return "file://".concat(path.toAbsolutePath().toString());
}

private String testSimpleBEDArray() {
Path arrayPath = Paths.get("src", "test", "resources", "arrays", "bed_array");
return "file://".concat(arrayPath.toAbsolutePath().toString());
}

private String testLargeBEDArray() {
Path arrayPath = Paths.get("src", "test", "resources", "arrays", "largebedarray");
return "file://".concat(arrayPath.toAbsolutePath().toString());
}

private String testLargeBEDFile() {
Path path = Paths.get("src", "test", "resources", "E001_15_coreMarks_dense.bed.gz");
return "file://".concat(path.toAbsolutePath().toString());
Expand Down Expand Up @@ -450,6 +460,93 @@ public void testSamplePartition() {
}
}

@Test
public void testNewPartitionWithBedArray() {
int rangePartitions = 32;
int samplePartitions = 2;
Dataset<Row> dfRead =
session()
.read()
.format("io.tiledb.vcf")
.option("uri", testSampleGroupURI("ingested_2samples", "v4"))
.option("bed_array", testLargeBEDArray())
.option("new_partition_method", true)
.option("range_partitions", rangePartitions)
.option("sample_partitions", samplePartitions)
.option("tiledb.vcf.log_level", "TRACE")
.load();

List<Row> rows =
dfRead
.select("sampleName", "contig", "posStart", "posEnd", "queryBedStart", "queryBedEnd")
.collectAsList();

// query from bed file line 184134 (0-indexed line numbers)
// 1 10600 540400 15_Quies 0 . 10600 540400 255,255,255

// NOTE: queryBedEnd returns the half-open value from the bed file,
// not the inclusive value used by tiledb-vcf
int expectedBedStart = 10600; // 0-indexed
int expectedBedEnd = 540400; // half-open

for (int i = 0; i < rows.size(); i++) {
System.out.println(
String.format(
"*** %s, %s, pos=%d-%d, query=%d-%d",
rows.get(i).getString(0),
rows.get(i).getString(1),
rows.get(i).getInt(2),
rows.get(i).getInt(3),
rows.get(i).getInt(4),
rows.get(i).getInt(5)));
Assert.assertEquals(expectedBedStart, rows.get(i).getInt(4));
Assert.assertEquals(expectedBedEnd, rows.get(i).getInt(5));
}
}

@Test
public void testNewPartitionWithBedArrayVsBedFile() {
int rangePartitions = 32;
int samplePartitions = 2;
Dataset<Row> dfRead =
session()
.read()
.format("io.tiledb.vcf")
.option("uri", testSampleGroupURI("ingested_2samples", "v4"))
.option("bedfile", testLargeBEDFile())
.option("new_partition_method", true)
.option("range_partitions", rangePartitions)
.option("sample_partitions", samplePartitions)
.option("tiledb.vcf.log_level", "TRACE")
.load();

List<Row> rows =
dfRead
.select("sampleName", "contig", "posStart", "posEnd", "queryBedStart", "queryBedEnd")
.orderBy("contig")
.collectAsList();

Dataset<Row> dfRead2 =
session()
.read()
.format("io.tiledb.vcf")
.option("uri", testSampleGroupURI("ingested_2samples", "v4"))
.option("bed_array", testLargeBEDArray())
.option("new_partition_method", true)
.option("range_partitions", rangePartitions)
.option("sample_partitions", samplePartitions)
.option("tiledb.vcf.log_level", "TRACE")
.load();

List<Row> rows2 =
dfRead2
.select("sampleName", "contig", "posStart", "posEnd", "queryBedStart", "queryBedEnd")
.orderBy("contig")
.collectAsList();

Assert.assertEquals(rows, rows2);
}

@Test
public void testNewPartition() {
int rangePartitions = 32;
Expand Down Expand Up @@ -521,6 +618,20 @@ public void testBedFile() {
Assert.assertEquals(10, rows.size());
}

@Test
public void testBedArray() {
Dataset<Row> dfRead =
session()
.read()
.format("io.tiledb.vcf")
.option("uri", testSampleGroupURI("ingested_2samples"))
.option("samplefile", testSampleFile())
.option("bed_array", testSimpleBEDArray())
.load();
List<Row> rows = dfRead.select("sampleName").collectAsList();
Assert.assertEquals(10, rows.size());
}

@Test
public void testSchemaShowTopN() {
Dataset<Row> dfRead = testSampleDataset();
Expand Down
1 change: 1 addition & 0 deletions apis/spark3/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ repositories {
dependencies {
compileOnly 'org.apache.spark:spark-sql_2.12:3.2.0'
compileOnly 'org.apache.spark:spark-core_2.12:3.2.0'
implementation 'io.tiledb:tiledb-java:0.24.0'
compile group: 'io.tiledb', name: 'tiledb-vcf-java', version: version
compile 'com.amazonaws:aws-java-sdk:1.11.650'

Expand Down
Loading

0 comments on commit a5aaf7a

Please sign in to comment.