From 39e867a9970824b5d3438048fc56ebb95d4a5016 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Mon, 9 Jun 2025 17:44:58 +0200 Subject: [PATCH 01/17] Fix #840: Add GraphML serializer This commit adds a GraphML serializer: ```python def write_graphml( graphs: list[PyGraph | PyDiGraph], keys: list[tuple[str, Domain, str, Type, Any]], path: str, /, compression: str | None = ..., ) -> None: ... ``` `keys` is a list of tuples: id, domain, name of the key, type, and default value. This commit also introduces the `read_graphml_with_keys` function, which returns the key definitions in the same format, along with the list of parsed graphs. The implementation preserves the ids of graphs, nodes, and edges when possible. If some ids conflict, fresh ids are generated in the written GraphML file. The `read_graphml` function has also been updated to store the graph id in the graph attributes, just like node and edge ids are stored in the corresponding attributes. The `write_graphml` function supports gzip compression, as does `read_graphml`. Note that the JSON node-link serializer (the other part of #840) was already implemented in #1091. Compared to #1462: - Keys are passed explicitly instead of being inferred (which allows to use the types `float` and `int`, and to use default values); - Attributes for graphs, nodes, and edges are taken from the weight of elements, instead of relying on callbacks. This allows write_graphml to act as a proper reciprocal of read_graphml. Round-trip tests have been added. - IDs are taken from attributes when possible, instead of being generated from indices. - Multiple graphs can be written to the same file. - Gzip compression is supported. - Tests have been added. Regarding @IvanIsCoding's comment (https://github.com/Qiskit/rustworkx/pull/1462#issuecomment-2951935390), about using https://github.com/jonasbb/petgraph-graphml: - Rustworkx's `graphml.rs` introduces an internal `Graph` data structure, which is used for `read_graphml`. It is natural to have `write_graphml` rely on the same data structure. - `petgraph-graphml` generates ids from indices, which prevents us from preserving ids accross the `read_graphml`/`write_graphml` round trip. --- rustworkx/rustworkx.pyi | 26 ++ src/graphml.rs | 563 +++++++++++++++++++++++++++++++++++++--- src/lib.rs | 4 + tests/test_graphml.py | 205 +++++++++++++-- 4 files changed, 741 insertions(+), 57 deletions(-) diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index a9631e8c7d..a8f30e7a2a 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -68,6 +68,20 @@ class ColoringStrategy: Saturation: Any IndependentSet: Any +@final +class Domain: + Node: Domain + Edge: Domain + Graph: Domain + All: Domain + +@final +class Type: + Boolean: Type + Int: Type + Float: Type + Long: Type + # Cartesian product def digraph_cartesian_product( @@ -680,11 +694,23 @@ def directed_random_bipartite_graph( # Read Write +def read_graphml_with_keys( + path: str, + /, + compression: str | None = ..., +) -> tuple[list[tuple[str, Domain, str, Type, Any]], list[PyGraph | PyDiGraph]]: ... def read_graphml( path: str, /, compression: str | None = ..., ) -> list[PyGraph | PyDiGraph]: ... +def write_graphml( + graphs: list[PyGraph | PyDiGraph], + keys: list[tuple[str, Domain, str, Type, Any]], + path: str, + /, + compression: str | None = ..., +) -> None: ... def digraph_node_link_json( graph: PyDiGraph[_S, _T], /, diff --git a/src/graphml.rs b/src/graphml.rs index b5d61b9981..a0c459c870 100644 --- a/src/graphml.rs +++ b/src/graphml.rs @@ -12,26 +12,29 @@ #![allow(clippy::borrow_as_ptr)] +use std::borrow::{Borrow, Cow}; use std::convert::From; use std::ffi::OsStr; use std::fs::File; -use std::io::{BufRead, BufReader}; +use std::io::{BufRead, BufReader, BufWriter}; use std::iter::FromIterator; use std::num::{ParseFloatError, ParseIntError}; use std::path::Path; use std::str::ParseBoolError; use flate2::bufread::GzDecoder; -use hashbrown::HashMap; +use flate2::write::GzEncoder; +use flate2::Compression; +use hashbrown::{HashMap, HashSet}; use indexmap::IndexMap; -use quick_xml::events::{BytesStart, Event}; +use quick_xml::events::{BytesDecl, BytesStart, BytesText, Event}; use quick_xml::name::QName; use quick_xml::Error as XmlError; -use quick_xml::Reader; +use quick_xml::{Reader, Writer}; use petgraph::algo; -use petgraph::{Directed, Undirected}; +use petgraph::{Directed, EdgeType, Undirected}; use pyo3::exceptions::PyException; use pyo3::prelude::*; @@ -46,6 +49,7 @@ pub enum Error { NotFound(String), UnSupported(String), InvalidDoc(String), + IO(String), } impl From for Error { @@ -76,6 +80,13 @@ impl From for Error { } } +impl From for Error { + #[inline] + fn from(e: std::io::Error) -> Error { + Error::IO(format!("Input/output error: {}", e)) + } +} + impl From for PyErr { #[inline] fn from(error: Error) -> PyErr { @@ -84,7 +95,8 @@ impl From for PyErr { | Error::ParseValue(msg) | Error::NotFound(msg) | Error::UnSupported(msg) - | Error::InvalidDoc(msg) => PyException::new_err(msg), + | Error::InvalidDoc(msg) + | Error::IO(msg) => PyException::new_err(msg), } } } @@ -112,15 +124,32 @@ fn xml_attribute<'a>(element: &'a BytesStart<'a>, key: &[u8]) -> Result for Domain { + type Error = (); + + fn try_from(value: &[u8]) -> Result { + match value { + b"node" => Ok(Domain::Node), + b"edge" => Ok(Domain::Edge), + b"graph" => Ok(Domain::Graph), + b"all" => Ok(Domain::All), + _ => Err(()), + } + } +} + +#[pyclass(eq)] +#[derive(Clone, Copy, PartialEq)] +pub enum Type { Boolean, Int, Float, @@ -129,7 +158,20 @@ enum Type { Long, } -#[derive(Clone)] +impl Into<&'static str> for Type { + fn into(self) -> &'static str { + match self { + Type::Boolean => "boolean", + Type::Int => "int", + Type::Float => "float", + Type::Double => "double", + Type::String => "string", + Type::Long => "long", + } + } +} + +#[derive(Clone, PartialEq)] enum Value { Boolean(bool), Int(isize), @@ -140,6 +182,27 @@ enum Value { UnDefined, } +impl Value { + fn serialize(&self) -> Option> { + match self { + Value::Boolean(val) => Some(Cow::from(val.to_string())), + Value::Int(val) => Some(Cow::from(val.to_string())), + Value::Float(val) => Some(Cow::from(val.to_string())), + Value::Double(val) => Some(Cow::from(val.to_string())), + Value::String(val) => Some(Cow::from(val)), + Value::Long(val) => Some(Cow::from(val.to_string())), + Value::UnDefined => None, + } + } + + fn to_id(&self) -> PyResult<&str> { + match self { + Value::String(value_str) => Ok(value_str), + _ => Err(PyException::new_err("Expected string value for id")), + } + } +} + impl<'py> IntoPyObject<'py> for Value { type Target = PyAny; type Output = Bound<'py, Self::Target>; @@ -158,6 +221,41 @@ impl<'py> IntoPyObject<'py> for Value { } } +impl Value { + fn from_pyobject<'py>(ob: &Bound<'py, PyAny>, ty: Type) -> PyResult { + let value = match ty { + Type::Boolean => Value::Boolean(ob.extract::()?), + Type::Int => Value::Int(ob.extract::()?), + Type::Float => Value::Float(ob.extract::()?), + Type::Double => Value::Double(ob.extract::()?), + Type::String => Value::String(ob.extract::()?), + Type::Long => Value::Long(ob.extract::()?), + }; + Ok(value) + } +} + +impl<'py> FromPyObject<'py> for Value { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + if let Ok(value) = ob.extract::() { + return Ok(Value::Boolean(value)); + } + if let Ok(value) = ob.extract::() { + return Ok(Value::Int(value)); + } + if let Ok(value) = ob.extract::() { + return Ok(Value::Float(value)); + } + if let Ok(value) = ob.extract::() { + return Ok(Value::Double(value)); + } + if let Ok(value) = ob.extract::() { + return Ok(Value::String(value)); + } + Ok(Value::UnDefined) + } +} + struct Key { name: String, ty: Type, @@ -200,6 +298,7 @@ enum Direction { } struct Graph { + id: Option, dir: Direction, nodes: Vec, edges: Vec, @@ -207,11 +306,12 @@ struct Graph { } impl Graph { - fn new<'a, I>(dir: Direction, default_attrs: I) -> Self + fn new<'a, I>(id: Option, dir: Direction, default_attrs: I) -> Self where I: Iterator, { Self { + id, dir, nodes: Vec::new(), edges: Vec::new(), @@ -273,9 +373,13 @@ impl<'py> IntoPyObject<'py> for Graph { type Output = Bound<'py, Self::Target>; type Error = PyErr; - fn into_pyobject(self, py: Python<'py>) -> Result { + fn into_pyobject(mut self, py: Python<'py>) -> Result { macro_rules! make_graph { ($graph:ident) => { + // Write the graph id from GraphML doc into the graph data payload. + if let Some(id) = self.id { + self.attributes.insert(String::from("id"), Value::String(id.clone())); + } let mut mapping = HashMap::with_capacity(self.nodes.len()); for mut node in self.nodes { // Write the node id from GraphML doc into the node data payload @@ -340,6 +444,172 @@ impl<'py> IntoPyObject<'py> for Graph { } } +struct GraphElementInfo { + attributes: HashMap, + id: Option, +} + +impl Default for GraphElementInfo { + fn default() -> Self { + Self { + attributes: HashMap::new(), + id: None, + } + } +} + +struct GraphElementInfos { + vec: Vec<(Index, GraphElementInfo)>, + id_taken: HashSet, +} + +impl GraphElementInfos { + fn new() -> Self { + Self { + vec: vec![], + id_taken: HashSet::new(), + } + } + + fn insert<'py>( + &mut self, + py: Python<'py>, + index: Index, + weight: Option<&Py>, + ) -> PyResult<()> { + let element_info = weight + .and_then(|data| { + data.extract::>(py) + .ok() + .map(|mut attributes| -> PyResult { + let id = attributes + .remove_entry("id") + .map(|(id, value)| -> PyResult> { + let value_str = value.to_id()?; + if self.id_taken.contains(value_str) { + attributes.insert(id, value); + Ok(None) + } else { + self.id_taken.insert(value_str.to_string()); + Ok(Some(value_str.to_string())) + } + }) + .unwrap_or_else(|| Ok(None))?; + Ok(GraphElementInfo { + attributes: attributes.into_iter().collect(), + id, + }) + }) + }) + .unwrap_or_else(|| Ok(GraphElementInfo::default()))?; + self.vec.push((index, element_info)); + Ok(()) + } +} + +impl Graph { + fn try_from_stable<'py, Ty: EdgeType>( + py: Python<'py>, + dir: Direction, + pygraph: &StablePyGraph, + attrs: &PyObject, + ) -> PyResult { + let mut attrs: Option> = attrs.extract(py).ok(); + let id = attrs + .as_mut() + .and_then(|attributes| { + attributes + .remove("id") + .map(|v| v.to_id().map(|id| id.to_string())) + }) + .transpose()?; + let mut graph = Graph::new(id, dir, std::iter::empty()); + if let Some(attributes) = attrs { + graph.attributes.extend(attributes); + } + let mut node_infos = GraphElementInfos::new(); + for node_index in pygraph.node_indices() { + node_infos.insert(py, node_index, pygraph.node_weight(node_index))?; + } + let mut edge_infos = GraphElementInfos::new(); + for edge_index in pygraph.edge_indices() { + edge_infos.insert(py, edge_index, pygraph.edge_weight(edge_index))?; + } + let mut node_ids = HashMap::new(); + let mut fresh_index_counter = 0; + for (node_index, element_info) in node_infos.vec { + let id = element_info.id.unwrap_or_else(|| loop { + let id = format!("n{fresh_index_counter}"); + fresh_index_counter += 1; + if node_infos.id_taken.contains(&id) { + continue; + } + node_infos.id_taken.insert(id.clone()); + break id; + }); + graph.nodes.push(Node { + id: id.clone(), + data: element_info.attributes, + }); + node_ids.insert(node_index, id); + } + for (edge_index, element_info) in edge_infos.vec { + if let Some((source, target)) = pygraph.edge_endpoints(edge_index) { + let source = node_ids + .get(&source) + .ok_or(PyException::new_err("Missing source"))?; + let target = node_ids + .get(&target) + .ok_or(PyException::new_err("Missing target"))?; + graph.edges.push(Edge { + id: element_info.id, + source: source.clone(), + target: target.clone(), + data: element_info.attributes, + }); + } + } + Ok(graph) + } +} + +impl<'py> TryFrom<&Bound<'py, PyGraph>> for Graph { + type Error = PyErr; + + fn try_from(value: &Bound<'py, PyGraph>) -> PyResult { + let pygraph = value.borrow(); + return Graph::try_from_stable( + value.py(), + Direction::UnDirected, + &pygraph.graph, + &pygraph.attrs, + ); + } +} + +impl<'py> TryFrom<&Bound<'py, PyDiGraph>> for Graph { + type Error = PyErr; + + fn try_from(value: &Bound<'py, PyDiGraph>) -> PyResult { + let pygraph = value.borrow(); + return Graph::try_from_stable( + value.py(), + Direction::Directed, + &pygraph.graph, + &pygraph.attrs, + ); + } +} + +impl<'py> FromPyObject<'py> for Graph { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + match ob.downcast::() { + Ok(graph) => Graph::try_from(graph), + Err(_) => Graph::try_from(ob.downcast::()?), + } + } +} + enum State { Start, Graph, @@ -386,6 +656,23 @@ impl Default for GraphML { } } +/// Given maps from ids to keys, return a map from key name to ids and keys. +fn build_key_name_map<'a>( + key_for_items: &'a IndexMap, + key_for_all: &'a IndexMap, +) -> HashMap { + // `key_for_items` is iterated before `key_for_all` since last + // items take precedence in the collected map. Similarly, + // the map `for_all` take precedence over kind-specific maps in + // `last_node_set_data`, `last_edge_set_data` and + // `last_graph_set_attribute`. + key_for_all + .iter() + .chain(key_for_items.iter()) + .map(|(id, key)| (key.name.clone(), (id, key))) + .collect() +} + impl GraphML { fn create_graph<'a>(&mut self, element: &'a BytesStart<'a>) -> Result<(), Error> { let dir = match xml_attribute(element, b"edgedefault")?.as_bytes() { @@ -399,6 +686,7 @@ impl GraphML { }; self.graphs.push(Graph::new( + xml_attribute(element, b"id").ok(), dir, self.key_for_graph.values().chain(self.key_for_all.values()), )); @@ -428,6 +716,24 @@ impl GraphML { Ok(()) } + fn get_keys(&self, domain: Domain) -> &IndexMap { + match domain { + Domain::Node => &self.key_for_nodes, + Domain::Edge => &self.key_for_edges, + Domain::Graph => &self.key_for_graph, + Domain::All => &self.key_for_all, + } + } + + fn get_keys_mut(&mut self, domain: Domain) -> &mut IndexMap { + match domain { + Domain::Node => &mut self.key_for_nodes, + Domain::Edge => &mut self.key_for_edges, + Domain::Graph => &mut self.key_for_graph, + Domain::All => &mut self.key_for_all, + } + } + fn add_graphml_key<'a>(&mut self, element: &'a BytesStart<'a>) -> Result { let id = xml_attribute(element, b"id")?; let ty = match xml_attribute(element, b"attr.type")?.as_bytes() { @@ -450,38 +756,18 @@ impl GraphML { ty, default: Value::UnDefined, }; - - match xml_attribute(element, b"for")?.as_bytes() { - b"node" => { - self.key_for_nodes.insert(id, key); - Ok(Domain::Node) - } - b"edge" => { - self.key_for_edges.insert(id, key); - Ok(Domain::Edge) - } - b"graph" => { - self.key_for_graph.insert(id, key); - Ok(Domain::Graph) - } - b"all" => { - self.key_for_all.insert(id, key); - Ok(Domain::All) - } - _ => Err(Error::InvalidDoc(format!( - "Invalid 'for' attribute in key with id={}.", - id, - ))), - } + let domain: Domain = xml_attribute(element, b"for")? + .as_bytes() + .try_into() + .map_err(|()| { + Error::InvalidDoc(format!("Invalid 'for' attribute in key with id={}.", id,)) + })?; + self.get_keys_mut(domain).insert(id, key); + Ok(domain) } fn last_key_set_value(&mut self, val: String, domain: Domain) -> Result<(), Error> { - let elem = match domain { - Domain::Node => self.key_for_nodes.last_mut(), - Domain::Edge => self.key_for_edges.last_mut(), - Domain::Graph => self.key_for_graph.last_mut(), - Domain::All => self.key_for_all.last_mut(), - }; + let elem = self.get_keys_mut(domain).last_mut(); if let Some((_, key)) = elem { key.set_value(val)?; @@ -715,6 +1001,141 @@ impl GraphML { graph } + + fn write_data( + writer: &mut Writer, + keys: &HashMap, + data: &HashMap, + ) -> Result<(), Error> { + for (key_name, value) in data { + let (id, key) = keys + .get(key_name) + .ok_or_else(|| Error::NotFound(format!("Unknown key {key_name}")))?; + if key.default == *value { + continue; + } + let mut elem = BytesStart::new("data"); + elem.push_attribute(("key", id.as_str())); + writer.write_event(Event::Start(elem.borrow()))?; + if let Some(contents) = value.serialize() { + writer.write_event(Event::Text(BytesText::new(contents.borrow())))?; + } + writer.write_event(Event::End(elem.to_end()))?; + } + Ok(()) + } + + fn write_elem_data( + writer: &mut Writer, + keys: &HashMap, + elem: BytesStart, + data: &HashMap, + ) -> Result<(), Error> { + if data.is_empty() { + writer.write_event(Event::Empty(elem))?; + return Ok(()); + } + writer.write_event(Event::Start(elem.borrow()))?; + Self::write_data(writer, keys, data)?; + writer.write_event(Event::End(elem.to_end()))?; + Ok(()) + } + + fn write_keys( + writer: &mut Writer, + key_for: &str, + map: &IndexMap, + ) -> Result<(), quick_xml::Error> { + for (id, key) in map { + let mut elem = BytesStart::new("key"); + elem.push_attribute(("id", id.as_str())); + elem.push_attribute(("for", key_for)); + elem.push_attribute(("attr.name", key.name.as_str())); + let ty: &str = key.ty.into(); + elem.push_attribute(("attr.type", ty)); + writer.write_event(Event::Start(elem.borrow()))?; + if let Some(contents) = key.default.serialize() { + let elem = BytesStart::new("default"); + writer.write_event(Event::Start(elem.borrow()))?; + writer.write_event(Event::Text(BytesText::new(contents.borrow())))?; + writer.write_event(Event::End(elem.to_end()))?; + }; + writer.write_event(Event::End(elem.to_end()))?; + } + Ok(()) + } + + fn write_graph_to_writer( + &self, + writer: &mut Writer, + ) -> Result<(), Error> { + writer.write_event(Event::Decl(BytesDecl::new("1.0", Some("UTF-8"), None)))?; + let mut elem = BytesStart::new("graphml"); + elem.push_attribute(("xmlns", "http://graphml.graphdrawing.org/xmlns")); + elem.push_attribute(("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance")); + elem.push_attribute(( + "xsi:schemaLocation", + "http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd", + )); + writer.write_event(Event::Start(elem.borrow()))?; + Self::write_keys(writer, "node", &self.key_for_nodes)?; + Self::write_keys(writer, "edge", &self.key_for_edges)?; + Self::write_keys(writer, "graph", &self.key_for_graph)?; + Self::write_keys(writer, "all", &self.key_for_all)?; + let graph_keys: HashMap = + build_key_name_map(&self.key_for_graph, &self.key_for_all); + let node_keys: HashMap = + build_key_name_map(&self.key_for_nodes, &self.key_for_all); + let edge_keys: HashMap = + build_key_name_map(&self.key_for_edges, &self.key_for_all); + for graph in self.graphs.iter() { + let mut elem = BytesStart::new("graph"); + if let Some(id) = &graph.id { + elem.push_attribute(("id", id.as_str())); + } + let edgedefault = match graph.dir { + Direction::Directed => "directed", + Direction::UnDirected => "undirected", + }; + elem.push_attribute(("edgedefault", edgedefault)); + writer.write_event(Event::Start(elem.borrow()))?; + Self::write_data(writer, &graph_keys, &graph.attributes)?; + for node in &graph.nodes { + let mut elem = BytesStart::new("node"); + elem.push_attribute(("id", node.id.as_str())); + Self::write_elem_data(writer, &node_keys, elem, &node.data)?; + } + for edge in &graph.edges { + let mut elem = BytesStart::new("edge"); + if let Some(id) = &edge.id { + elem.push_attribute(("id", id.as_str())); + } + elem.push_attribute(("source", edge.source.as_str())); + elem.push_attribute(("target", edge.target.as_str())); + Self::write_elem_data(writer, &edge_keys, elem, &edge.data)?; + } + writer.write_event(Event::End(elem.to_end()))?; + } + writer.write_event(Event::End(elem.to_end()))?; + Ok(()) + } + + fn to_file(&self, path: impl AsRef, compression: &str) -> Result<(), Error> { + let extension = path.as_ref().extension().unwrap_or(OsStr::new("")); + if extension.eq("graphmlz") || extension.eq("gz") || compression.eq("gzip") { + let file = File::create(path)?; + let buf_writer = BufWriter::new(file); + let gzip_encoder = GzEncoder::new(buf_writer, Compression::default()); + let mut writer = Writer::new(gzip_encoder); + self.write_graph_to_writer(&mut writer)?; + writer.into_inner().finish()?; + } else { + let file = File::create(path)?; + let mut writer = Writer::new(file); + self.write_graph_to_writer(&mut writer)?; + } + Ok(()) + } } /// Read a list of graphs from a file in GraphML format. @@ -756,3 +1177,63 @@ pub fn read_graphml<'py>( Ok(out) } + +/// Read a list of graphs from a file in GraphML format and return the pair containing the list of key definitions and the graph. +/// +/// Each key definition is a tuple: id, domain, name of the key, type, default value. +#[pyfunction] +#[pyo3(signature=(path, compression=None),text_signature = "(path, /, compression=None)")] +pub fn read_graphml_with_keys<'py>( + py: Python<'py>, + path: &str, + compression: Option, +) -> PyResult<( + Vec<(String, Domain, String, Type, Bound<'py, PyAny>)>, + Vec>, +)> { + let graphml = GraphML::from_file(path, &compression.unwrap_or_default())?; + + let mut keys = Vec::new(); + for domain in [Domain::Node, Domain::Edge, Domain::Graph, Domain::All] { + for (id, key) in graphml.get_keys(domain) { + let default = key.default.clone().into_pyobject(py)?.into_any(); + keys.push((id.clone(), domain, key.name.clone(), key.ty, default)); + } + } + + let mut out = Vec::new(); + for graph in graphml.graphs { + out.push(graph.into_pyobject(py)?) + } + + Ok((keys, out)) +} + +/// Write a list of graphs to a file in GraphML format given the list of key definitions. +#[pyfunction] +#[pyo3(signature=(graphs, keys, path, compression=None),text_signature = "(graphs, keys, path, /, compression=None)")] +pub fn write_graphml<'py>( + py: Python<'py>, + graphs: Vec>, + keys: Vec<(String, Domain, String, Type, Py)>, + path: &str, + compression: Option, +) -> PyResult<()> { + let mut graphml = GraphML::default(); + for (id, domain, name, ty, default) in keys { + let bound_default = default.bind(py); + let default = if bound_default.is_none() { + Value::UnDefined + } else { + Value::from_pyobject(bound_default, ty)? + }; + graphml + .get_keys_mut(domain) + .insert(id, Key { name, ty, default }); + } + for graph in graphs { + graphml.graphs.push(Graph::extract_bound(graph.bind(py))?) + } + graphml.to_file(path, &compression.unwrap_or_default())?; + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index dad6d95648..77ef011015 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -669,6 +669,8 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(connected_subgraphs))?; m.add_wrapped(wrap_pyfunction!(is_planar))?; m.add_wrapped(wrap_pyfunction!(read_graphml))?; + m.add_wrapped(wrap_pyfunction!(read_graphml_with_keys))?; + m.add_wrapped(wrap_pyfunction!(write_graphml))?; m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(from_node_link_json_file))?; @@ -702,6 +704,8 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pymodule!(generators::generators))?; Ok(()) } diff --git a/tests/test_graphml.py b/tests/test_graphml.py index 73b8c89289..539aad595d 100644 --- a/tests/test_graphml.py +++ b/tests/test_graphml.py @@ -10,6 +10,7 @@ # License for the specific language governing permissions and limitations # under the License. +import math import unittest import tempfile import gzip @@ -98,7 +99,42 @@ def test_simple(self): ("n0", "n1", {"fidelity": 0.98}), ("n0", "n2", {"fidelity": 0.95}), ] - self.assertGraphEqual(graph, nodes, edges, directed=False) + self.assertGraphEqual(graph, nodes, edges, attrs={"id": "G"}, directed=False) + + def test_write(self): + graph_xml = self.graphml_xml_example() + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + graphml = rustworkx.read_graphml(fd.name) + graph = graphml[0] + with tempfile.NamedTemporaryFile("wt") as fd: + keys = [ + ("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow"), + ("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float, 0.95), + ] + rustworkx.write_graphml([graph], keys, fd.name) + graphml = rustworkx.read_graphml(fd.name) + graph_reread = graphml[0] + edges = [(graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list()] + self.assertGraphEqual(graph_reread, graph.nodes(), edges, attrs={"id": "G"}, directed=False) + + def test_write_with_keys(self): + graph_xml = self.graphml_xml_example() + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + keys, graphml = rustworkx.read_graphml_with_keys(fd.name) + assert keys[0] == ("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow") + assert keys[1][0:4] == ("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float) + assert math.isclose(keys[1][4], 0.95, rel_tol=1e-7) + graph = graphml[0] + with tempfile.NamedTemporaryFile("wt") as fd: + rustworkx.write_graphml([graph], keys, fd.name) + graphml = rustworkx.read_graphml(fd.name) + graph_reread = graphml[0] + edges = [(graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list()] + self.assertGraphEqual(graph_reread, graph.nodes(), edges, attrs={"id": "G"}, directed=False) def test_gzipped(self): graph_xml = self.graphml_xml_example() @@ -121,7 +157,7 @@ def test_gzipped(self): ("n0", "n1", {"fidelity": 0.98}), ("n0", "n2", {"fidelity": 0.95}), ] - self.assertGraphEqual(graph, nodes, edges, directed=False) + self.assertGraphEqual(graph, nodes, edges, attrs={"id": "G"}, directed=False) def test_gzipped_force(self): graph_xml = self.graphml_xml_example() @@ -145,10 +181,29 @@ def test_gzipped_force(self): ("n0", "n1", {"fidelity": 0.98}), ("n0", "n2", {"fidelity": 0.95}), ] - self.assertGraphEqual(graph, nodes, edges, directed=False) + self.assertGraphEqual(graph, nodes, edges, attrs={"id": "G"}, directed=False) - def test_multiple_graphs_in_single_file(self): - graph_xml = self.HEADER.format( + def test_write_gzipped(self): + graph_xml = self.graphml_xml_example() + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + graphml = rustworkx.read_graphml(fd.name) + graph = graphml[0] + with tempfile.NamedTemporaryFile("wt") as fd: + keys = [ + ("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow"), + ("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float, 0.95), + ] + newname = f"{fd.name}.gz" + rustworkx.write_graphml([graph], keys, newname) + graphml = rustworkx.read_graphml(newname) + graph_reread = graphml[0] + edges = [(graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list()] + self.assertGraphEqual(graph_reread, graph.nodes(), edges, attrs={"id": "G"}, directed=False) + + def graphml_xml_example_multiple_graphs(self): + return self.HEADER.format( """ yellow @@ -175,6 +230,9 @@ def test_multiple_graphs_in_single_file(self): """ ) + def test_multiple_graphs_in_single_file(self): + graph_xml = self.graphml_xml_example_multiple_graphs() + with tempfile.NamedTemporaryFile("wt") as fd: fd.write(graph_xml) fd.flush() @@ -188,7 +246,7 @@ def test_multiple_graphs_in_single_file(self): edges = [ ("n0", "n1", {"id": "e01", "fidelity": 0.98}), ] - self.assertGraphEqual(graph, nodes, edges, directed=False) + self.assertGraphEqual(graph, nodes, edges, attrs={"id": "G"}, directed=False) graph = graphml[1] nodes = [ {"id": "n0", "color": "red"}, @@ -197,7 +255,32 @@ def test_multiple_graphs_in_single_file(self): edges = [ ("n0", "n1", {"id": "e01", "fidelity": 0.95}), ] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, attrs={"id": "H"}, directed=True) + + def test_write_multiple_graphs(self): + graph_xml = self.graphml_xml_example_multiple_graphs() + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + graphml = rustworkx.read_graphml(fd.name) + with tempfile.NamedTemporaryFile("wt") as fd: + keys = [ + ("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow"), + ("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float, 0.95), + ] + rustworkx.write_graphml(graphml, keys, fd.name) + graphml_reread = rustworkx.read_graphml(fd.name) + for graph, graph_reread in zip(graphml, graphml_reread): + edges = [ + (graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list() + ] + self.assertGraphEqual( + graph_reread, + graph.nodes(), + edges, + attrs=graph.attrs, + directed=isinstance(graph, rustworkx.PyDiGraph), + ) def test_key_for_graph(self): graph_xml = self.HEADER.format( @@ -217,7 +300,31 @@ def test_key_for_graph(self): graph = graphml[0] nodes = [{"id": "n0"}] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"test": True}) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G", "test": True}) + + def test_write_key_for_graph(self): + graph_xml = self.HEADER.format( + """ + + + true + + + """ + ) + + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + keys, graphml = rustworkx.read_graphml_with_keys(fd.name) + assert keys[0] == ("d0", rustworkx.Domain.Graph, "test", rustworkx.Type.Boolean, None) + with tempfile.NamedTemporaryFile("wt") as fd: + rustworkx.write_graphml(graphml, keys, fd.name) + graphml = rustworkx.read_graphml(fd.name) + graph = graphml[0] + nodes = [{"id": "n0"}] + edges = [] + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G", "test": True}) def test_key_for_all(self): graph_xml = self.HEADER.format( @@ -249,9 +356,46 @@ def test_key_for_all(self): ] edges = [("n0", "n1", {"test": "I'm an edge."})] self.assertGraphEqual( - graph, nodes, edges, directed=True, attrs={"test": "I'm a graph."} + graph, nodes, edges, directed=True, attrs={"id": "G", "test": "I'm a graph."} ) + def test_write_key_for_all(self): + graph_xml = self.HEADER.format( + """ + + + I'm a graph. + + I'm a node. + + + I'm a node. + + + I'm an edge. + + + """ + ) + + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + keys, graphml = rustworkx.read_graphml_with_keys(fd.name) + assert keys[0] == ("d0", rustworkx.Domain.All, "test", rustworkx.Type.String, None) + with tempfile.NamedTemporaryFile("wt") as fd: + rustworkx.write_graphml(graphml, keys, fd.name) + graphml = rustworkx.read_graphml(fd.name) + graph = graphml[0] + nodes = [ + {"id": "n0", "test": "I'm a node."}, + {"id": "n1", "test": "I'm a node."}, + ] + edges = [("n0", "n1", {"test": "I'm an edge."})] + self.assertGraphEqual( + graph, nodes, edges, directed=True, attrs={"id": "G", "test": "I'm a graph."} + ) + def test_key_default_undefined(self): graph_xml = self.HEADER.format( """ @@ -275,7 +419,36 @@ def test_key_default_undefined(self): {"id": "n1", "test": None}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) + + def test_write_key_undefined(self): + graph_xml = self.HEADER.format( + """ + + + + true + + + + """ + ) + + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + keys, graphml = rustworkx.read_graphml_with_keys(fd.name) + assert keys[0] == ("d0", rustworkx.Domain.Node, "test", rustworkx.Type.Boolean, None) + with tempfile.NamedTemporaryFile("wt") as fd: + rustworkx.write_graphml(graphml, keys, fd.name) + graphml = rustworkx.read_graphml(fd.name) + graph = graphml[0] + nodes = [ + {"id": "n0", "test": True}, + {"id": "n1", "test": None}, + ] + edges = [] + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) def test_bool(self): graph_xml = self.HEADER.format( @@ -306,7 +479,7 @@ def test_bool(self): {"id": "n2", "test": False}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) def test_int(self): graph_xml = self.HEADER.format( @@ -337,7 +510,7 @@ def test_int(self): {"id": "n2", "test": 42}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) def test_float(self): graph_xml = self.HEADER.format( @@ -368,7 +541,7 @@ def test_float(self): {"id": "n2", "test": 4.2}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) def test_double(self): graph_xml = self.HEADER.format( @@ -399,7 +572,7 @@ def test_double(self): {"id": "n2", "test": 4.2}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) def test_string(self): graph_xml = self.HEADER.format( @@ -430,7 +603,7 @@ def test_string(self): {"id": "n2", "test": "yellow"}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) def test_long(self): graph_xml = self.HEADER.format( @@ -461,7 +634,7 @@ def test_long(self): {"id": "n2", "test": 42}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, attrs={"id": "G"}, directed=True) def test_convert_error(self): graph_xml = self.HEADER.format( From 945b588543b1c25386ec9def4a3f1bd305eb9aa3 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Tue, 10 Jun 2025 23:50:51 +0200 Subject: [PATCH 02/17] Fix clippy comments --- rustworkx/rustworkx.pyi | 12 ++++- src/graphml.rs | 100 +++++++++++++++++++++++++++------------- src/lib.rs | 1 + tests/test_graphml.py | 43 ++++++++++++----- 4 files changed, 110 insertions(+), 46 deletions(-) diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index a66a3c4b80..8d6387ce80 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -82,6 +82,14 @@ class Type: Float: Type Long: Type +@final +class KeySpec: + id: str + domain: Domain + name: str + ty: Type + default: Any + # Cartesian product def digraph_cartesian_product( @@ -698,7 +706,7 @@ def read_graphml_with_keys( path: str, /, compression: str | None = ..., -) -> tuple[list[tuple[str, Domain, str, Type, Any]], list[PyGraph | PyDiGraph]]: ... +) -> tuple[list[KeySpec], list[PyGraph | PyDiGraph]]: ... def read_graphml( path: str, /, @@ -706,7 +714,7 @@ def read_graphml( ) -> list[PyGraph | PyDiGraph]: ... def write_graphml( graphs: list[PyGraph | PyDiGraph], - keys: list[tuple[str, Domain, str, Type, Any]], + keys: list[KeySpec], path: str, /, compression: str | None = ..., diff --git a/src/graphml.rs b/src/graphml.rs index a0c459c870..066f031baf 100644 --- a/src/graphml.rs +++ b/src/graphml.rs @@ -158,9 +158,9 @@ pub enum Type { Long, } -impl Into<&'static str> for Type { - fn into(self) -> &'static str { - match self { +impl From for &'static str { + fn from(ty: Type) -> Self { + match ty { Type::Boolean => "boolean", Type::Int => "int", Type::Float => "float", @@ -222,7 +222,7 @@ impl<'py> IntoPyObject<'py> for Value { } impl Value { - fn from_pyobject<'py>(ob: &Bound<'py, PyAny>, ty: Type) -> PyResult { + fn from_pyobject(ob: &Bound<'_, PyAny>, ty: Type) -> PyResult { let value = match ty { Type::Boolean => Value::Boolean(ob.extract::()?), Type::Int => Value::Int(ob.extract::()?), @@ -471,12 +471,7 @@ impl GraphElementInfos { } } - fn insert<'py>( - &mut self, - py: Python<'py>, - index: Index, - weight: Option<&Py>, - ) -> PyResult<()> { + fn insert(&mut self, py: Python<'_>, index: Index, weight: Option<&Py>) -> PyResult<()> { let element_info = weight .and_then(|data| { data.extract::>(py) @@ -508,8 +503,8 @@ impl GraphElementInfos { } impl Graph { - fn try_from_stable<'py, Ty: EdgeType>( - py: Python<'py>, + fn try_from_stable( + py: Python<'_>, dir: Direction, pygraph: &StablePyGraph, attrs: &PyObject, @@ -578,12 +573,12 @@ impl<'py> TryFrom<&Bound<'py, PyGraph>> for Graph { fn try_from(value: &Bound<'py, PyGraph>) -> PyResult { let pygraph = value.borrow(); - return Graph::try_from_stable( + Graph::try_from_stable( value.py(), Direction::UnDirected, &pygraph.graph, &pygraph.attrs, - ); + ) } } @@ -592,12 +587,12 @@ impl<'py> TryFrom<&Bound<'py, PyDiGraph>> for Graph { fn try_from(value: &Bound<'py, PyDiGraph>) -> PyResult { let pygraph = value.borrow(); - return Graph::try_from_stable( + Graph::try_from_stable( value.py(), Direction::Directed, &pygraph.graph, &pygraph.attrs, - ); + ) } } @@ -1178,26 +1173,61 @@ pub fn read_graphml<'py>( Ok(out) } +/// Key definition: id, domain, name of the key, type, default value. +#[pyclass] +pub struct KeySpec { + #[pyo3(get)] + id: String, + #[pyo3(get)] + domain: Domain, + #[pyo3(get)] + name: String, + #[pyo3(get)] + ty: Type, + #[pyo3(get)] + default: Py, +} + +#[pymethods] +impl KeySpec { + #[new] + pub fn new(id: String, domain: Domain, name: String, ty: Type, default: Py) -> Self { + KeySpec { + id, + domain, + name, + ty, + default, + } + } +} + +type GraphMLWithKeys<'py> = PyResult<(Vec>, Vec>)>; + /// Read a list of graphs from a file in GraphML format and return the pair containing the list of key definitions and the graph. -/// -/// Each key definition is a tuple: id, domain, name of the key, type, default value. #[pyfunction] #[pyo3(signature=(path, compression=None),text_signature = "(path, /, compression=None)")] pub fn read_graphml_with_keys<'py>( py: Python<'py>, path: &str, compression: Option, -) -> PyResult<( - Vec<(String, Domain, String, Type, Bound<'py, PyAny>)>, - Vec>, -)> { +) -> GraphMLWithKeys<'py> { let graphml = GraphML::from_file(path, &compression.unwrap_or_default())?; let mut keys = Vec::new(); for domain in [Domain::Node, Domain::Edge, Domain::Graph, Domain::All] { for (id, key) in graphml.get_keys(domain) { let default = key.default.clone().into_pyobject(py)?.into_any(); - keys.push((id.clone(), domain, key.name.clone(), key.ty, default)); + keys.push(Py::new( + py, + KeySpec { + id: id.clone(), + domain, + name: key.name.clone(), + ty: key.ty, + default: default.into(), + }, + )?); } } @@ -1212,24 +1242,30 @@ pub fn read_graphml_with_keys<'py>( /// Write a list of graphs to a file in GraphML format given the list of key definitions. #[pyfunction] #[pyo3(signature=(graphs, keys, path, compression=None),text_signature = "(graphs, keys, path, /, compression=None)")] -pub fn write_graphml<'py>( - py: Python<'py>, +pub fn write_graphml( + py: Python<'_>, graphs: Vec>, - keys: Vec<(String, Domain, String, Type, Py)>, + keys: Vec>, path: &str, compression: Option, ) -> PyResult<()> { let mut graphml = GraphML::default(); - for (id, domain, name, ty, default) in keys { - let bound_default = default.bind(py); + for pykey in keys { + let key = pykey.borrow(py); + let bound_default = key.default.bind(py); let default = if bound_default.is_none() { Value::UnDefined } else { - Value::from_pyobject(bound_default, ty)? + Value::from_pyobject(bound_default, key.ty)? }; - graphml - .get_keys_mut(domain) - .insert(id, Key { name, ty, default }); + graphml.get_keys_mut(key.domain).insert( + key.id.clone(), + Key { + name: key.name.clone(), + ty: key.ty, + default, + }, + ); } for graph in graphs { graphml.graphs.push(Graph::extract_bound(graph.bind(py))?) diff --git a/src/lib.rs b/src/lib.rs index be62aafb08..929efd14a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -708,6 +708,7 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pymodule!(generators::generators))?; Ok(()) } diff --git a/tests/test_graphml.py b/tests/test_graphml.py index 539aad595d..48d534626c 100644 --- a/tests/test_graphml.py +++ b/tests/test_graphml.py @@ -110,8 +110,8 @@ def test_write(self): graph = graphml[0] with tempfile.NamedTemporaryFile("wt") as fd: keys = [ - ("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow"), - ("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float, 0.95), + rustworkx.KeySpec("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow"), + rustworkx.KeySpec("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float, 0.95), ] rustworkx.write_graphml([graph], keys, fd.name) graphml = rustworkx.read_graphml(fd.name) @@ -125,9 +125,16 @@ def test_write_with_keys(self): fd.write(graph_xml) fd.flush() keys, graphml = rustworkx.read_graphml_with_keys(fd.name) - assert keys[0] == ("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow") - assert keys[1][0:4] == ("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float) - assert math.isclose(keys[1][4], 0.95, rel_tol=1e-7) + assert keys[0].id == "d0" + assert keys[0].domain == rustworkx.Domain.Node + assert keys[0].name == "color" + assert keys[0].ty == rustworkx.Type.String + assert keys[0].default == "yellow" + assert keys[1].id == "d1" + assert keys[1].domain == rustworkx.Domain.Edge + assert keys[1].name == "fidelity" + assert keys[1].ty == rustworkx.Type.Float + assert math.isclose(keys[1].default, 0.95, rel_tol=1e-7) graph = graphml[0] with tempfile.NamedTemporaryFile("wt") as fd: rustworkx.write_graphml([graph], keys, fd.name) @@ -192,8 +199,8 @@ def test_write_gzipped(self): graph = graphml[0] with tempfile.NamedTemporaryFile("wt") as fd: keys = [ - ("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow"), - ("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float, 0.95), + rustworkx.KeySpec("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow"), + rustworkx.KeySpec("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float, 0.95), ] newname = f"{fd.name}.gz" rustworkx.write_graphml([graph], keys, newname) @@ -265,8 +272,8 @@ def test_write_multiple_graphs(self): graphml = rustworkx.read_graphml(fd.name) with tempfile.NamedTemporaryFile("wt") as fd: keys = [ - ("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow"), - ("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float, 0.95), + rustworkx.KeySpec("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow"), + rustworkx.KeySpec("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float, 0.95), ] rustworkx.write_graphml(graphml, keys, fd.name) graphml_reread = rustworkx.read_graphml(fd.name) @@ -317,7 +324,11 @@ def test_write_key_for_graph(self): fd.write(graph_xml) fd.flush() keys, graphml = rustworkx.read_graphml_with_keys(fd.name) - assert keys[0] == ("d0", rustworkx.Domain.Graph, "test", rustworkx.Type.Boolean, None) + assert keys[0].id == "d0" + assert keys[0].domain == rustworkx.Domain.Graph + assert keys[0].name == "test" + assert keys[0].ty == rustworkx.Type.Boolean + assert keys[0].default is None with tempfile.NamedTemporaryFile("wt") as fd: rustworkx.write_graphml(graphml, keys, fd.name) graphml = rustworkx.read_graphml(fd.name) @@ -382,7 +393,11 @@ def test_write_key_for_all(self): fd.write(graph_xml) fd.flush() keys, graphml = rustworkx.read_graphml_with_keys(fd.name) - assert keys[0] == ("d0", rustworkx.Domain.All, "test", rustworkx.Type.String, None) + assert keys[0].id == "d0" + assert keys[0].domain == rustworkx.Domain.All + assert keys[0].name == "test" + assert keys[0].ty == rustworkx.Type.String + assert keys[0].default is None with tempfile.NamedTemporaryFile("wt") as fd: rustworkx.write_graphml(graphml, keys, fd.name) graphml = rustworkx.read_graphml(fd.name) @@ -438,7 +453,11 @@ def test_write_key_undefined(self): fd.write(graph_xml) fd.flush() keys, graphml = rustworkx.read_graphml_with_keys(fd.name) - assert keys[0] == ("d0", rustworkx.Domain.Node, "test", rustworkx.Type.Boolean, None) + assert keys[0].id == "d0" + assert keys[0].domain == rustworkx.Domain.Node + assert keys[0].name == "test" + assert keys[0].ty == rustworkx.Type.Boolean + assert keys[0].default is None with tempfile.NamedTemporaryFile("wt") as fd: rustworkx.write_graphml(graphml, keys, fd.name) graphml = rustworkx.read_graphml(fd.name) From c350943743f32027c2b463d17d38a6b93690fb86 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Tue, 10 Jun 2025 23:59:10 +0200 Subject: [PATCH 03/17] Prefix types with GraphML Suggested by @IvanIsCoding: https://github.com/Qiskit/rustworkx/pull/1464#discussion_r2137676829 --- rustworkx/rustworkx.pyi | 30 +++++++++++++++--------------- src/graphml.rs | 6 +++--- tests/test_graphml.py | 32 ++++++++++++++++---------------- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 8d6387ce80..6af8b963f5 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -69,25 +69,25 @@ class ColoringStrategy: IndependentSet: Any @final -class Domain: - Node: Domain - Edge: Domain - Graph: Domain - All: Domain +class GraphMLDomain: + Node: GraphMLDomain + Edge: GraphMLDomain + Graph: GraphMLDomain + All: GraphMLDomain @final -class Type: - Boolean: Type - Int: Type - Float: Type - Long: Type +class GraphMLType: + Boolean: GraphMLType + Int: GraphMLType + Float: GraphMLType + Long: GraphMLType @final -class KeySpec: +class GraphMLKey: id: str - domain: Domain + domain: GraphMLDomain name: str - ty: Type + ty: GraphMLType default: Any # Cartesian product @@ -706,7 +706,7 @@ def read_graphml_with_keys( path: str, /, compression: str | None = ..., -) -> tuple[list[KeySpec], list[PyGraph | PyDiGraph]]: ... +) -> tuple[list[GraphMLKeySpec], list[PyGraph | PyDiGraph]]: ... def read_graphml( path: str, /, @@ -714,7 +714,7 @@ def read_graphml( ) -> list[PyGraph | PyDiGraph]: ... def write_graphml( graphs: list[PyGraph | PyDiGraph], - keys: list[KeySpec], + keys: list[GraphMLKeySpec], path: str, /, compression: str | None = ..., diff --git a/src/graphml.rs b/src/graphml.rs index 066f031baf..8572f4edd2 100644 --- a/src/graphml.rs +++ b/src/graphml.rs @@ -124,7 +124,7 @@ fn xml_attribute<'a>(element: &'a BytesStart<'a>, key: &[u8]) -> Result for Domain { } } -#[pyclass(eq)] +#[pyclass(eq, name = "GraphMLType")] #[derive(Clone, Copy, PartialEq)] pub enum Type { Boolean, @@ -1174,7 +1174,7 @@ pub fn read_graphml<'py>( } /// Key definition: id, domain, name of the key, type, default value. -#[pyclass] +#[pyclass(name = "GraphMLKey")] pub struct KeySpec { #[pyo3(get)] id: String, diff --git a/tests/test_graphml.py b/tests/test_graphml.py index 48d534626c..b626126cb9 100644 --- a/tests/test_graphml.py +++ b/tests/test_graphml.py @@ -110,8 +110,8 @@ def test_write(self): graph = graphml[0] with tempfile.NamedTemporaryFile("wt") as fd: keys = [ - rustworkx.KeySpec("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow"), - rustworkx.KeySpec("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float, 0.95), + rustworkx.GraphMLKey("d0", rustworkx.GraphMLDomain.Node, "color", rustworkx.GraphMLType.String, "yellow"), + rustworkx.GraphMLKey("d1", rustworkx.GraphMLDomain.Edge, "fidelity", rustworkx.GraphMLType.Float, 0.95), ] rustworkx.write_graphml([graph], keys, fd.name) graphml = rustworkx.read_graphml(fd.name) @@ -126,14 +126,14 @@ def test_write_with_keys(self): fd.flush() keys, graphml = rustworkx.read_graphml_with_keys(fd.name) assert keys[0].id == "d0" - assert keys[0].domain == rustworkx.Domain.Node + assert keys[0].domain == rustworkx.GraphMLDomain.Node assert keys[0].name == "color" - assert keys[0].ty == rustworkx.Type.String + assert keys[0].ty == rustworkx.GraphMLType.String assert keys[0].default == "yellow" assert keys[1].id == "d1" - assert keys[1].domain == rustworkx.Domain.Edge + assert keys[1].domain == rustworkx.GraphMLDomain.Edge assert keys[1].name == "fidelity" - assert keys[1].ty == rustworkx.Type.Float + assert keys[1].ty == rustworkx.GraphMLType.Float assert math.isclose(keys[1].default, 0.95, rel_tol=1e-7) graph = graphml[0] with tempfile.NamedTemporaryFile("wt") as fd: @@ -199,8 +199,8 @@ def test_write_gzipped(self): graph = graphml[0] with tempfile.NamedTemporaryFile("wt") as fd: keys = [ - rustworkx.KeySpec("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow"), - rustworkx.KeySpec("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float, 0.95), + rustworkx.GraphMLKey("d0", rustworkx.GraphMLDomain.Node, "color", rustworkx.GraphMLType.String, "yellow"), + rustworkx.GraphMLKey("d1", rustworkx.GraphMLDomain.Edge, "fidelity", rustworkx.GraphMLType.Float, 0.95), ] newname = f"{fd.name}.gz" rustworkx.write_graphml([graph], keys, newname) @@ -272,8 +272,8 @@ def test_write_multiple_graphs(self): graphml = rustworkx.read_graphml(fd.name) with tempfile.NamedTemporaryFile("wt") as fd: keys = [ - rustworkx.KeySpec("d0", rustworkx.Domain.Node, "color", rustworkx.Type.String, "yellow"), - rustworkx.KeySpec("d1", rustworkx.Domain.Edge, "fidelity", rustworkx.Type.Float, 0.95), + rustworkx.GraphMLKey("d0", rustworkx.GraphMLDomain.Node, "color", rustworkx.GraphMLType.String, "yellow"), + rustworkx.GraphMLKey("d1", rustworkx.GraphMLDomain.Edge, "fidelity", rustworkx.GraphMLType.Float, 0.95), ] rustworkx.write_graphml(graphml, keys, fd.name) graphml_reread = rustworkx.read_graphml(fd.name) @@ -325,9 +325,9 @@ def test_write_key_for_graph(self): fd.flush() keys, graphml = rustworkx.read_graphml_with_keys(fd.name) assert keys[0].id == "d0" - assert keys[0].domain == rustworkx.Domain.Graph + assert keys[0].domain == rustworkx.GraphMLDomain.Graph assert keys[0].name == "test" - assert keys[0].ty == rustworkx.Type.Boolean + assert keys[0].ty == rustworkx.GraphMLType.Boolean assert keys[0].default is None with tempfile.NamedTemporaryFile("wt") as fd: rustworkx.write_graphml(graphml, keys, fd.name) @@ -394,9 +394,9 @@ def test_write_key_for_all(self): fd.flush() keys, graphml = rustworkx.read_graphml_with_keys(fd.name) assert keys[0].id == "d0" - assert keys[0].domain == rustworkx.Domain.All + assert keys[0].domain == rustworkx.GraphMLDomain.All assert keys[0].name == "test" - assert keys[0].ty == rustworkx.Type.String + assert keys[0].ty == rustworkx.GraphMLType.String assert keys[0].default is None with tempfile.NamedTemporaryFile("wt") as fd: rustworkx.write_graphml(graphml, keys, fd.name) @@ -454,9 +454,9 @@ def test_write_key_undefined(self): fd.flush() keys, graphml = rustworkx.read_graphml_with_keys(fd.name) assert keys[0].id == "d0" - assert keys[0].domain == rustworkx.Domain.Node + assert keys[0].domain == rustworkx.GraphMLDomain.Node assert keys[0].name == "test" - assert keys[0].ty == rustworkx.Type.Boolean + assert keys[0].ty == rustworkx.GraphMLType.Boolean assert keys[0].default is None with tempfile.NamedTemporaryFile("wt") as fd: rustworkx.write_graphml(graphml, keys, fd.name) From 2604aa9afa0336200539ef09b9bacdfc1343e597 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Wed, 11 Jun 2025 10:29:07 +0200 Subject: [PATCH 04/17] Black --- rustworkx/rustworkx.pyi | 4 +-- tests/test_graphml.py | 67 ++++++++++++++++++++++++++++++++++------- 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 6af8b963f5..df3e574a32 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -706,7 +706,7 @@ def read_graphml_with_keys( path: str, /, compression: str | None = ..., -) -> tuple[list[GraphMLKeySpec], list[PyGraph | PyDiGraph]]: ... +) -> tuple[list[GraphMLKey], list[PyGraph | PyDiGraph]]: ... def read_graphml( path: str, /, @@ -714,7 +714,7 @@ def read_graphml( ) -> list[PyGraph | PyDiGraph]: ... def write_graphml( graphs: list[PyGraph | PyDiGraph], - keys: list[GraphMLKeySpec], + keys: list[GraphMLKey], path: str, /, compression: str | None = ..., diff --git a/tests/test_graphml.py b/tests/test_graphml.py index b626126cb9..0c791563d9 100644 --- a/tests/test_graphml.py +++ b/tests/test_graphml.py @@ -110,13 +110,27 @@ def test_write(self): graph = graphml[0] with tempfile.NamedTemporaryFile("wt") as fd: keys = [ - rustworkx.GraphMLKey("d0", rustworkx.GraphMLDomain.Node, "color", rustworkx.GraphMLType.String, "yellow"), - rustworkx.GraphMLKey("d1", rustworkx.GraphMLDomain.Edge, "fidelity", rustworkx.GraphMLType.Float, 0.95), + rustworkx.GraphMLKey( + "d0", + rustworkx.GraphMLDomain.Node, + "color", + rustworkx.GraphMLType.String, + "yellow", + ), + rustworkx.GraphMLKey( + "d1", + rustworkx.GraphMLDomain.Edge, + "fidelity", + rustworkx.GraphMLType.Float, + 0.95, + ), ] rustworkx.write_graphml([graph], keys, fd.name) graphml = rustworkx.read_graphml(fd.name) graph_reread = graphml[0] - edges = [(graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list()] + edges = [ + (graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list() + ] self.assertGraphEqual(graph_reread, graph.nodes(), edges, attrs={"id": "G"}, directed=False) def test_write_with_keys(self): @@ -140,7 +154,9 @@ def test_write_with_keys(self): rustworkx.write_graphml([graph], keys, fd.name) graphml = rustworkx.read_graphml(fd.name) graph_reread = graphml[0] - edges = [(graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list()] + edges = [ + (graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list() + ] self.assertGraphEqual(graph_reread, graph.nodes(), edges, attrs={"id": "G"}, directed=False) def test_gzipped(self): @@ -199,14 +215,28 @@ def test_write_gzipped(self): graph = graphml[0] with tempfile.NamedTemporaryFile("wt") as fd: keys = [ - rustworkx.GraphMLKey("d0", rustworkx.GraphMLDomain.Node, "color", rustworkx.GraphMLType.String, "yellow"), - rustworkx.GraphMLKey("d1", rustworkx.GraphMLDomain.Edge, "fidelity", rustworkx.GraphMLType.Float, 0.95), + rustworkx.GraphMLKey( + "d0", + rustworkx.GraphMLDomain.Node, + "color", + rustworkx.GraphMLType.String, + "yellow", + ), + rustworkx.GraphMLKey( + "d1", + rustworkx.GraphMLDomain.Edge, + "fidelity", + rustworkx.GraphMLType.Float, + 0.95, + ), ] newname = f"{fd.name}.gz" rustworkx.write_graphml([graph], keys, newname) graphml = rustworkx.read_graphml(newname) graph_reread = graphml[0] - edges = [(graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list()] + edges = [ + (graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list() + ] self.assertGraphEqual(graph_reread, graph.nodes(), edges, attrs={"id": "G"}, directed=False) def graphml_xml_example_multiple_graphs(self): @@ -272,14 +302,27 @@ def test_write_multiple_graphs(self): graphml = rustworkx.read_graphml(fd.name) with tempfile.NamedTemporaryFile("wt") as fd: keys = [ - rustworkx.GraphMLKey("d0", rustworkx.GraphMLDomain.Node, "color", rustworkx.GraphMLType.String, "yellow"), - rustworkx.GraphMLKey("d1", rustworkx.GraphMLDomain.Edge, "fidelity", rustworkx.GraphMLType.Float, 0.95), + rustworkx.GraphMLKey( + "d0", + rustworkx.GraphMLDomain.Node, + "color", + rustworkx.GraphMLType.String, + "yellow", + ), + rustworkx.GraphMLKey( + "d1", + rustworkx.GraphMLDomain.Edge, + "fidelity", + rustworkx.GraphMLType.Float, + 0.95, + ), ] rustworkx.write_graphml(graphml, keys, fd.name) graphml_reread = rustworkx.read_graphml(fd.name) for graph, graph_reread in zip(graphml, graphml_reread): edges = [ - (graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list() + (graph[s]["id"], graph[t]["id"], weight) + for s, t, weight in graph.weighted_edge_list() ] self.assertGraphEqual( graph_reread, @@ -307,7 +350,9 @@ def test_key_for_graph(self): graph = graphml[0] nodes = [{"id": "n0"}] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G", "test": True}) + self.assertGraphEqual( + graph, nodes, edges, directed=True, attrs={"id": "G", "test": True} + ) def test_write_key_for_graph(self): graph_xml = self.HEADER.format( From 40380e1bd602d69f877cd17388bc337d9a73be32 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Wed, 11 Jun 2025 10:29:17 +0200 Subject: [PATCH 05/17] Add release notes --- .../notes/write_graphml-624c10b6f7592ee1.yaml | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 releasenotes/notes/write_graphml-624c10b6f7592ee1.yaml diff --git a/releasenotes/notes/write_graphml-624c10b6f7592ee1.yaml b/releasenotes/notes/write_graphml-624c10b6f7592ee1.yaml new file mode 100644 index 0000000000..ca11a1b52e --- /dev/null +++ b/releasenotes/notes/write_graphml-624c10b6f7592ee1.yaml @@ -0,0 +1,13 @@ +--- +features: + - | + Added a new function :func:`~rustworkx.write_graphml` that writes + a list of rustworkx graph objects to a file in GraphML format. + - | + Added a new function :func:`~rustworkx.read_graphml_with_keys` + that reads a GraphML file and returns the list of defined keys + along with the list of rustworkx graph objects. +other: + - | + When graphs read with :func:`~rustworkx.read_graphml` include IDs, + these IDs are now stored in the graph attributes. From e53c5510de574cdd8ef24250c61a87d5d14e46fe Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Wed, 11 Jun 2025 16:49:13 +0200 Subject: [PATCH 06/17] Fix stubs error --- rustworkx/__init__.pyi | 2 ++ rustworkx/rustworkx.pyi | 2 ++ 2 files changed, 4 insertions(+) diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi index c8bad24a99..6ce0696221 100644 --- a/rustworkx/__init__.pyi +++ b/rustworkx/__init__.pyi @@ -163,6 +163,8 @@ from .rustworkx import directed_barabasi_albert_graph as directed_barabasi_alber from .rustworkx import undirected_random_bipartite_graph as undirected_random_bipartite_graph from .rustworkx import directed_random_bipartite_graph as directed_random_bipartite_graph from .rustworkx import read_graphml as read_graphml +from .rustworkx import read_graphml_with_keys as read_graphml_with_keys +from .rustworkx import write_graphml as write_graphml from .rustworkx import digraph_node_link_json as digraph_node_link_json from .rustworkx import graph_node_link_json as graph_node_link_json from .rustworkx import from_node_link_json_file as from_node_link_json_file diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index df3e574a32..9b46ce1f21 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -80,6 +80,8 @@ class GraphMLType: Boolean: GraphMLType Int: GraphMLType Float: GraphMLType + Double: GraphMLType + String: GraphMLType Long: GraphMLType @final From ee1b220a41f0283ace52859b92bf724918b1d527 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Wed, 11 Jun 2025 16:49:24 +0200 Subject: [PATCH 07/17] Add documentation --- docs/source/api/serialization.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/api/serialization.rst b/docs/source/api/serialization.rst index 49c5e49e93..338246f153 100644 --- a/docs/source/api/serialization.rst +++ b/docs/source/api/serialization.rst @@ -8,5 +8,7 @@ Serialization rustworkx.node_link_json rustworkx.read_graphml + rustworkx.read_graphml_with_keys + rustworkx.write_graphml rustworkx.from_node_link_json_file rustworkx.parse_node_link_json From de6b1a49027018b95e64e2998031e7a6659cf153 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Mon, 16 Jun 2025 21:02:03 +0200 Subject: [PATCH 08/17] Remove read_graphml_with_keys / write_graphml for single graph only --- .../notes/write_graphml-624c10b6f7592ee1.yaml | 4 - rustworkx/__init__.py | 14 + rustworkx/__init__.pyi | 4 +- rustworkx/rustworkx.pyi | 16 +- src/graphml.rs | 241 +++++++++++------- src/lib.rs | 4 +- tests/test_graphml.py | 110 ++------ 7 files changed, 198 insertions(+), 195 deletions(-) diff --git a/releasenotes/notes/write_graphml-624c10b6f7592ee1.yaml b/releasenotes/notes/write_graphml-624c10b6f7592ee1.yaml index ca11a1b52e..e66a82aa16 100644 --- a/releasenotes/notes/write_graphml-624c10b6f7592ee1.yaml +++ b/releasenotes/notes/write_graphml-624c10b6f7592ee1.yaml @@ -3,10 +3,6 @@ features: - | Added a new function :func:`~rustworkx.write_graphml` that writes a list of rustworkx graph objects to a file in GraphML format. - - | - Added a new function :func:`~rustworkx.read_graphml_with_keys` - that reads a GraphML file and returns the list of defined keys - along with the list of rustworkx graph objects. other: - | When graphs read with :func:`~rustworkx.read_graphml` include IDs, diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index 9411c8c790..1ef878479d 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -2279,3 +2279,17 @@ def single_source_all_shortest_paths( For most use cases, consider using `dijkstra_shortest_paths` for a single shortest path, which runs much faster. """ raise TypeError(f"Invalid Input Type {type(graph)} for graph") + + +def write_graphml( + graph: PyGraph | PyDiGraph, + path: str, + /, + keys: list[GraphMLKey] | None = None, + compression: str | None = None, +) -> None: + """ """ + if isinstance(graph, PyGraph): + graph_write_graphml(graph, path, keys, compression) + return + digraph_write_graphml(graph, path, keys, compression) diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi index 6ce0696221..55cda12902 100644 --- a/rustworkx/__init__.pyi +++ b/rustworkx/__init__.pyi @@ -163,8 +163,8 @@ from .rustworkx import directed_barabasi_albert_graph as directed_barabasi_alber from .rustworkx import undirected_random_bipartite_graph as undirected_random_bipartite_graph from .rustworkx import directed_random_bipartite_graph as directed_random_bipartite_graph from .rustworkx import read_graphml as read_graphml -from .rustworkx import read_graphml_with_keys as read_graphml_with_keys -from .rustworkx import write_graphml as write_graphml +from .rustworkx import graph_write_graphml as graph_write_graphml +from .rustworkx import digraph_write_graphml as digraph_write_graphml from .rustworkx import digraph_node_link_json as digraph_node_link_json from .rustworkx import graph_node_link_json as graph_node_link_json from .rustworkx import from_node_link_json_file as from_node_link_json_file diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 9b46ce1f21..dc4a1d99fd 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -704,21 +704,23 @@ def directed_random_bipartite_graph( # Read Write -def read_graphml_with_keys( +def read_graphml( path: str, /, compression: str | None = ..., -) -> tuple[list[GraphMLKey], list[PyGraph | PyDiGraph]]: ... -def read_graphml( +) -> list[PyGraph | PyDiGraph]: ... +def graph_write_graphml( + graph: PyGraph, path: str, /, + keys: list[GraphMLKey] | None = ..., compression: str | None = ..., -) -> list[PyGraph | PyDiGraph]: ... -def write_graphml( - graphs: list[PyGraph | PyDiGraph], - keys: list[GraphMLKey], +) -> None: ... +def digraph_write_graphml( + graph: PyDiGraph, path: str, /, + keys: list[GraphMLKey] | None = ..., compression: str | None = ..., ) -> None: ... def digraph_node_link_json( diff --git a/src/graphml.rs b/src/graphml.rs index 8572f4edd2..64d3a03ba3 100644 --- a/src/graphml.rs +++ b/src/graphml.rs @@ -17,7 +17,7 @@ use std::convert::From; use std::ffi::OsStr; use std::fs::File; use std::io::{BufRead, BufReader, BufWriter}; -use std::iter::FromIterator; +use std::iter::{FromIterator, Iterator}; use std::num::{ParseFloatError, ParseIntError}; use std::path::Path; use std::str::ParseBoolError; @@ -26,7 +26,8 @@ use flate2::bufread::GzDecoder; use flate2::write::GzEncoder; use flate2::Compression; use hashbrown::{HashMap, HashSet}; -use indexmap::IndexMap; + +use indexmap::map::Entry; use quick_xml::events::{BytesDecl, BytesStart, BytesText, Event}; use quick_xml::name::QName; @@ -41,6 +42,8 @@ use pyo3::prelude::*; use pyo3::IntoPyObjectExt; use pyo3::PyErr; +use rustworkx_core::dictmap::{DictMap, InitWithHasher}; + use crate::{digraph::PyDiGraph, graph::PyGraph, StablePyGraph}; pub enum Error { @@ -148,7 +151,7 @@ impl TryFrom<&[u8]> for Domain { } #[pyclass(eq, name = "GraphMLType")] -#[derive(Clone, Copy, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum Type { Boolean, Int, @@ -201,6 +204,18 @@ impl Value { _ => Err(PyException::new_err("Expected string value for id")), } } + + fn ty(&self) -> Option { + match self { + Value::Boolean(_) => Some(Type::Boolean), + Value::Int(_) => Some(Type::Int), + Value::Float(_) => Some(Type::Float), + Value::Double(_) => Some(Type::Double), + Value::String(_) => Some(Type::String), + Value::Long(_) => Some(Type::Long), + Value::UnDefined => None, + } + } } impl<'py> IntoPyObject<'py> for Value { @@ -596,15 +611,6 @@ impl<'py> TryFrom<&Bound<'py, PyDiGraph>> for Graph { } } -impl<'py> FromPyObject<'py> for Graph { - fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { - match ob.downcast::() { - Ok(graph) => Graph::try_from(graph), - Err(_) => Graph::try_from(ob.downcast::()?), - } - } -} - enum State { Start, Graph, @@ -633,28 +639,28 @@ macro_rules! matches { struct GraphML { graphs: Vec, - key_for_nodes: IndexMap, - key_for_edges: IndexMap, - key_for_graph: IndexMap, - key_for_all: IndexMap, + key_for_nodes: DictMap, + key_for_edges: DictMap, + key_for_graph: DictMap, + key_for_all: DictMap, } impl Default for GraphML { fn default() -> Self { Self { graphs: Vec::new(), - key_for_nodes: IndexMap::new(), - key_for_edges: IndexMap::new(), - key_for_graph: IndexMap::new(), - key_for_all: IndexMap::new(), + key_for_nodes: DictMap::new(), + key_for_edges: DictMap::new(), + key_for_graph: DictMap::new(), + key_for_all: DictMap::new(), } } } /// Given maps from ids to keys, return a map from key name to ids and keys. fn build_key_name_map<'a>( - key_for_items: &'a IndexMap, - key_for_all: &'a IndexMap, + key_for_items: &'a DictMap, + key_for_all: &'a DictMap, ) -> HashMap { // `key_for_items` is iterated before `key_for_all` since last // items take precedence in the collected map. Similarly, @@ -668,6 +674,42 @@ fn build_key_name_map<'a>( .collect() } +fn infer_keys_for_attributes<'a>( + target: &mut DictMap, + attributes: impl Iterator, +) -> Result<(), Error> { + let mut inferred = DictMap::new(); + let mut counter = 0; + for (name, value) in attributes { + if let Some(ty) = value.ty() { + match inferred.entry(name.clone()) { + Entry::Vacant(entry) => { + counter += 1; + let id = format!("d{counter}"); + entry.insert(ty); + target.insert( + id, + Key { + name: name.to_string(), + ty, + default: Value::UnDefined, + }, + ); + } + Entry::Occupied(entry) => { + let other_ty = entry.get(); + if *other_ty != ty { + return Err(Error::InvalidDoc(format!( + "Mismatch type for key {name}: {ty:?} and {other_ty:?}" + ))); + } + } + } + } + } + Ok(()) +} + impl GraphML { fn create_graph<'a>(&mut self, element: &'a BytesStart<'a>) -> Result<(), Error> { let dir = match xml_attribute(element, b"edgedefault")?.as_bytes() { @@ -711,16 +753,7 @@ impl GraphML { Ok(()) } - fn get_keys(&self, domain: Domain) -> &IndexMap { - match domain { - Domain::Node => &self.key_for_nodes, - Domain::Edge => &self.key_for_edges, - Domain::Graph => &self.key_for_graph, - Domain::All => &self.key_for_all, - } - } - - fn get_keys_mut(&mut self, domain: Domain) -> &mut IndexMap { + fn get_keys_mut(&mut self, domain: Domain) -> &mut DictMap { match domain { Domain::Node => &mut self.key_for_nodes, Domain::Edge => &mut self.key_for_edges, @@ -1039,7 +1072,7 @@ impl GraphML { fn write_keys( writer: &mut Writer, key_for: &str, - map: &IndexMap, + map: &DictMap, ) -> Result<(), quick_xml::Error> { for (id, key) in map { let mut elem = BytesStart::new("key"); @@ -1131,6 +1164,72 @@ impl GraphML { } Ok(()) } + + fn infer_keys(&mut self) -> Result<(), Error> { + infer_keys_for_attributes( + &mut self.key_for_graph, + self.graphs + .iter() + .map(|graph| graph.attributes.iter()) + .flatten(), + )?; + infer_keys_for_attributes( + &mut self.key_for_nodes, + self.graphs + .iter() + .map(|graph| graph.nodes.iter()) + .flatten() + .map(|nodes| nodes.data.iter()) + .flatten(), + )?; + infer_keys_for_attributes( + &mut self.key_for_edges, + self.graphs + .iter() + .map(|graph| graph.edges.iter()) + .flatten() + .map(|edges| edges.data.iter()) + .flatten(), + )?; + Ok(()) + } + + fn set_keys<'py>( + &mut self, + py: Python<'py>, + keys: Vec>, + ) -> Result<(), pyo3::PyErr> { + for pykey in keys { + let key = pykey.borrow(py); + let bound_default = key.default.bind(py); + let default = if bound_default.is_none() { + Value::UnDefined + } else { + Value::from_pyobject(bound_default, key.ty)? + }; + self.get_keys_mut(key.domain).insert( + key.id.clone(), + Key { + name: key.name.clone(), + ty: key.ty, + default, + }, + ); + } + Ok(()) + } + + fn set_or_infer_keys<'py>( + &mut self, + py: Python<'py>, + keys: Option>>, + ) -> Result<(), pyo3::PyErr> { + match keys { + None => self.infer_keys()?, + Some(keys) => self.set_keys(py, keys)?, + } + Ok(()) + } } /// Read a list of graphs from a file in GraphML format. @@ -1202,74 +1301,36 @@ impl KeySpec { } } -type GraphMLWithKeys<'py> = PyResult<(Vec>, Vec>)>; - -/// Read a list of graphs from a file in GraphML format and return the pair containing the list of key definitions and the graph. +/// Write a graph to a file in GraphML format given the list of key definitions. #[pyfunction] -#[pyo3(signature=(path, compression=None),text_signature = "(path, /, compression=None)")] -pub fn read_graphml_with_keys<'py>( - py: Python<'py>, +#[pyo3(signature=(graph, path, keys, compression=None),text_signature = "(graph, path, /, keys=None, compression=None)")] +pub fn graph_write_graphml( + py: Python<'_>, + graph: Py, path: &str, + keys: Option>>, compression: Option, -) -> GraphMLWithKeys<'py> { - let graphml = GraphML::from_file(path, &compression.unwrap_or_default())?; - - let mut keys = Vec::new(); - for domain in [Domain::Node, Domain::Edge, Domain::Graph, Domain::All] { - for (id, key) in graphml.get_keys(domain) { - let default = key.default.clone().into_pyobject(py)?.into_any(); - keys.push(Py::new( - py, - KeySpec { - id: id.clone(), - domain, - name: key.name.clone(), - ty: key.ty, - default: default.into(), - }, - )?); - } - } - - let mut out = Vec::new(); - for graph in graphml.graphs { - out.push(graph.into_pyobject(py)?) - } - - Ok((keys, out)) +) -> PyResult<()> { + let mut graphml = GraphML::default(); + graphml.graphs.push(Graph::try_from(graph.bind(py))?); + graphml.set_or_infer_keys(py, keys)?; + graphml.to_file(path, &compression.unwrap_or_default())?; + Ok(()) } -/// Write a list of graphs to a file in GraphML format given the list of key definitions. +/// Write a digraph to a file in GraphML format given the list of key definitions. #[pyfunction] -#[pyo3(signature=(graphs, keys, path, compression=None),text_signature = "(graphs, keys, path, /, compression=None)")] -pub fn write_graphml( +#[pyo3(signature=(graph, path, keys, compression=None),text_signature = "(graph, path, /, keys=None, compression=None)")] +pub fn digraph_write_graphml( py: Python<'_>, - graphs: Vec>, - keys: Vec>, + graph: Py, path: &str, + keys: Option>>, compression: Option, ) -> PyResult<()> { let mut graphml = GraphML::default(); - for pykey in keys { - let key = pykey.borrow(py); - let bound_default = key.default.bind(py); - let default = if bound_default.is_none() { - Value::UnDefined - } else { - Value::from_pyobject(bound_default, key.ty)? - }; - graphml.get_keys_mut(key.domain).insert( - key.id.clone(), - Key { - name: key.name.clone(), - ty: key.ty, - default, - }, - ); - } - for graph in graphs { - graphml.graphs.push(Graph::extract_bound(graph.bind(py))?) - } + graphml.graphs.push(Graph::try_from(graph.bind(py))?); + graphml.set_or_infer_keys(py, keys)?; graphml.to_file(path, &compression.unwrap_or_default())?; Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 929efd14a3..6fff09c277 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -671,8 +671,8 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(connected_subgraphs))?; m.add_wrapped(wrap_pyfunction!(is_planar))?; m.add_wrapped(wrap_pyfunction!(read_graphml))?; - m.add_wrapped(wrap_pyfunction!(read_graphml_with_keys))?; - m.add_wrapped(wrap_pyfunction!(write_graphml))?; + m.add_wrapped(wrap_pyfunction!(graph_write_graphml))?; + m.add_wrapped(wrap_pyfunction!(digraph_write_graphml))?; m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(from_node_link_json_file))?; diff --git a/tests/test_graphml.py b/tests/test_graphml.py index 0c791563d9..83047c69c0 100644 --- a/tests/test_graphml.py +++ b/tests/test_graphml.py @@ -125,7 +125,7 @@ def test_write(self): 0.95, ), ] - rustworkx.write_graphml([graph], keys, fd.name) + rustworkx.write_graphml(graph, fd.name, keys=keys) graphml = rustworkx.read_graphml(fd.name) graph_reread = graphml[0] edges = [ @@ -133,25 +133,15 @@ def test_write(self): ] self.assertGraphEqual(graph_reread, graph.nodes(), edges, attrs={"id": "G"}, directed=False) - def test_write_with_keys(self): + def test_write_without_keys(self): graph_xml = self.graphml_xml_example() with tempfile.NamedTemporaryFile("wt") as fd: fd.write(graph_xml) fd.flush() - keys, graphml = rustworkx.read_graphml_with_keys(fd.name) - assert keys[0].id == "d0" - assert keys[0].domain == rustworkx.GraphMLDomain.Node - assert keys[0].name == "color" - assert keys[0].ty == rustworkx.GraphMLType.String - assert keys[0].default == "yellow" - assert keys[1].id == "d1" - assert keys[1].domain == rustworkx.GraphMLDomain.Edge - assert keys[1].name == "fidelity" - assert keys[1].ty == rustworkx.GraphMLType.Float - assert math.isclose(keys[1].default, 0.95, rel_tol=1e-7) + graphml = rustworkx.read_graphml(fd.name) graph = graphml[0] with tempfile.NamedTemporaryFile("wt") as fd: - rustworkx.write_graphml([graph], keys, fd.name) + rustworkx.write_graphml(graph, fd.name) graphml = rustworkx.read_graphml(fd.name) graph_reread = graphml[0] edges = [ @@ -214,24 +204,8 @@ def test_write_gzipped(self): graphml = rustworkx.read_graphml(fd.name) graph = graphml[0] with tempfile.NamedTemporaryFile("wt") as fd: - keys = [ - rustworkx.GraphMLKey( - "d0", - rustworkx.GraphMLDomain.Node, - "color", - rustworkx.GraphMLType.String, - "yellow", - ), - rustworkx.GraphMLKey( - "d1", - rustworkx.GraphMLDomain.Edge, - "fidelity", - rustworkx.GraphMLType.Float, - 0.95, - ), - ] newname = f"{fd.name}.gz" - rustworkx.write_graphml([graph], keys, newname) + rustworkx.write_graphml(graph, newname) graphml = rustworkx.read_graphml(newname) graph_reread = graphml[0] edges = [ @@ -294,44 +268,6 @@ def test_multiple_graphs_in_single_file(self): ] self.assertGraphEqual(graph, nodes, edges, attrs={"id": "H"}, directed=True) - def test_write_multiple_graphs(self): - graph_xml = self.graphml_xml_example_multiple_graphs() - with tempfile.NamedTemporaryFile("wt") as fd: - fd.write(graph_xml) - fd.flush() - graphml = rustworkx.read_graphml(fd.name) - with tempfile.NamedTemporaryFile("wt") as fd: - keys = [ - rustworkx.GraphMLKey( - "d0", - rustworkx.GraphMLDomain.Node, - "color", - rustworkx.GraphMLType.String, - "yellow", - ), - rustworkx.GraphMLKey( - "d1", - rustworkx.GraphMLDomain.Edge, - "fidelity", - rustworkx.GraphMLType.Float, - 0.95, - ), - ] - rustworkx.write_graphml(graphml, keys, fd.name) - graphml_reread = rustworkx.read_graphml(fd.name) - for graph, graph_reread in zip(graphml, graphml_reread): - edges = [ - (graph[s]["id"], graph[t]["id"], weight) - for s, t, weight in graph.weighted_edge_list() - ] - self.assertGraphEqual( - graph_reread, - graph.nodes(), - edges, - attrs=graph.attrs, - directed=isinstance(graph, rustworkx.PyDiGraph), - ) - def test_key_for_graph(self): graph_xml = self.HEADER.format( """ @@ -368,14 +304,9 @@ def test_write_key_for_graph(self): with tempfile.NamedTemporaryFile("wt") as fd: fd.write(graph_xml) fd.flush() - keys, graphml = rustworkx.read_graphml_with_keys(fd.name) - assert keys[0].id == "d0" - assert keys[0].domain == rustworkx.GraphMLDomain.Graph - assert keys[0].name == "test" - assert keys[0].ty == rustworkx.GraphMLType.Boolean - assert keys[0].default is None + graphml = rustworkx.read_graphml(fd.name) with tempfile.NamedTemporaryFile("wt") as fd: - rustworkx.write_graphml(graphml, keys, fd.name) + rustworkx.write_graphml(graphml[0], fd.name) graphml = rustworkx.read_graphml(fd.name) graph = graphml[0] nodes = [{"id": "n0"}] @@ -437,14 +368,18 @@ def test_write_key_for_all(self): with tempfile.NamedTemporaryFile("wt") as fd: fd.write(graph_xml) fd.flush() - keys, graphml = rustworkx.read_graphml_with_keys(fd.name) - assert keys[0].id == "d0" - assert keys[0].domain == rustworkx.GraphMLDomain.All - assert keys[0].name == "test" - assert keys[0].ty == rustworkx.GraphMLType.String - assert keys[0].default is None + graphml = rustworkx.read_graphml(fd.name) + keys = [ + rustworkx.GraphMLKey( + "d0", + rustworkx.GraphMLDomain.All, + "test", + rustworkx.GraphMLType.String, + None, + ) + ] with tempfile.NamedTemporaryFile("wt") as fd: - rustworkx.write_graphml(graphml, keys, fd.name) + rustworkx.write_graphml(graphml[0], fd.name, keys=keys) graphml = rustworkx.read_graphml(fd.name) graph = graphml[0] nodes = [ @@ -497,14 +432,9 @@ def test_write_key_undefined(self): with tempfile.NamedTemporaryFile("wt") as fd: fd.write(graph_xml) fd.flush() - keys, graphml = rustworkx.read_graphml_with_keys(fd.name) - assert keys[0].id == "d0" - assert keys[0].domain == rustworkx.GraphMLDomain.Node - assert keys[0].name == "test" - assert keys[0].ty == rustworkx.GraphMLType.Boolean - assert keys[0].default is None + graphml = rustworkx.read_graphml(fd.name) with tempfile.NamedTemporaryFile("wt") as fd: - rustworkx.write_graphml(graphml, keys, fd.name) + rustworkx.write_graphml(graphml[0], fd.name) graphml = rustworkx.read_graphml(fd.name) graph = graphml[0] nodes = [ From 5ef3e4f773e269c8df7a059fa57b16cac67df126 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Mon, 16 Jun 2025 21:27:45 +0200 Subject: [PATCH 09/17] Use `DictMap` everywhere Suggested by @IvanIsCoding: https://github.com/Qiskit/rustworkx/pull/1464#issuecomment-2970274047 --- src/graphml.rs | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/graphml.rs b/src/graphml.rs index 64d3a03ba3..0fbde600a8 100644 --- a/src/graphml.rs +++ b/src/graphml.rs @@ -25,7 +25,7 @@ use std::str::ParseBoolError; use flate2::bufread::GzDecoder; use flate2::write::GzEncoder; use flate2::Compression; -use hashbrown::{HashMap, HashSet}; +use hashbrown::HashSet; use indexmap::map::Entry; @@ -297,14 +297,14 @@ impl Key { struct Node { id: String, - data: HashMap, + data: DictMap, } struct Edge { id: Option, source: String, target: String, - data: HashMap, + data: DictMap, } enum Direction { @@ -317,7 +317,7 @@ struct Graph { dir: Direction, nodes: Vec, edges: Vec, - attributes: HashMap, + attributes: DictMap, } impl Graph { @@ -330,7 +330,7 @@ impl Graph { dir, nodes: Vec::new(), edges: Vec::new(), - attributes: HashMap::from_iter( + attributes: DictMap::from_iter( default_attrs.map(|key| (key.name.clone(), key.default.clone())), ), } @@ -342,7 +342,7 @@ impl Graph { { self.nodes.push(Node { id: xml_attribute(element, b"id")?, - data: HashMap::from_iter( + data: DictMap::from_iter( default_data.map(|key| (key.name.clone(), key.default.clone())), ), }); @@ -358,7 +358,7 @@ impl Graph { id: xml_attribute(element, b"id").ok(), source: xml_attribute(element, b"source")?, target: xml_attribute(element, b"target")?, - data: HashMap::from_iter( + data: DictMap::from_iter( default_data.map(|key| (key.name.clone(), key.default.clone())), ), }); @@ -395,7 +395,7 @@ impl<'py> IntoPyObject<'py> for Graph { if let Some(id) = self.id { self.attributes.insert(String::from("id"), Value::String(id.clone())); } - let mut mapping = HashMap::with_capacity(self.nodes.len()); + let mut mapping = DictMap::with_capacity(self.nodes.len()); for mut node in self.nodes { // Write the node id from GraphML doc into the node data payload // since in rustworkx nodes are indexed by an unsigned integer and @@ -460,14 +460,14 @@ impl<'py> IntoPyObject<'py> for Graph { } struct GraphElementInfo { - attributes: HashMap, + attributes: DictMap, id: Option, } impl Default for GraphElementInfo { fn default() -> Self { Self { - attributes: HashMap::new(), + attributes: DictMap::new(), id: None, } } @@ -489,11 +489,11 @@ impl GraphElementInfos { fn insert(&mut self, py: Python<'_>, index: Index, weight: Option<&Py>) -> PyResult<()> { let element_info = weight .and_then(|data| { - data.extract::>(py) + data.extract::>(py) .ok() .map(|mut attributes| -> PyResult { let id = attributes - .remove_entry("id") + .shift_remove_entry("id") .map(|(id, value)| -> PyResult> { let value_str = value.to_id()?; if self.id_taken.contains(value_str) { @@ -524,12 +524,12 @@ impl Graph { pygraph: &StablePyGraph, attrs: &PyObject, ) -> PyResult { - let mut attrs: Option> = attrs.extract(py).ok(); + let mut attrs: Option> = attrs.extract(py).ok(); let id = attrs .as_mut() .and_then(|attributes| { attributes - .remove("id") + .shift_remove("id") .map(|v| v.to_id().map(|id| id.to_string())) }) .transpose()?; @@ -545,7 +545,7 @@ impl Graph { for edge_index in pygraph.edge_indices() { edge_infos.insert(py, edge_index, pygraph.edge_weight(edge_index))?; } - let mut node_ids = HashMap::new(); + let mut node_ids = DictMap::new(); let mut fresh_index_counter = 0; for (node_index, element_info) in node_infos.vec { let id = element_info.id.unwrap_or_else(|| loop { @@ -661,7 +661,7 @@ impl Default for GraphML { fn build_key_name_map<'a>( key_for_items: &'a DictMap, key_for_all: &'a DictMap, -) -> HashMap { +) -> DictMap { // `key_for_items` is iterated before `key_for_all` since last // items take precedence in the collected map. Similarly, // the map `for_all` take precedence over kind-specific maps in @@ -1032,8 +1032,8 @@ impl GraphML { fn write_data( writer: &mut Writer, - keys: &HashMap, - data: &HashMap, + keys: &DictMap, + data: &DictMap, ) -> Result<(), Error> { for (key_name, value) in data { let (id, key) = keys @@ -1055,9 +1055,9 @@ impl GraphML { fn write_elem_data( writer: &mut Writer, - keys: &HashMap, + keys: &DictMap, elem: BytesStart, - data: &HashMap, + data: &DictMap, ) -> Result<(), Error> { if data.is_empty() { writer.write_event(Event::Empty(elem))?; @@ -1110,11 +1110,11 @@ impl GraphML { Self::write_keys(writer, "edge", &self.key_for_edges)?; Self::write_keys(writer, "graph", &self.key_for_graph)?; Self::write_keys(writer, "all", &self.key_for_all)?; - let graph_keys: HashMap = + let graph_keys: DictMap = build_key_name_map(&self.key_for_graph, &self.key_for_all); - let node_keys: HashMap = + let node_keys: DictMap = build_key_name_map(&self.key_for_nodes, &self.key_for_all); - let edge_keys: HashMap = + let edge_keys: DictMap = build_key_name_map(&self.key_for_edges, &self.key_for_all); for graph in self.graphs.iter() { let mut elem = BytesStart::new("graph"); From 2d51ff352eb6a7911b2000e3c3442b6213f44dd2 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Mon, 16 Jun 2025 22:38:31 +0200 Subject: [PATCH 10/17] rustfmt and clippy --- src/graphml.rs | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/src/graphml.rs b/src/graphml.rs index 0fbde600a8..454a6adfea 100644 --- a/src/graphml.rs +++ b/src/graphml.rs @@ -489,9 +489,8 @@ impl GraphElementInfos { fn insert(&mut self, py: Python<'_>, index: Index, weight: Option<&Py>) -> PyResult<()> { let element_info = weight .and_then(|data| { - data.extract::>(py) - .ok() - .map(|mut attributes| -> PyResult { + data.extract::>(py).ok().map( + |mut attributes| -> PyResult { let id = attributes .shift_remove_entry("id") .map(|(id, value)| -> PyResult> { @@ -509,7 +508,8 @@ impl GraphElementInfos { attributes: attributes.into_iter().collect(), id, }) - }) + }, + ) }) .unwrap_or_else(|| Ok(GraphElementInfo::default()))?; self.vec.push((index, element_info)); @@ -1168,37 +1168,26 @@ impl GraphML { fn infer_keys(&mut self) -> Result<(), Error> { infer_keys_for_attributes( &mut self.key_for_graph, - self.graphs - .iter() - .map(|graph| graph.attributes.iter()) - .flatten(), + self.graphs.iter().flat_map(|graph| graph.attributes.iter()), )?; infer_keys_for_attributes( &mut self.key_for_nodes, self.graphs .iter() - .map(|graph| graph.nodes.iter()) - .flatten() - .map(|nodes| nodes.data.iter()) - .flatten(), + .flat_map(|graph| graph.nodes.iter()) + .flat_map(|nodes| nodes.data.iter()), )?; infer_keys_for_attributes( &mut self.key_for_edges, self.graphs .iter() - .map(|graph| graph.edges.iter()) - .flatten() - .map(|edges| edges.data.iter()) - .flatten(), + .flat_map(|graph| graph.edges.iter()) + .flat_map(|edges| edges.data.iter()), )?; Ok(()) } - fn set_keys<'py>( - &mut self, - py: Python<'py>, - keys: Vec>, - ) -> Result<(), pyo3::PyErr> { + fn set_keys(&mut self, py: Python<'_>, keys: Vec>) -> Result<(), pyo3::PyErr> { for pykey in keys { let key = pykey.borrow(py); let bound_default = key.default.bind(py); @@ -1219,9 +1208,9 @@ impl GraphML { Ok(()) } - fn set_or_infer_keys<'py>( + fn set_or_infer_keys( &mut self, - py: Python<'py>, + py: Python<'_>, keys: Option>>, ) -> Result<(), pyo3::PyErr> { match keys { From dacb57020bcfc56f78c4b408ab21ef14be2293a9 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Mon, 16 Jun 2025 23:08:50 +0200 Subject: [PATCH 11/17] Remove unused math module (ruff check) --- tests/test_graphml.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_graphml.py b/tests/test_graphml.py index 83047c69c0..ccad0b44f2 100644 --- a/tests/test_graphml.py +++ b/tests/test_graphml.py @@ -10,7 +10,6 @@ # License for the specific language governing permissions and limitations # under the License. -import math import unittest import tempfile import gzip From afe7e9bb1d4726952818c288388d20da1b0e998c Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Tue, 17 Jun 2025 00:33:45 +0200 Subject: [PATCH 12/17] Use `from __future__ import annotations` for Python <3.10 --- rustworkx/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index 1ef878479d..0552e88083 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -6,6 +6,7 @@ # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. +from __future__ import annotations import importlib import sys From f45926d5194bbdf0d3171e35339b0aa89d165527 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Tue, 17 Jun 2025 07:51:13 +0200 Subject: [PATCH 13/17] Add stub for `write_graphml` --- rustworkx/__init__.pyi | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi index 55cda12902..a180bd9f72 100644 --- a/rustworkx/__init__.pyi +++ b/rustworkx/__init__.pyi @@ -165,6 +165,7 @@ from .rustworkx import directed_random_bipartite_graph as directed_random_bipart from .rustworkx import read_graphml as read_graphml from .rustworkx import graph_write_graphml as graph_write_graphml from .rustworkx import digraph_write_graphml as digraph_write_graphml +from .rustworkx import GraphMLKey as GraphMLKey from .rustworkx import digraph_node_link_json as digraph_node_link_json from .rustworkx import graph_node_link_json as graph_node_link_json from .rustworkx import from_node_link_json_file as from_node_link_json_file @@ -664,3 +665,10 @@ def is_bipartite(graph: PyGraph[_S, _T] | PyDiGraph[_S, _T]) -> bool: ... def condensation( graph: PyDiGraph | PyGraph, /, sccs: list[int] | None = ... ) -> PyDiGraph | PyGraph: ... +def write_graphml( + graph: PyGraph | PyDiGraph, + path: str, + /, + keys: list[GraphMLKey] | None = ..., + compression: str | None = ..., +) -> None: ... From 6ca5c1a2f9ea4abc681c9bfd8793d5987f808ecd Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Tue, 17 Jun 2025 07:51:28 +0200 Subject: [PATCH 14/17] Remove `read_graphml_with_keys` from documentation --- docs/source/api/serialization.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/api/serialization.rst b/docs/source/api/serialization.rst index 338246f153..88cbeac9c3 100644 --- a/docs/source/api/serialization.rst +++ b/docs/source/api/serialization.rst @@ -8,7 +8,6 @@ Serialization rustworkx.node_link_json rustworkx.read_graphml - rustworkx.read_graphml_with_keys rustworkx.write_graphml rustworkx.from_node_link_json_file rustworkx.parse_node_link_json From 915769c2bdd1a37c663a87ab846e4286e1a70605 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Date: Tue, 17 Jun 2025 07:54:05 -0400 Subject: [PATCH 15/17] Apply suggestions from code review --- rustworkx/__init__.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index 0552e88083..d00fdb81cf 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -6,7 +6,6 @@ # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. -from __future__ import annotations import importlib import sys @@ -2282,15 +2281,7 @@ def single_source_all_shortest_paths( raise TypeError(f"Invalid Input Type {type(graph)} for graph") -def write_graphml( - graph: PyGraph | PyDiGraph, - path: str, - /, - keys: list[GraphMLKey] | None = None, - compression: str | None = None, -) -> None: +@_rustworkx_dispatch +def write_graphml(graph, path, /, keys, compression): """ """ - if isinstance(graph, PyGraph): - graph_write_graphml(graph, path, keys, compression) - return - digraph_write_graphml(graph, path, keys, compression) + raise TypeError(f"Invalid Input Type {type(graph)} for graph") From 019d4bd2e168d85c568d5202fccad6795e853a50 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Date: Tue, 17 Jun 2025 08:22:27 -0400 Subject: [PATCH 16/17] Update rustworkx/__init__.py --- rustworkx/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index d00fdb81cf..a8611e7c24 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -2282,6 +2282,6 @@ def single_source_all_shortest_paths( @_rustworkx_dispatch -def write_graphml(graph, path, /, keys, compression): +def write_graphml(graph, path, /, keys=None, compression=None): """ """ raise TypeError(f"Invalid Input Type {type(graph)} for graph") From ad5ae6cbd84984558c8937ff0f0786ad7e2e2534 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Date: Tue, 17 Jun 2025 08:44:03 -0400 Subject: [PATCH 17/17] Apply suggestions from code review --- src/graphml.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/graphml.rs b/src/graphml.rs index 454a6adfea..5531391b45 100644 --- a/src/graphml.rs +++ b/src/graphml.rs @@ -1292,7 +1292,7 @@ impl KeySpec { /// Write a graph to a file in GraphML format given the list of key definitions. #[pyfunction] -#[pyo3(signature=(graph, path, keys, compression=None),text_signature = "(graph, path, /, keys=None, compression=None)")] +#[pyo3(signature=(graph, path, keys=None, compression=None),text_signature = "(graph, path, /, keys=None, compression=None)")] pub fn graph_write_graphml( py: Python<'_>, graph: Py, @@ -1309,7 +1309,7 @@ pub fn graph_write_graphml( /// Write a digraph to a file in GraphML format given the list of key definitions. #[pyfunction] -#[pyo3(signature=(graph, path, keys, compression=None),text_signature = "(graph, path, /, keys=None, compression=None)")] +#[pyo3(signature=(graph, path, keys=None, compression=None),text_signature = "(graph, path, /, keys=None, compression=None)")] pub fn digraph_write_graphml( py: Python<'_>, graph: Py,