Skip to content

Commit 3a52814

Browse files
bens-schreiberBen Schreiber
andauthored
Reconstruct weights in WeightedAliasIndex (#25)
Co-authored-by: Ben Schreiber <[email protected]>
1 parent 3dc934d commit 3a52814

File tree

3 files changed

+87
-1
lines changed

3 files changed

+87
-1
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [0.5.2]
8+
9+
### API Changes
10+
- Add `WeightedAliasIndex::weights()` to reconstruct the original weights in O(n)
11+
12+
### Testing
13+
- Added a test for `WeightedAliasIndex::weights()`
14+
715
## [0.5.1]
816

917
### Testing

benches/benches/weighted.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pub fn bench(c: &mut Criterion) {
4949
(1000, 1_000_000, "1M"),
5050
];
5151
for (amount, length, len_name) in lens {
52-
let name = format!("weighted_sample_indices_{}_of_{}", amount, len_name);
52+
let name = format!("weighted_sample_indices_{amount}_of_{len_name}");
5353
c.bench_function(name.as_str(), |b| {
5454
let length = black_box(length);
5555
let amount = black_box(amount);

src/weighted/weighted_alias.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ pub struct WeightedAliasIndex<W: AliasableWeight> {
7979
no_alias_odds: Box<[W]>,
8080
uniform_index: Uniform<u32>,
8181
uniform_within_weight_sum: Uniform<W>,
82+
weight_sum: W,
8283
}
8384

8485
impl<W: AliasableWeight> WeightedAliasIndex<W> {
@@ -231,8 +232,42 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
231232
no_alias_odds,
232233
uniform_index,
233234
uniform_within_weight_sum,
235+
weight_sum,
234236
})
235237
}
238+
239+
/// Reconstructs and returns the original weights used to create the distribution.
240+
///
241+
/// `O(n)` time, where `n` is the number of weights.
242+
///
243+
/// Note: Exact values may not be recovered if `W` is a float.
244+
pub fn weights(&self) -> Vec<W> {
245+
let n = self.aliases.len();
246+
247+
// `n` was validated in the constructor.
248+
let n_converted = W::try_from_u32_lossy(n as u32).unwrap();
249+
250+
// pre-calculate the total contribution each index receives from serving
251+
// as an alias for other indices.
252+
let mut alias_contributions = vec![W::ZERO; n];
253+
for j in 0..n {
254+
if self.no_alias_odds[j] < self.weight_sum {
255+
let contribution = self.weight_sum - self.no_alias_odds[j];
256+
let alias_index = self.aliases[j] as usize;
257+
alias_contributions[alias_index] += contribution;
258+
}
259+
}
260+
261+
// Reconstruct each weight by combining its direct `no_alias_odds`
262+
// with its total `alias_contributions` and scaling the result.
263+
self.no_alias_odds
264+
.iter()
265+
.zip(&alias_contributions)
266+
.map(|(&no_alias_odd, &alias_contribution)| {
267+
(no_alias_odd + alias_contribution) / n_converted
268+
})
269+
.collect()
270+
}
236271
}
237272

238273
impl<W: AliasableWeight> Distribution<usize> for WeightedAliasIndex<W> {
@@ -271,6 +306,7 @@ where
271306
no_alias_odds: self.no_alias_odds.clone(),
272307
uniform_index: self.uniform_index,
273308
uniform_within_weight_sum: self.uniform_within_weight_sum.clone(),
309+
weight_sum: self.weight_sum,
274310
}
275311
}
276312
}
@@ -503,6 +539,48 @@ mod test {
503539
);
504540
}
505541

542+
#[test]
543+
fn test_weights_reconstruction() {
544+
// Standard integers
545+
{
546+
let weights_i32 = vec![10, 2, 8, 0, 30, 5];
547+
let dist_i32 = WeightedAliasIndex::new(weights_i32.clone()).unwrap();
548+
assert_eq!(weights_i32, dist_i32.weights());
549+
}
550+
551+
// Uniform weights
552+
{
553+
let weights_u64 = vec![1, 1, 1, 1, 1];
554+
let dist_u64 = WeightedAliasIndex::new(weights_u64.clone()).unwrap();
555+
assert_eq!(weights_u64, dist_u64.weights());
556+
}
557+
558+
// Floating point
559+
{
560+
const EPSILON: f64 = 1e-9;
561+
let weights_f64 = vec![0.5, 0.2, 0.3, 0.0, 1.5, 0.88];
562+
let dist_f64 = WeightedAliasIndex::new(weights_f64.clone()).unwrap();
563+
let reconstructed_f64 = dist_f64.weights();
564+
565+
assert_eq!(weights_f64.len(), reconstructed_f64.len());
566+
for (original, reconstructed) in weights_f64.iter().zip(reconstructed_f64.iter()) {
567+
assert!(
568+
f64::abs(original - reconstructed) < EPSILON,
569+
"Weight reconstruction failed: original {}, reconstructed {}",
570+
original,
571+
reconstructed
572+
);
573+
}
574+
}
575+
576+
// Single item
577+
{
578+
let weights_single = vec![42_u32];
579+
let dist_single = WeightedAliasIndex::new(weights_single.clone()).unwrap();
580+
assert_eq!(weights_single, dist_single.weights());
581+
}
582+
}
583+
506584
#[test]
507585
fn value_stability() {
508586
fn test_samples<W: AliasableWeight>(

0 commit comments

Comments
 (0)