diff --git a/src/bin/doodle/format/deflate.rs b/src/bin/doodle/format/deflate.rs index 6d6d3c07..b95acfc7 100644 --- a/src/bin/doodle/format/deflate.rs +++ b/src/bin/doodle/format/deflate.rs @@ -148,10 +148,7 @@ fn length_record(start: usize, base: &BaseModule, extra_bits: usize) -> Format { ), ( "distance-code", - Format::Dynamic(DynFormat::Huffman( - var("distance-alphabet-code-lengths-value"), - None, - )), + Format::Apply(var("distance-alphabet-format")), ), ("distance-record", distance_record(base)), ]) @@ -256,6 +253,10 @@ pub fn main(module: &mut FormatModule, base: &BaseModule) -> FormatRef { let fixed_huffman = module.define_format( "deflate.fixed_huffman", record([ + ( + "format", + Format::Dynamic(DynFormat::Huffman(fixed_code_lengths(), None)), + ), ( "codes", repeat_until_last( @@ -270,10 +271,7 @@ pub fn main(module: &mut FormatModule, base: &BaseModule) -> FormatRef { )), ), record([ - ( - "code", - Format::Dynamic(DynFormat::Huffman(fixed_code_lengths(), None)), - ), + ("code", Format::Apply(var("format"))), ( "extra", Format::MatchVariant( @@ -498,6 +496,33 @@ pub fn main(module: &mut FormatModule, base: &BaseModule) -> FormatRef { "code-length-alphabet-code-lengths", repeat_count(add(var("hclen"), Expr::U8(4)), bits3.clone()), ), + ( + "code-length-alphabet-format", + Format::Dynamic(DynFormat::Huffman( + var("code-length-alphabet-code-lengths"), + Some(Expr::Seq(vec![ + Expr::U8(16), + Expr::U8(17), + Expr::U8(18), + Expr::U8(0), + Expr::U8(8), + Expr::U8(7), + Expr::U8(9), + Expr::U8(6), + Expr::U8(10), + Expr::U8(5), + Expr::U8(11), + Expr::U8(4), + Expr::U8(12), + Expr::U8(3), + Expr::U8(13), + Expr::U8(2), + Expr::U8(14), + Expr::U8(1), + Expr::U8(15), + ])), + )), + ), ( "literal-length-distance-alphabet-code-lengths", repeat_until_seq( @@ -611,33 +636,7 @@ pub fn main(module: &mut FormatModule, base: &BaseModule) -> FormatRef { )), ), record([ - ( - "code", - Format::Dynamic(DynFormat::Huffman( - var("code-length-alphabet-code-lengths"), - Some(Expr::Seq(vec![ - Expr::U8(16), - Expr::U8(17), - Expr::U8(18), - Expr::U8(0), - Expr::U8(8), - Expr::U8(7), - Expr::U8(9), - Expr::U8(6), - Expr::U8(10), - Expr::U8(5), - Expr::U8(11), - Expr::U8(4), - Expr::U8(12), - Expr::U8(3), - Expr::U8(13), - Expr::U8(2), - Expr::U8(14), - Expr::U8(1), - Expr::U8(15), - ])), - )), - ), + ("code", Format::Apply(var("code-length-alphabet-format"))), ( "extra", Format::Match( @@ -764,6 +763,20 @@ pub fn main(module: &mut FormatModule, base: &BaseModule) -> FormatRef { Box::new(add(Expr::AsU16(Box::new(var("hdist"))), Expr::U16(1))), )), ), + ( + "distance-alphabet-format", + Format::Dynamic(DynFormat::Huffman( + var("distance-alphabet-code-lengths-value"), + None, + )), + ), + ( + "literal-length-alphabet-format", + Format::Dynamic(DynFormat::Huffman( + var("literal-length-alphabet-code-lengths-value"), + None, + )), + ), ( "codes", repeat_until_last( @@ -778,13 +791,7 @@ pub fn main(module: &mut FormatModule, base: &BaseModule) -> FormatRef { )), ), record([ - ( - "code", - Format::Dynamic(DynFormat::Huffman( - var("literal-length-alphabet-code-lengths-value"), - None, - )), - ), + ("code", Format::Apply(var("literal-length-alphabet-format"))), ( "extra", Format::MatchVariant( diff --git a/src/lib.rs b/src/lib.rs index 62fbc193..135fe588 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #![allow(clippy::new_without_default)] #![deny(rust_2018_idioms)] +use std::cell::RefCell; use std::collections::{HashMap, HashSet}; use std::ops::Add; use std::rc::Rc; @@ -113,6 +114,7 @@ pub enum ValueType { Record(Vec<(String, ValueType)>), Union(Vec<(String, ValueType)>), Seq(Box), + Format(Box), } #[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)] @@ -126,6 +128,7 @@ pub enum Value { Record(Vec<(String, Value)>), Variant(String, Box), Seq(Vec), + Format(Box), } impl Value { @@ -454,6 +457,8 @@ pub enum Format { MatchVariant(Expr, Vec<(Pattern, String, Format)>), /// Format generated dynamically Dynamic(DynFormat), + /// Apply a dynamic format + Apply(Expr), } #[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)] @@ -522,6 +527,14 @@ impl FormatModule { self.define_format_args(name, vec![], format) } + pub fn define_format_rec( + &mut self, + name: impl Into, + f: impl FnOnce(FormatRef) -> Format, + ) -> FormatRef { + self.define_format_args_rec(name, vec![], f) + } + pub fn define_format_args( &mut self, name: impl Into, @@ -544,6 +557,32 @@ impl FormatModule { FormatRef(level) } + pub fn define_format_args_rec( + &mut self, + name: impl Into, + args: Vec<(String, ValueType)>, + f: impl FnOnce(FormatRef) -> Format, + ) -> FormatRef { + let format_ref = FormatRef(self.names.len()); + let format = f(format_ref); + if let Err(()) = format.recursion_check(self, format_ref) { + panic!("format fails recursion check!"); + } + let mut scope = TypeScope::new(); + for (arg_name, arg_type) in &args { + scope.push(arg_name.clone(), arg_type.clone()); + } + let format_type = match self.infer_format_type(&mut scope, &format) { + Ok(t) => t, + Err(msg) => panic!("{msg}"), + }; + self.names.push(name.into()); + self.args.push(args); + self.formats.push(format); + self.format_types.push(format_type); + format_ref + } + fn get_name(&self, level: usize) -> &str { &self.names[level] } @@ -646,14 +685,25 @@ impl FormatModule { } Ok(t) } - Format::Dynamic(DynFormat::Huffman(_, _)) => { - // FIXME check expr type + Format::Dynamic(DynFormat::Huffman(lengths_expr, _opt_values_expr)) => { + match lengths_expr.infer_type_coerce_value(scope)? { + ValueType::Seq(t) => match &*t { + ValueType::U8 | ValueType::U16 => {} + _ => return Err(format!("Huffman: expected U8 or U16")), + }, + _ => return Err(format!("Huffman: expected Seq")), + } + // FIXME check opt_values_expr type let ts = vec![ // FIXME ("bits", alt???) ("@value".to_string(), ValueType::U16), ]; - Ok(ValueType::Record(ts)) + Ok(ValueType::Format(Box::new(ValueType::Record(ts)))) } + Format::Apply(expr) => match expr.infer_type_coerce_value(scope)? { + ValueType::Format(t) => Ok((*t).clone()), + _ => Err(format!("Apply: expected format")), + }, } } } @@ -715,6 +765,7 @@ enum Decoder { Match(Expr, Vec<(Pattern, Decoder)>), MatchVariant(Expr, Vec<(Pattern, String, Decoder)>), Dynamic(DynFormat), + Apply(Expr, RefCell>), } impl Expr { @@ -1237,7 +1288,8 @@ impl Format { .map(|(_, _, f)| f.match_bounds(module)) .reduce(Bounds::union) .unwrap(), - Format::Dynamic(DynFormat::Huffman(_, _)) => Bounds::new(1, None), + Format::Dynamic(DynFormat::Huffman(_, _)) => Bounds::exact(0), + Format::Apply(_) => Bounds::new(1, None), } } @@ -1275,6 +1327,7 @@ impl Format { branches.iter().any(|(_, _, f)| f.depends_on_next(module)) } Format::Dynamic(_) => false, + Format::Apply(_) => false, } } @@ -1288,6 +1341,93 @@ impl Format { } MatchTree::build(module, &fs, Rc::new(Next::Empty)).is_none() } + + fn recursion_check(&self, module: &FormatModule, format_ref: FormatRef) -> Result { + match self { + Format::ItemVar(level, _arg_exprs) => { + if format_ref.get_level() == *level { + Err(()) + } else { + Ok(module.get_format(*level).is_nullable(module)) + } + } + Format::Fail => Ok(false), + Format::EndOfInput => Ok(true), + Format::Align(_n) => Ok(true), + Format::Byte(_bs) => Ok(false), + Format::Union(branches) | Format::NondetUnion(branches) => { + let mut nullable = false; + for (_label, f) in branches { + nullable = nullable || f.recursion_check(module, format_ref)?; + } + Ok(nullable) + } + Format::Tuple(fields) => { + for f in fields { + if !f.recursion_check(module, format_ref)? { + return Ok(false); + } + } + Ok(true) + } + Format::Record(fields) => { + for (_label, f) in fields { + if !f.recursion_check(module, format_ref)? { + return Ok(false); + } + } + Ok(true) + } + Format::Repeat(a) => { + a.recursion_check(module, format_ref)?; + Ok(true) + } + Format::Repeat1(a) => { + a.recursion_check(module, format_ref)?; + Ok(false) + } + Format::RepeatCount(_expr, a) + | Format::RepeatUntilLast(_expr, a) + | Format::RepeatUntilSeq(_expr, a) => { + a.recursion_check(module, format_ref)?; + Ok(true) // FIXME bit sloppy but okay + } + Format::Peek(a) => { + a.recursion_check(module, format_ref)?; + Ok(true) + } + Format::PeekNot(a) => { + a.recursion_check(module, format_ref)?; + Ok(true) + } + Format::Slice(expr, a) => { + a.recursion_check(module, format_ref)?; + Ok(expr.bounds().min == 0) + } + Format::Bits(_a) => Ok(false), + Format::WithRelativeOffset(_expr, a) => { + a.recursion_check(module, format_ref)?; // technically okay if expr > 0 + Ok(true) + } + Format::Compute(_expr) => Ok(true), + Format::Match(_head, branches) => { + let mut nullable = false; + for (_pattern, f) in branches { + nullable = nullable || f.recursion_check(module, format_ref)?; + } + Ok(nullable) + } + Format::MatchVariant(_head, branches) => { + let mut nullable = false; + for (_label, _pattern, f) in branches { + nullable = nullable || f.recursion_check(module, format_ref)?; + } + Ok(nullable) + } + Format::Dynamic(DynFormat::Huffman(_, _)) => Ok(true), + Format::Apply(_expr) => Ok(false), + } + } } impl Format { @@ -1589,9 +1729,8 @@ impl<'a> MatchTreeStep<'a> { } tree } - Format::Dynamic(DynFormat::Huffman(_, _)) => { - Self::accept() // FIXME - } + Format::Dynamic(DynFormat::Huffman(_, _)) => Self::add_next(module, next), + Format::Apply(_expr) => Self::accept(), } } } @@ -1775,6 +1914,7 @@ pub enum TypeRef { U32, Tuple(Vec), Seq(Box), + Format(Box), } pub enum TypeDef { @@ -1807,6 +1947,7 @@ pub struct Compiler<'a> { record_map: HashMap, usize>, union_map: HashMap, usize>, decoder_map: HashMap<(usize, Rc>), usize>, + to_compile: Vec<(&'a Format, Rc>, usize)>, } impl<'a> Compiler<'a> { @@ -1815,12 +1956,14 @@ impl<'a> Compiler<'a> { let record_map = HashMap::new(); let union_map = HashMap::new(); let decoder_map = HashMap::new(); + let to_compile = Vec::new(); Compiler { module, program, record_map, union_map, decoder_map, + to_compile, } } @@ -1835,13 +1978,21 @@ impl<'a> Compiler<'a> { ); */ // decoder - let n = compiler.program.decoders.len(); - compiler.program.decoders.push(Decoder::Fail); - let d = Decoder::compile(&mut compiler, format)?; - compiler.program.decoders[n] = d; + compiler.queue_compile(format, Rc::new(Next::Empty)); + while let Some((f, next, n)) = compiler.to_compile.pop() { + let d = Decoder::compile_next(&mut compiler, &f, next)?; + compiler.program.decoders[n] = d; + } Ok(compiler.program) } + fn queue_compile(&mut self, f: &'a Format, next: Rc>) -> usize { + let n = self.program.decoders.len(); + self.program.decoders.push(Decoder::Fail); + self.to_compile.push((f, next, n)); + n + } + pub fn add_typedef(&mut self, t: TypeDef) -> TypeRef { let n = self.program.typedefs.len(); self.program.typedefs.push(t); @@ -2014,6 +2165,7 @@ impl TypeRef { TypeRef::Var(n) } ValueType::Seq(t) => TypeRef::Seq(Box::new(Self::from_value_type(compiler, &*t))), + ValueType::Format(t) => TypeRef::Format(Box::new(Self::from_value_type(compiler, &*t))), } } @@ -2042,6 +2194,7 @@ impl TypeRef { ValueType::Tuple(ts.iter().map(|t| t.to_value_type(typedefs)).collect()) } TypeRef::Seq(t) => ValueType::Seq(Box::new(t.to_value_type(typedefs))), + TypeRef::Format(t) => ValueType::Format(Box::new(t.to_value_type(typedefs))), } } } @@ -2076,13 +2229,7 @@ impl Decoder { let n = if let Some(n) = compiler.decoder_map.get(&(*level, next.clone())) { *n } else { - let d = Decoder::compile_next( - compiler, - compiler.module.get_format(*level), - next.clone(), - )?; - let n = compiler.program.decoders.len(); - compiler.program.decoders.push(d); + let n = compiler.queue_compile(compiler.module.get_format(*level), next.clone()); compiler.decoder_map.insert((*level, next.clone()), n); n }; @@ -2245,6 +2392,7 @@ impl Decoder { Ok(Decoder::MatchVariant(head.clone(), branches)) } Format::Dynamic(d) => Ok(Decoder::Dynamic(d.clone())), + Format::Apply(expr) => Ok(Decoder::Apply(expr.clone(), RefCell::new(HashMap::new()))), } } @@ -2480,9 +2628,16 @@ impl Decoder { } }; let f = make_huffman_codes(&lengths); - let d = Decoder::compile_one(&f).unwrap(); - d.parse(program, scope, input) - } + Ok((Value::Format(Box::new(f)), input)) + } + Decoder::Apply(expr, cache) => match expr.eval(scope) { + Value::Format(f) => cache + .borrow_mut() + .entry(*f.clone()) + .or_insert_with(|| Decoder::compile_one(&f).unwrap()) + .parse(program, scope, input), + _ => panic!("expected format value"), + }, } } } diff --git a/src/output/flat.rs b/src/output/flat.rs index 23385c1e..e1df61e9 100644 --- a/src/output/flat.rs +++ b/src/output/flat.rs @@ -166,6 +166,7 @@ fn check_covered( } } Format::Dynamic(_) => {} // FIXME + Format::Apply(_) => {} // FIXME } Ok(()) } @@ -269,6 +270,7 @@ impl<'module, W: io::Write> Context<'module, W> { Ok(()) } Format::Dynamic(_) => Ok(()), // FIXME + Format::Apply(_) => Ok(()), // FIXME } } } diff --git a/src/output/tree.rs b/src/output/tree.rs index 812d646c..ba07c7aa 100644 --- a/src/output/tree.rs +++ b/src/output/tree.rs @@ -121,6 +121,7 @@ impl<'module> MonoidalPrinter<'module> { } _ => self.is_atomic_value(value, None), }, + Value::Format(_) => false, } } } @@ -240,6 +241,7 @@ impl<'module> MonoidalPrinter<'module> { frag } Format::Dynamic(_) => self.compile_value(value), + Format::Apply(_) => self.compile_value(value), } } @@ -258,6 +260,7 @@ impl<'module> MonoidalPrinter<'module> { Value::Seq(vals) => self.compile_seq(vals, None), Value::Record(fields) => self.compile_record(fields, None), Value::Variant(label, value) => self.compile_variant(label, value, None), + Value::Format(f) => self.compile_format(f, Default::default()), } } @@ -890,6 +893,7 @@ impl<'module> MonoidalPrinter<'module> { Precedence::FORMAT_COMPOUND, ), Format::Dynamic(_) => Fragment::String("dynamic".into()), + Format::Apply(_) => Fragment::String("apply".into()), Format::ItemVar(var, args) => { let mut frag = Fragment::new();