Skip to content

Commit e795c11

Browse files
authored
fix(queue-msg): fix flattening optimization passes (#2058)
`seq(conc(seq(a, b), ..), ..)` was flattening to `seq(conc(a, b, ..), ..)`, which is clearly incorrect.
2 parents 87ca603 + 1462296 commit e795c11

File tree

1 file changed

+88
-12
lines changed

1 file changed

+88
-12
lines changed

lib/queue-msg/src/optimize/passes.rs

+88-12
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,15 @@ impl<T: QueueMessageTypes> PurePass<T> for FlattenSeq {
210210
fn go<T: QueueMessageTypes>(msg: QueueMsg<T>) -> Vec<QueueMsg<T>> {
211211
match msg {
212212
QueueMsg::Sequence(new_seq) => new_seq.into_iter().flat_map(go).collect(),
213-
QueueMsg::Concurrent(c) => vec![conc(c.into_iter().flat_map(go))],
213+
QueueMsg::Concurrent(c) => vec![conc(c.into_iter().flat_map(|msg| {
214+
let mut msgs = go(msg);
215+
216+
match msgs.len() {
217+
0 => None,
218+
1 => Some(msgs.pop().unwrap()),
219+
_ => Some(seq(msgs)),
220+
}
221+
}))],
214222
QueueMsg::Aggregate {
215223
queue,
216224
data,
@@ -253,18 +261,16 @@ impl<T: QueueMessageTypes> PurePass<T> for FlattenConc {
253261
fn go<T: QueueMessageTypes>(msg: QueueMsg<T>) -> Vec<QueueMsg<T>> {
254262
match msg {
255263
QueueMsg::Concurrent(new_conc) => new_conc.into_iter().flat_map(go).collect(),
256-
// wrap in conc again
257-
// seq(conc(a.., conc(b..)), c..) == seq(conc(a.., b..), c..)
258-
// seq(conc(a.., conc(b..)), c..) != seq(a.., b.., c..)
259-
QueueMsg::Sequence(s) => vec![seq(s.into_iter().map(|msg| {
264+
QueueMsg::Sequence(s) => vec![seq(s.into_iter().flat_map(|msg| {
260265
let mut msgs = go(msg);
261266

262267
match msgs.len() {
263-
// return the original empty sequence
264-
0 => seq([]),
265-
// seq(a) == a
266-
1 => msgs.pop().unwrap(),
267-
_ => conc(msgs),
268+
0 => None,
269+
1 => Some(msgs.pop().unwrap()),
270+
// wrap in conc again
271+
// seq(conc(a.., conc(b..)), c..) == seq(conc(a.., b..), c..)
272+
// seq(conc(a.., conc(b..)), c..) != seq(a.., b.., c..)
273+
_ => Some(conc(msgs)),
268274
}
269275
}))],
270276
QueueMsg::Aggregate {
@@ -298,8 +304,10 @@ impl<T: QueueMessageTypes> PurePass<T> for FlattenConc {
298304
mod tests {
299305
use super::*;
300306
use crate::{
301-
data, effect, fetch, noop,
302-
test_utils::{DataA, DataB, DataC, FetchA, PrintAbc, SimpleMessage},
307+
aggregate, data, defer_relative, effect, event, fetch, noop,
308+
test_utils::{
309+
AggregatePrintAbc, DataA, DataB, DataC, FetchA, PrintAbc, SimpleEvent, SimpleMessage,
310+
},
303311
};
304312

305313
#[test]
@@ -349,4 +357,72 @@ mod tests {
349357
assert_eq!(optimized.ready, expected_output);
350358
assert_eq!(optimized.optimize_further, []);
351359
}
360+
361+
#[test]
362+
fn seq_conc_conc() {
363+
let msgs = vec![seq::<SimpleMessage>([
364+
conc([
365+
aggregate([], [], AggregatePrintAbc {}),
366+
aggregate([], [], AggregatePrintAbc {}),
367+
]),
368+
conc([
369+
aggregate([], [], AggregatePrintAbc {}),
370+
aggregate([], [], AggregatePrintAbc {}),
371+
]),
372+
conc([
373+
repeat(None, seq([event(SimpleEvent {}), defer_relative(10)])),
374+
repeat(None, seq([event(SimpleEvent {}), defer_relative(10)])),
375+
// this seq is the only message that should be flattened
376+
seq([
377+
effect(PrintAbc {
378+
a: DataA {},
379+
b: DataB {},
380+
c: DataC {},
381+
}),
382+
seq([
383+
aggregate([], [], AggregatePrintAbc {}),
384+
aggregate([], [], AggregatePrintAbc {}),
385+
aggregate([], [], AggregatePrintAbc {}),
386+
]),
387+
]),
388+
]),
389+
])];
390+
391+
let expected_output = vec![(
392+
vec![0],
393+
seq::<SimpleMessage>([
394+
conc([
395+
aggregate([], [], AggregatePrintAbc {}),
396+
aggregate([], [], AggregatePrintAbc {}),
397+
]),
398+
conc([
399+
aggregate([], [], AggregatePrintAbc {}),
400+
aggregate([], [], AggregatePrintAbc {}),
401+
]),
402+
conc([
403+
repeat(None, seq([event(SimpleEvent {}), defer_relative(10)])),
404+
repeat(None, seq([event(SimpleEvent {}), defer_relative(10)])),
405+
seq([
406+
effect(PrintAbc {
407+
a: DataA {},
408+
b: DataB {},
409+
c: DataC {},
410+
}),
411+
aggregate([], [], AggregatePrintAbc {}),
412+
aggregate([], [], AggregatePrintAbc {}),
413+
aggregate([], [], AggregatePrintAbc {}),
414+
]),
415+
]),
416+
]),
417+
)];
418+
419+
let optimized = Normalize::default().run_pass_pure(msgs.clone());
420+
421+
assert_eq!(optimized.optimize_further, expected_output);
422+
assert_eq!(optimized.ready, []);
423+
424+
let optimized = NormalizeFinal::default().run_pass_pure(msgs);
425+
assert_eq!(optimized.ready, expected_output);
426+
assert_eq!(optimized.optimize_further, []);
427+
}
352428
}

0 commit comments

Comments
 (0)