Skip to content

Commit

Permalink
Add backwards non-io-dependent latency counts
Browse files Browse the repository at this point in the history
  • Loading branch information
VonTum committed Feb 2, 2024
1 parent cf8e1b0 commit 6a11004
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 98 deletions.
4 changes: 2 additions & 2 deletions multiply_add.sus
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,15 @@ module b: HandShake hs -> {




// (a*b) + c
module multiply_add :
int a,
int b,
int c
-> int total {

reg int tmp = a * b;
total = tmp + c;
reg total = tmp + c;
}

module fibonnaci : -> int num {
Expand Down
301 changes: 206 additions & 95 deletions src/instantiation/latency_algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@ pub struct FanInOut {
pub delta_latency : i64
}

fn convert_fanin_to_fanout(fanins : &[Vec<FanInOut>]) -> Vec<Vec<FanInOut>> {
let mut fanouts : Vec<Vec<FanInOut>> = fanins.iter().map(|_| {
Vec::new()
}).collect();

for (id, fin) in fanins.iter().enumerate() {
for f in fin {
fanouts[f.other].push(FanInOut { other: id, delta_latency: f.delta_latency })
}
}

fanouts
}

/*
Algorithm:
Initialize all inputs at latency 0
Expand Down Expand Up @@ -42,130 +56,167 @@ fn count_latency_recursive(part_of_path : &mut [bool], absolute_latency : &mut [
Ok(())
}

fn count_latency(part_of_path : &mut [bool], absolute_latency : &mut [i64], fanouts : &[Vec<FanInOut>], start_node : usize, start_value : i64) -> Result<(), LatencyCountingError> {
for p in part_of_path.iter() {assert!(!*p);}

assert!(absolute_latency[start_node] == i64::MIN);
absolute_latency[start_node] = start_value;
count_latency_recursive(part_of_path, absolute_latency, fanouts, start_node).map_err(|mut nodes_involved| {
let mut nodes_iter = nodes_involved.iter().enumerate();
fn make_cycle_error_from_path(mut nodes_involved : Vec<usize>) -> LatencyCountingError {
let mut nodes_iter = nodes_involved.iter().enumerate();
let first_node_in_cycle = nodes_iter.next().unwrap().1;
for (idx, node) in nodes_iter {
if node == first_node_in_cycle {
nodes_involved.truncate(idx);
break;
}
}
return LatencyCountingError::PositiveNetLatencyCycle{nodes_involved}
})?;
LatencyCountingError::PositiveNetLatencyCycle{nodes_involved}
}

fn count_latency(part_of_path : &mut [bool], absolute_latency : &mut [i64], fanouts : &[Vec<FanInOut>], start_node : usize) -> Result<(), LatencyCountingError> {
assert!(absolute_latency[start_node] != i64::MIN);

for p in part_of_path.iter() {assert!(!*p);}
count_latency_recursive(part_of_path, absolute_latency, fanouts, start_node).map_err(make_cycle_error_from_path)?;
for p in part_of_path.iter() {assert!(!*p);}
Ok(())
}

fn solve_latencies(fanins : &[Vec<FanInOut>], fanouts : &[Vec<FanInOut>], inputs : &[usize], outputs : &[usize]) -> Result<Vec<i64>, LatencyCountingError> {
assert!(fanins.len() == fanouts.len());
fn count_latency_from(part_of_path : &mut [bool], absolute_latency : &mut [i64], fanouts : &[Vec<FanInOut>], start_node : usize, start_value : i64) -> Result<(), LatencyCountingError> {
assert!(absolute_latency[start_node] == i64::MIN);
absolute_latency[start_node] = start_value;

count_latency(part_of_path, absolute_latency, fanouts, start_node)?;
Ok(())
}

let mut part_of_path : Vec<bool> = vec![false; fanouts.len()];
fn invert_latency(latencies : &mut [i64]) {
for lat in latencies.iter_mut() {
if *lat != i64::MIN {
*lat = -*lat;
}
}
}

// Forwards are all performed in the same block. This block is then also used as the output latencies
let mut absolute_latencies_forward : Vec<i64> = vec![i64::MIN; fanouts.len()];
let mut absolute_latencies_backward_combined : Vec<i64> = vec![i64::MAX; fanouts.len()];
struct LatencySolver<'d> {
fanins : &'d [Vec<FanInOut>],
fanouts : &'d [Vec<FanInOut>],
inputs : &'d [usize],
outputs : &'d [usize],

part_of_path : Vec<bool>,

// To find input latencies based on output latencies, we use a separate block to go backwards.
// These are done one at a time, such that we can find conflicting latencies.
let mut absolute_latencies_backward_temporary : Vec<i64> = vec![i64::MIN; fanouts.len()];

let mut output_was_covered : Vec<bool> = vec![false; outputs.len()];
let mut input_node_assignments : Vec<i64> = vec![i64::MIN; inputs.len()];

input_node_assignments[0] = 0; // Provide a seed to start the algorithm

let mut last_num_valid_start_nodes = 0;
loop {
let mut cur_num_valid_start_nodes = 0;
// Add newly discovered input assignments
for (input_wire, assignment) in zip(inputs.iter(), input_node_assignments.iter()) {
if *assignment != i64::MIN {
if absolute_latencies_forward[*input_wire] == i64::MIN {
count_latency(&mut part_of_path, &mut absolute_latencies_forward, fanouts, *input_wire, *assignment)?;
} else {
// Erroneous is unreachable, because conflicting assignments should have been caught when they're put into the input_node_assignments list
assert!(absolute_latencies_forward[*input_wire] == *assignment);
absolute_latencies_backward_temporary : Vec<i64>,
touched_backwards : Vec<bool>,

output_was_covered : Vec<bool>,
input_node_assignments : Vec<i64>,
}

impl<'d> LatencySolver<'d> {
fn new(fanins : &'d [Vec<FanInOut>], fanouts : &'d [Vec<FanInOut>], inputs : &'d [usize], outputs : &'d [usize]) -> Self {
assert!(fanins.len() == fanouts.len());

// Initialize main buffers
let part_of_path : Vec<bool> = vec![false; fanouts.len()];
let absolute_latencies_backward_temporary : Vec<i64> = vec![i64::MIN; fanouts.len()];
let touched_backwards : Vec<bool> = vec![false; fanouts.len()];
let output_was_covered : Vec<bool> = vec![false; outputs.len()];
let input_node_assignments : Vec<i64> = vec![i64::MIN; inputs.len()];

Self{fanins, fanouts, inputs, outputs, part_of_path, absolute_latencies_backward_temporary, touched_backwards, output_was_covered, input_node_assignments}
}

fn seed(&mut self) {
self.input_node_assignments[0] = 0; // Provide a seed to start the algorithm
}

fn solve_latencies(&mut self) -> Result<Vec<i64>, LatencyCountingError> {
// Forwards are all performed in the same block. This block is then also used as the output latencies
let mut absolute_latencies_forward : Vec<i64> = vec![i64::MIN; self.fanouts.len()];

for t in self.touched_backwards.iter_mut() {*t = false;}

let mut last_num_valid_start_nodes = 0;
loop {
let mut cur_num_valid_start_nodes = 0;
// Add newly discovered input assignments
for (input_wire, assignment) in zip(self.inputs.iter(), self.input_node_assignments.iter()) {
if *assignment != i64::MIN {
if absolute_latencies_forward[*input_wire] == i64::MIN {
count_latency_from(&mut self.part_of_path, &mut absolute_latencies_forward, self.fanouts, *input_wire, *assignment)?;
} else {
// Erroneous is unreachable, because conflicting assignments should have been caught when they're put into the input_node_assignments list
assert!(absolute_latencies_forward[*input_wire] == *assignment);
}
cur_num_valid_start_nodes += 1;
}
cur_num_valid_start_nodes += 1;
}
}
if cur_num_valid_start_nodes == last_num_valid_start_nodes {
break;
}
if cur_num_valid_start_nodes == last_num_valid_start_nodes {
break;
}

last_num_valid_start_nodes = cur_num_valid_start_nodes;

// Find new backwards starting nodes
let mut bad_ports = Vec::new();
for (output, was_covered) in zip(outputs.iter(), output_was_covered.iter_mut()) {
if absolute_latencies_forward[*output] != i64::MIN {
if !*was_covered { // Little optimization, so we only every cover a backwards latency once
*was_covered = true;
// new latency
// Reset temporary buffer
for v in absolute_latencies_backward_temporary.iter_mut() {
*v = i64::MIN;
}
count_latency(&mut part_of_path, &mut absolute_latencies_backward_temporary, fanins, *output, -absolute_latencies_forward[*output])?;

for (input, assignment) in zip(inputs.iter(), input_node_assignments.iter_mut()) {
let found_inv_latency = absolute_latencies_backward_temporary[*input];
if found_inv_latency != i64::MIN {
if *assignment == i64::MIN {
*assignment = -found_inv_latency;
} else {
if -found_inv_latency != *assignment {
// Error because two outputs are attempting to create differing input latencies
bad_ports.push((*output, -found_inv_latency, *assignment))
}// else we're fine
last_num_valid_start_nodes = cur_num_valid_start_nodes;

// Find new backwards starting nodes
let mut bad_ports = Vec::new();
for (output, was_covered) in zip(self.outputs.iter(), self.output_was_covered.iter_mut()) {
if absolute_latencies_forward[*output] != i64::MIN {
if !*was_covered { // Little optimization, so we only every cover a backwards latency once
*was_covered = true;
// new latency
// Reset temporary buffer
for v in &self.absolute_latencies_backward_temporary {assert!(*v == i64::MIN);}
count_latency_from(&mut self.part_of_path, &mut self.absolute_latencies_backward_temporary, self.fanins, *output, -absolute_latencies_forward[*output])?;

for (input, assignment) in zip(self.inputs.iter(), self.input_node_assignments.iter_mut()) {
let found_inv_latency = self.absolute_latencies_backward_temporary[*input];
if found_inv_latency != i64::MIN {
if *assignment == i64::MIN {
*assignment = -found_inv_latency;
} else {
if -found_inv_latency != *assignment {
// Error because two outputs are attempting to create differing input latencies
bad_ports.push((*output, -found_inv_latency, *assignment))
}// else we're fine
}
}
}
}

// Add backwards latencies to combined list
for (from, to) in zip(absolute_latencies_backward_temporary.iter(), absolute_latencies_backward_combined.iter_mut()) {
if *from != i64::MIN && -*from < *to {
*to = -*from;
for (t_b, v) in zip(self.touched_backwards.iter_mut(), self.absolute_latencies_backward_temporary.iter_mut()) {
if *v != i64::MIN {
*v = i64::MIN;
*t_b = true;
}
}
}
}
}
}

if !bad_ports.is_empty() {
return Err(LatencyCountingError::ConflictingPortLatency{bad_ports})
if !bad_ports.is_empty() {
return Err(LatencyCountingError::ConflictingPortLatency{bad_ports})
}
}
}

// Also add nodes in fanin not dependent on an input to this input-output cluster.
// Nodes in fanout are included implicitly due to forward being the default direction
invert_latency(&mut absolute_latencies_forward);
for (start_node, fanin_of_output) in self.touched_backwards.iter().enumerate() {
if *fanin_of_output && (absolute_latencies_forward[start_node] != i64::MIN) {
count_latency(&mut self.part_of_path, &mut absolute_latencies_forward, self.fanins, start_node)?;
}
}
invert_latency(&mut absolute_latencies_forward);

// Check end conditions
let nodes_not_reached : Vec<usize> = absolute_latencies_forward.iter().enumerate().filter_map(|(idx, v)| (*v == i64::MIN).then_some(idx)).collect();
if nodes_not_reached.is_empty() {
Ok(absolute_latencies_forward)
} else {
Err(LatencyCountingError::DisjointNodes{nodes_not_reached})
}
}

fn convert_fanin_to_fanout(fanins : &[Vec<FanInOut>]) -> Vec<Vec<FanInOut>> {
let mut fanouts : Vec<Vec<FanInOut>> = fanins.iter().map(|_| {
Vec::new()
}).collect();

for (id, fin) in fanins.iter().enumerate() {
for f in fin {
fanouts[f.other].push(FanInOut { other: id, delta_latency: f.delta_latency })
let nodes_not_reached : Vec<usize> = absolute_latencies_forward.iter().enumerate().filter_map(|(idx, v)| (*v == i64::MIN).then_some(idx)).collect();
if nodes_not_reached.is_empty() {
Ok(absolute_latencies_forward)
} else {
Err(LatencyCountingError::DisjointNodes{nodes_not_reached})
}
}
}

fanouts
fn solve_latencies(fanins : &[Vec<FanInOut>], fanouts : &[Vec<FanInOut>], inputs : &[usize], outputs : &[usize]) -> Result<Vec<i64>, LatencyCountingError> {
let mut latency_solver = LatencySolver::new(fanins, fanouts, inputs, outputs);
latency_solver.seed();
latency_solver.solve_latencies()
}

#[cfg(test)]
Expand All @@ -176,11 +227,16 @@ mod tests {
FanInOut{other, delta_latency}
}

// makes inputs for fanins, outputs for fanouts
fn infer_ports(fanins : &[Vec<FanInOut>]) -> Vec<usize> {
fanins.iter().enumerate().filter_map(|(idx, v)| v.is_empty().then_some(idx)).collect()
}

fn solve_latencies_infer_ports(fanins : &[Vec<FanInOut>]) -> Result<Vec<i64>, LatencyCountingError> {
let fanouts = convert_fanin_to_fanout(fanins);

let inputs : Box<[usize]> = fanins.iter().enumerate().filter_map(|(idx, v)| v.is_empty().then_some(idx)).collect();
let outputs : Box<[usize]> = fanouts.iter().enumerate().filter_map(|(idx, v)| v.is_empty().then_some(idx)).collect();
let inputs = infer_ports(&fanins);
let outputs = infer_ports(&fanouts);

solve_latencies(fanins, &fanouts, &inputs, &outputs)
}
Expand All @@ -200,7 +256,7 @@ mod tests {

#[test]
fn check_correct_latency_basic() {
let graph = [
let fanins = [
/*0*/vec![],
/*1*/vec![mk_fan(0, 0)],
/*2*/vec![mk_fan(1, 1),mk_fan(5, 1)],
Expand All @@ -212,7 +268,62 @@ mod tests {

let correct_latencies = [-1,-1,1,1,0,0,0];

let found_latencies = solve_latencies_infer_ports(&graph).unwrap();
let fanouts = convert_fanin_to_fanout(&fanins);

let inputs = vec![0, 4];
let outputs = vec![3, 6];

let found_latencies = solve_latencies(&fanins, &fanouts, &inputs, &outputs).unwrap();

assert!(latencies_equal(&found_latencies, &correct_latencies), "{found_latencies:?} =lat= {correct_latencies:?}");
}

#[test]
fn check_correct_latency_with_superfluous_input() {
let fanins = [
/*0*/vec![],
/*1*/vec![mk_fan(0, 0)],
/*2*/vec![mk_fan(1, 1),mk_fan(5, 1)],
/*3*/vec![mk_fan(2, 0)],
/*4*/vec![],
/*5*/vec![mk_fan(4, 0),mk_fan(1, 1),mk_fan(7, 2)],
/*6*/vec![mk_fan(5, 0)],
/*7*/vec![] // superfluous input
];

let correct_latencies = [-1,-1,1,1,0,0,0,-2];

let fanouts = convert_fanin_to_fanout(&fanins);

let inputs = vec![0, 4];
let outputs = vec![3, 6];

let found_latencies = solve_latencies(&fanins, &fanouts, &inputs, &outputs).unwrap();

assert!(latencies_equal(&found_latencies, &correct_latencies), "{found_latencies:?} =lat= {correct_latencies:?}");
}

#[test]
fn check_correct_latency_with_superfluous_output() {
let fanins = [
/*0*/vec![],
/*1*/vec![mk_fan(0, 0)],
/*2*/vec![mk_fan(1, 1),mk_fan(5, 1)],
/*3*/vec![mk_fan(2, 0)],
/*4*/vec![],
/*5*/vec![mk_fan(4, 0),mk_fan(1, 1)],
/*6*/vec![mk_fan(5, 0)],
/*7*/vec![mk_fan(5, 2)] // superfluous output
];

let correct_latencies = [-1,-1,1,1,0,0,0,2];

let fanouts = convert_fanin_to_fanout(&fanins);

let inputs = vec![0, 4];
let outputs = vec![3, 6];

let found_latencies = solve_latencies(&fanins, &fanouts, &inputs, &outputs).unwrap();

assert!(latencies_equal(&found_latencies, &correct_latencies), "{found_latencies:?} =lat= {correct_latencies:?}");
}
Expand Down
2 changes: 1 addition & 1 deletion tree-sitter-sus
Submodule tree-sitter-sus updated 5 files
+149 −20 grammar.js
+10 −1 package.json
+670 −49 src/grammar.json
+472 −13 src/node-types.json
+4,484 −579 src/parser.c

0 comments on commit 6a11004

Please sign in to comment.