Skip to content
Merged
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 134 additions & 3 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4350,9 +4350,7 @@ def _format(operand):

let mut new_layer = self.copy_empty_like(py, vars_mode)?;

for (node, _) in op_nodes {
new_layer.push_back(py, node.clone())?;
}
new_layer.extend(py, op_nodes.iter().map(|(inst, _)| (*inst).clone()))?;

let new_layer_op_nodes = new_layer.op_nodes(false).filter_map(|node_index| {
match new_layer.dag.node_weight(node_index) {
Expand Down Expand Up @@ -6347,6 +6345,139 @@ impl DAGCircuit {
Err(DAGCircuitError::new_err("Specified node is not an op node"))
}
}

/// Adds valid instances of [PackedInstruction] to the back of the Circuit.
pub fn extend<I>(&mut self, py: Python, iter: I) -> PyResult<Vec<NodeIndex>>
where
I: IntoIterator<Item = PackedInstruction>,
{
// Create HashSets to keep track of each bit/var's last node
let mut qubit_last_nodes: HashMap<Qubit, NodeIndex> = HashMap::default();
let mut clbit_last_nodes: HashMap<Clbit, NodeIndex> = HashMap::default();
// TODO: Refactor once Vars are in rust
// Dict [ Var: (int, VarWeight)]
let vars_last_nodes: Bound<PyDict> = PyDict::new_bound(py);

// Store new nodes to return
let mut new_nodes = vec![];
for instr in iter {
let op_name = instr.op.name();
let (all_cbits, vars): (Vec<Clbit>, Option<Vec<PyObject>>) = {
if self.may_have_additional_wires(py, &instr) {
let mut clbits: HashSet<Clbit> =
HashSet::from_iter(self.cargs_interner.get(instr.clbits).iter().copied());
let (additional_clbits, additional_vars) =
self.additional_wires(py, instr.op.view(), instr.condition())?;
for clbit in additional_clbits {
clbits.insert(clbit);
}
(clbits.into_iter().collect(), Some(additional_vars))
} else {
(self.cargs_interner.get(instr.clbits).to_vec(), None)
}
};

// Increment the operation count
self.increment_op(op_name);

// Get the correct qubit indices
let qubits_id = instr.qubits;

// Insert op-node to graph.
let new_node = self.dag.add_node(NodeType::Operation(instr));
new_nodes.push(new_node);

// Check all the qubits in this instruction.
for qubit in self.qargs_interner.get(qubits_id) {
// Retrieve each qubit's last node
let qubit_last_node = if let Some(node) = qubit_last_nodes.remove(qubit) {
node
} else {
let output_node = self.qubit_io_map[qubit.0 as usize][1];
let (edge_id, predecessor_node) = self
.dag
.edges_directed(output_node, Incoming)
.next()
.map(|edge| (edge.id(), edge.source()))
.unwrap();
self.dag.remove_edge(edge_id);
predecessor_node
};
qubit_last_nodes.entry(*qubit).or_insert(new_node);
self.dag
.add_edge(qubit_last_node, new_node, Wire::Qubit(*qubit));
}

// Check all the clbits in this instruction.
for clbit in all_cbits {
let clbit_last_node = if let Some(node) = clbit_last_nodes.remove(&clbit) {
node
} else {
let output_node = self.clbit_io_map[clbit.0 as usize][1];
let (edge_id, predecessor_node) = self
.dag
.edges_directed(output_node, Incoming)
.next()
.map(|edge| (edge.id(), edge.source()))
.unwrap();
self.dag.remove_edge(edge_id);
predecessor_node
};
clbit_last_nodes.entry(clbit).or_insert(new_node);
self.dag
.add_edge(clbit_last_node, new_node, Wire::Clbit(clbit));
}

// If available, check all the vars in this instruction
for var in vars.iter().flatten() {
let var_last_node = if let Some(result) = vars_last_nodes.get_item(var)? {
let node: usize = result.extract()?;
vars_last_nodes.del_item(var)?;
NodeIndex::new(node)
} else {
let output_node = self.var_output_map.get(py, var).unwrap();
let (edge_id, predecessor_node) = self
.dag
.edges_directed(output_node, Incoming)
.next()
.map(|edge| (edge.id(), edge.source()))
.unwrap();
self.dag.remove_edge(edge_id);
predecessor_node
};

vars_last_nodes.set_item(var, new_node.index())?;
if var_last_node == new_node {
// TODO: Fix instances of duplicate nodes for Vars
continue;
}
self.dag
.add_edge(var_last_node, new_node, Wire::Var(var.clone_ref(py)));
}
}

// Add the output_nodes back to qargs
for (qubit, node) in qubit_last_nodes {
let output_node = self.qubit_io_map[qubit.0 as usize][1];
self.dag.add_edge(node, output_node, Wire::Qubit(qubit));
}

// Add the output_nodes back to cargs
for (clbit, node) in clbit_last_nodes {
let output_node = self.clbit_io_map[clbit.0 as usize][1];
self.dag.add_edge(node, output_node, Wire::Clbit(clbit));
}

// Add the output_nodes back to vars
for item in vars_last_nodes.items() {
let (var, node): (PyObject, usize) = item.extract()?;
let output_node = self.var_output_map.get(py, &var).unwrap();
self.dag
.add_edge(NodeIndex::new(node), output_node, Wire::Var(var));
}

Ok(new_nodes)
}
}

/// Add to global phase. Global phase can only be Float or ParameterExpression so this
Expand Down