@@ -79,6 +79,7 @@ pub struct WeightedAliasIndex<W: AliasableWeight> {
79
79
no_alias_odds : Box < [ W ] > ,
80
80
uniform_index : Uniform < u32 > ,
81
81
uniform_within_weight_sum : Uniform < W > ,
82
+ weight_sum : W ,
82
83
}
83
84
84
85
impl < W : AliasableWeight > WeightedAliasIndex < W > {
@@ -231,8 +232,42 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
231
232
no_alias_odds,
232
233
uniform_index,
233
234
uniform_within_weight_sum,
235
+ weight_sum,
234
236
} )
235
237
}
238
+
239
+ /// Reconstructs and returns the original weights used to create the distribution.
240
+ ///
241
+ /// `O(n)` time, where `n` is the number of weights.
242
+ ///
243
+ /// Note: Exact values may not be recovered if `W` is a float.
244
+ pub fn weights ( & self ) -> Vec < W > {
245
+ let n = self . aliases . len ( ) ;
246
+
247
+ // `n` was validated in the constructor.
248
+ let n_converted = W :: try_from_u32_lossy ( n as u32 ) . unwrap ( ) ;
249
+
250
+ // pre-calculate the total contribution each index receives from serving
251
+ // as an alias for other indices.
252
+ let mut alias_contributions = vec ! [ W :: ZERO ; n] ;
253
+ for j in 0 ..n {
254
+ if self . no_alias_odds [ j] < self . weight_sum {
255
+ let contribution = self . weight_sum - self . no_alias_odds [ j] ;
256
+ let alias_index = self . aliases [ j] as usize ;
257
+ alias_contributions[ alias_index] += contribution;
258
+ }
259
+ }
260
+
261
+ // Reconstruct each weight by combining its direct `no_alias_odds`
262
+ // with its total `alias_contributions` and scaling the result.
263
+ self . no_alias_odds
264
+ . iter ( )
265
+ . zip ( & alias_contributions)
266
+ . map ( |( & no_alias_odd, & alias_contribution) | {
267
+ ( no_alias_odd + alias_contribution) / n_converted
268
+ } )
269
+ . collect ( )
270
+ }
236
271
}
237
272
238
273
impl < W : AliasableWeight > Distribution < usize > for WeightedAliasIndex < W > {
@@ -271,6 +306,7 @@ where
271
306
no_alias_odds : self . no_alias_odds . clone ( ) ,
272
307
uniform_index : self . uniform_index ,
273
308
uniform_within_weight_sum : self . uniform_within_weight_sum . clone ( ) ,
309
+ weight_sum : self . weight_sum ,
274
310
}
275
311
}
276
312
}
@@ -503,6 +539,48 @@ mod test {
503
539
) ;
504
540
}
505
541
542
+ #[ test]
543
+ fn test_weights_reconstruction ( ) {
544
+ // Standard integers
545
+ {
546
+ let weights_i32 = vec ! [ 10 , 2 , 8 , 0 , 30 , 5 ] ;
547
+ let dist_i32 = WeightedAliasIndex :: new ( weights_i32. clone ( ) ) . unwrap ( ) ;
548
+ assert_eq ! ( weights_i32, dist_i32. weights( ) ) ;
549
+ }
550
+
551
+ // Uniform weights
552
+ {
553
+ let weights_u64 = vec ! [ 1 , 1 , 1 , 1 , 1 ] ;
554
+ let dist_u64 = WeightedAliasIndex :: new ( weights_u64. clone ( ) ) . unwrap ( ) ;
555
+ assert_eq ! ( weights_u64, dist_u64. weights( ) ) ;
556
+ }
557
+
558
+ // Floating point
559
+ {
560
+ const EPSILON : f64 = 1e-9 ;
561
+ let weights_f64 = vec ! [ 0.5 , 0.2 , 0.3 , 0.0 , 1.5 , 0.88 ] ;
562
+ let dist_f64 = WeightedAliasIndex :: new ( weights_f64. clone ( ) ) . unwrap ( ) ;
563
+ let reconstructed_f64 = dist_f64. weights ( ) ;
564
+
565
+ assert_eq ! ( weights_f64. len( ) , reconstructed_f64. len( ) ) ;
566
+ for ( original, reconstructed) in weights_f64. iter ( ) . zip ( reconstructed_f64. iter ( ) ) {
567
+ assert ! (
568
+ f64 :: abs( original - reconstructed) < EPSILON ,
569
+ "Weight reconstruction failed: original {}, reconstructed {}" ,
570
+ original,
571
+ reconstructed
572
+ ) ;
573
+ }
574
+ }
575
+
576
+ // Single item
577
+ {
578
+ let weights_single = vec ! [ 42_u32 ] ;
579
+ let dist_single = WeightedAliasIndex :: new ( weights_single. clone ( ) ) . unwrap ( ) ;
580
+ assert_eq ! ( weights_single, dist_single. weights( ) ) ;
581
+ }
582
+ }
583
+
506
584
#[ test]
507
585
fn value_stability ( ) {
508
586
fn test_samples < W : AliasableWeight > (
0 commit comments