Skip to content

Commit b9fd438

Browse files
authored
add Python API for sample deletion (#759)
1 parent 867f130 commit b9fd438

File tree

10 files changed

+121
-22
lines changed

10 files changed

+121
-22
lines changed

apis/python/src/tiledbvcf/binding/libtiledbvcf.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ PYBIND11_MODULE(libtiledbvcf, m) {
134134
"ingest_samples",
135135
&Writer::ingest_samples,
136136
py::call_guard<py::gil_scoped_release>())
137+
.def(
138+
"delete_samples",
139+
&Writer::delete_samples,
140+
py::call_guard<py::gil_scoped_release>())
137141
.def("get_schema_version", &Writer::get_schema_version)
138142
.def("set_tiledb_config", &Writer::set_tiledb_config)
139143
.def("set_sample_batch_size", &Writer::set_sample_batch_size)

apis/python/src/tiledbvcf/binding/writer.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,18 @@ void Writer::ingest_samples() {
232232
check_error(writer, tiledb_vcf_writer_store(writer));
233233
}
234234

235+
void Writer::delete_samples(std::vector<std::string> samples_to_delete) {
236+
std::vector<const char*> samples;
237+
for (std::string& sample : samples_to_delete) {
238+
samples.emplace_back(sample.c_str());
239+
}
240+
241+
auto writer = ptr.get();
242+
check_error(
243+
writer,
244+
tiledb_vcf_writer_delete_samples(writer, samples.data(), samples.size()));
245+
}
246+
235247
void Writer::deleter(tiledb_vcf_writer_t* w) {
236248
tiledb_vcf_writer_free(&w);
237249
}

apis/python/src/tiledbvcf/binding/writer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ class Writer {
162162

163163
void ingest_samples();
164164

165+
void delete_samples(std::vector<std::string> samples);
166+
165167
/** Returns schema version number of the TileDB VCF dataset */
166168
int32_t get_schema_version();
167169

apis/python/src/tiledbvcf/dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,14 @@ def ingest_samples(
851851
self.writer.register_samples()
852852
self.writer.ingest_samples()
853853

854+
def delete_samples(
855+
self,
856+
sample_uris: List[str] = None,
857+
):
858+
if self.mode != "w":
859+
raise Exception("Dataset not open in write mode")
860+
self.writer.delete_samples(sample_uris)
861+
854862
def tiledb_stats(self) -> str:
855863
"""
856864
Get TileDB stats as a string.

apis/python/tests/test_tiledbvcf.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,15 +1197,8 @@ def test_ingest_mode_merged(tmp_path):
11971197
assert ds.count(regions=["chrX:9032893-9032893"]) == 0
11981198

11991199

1200-
# Ok to skip is missing bcftools in Windows CI job
1201-
@pytest.mark.skipif(
1202-
os.environ.get("CI") == "true"
1203-
and platform.system() == "Windows"
1204-
and shutil.which("bcftools") is None,
1205-
reason="no bcftools",
1206-
)
1207-
def test_ingest_with_stats_v3(tmp_path):
1208-
# tiledbvcf.config_logging("debug")
1200+
@pytest.fixture
1201+
def test_stats_bgzipped_inputs(tmp_path):
12091202
tmp_path_contents = os.listdir(tmp_path)
12101203
if "stats" in tmp_path_contents:
12111204
shutil.rmtree(os.path.join(tmp_path, "stats"))
@@ -1221,23 +1214,46 @@ def test_ingest_with_stats_v3(tmp_path):
12211214
check=True,
12221215
)
12231216
bgzipped_inputs = glob.glob(os.path.join(tmp_path, "stats", "*.gz"))
1224-
# print(f"bgzipped inputs: {bgzipped_inputs}")
12251217
for vcf_file in bgzipped_inputs:
12261218
assert subprocess.run("bcftools index " + vcf_file, shell=True).returncode == 0
12271219
if "outputs" in tmp_path_contents:
12281220
shutil.rmtree(os.path.join(tmp_path, "outputs"))
12291221
if "stats_test" in tmp_path_contents:
12301222
shutil.rmtree(os.path.join(tmp_path, "stats_test"))
1231-
# tiledbvcf.config_logging("trace")
1223+
return bgzipped_inputs
1224+
1225+
1226+
@pytest.fixture
1227+
def test_stats_sample_names(test_stats_bgzipped_inputs):
1228+
assert len(test_stats_bgzipped_inputs) == 8
1229+
return [os.path.basename(file).split(".")[0] for file in test_stats_bgzipped_inputs]
1230+
1231+
1232+
@pytest.fixture
1233+
def test_stats_v3_ingestion(tmp_path, test_stats_bgzipped_inputs):
1234+
assert len(test_stats_bgzipped_inputs) == 8
1235+
# print(f"bgzipped inputs: {test_stats_bgzipped_inputs}")
12321236
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w")
12331237
ds.create_dataset(
12341238
enable_variant_stats=True, enable_allele_count=True, variant_stats_version=3
12351239
)
1236-
ds.ingest_samples(bgzipped_inputs)
1240+
ds.ingest_samples(test_stats_bgzipped_inputs)
12371241
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r")
1238-
sample_names = [os.path.basename(file).split(".")[0] for file in bgzipped_inputs]
1239-
data_frame = ds.read(
1240-
samples=sample_names,
1242+
return ds
1243+
1244+
1245+
# Ok to skip is missing bcftools in Windows CI job
1246+
@pytest.mark.skipif(
1247+
os.environ.get("CI") == "true"
1248+
and platform.system() == "Windows"
1249+
and shutil.which("bcftools") is None,
1250+
reason="no bcftools",
1251+
)
1252+
def test_ingest_with_stats_v3(
1253+
tmp_path, test_stats_v3_ingestion, test_stats_sample_names
1254+
):
1255+
data_frame = test_stats_v3_ingestion.read(
1256+
samples=test_stats_sample_names,
12411257
attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"],
12421258
set_af_filter="<0.2",
12431259
)
@@ -1249,8 +1265,8 @@ def test_ingest_with_stats_v3(tmp_path):
12491265
data_frame[data_frame["sample_name"] == "second"]["info_TILEDB_IAF"].iloc[0][0]
12501266
== 0.9375
12511267
)
1252-
data_frame = ds.read(
1253-
samples=sample_names,
1268+
data_frame = test_stats_v3_ingestion.read(
1269+
samples=test_stats_sample_names,
12541270
attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"],
12551271
scan_all_samples=True,
12561272
)
@@ -1260,25 +1276,45 @@ def test_ingest_with_stats_v3(tmp_path):
12601276
]["info_TILEDB_IAF"].iloc[0][0]
12611277
== 0.9375
12621278
)
1263-
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r")
1264-
df = ds.read_variant_stats("chr1:1-10000")
1279+
df = test_stats_v3_ingestion.read_variant_stats("chr1:1-10000")
12651280
assert df.shape == (13, 5)
12661281
df = tiledbvcf.allele_frequency.read_allele_frequency(
12671282
os.path.join(tmp_path, "stats_test"), "chr1:1-10000"
12681283
)
12691284
assert df.pos.is_monotonic_increasing
12701285
df["an_check"] = (df.ac / df.af).round(0).astype("int32")
12711286
assert df.an_check.equals(df.an)
1272-
df = ds.read_variant_stats("chr1:1-10000")
1287+
df = test_stats_v3_ingestion.read_variant_stats("chr1:1-10000")
12731288
assert df.shape == (13, 5)
12741289
df = df.to_pandas()
1275-
df = ds.read_allele_count("chr1:1-10000")
1290+
df = test_stats_v3_ingestion.read_allele_count("chr1:1-10000")
12761291
assert df.shape == (7, 6)
12771292
df = df.to_pandas()
12781293
assert sum(df["pos"] == (0, 1, 1, 2, 2, 2, 3)) == 7
12791294
assert sum(df["count"] == (8, 5, 3, 4, 2, 2, 1)) == 7
12801295

12811296

1297+
@pytest.mark.skipif(
1298+
os.environ.get("CI") == "true"
1299+
and platform.system() == "Windows"
1300+
and shutil.which("bcftools") is None,
1301+
reason="no bcftools",
1302+
)
1303+
def test_delete_samples(tmp_path, test_stats_v3_ingestion, test_stats_sample_names):
1304+
# assert test_stats_v3_ingestion.samples() == test_stats_sample_names
1305+
assert "second" in test_stats_sample_names
1306+
assert "fifth" in test_stats_sample_names
1307+
assert "third" in test_stats_sample_names
1308+
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w")
1309+
# tiledbvcf.config_logging("trace")
1310+
ds.delete_samples(["second", "fifth"])
1311+
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r")
1312+
sample_names = ds.samples()
1313+
assert "second" not in sample_names
1314+
assert "fifth" not in sample_names
1315+
assert "third" in sample_names
1316+
1317+
12821318
# Ok to skip is missing bcftools in Windows CI job
12831319
@pytest.mark.skipif(
12841320
os.environ.get("CI") == "true"

libtiledbvcf/src/c_api/tiledbvcf.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,6 +1823,21 @@ int32_t tiledb_vcf_writer_set_variant_stats_version(
18231823
return TILEDB_VCF_OK;
18241824
}
18251825

1826+
int32_t tiledb_vcf_writer_delete_samples(
1827+
tiledb_vcf_writer_t* writer, const char** samples, size_t nsamples) {
1828+
std::vector<std::string> encoded_samples;
1829+
for (size_t i = 0; i < nsamples; i++)
1830+
encoded_samples.emplace_back(samples[i]);
1831+
if (sanity_check(writer) == TILEDB_VCF_ERR)
1832+
return TILEDB_VCF_ERR;
1833+
1834+
if (SAVE_ERROR_CATCH(
1835+
writer, writer->writer_->delete_samples(encoded_samples)))
1836+
return TILEDB_VCF_ERR;
1837+
1838+
return TILEDB_VCF_OK;
1839+
}
1840+
18261841
/* ********************************* */
18271842
/* ERROR */
18281843
/* ********************************* */

libtiledbvcf/src/c_api/tiledbvcf.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,6 +1706,16 @@ tiledb_vcf_writer_set_compression_level(tiledb_vcf_writer_t* writer, int level);
17061706
TILEDBVCF_EXPORT int32_t tiledb_vcf_writer_set_variant_stats_version(
17071707
tiledb_vcf_writer_t* writer, uint8_t version);
17081708

1709+
/**
1710+
* Deletes samples from dataset
1711+
* @param writer VCF writer object
1712+
* @param samples samples to delete
1713+
* @param nsamples number of samples to delete
1714+
*/
1715+
TILEDBVCF_EXPORT int32_t tiledb_vcf_writer_delete_samples(
1716+
1717+
tiledb_vcf_writer_t* writer, const char** samples, size_t nsamples);
1718+
17091719
/* ********************************* */
17101720
/* ERROR */
17111721
/* ********************************* */

libtiledbvcf/src/dataset/tiledbvcfdataset.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,9 @@ void TileDBVCFDataset::delete_samples(
938938
const std::vector<std::string>& sample_names,
939939
const std::vector<std::string>& tiledb_config) {
940940
// Open dataset in read mode, required before calling `sample_exists`.
941-
open(uri);
941+
if (!open_) {
942+
open(uri, tiledb_config);
943+
}
942944

943945
// Define a function that deletes a sample from an array
944946
auto delete_sample = [&](Array& array, const std::string& sample) {

libtiledbvcf/src/write/writer.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,5 +1484,10 @@ void Writer::set_variant_stats_array_version(uint8_t version) {
14841484
creation_params_.variant_stats_array_version = version;
14851485
}
14861486

1487+
void Writer::delete_samples(std::vector<std::string> samples) {
1488+
dataset_->delete_samples(
1489+
ingestion_params_.uri, samples, ingestion_params_.tiledb_config);
1490+
}
1491+
14871492
} // namespace vcf
14881493
} // namespace tiledb

libtiledbvcf/src/write/writer.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,11 @@ class Writer {
382382
/** Set variant stats array version */
383383
void set_variant_stats_array_version(uint8_t version);
384384

385+
/**
386+
* @brief Delete samples from the writer's dataset.
387+
*/
388+
void delete_samples(std::vector<std::string> samples);
389+
385390
private:
386391
/* ********************************* */
387392
/* PRIVATE ATTRIBUTES */

0 commit comments

Comments
 (0)