Skip to content

Commit a2d15f5

Browse files
author
dagou
committed
custom parser
1 parent b9cbf77 commit a2d15f5

File tree

8 files changed

+194
-29
lines changed

8 files changed

+194
-29
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "gtdb_tree"
3-
version = "0.1.8"
3+
version = "0.1.9"
44
edition = "2021"
55
description = "A library for parsing Newick format files, especially GTDB tree files."
66
homepage = "https://github.com/eric9n/gtdb_tree"

README.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Add this crate to your `Cargo.toml`:
1313

1414
```toml
1515
[dependencies]
16-
gtdb_tree = "0.1.0"
16+
gtdb_tree = "0.1.9"
1717
```
1818

1919
## Usage
@@ -48,3 +48,27 @@ result = gtdb_tree.parse_tree("((A:0.1,B:0.2):0.3,C:0.4);")
4848
print(result)
4949
```
5050

51+
## Advanced Usage
52+
### Custom Node Parser
53+
You can provide a custom parser function to handle special node formats:
54+
55+
```python
56+
import gtdb_tree
57+
58+
def custom_parser(node_str):
59+
# Custom parsing logic
60+
name, length = node_str.split(':')
61+
return name, 100.0, float(length) # name, bootstrap, length
62+
63+
result = gtdb_tree.parse_tree("((A:0.1,B:0.2):0.3,C:0.4);", custom_parser=custom_parser)
64+
print(result)
65+
```
66+
67+
## Working with Node Objects
68+
## Each Node object in the result has the following attributes:
69+
70+
* id: Unique identifier for the node
71+
* name: Name of the node
72+
* bootstrap: Bootstrap value (if available)
73+
* length: Branch length
74+
* parent: ID of the parent node

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ features = ["python"]
99

1010
[project]
1111
name = "gtdb_tree"
12-
version = "0.1.8"
12+
version = "0.1.9"
1313
description = "A Python package for parsing GTDB trees using Rust"
1414
readme = "README.md"
1515
authors = [{ name = "dagou", email = "[email protected]" }]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
setup(
55
name="gtdb_tree",
6-
version="0.1.8",
6+
version="0.1.9",
77
rust_extensions=[RustExtension("gtdb_tree.gtdb_tree", binding=Binding.PyO3)],
88
packages=["gtdb_tree"],
99
# rust extensions are not zip safe, just like C-extensions.

src/node.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ pub enum ParseError {
2222
UnexpectedEndOfInput,
2323
#[allow(dead_code)]
2424
InvalidFormat(String),
25+
PythonError(String),
2526
}
2627

2728
impl std::fmt::Display for ParseError {
2829
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2930
match self {
3031
ParseError::UnexpectedEndOfInput => write!(f, "Unexpected end of input"),
3132
ParseError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg),
33+
ParseError::PythonError(msg) => write!(f, "Python error: {}", msg),
3234
}
3335
}
3436
}

src/python.rs

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
use crate::node::Node as RustNode;
2-
use crate::tree;
2+
use crate::node::ParseError;
3+
use crate::tree::{self, NodeParser};
4+
use std::convert::From;
5+
use std::sync::Arc;
6+
7+
// 添加一个从 PyErr 到 ParseError 的转换实现
8+
impl From<PyErr> for ParseError {
9+
fn from(err: PyErr) -> Self {
10+
ParseError::PythonError(err.to_string())
11+
}
12+
}
313

414
#[cfg(feature = "python")]
515
use pyo3::prelude::*;
@@ -39,10 +49,94 @@ impl Node {
3949
}
4050
}
4151

52+
// #[cfg(feature = "python")]
53+
// #[pyfunction]
54+
// pub fn parse_tree(newick_str: &str) -> PyResult<Vec<Node>> {
55+
// tree::parse_tree(newick_str)
56+
// .map(|rust_nodes| rust_nodes.into_iter().map(|rn| Node { node: rn }).collect())
57+
// .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
58+
// }
59+
4260
#[cfg(feature = "python")]
4361
#[pyfunction]
44-
pub fn parse_tree(newick_str: &str) -> PyResult<Vec<Node>> {
45-
tree::parse_tree(newick_str)
62+
#[pyo3(signature = (newick_str, custom_parser = None))]
63+
#[pyo3(text_signature = "(newick_str, custom_parser=None)")]
64+
/// Parse a Newick format string into a list of Node objects.
65+
///
66+
/// This function takes a Newick format string and optionally a custom parser function,
67+
/// and returns a list of Node objects representing the phylogenetic tree.
68+
///
69+
/// Parameters:
70+
/// -----------
71+
/// newick_str : str
72+
/// The Newick format string representing the phylogenetic tree.
73+
/// custom_parser : callable, optional
74+
/// A custom parsing function for node information. If not provided, the default parser will be used.
75+
/// The custom parser should have the following signature:
76+
///
77+
/// def custom_parser(node_str: str) -> Tuple[str, float, float]:
78+
/// '''
79+
/// Parse a node string and return name, bootstrap, and length.
80+
///
81+
/// Parameters:
82+
/// -----------
83+
/// node_str : str
84+
/// The node string to parse.
85+
///
86+
/// Returns:
87+
/// --------
88+
/// Tuple[str, float, float]
89+
/// A tuple containing (name, bootstrap, length) for the node.
90+
/// '''
91+
/// # Your custom parsing logic here
92+
/// return name, bootstrap, length
93+
///
94+
/// Returns:
95+
/// --------
96+
/// List[Node]
97+
/// A list of Node objects representing the parsed phylogenetic tree.
98+
///
99+
/// Raises:
100+
/// -------
101+
/// ValueError
102+
/// If the Newick string is invalid or parsing fails.
103+
///
104+
/// Example:
105+
/// --------
106+
/// >>> newick_str = "(A:0.1,B:0.2,(C:0.3,D:0.4)70:0.5);"
107+
/// >>> nodes = parse_tree(newick_str)
108+
/// >>>
109+
/// >>> # Using a custom parser
110+
/// >>> def my_parser(node_str):
111+
/// ... parts = node_str.split(':')
112+
/// ... name = parts[0]
113+
/// ... length = float(parts[1]) if len(parts) > 1 else 0.0
114+
/// ... return name, 100.0, length # Always set bootstrap to 100.0
115+
/// >>>
116+
/// >>> nodes_custom = parse_tree(newick_str, custom_parser=my_parser)
117+
pub fn parse_tree(
118+
_py: Python,
119+
newick_str: &str,
120+
custom_parser: Option<PyObject>,
121+
) -> PyResult<Vec<Node>> {
122+
let parser = match custom_parser {
123+
Some(py_func) => {
124+
let py_func = Arc::new(py_func);
125+
NodeParser::Custom(Box::new(
126+
move |node_str: &str| -> Result<(String, f64, f64), ParseError> {
127+
Python::with_gil(|py| {
128+
let result = py_func.call1(py, (node_str,))?;
129+
let (name, bootstrap, length): (String, f64, f64) = result.extract(py)?;
130+
Ok((name, bootstrap, length))
131+
})
132+
.map_err(|e: PyErr| ParseError::PythonError(e.to_string()))
133+
},
134+
))
135+
}
136+
None => NodeParser::Default,
137+
};
138+
139+
tree::parse_tree(newick_str, parser)
46140
.map(|rust_nodes| rust_nodes.into_iter().map(|rn| Node { node: rn }).collect())
47141
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
48142
}

src/tree.rs

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,26 @@
11
use crate::node::{Node, ParseError};
22
use memchr::memchr2;
33

4+
// 修改 NodeParser 枚举以使用 trait 对象
5+
pub enum NodeParser {
6+
Default,
7+
Custom(Box<dyn Fn(&str) -> Result<(String, f64, f64), ParseError> + Send>),
8+
}
9+
10+
impl Default for NodeParser {
11+
fn default() -> Self {
12+
NodeParser::Default
13+
}
14+
}
15+
16+
/// Parse the label of a node from a Newick tree string.
17+
///
18+
/// This function takes a byte slice representing a node in a Newick tree string,
19+
/// and returns the name and length of the node as a tuple.
20+
///
21+
/// # Arguments
22+
///
23+
/// * `label` - A string slice representing the node in a Newick tree string.
424
fn parse_label(label: &str) -> Result<(String, f64), ParseError> {
525
let label = label.trim_end_matches(";").trim_matches('\'').to_string();
626

@@ -31,27 +51,41 @@ fn parse_label(label: &str) -> Result<(String, f64), ParseError> {
3151
///
3252
/// # Arguments
3353
///
34-
/// * `node_bytes` - A byte slice representing the node in a Newick tree string.
54+
/// * `node_str` - A string slice representing the node in a Newick tree string.
3555
///
3656
/// # Returns
3757
///
38-
/// Returns a `Result` containing a tuple of the name and length on success,
58+
/// Returns a `Result` containing a tuple of the name, bootstrap, and length on success,
3959
/// or an `Err(ParseError)` on failure.
4060
///
4161
/// # Example
4262
///
4363
/// ```
44-
/// use gtdb_tree::tree::parse_node;
64+
/// use gtdb_tree::tree::parse_node_default;
4565
///
46-
/// let node_bytes = b"A:0.1";
47-
/// let (name, bootstrap, length) = parse_node(node_bytes).unwrap();
66+
/// let node_str = "A:0.1";
67+
/// let (name, bootstrap, length) = parse_node_default(node_str).unwrap();
4868
/// assert_eq!(name, "A");
4969
/// assert_eq!(bootstrap, 0.0);
5070
/// assert_eq!(length, 0.1);
5171
/// ```
52-
pub fn parse_node(node_bytes: &[u8]) -> Result<(String, f64, f64), ParseError> {
53-
let node_str = std::str::from_utf8(node_bytes).expect("UTF-8 sequence");
54-
// gtdb
72+
pub fn parse_node_default(node_str: &str) -> Result<(String, f64, f64), ParseError> {
73+
// 处理 "AD:0.03347[21.0]" 格式
74+
if let Some((name_length, bootstrap_str)) = node_str.rsplit_once('[') {
75+
if let Some((name, length_str)) = name_length.rsplit_once(':') {
76+
let bootstrap = bootstrap_str
77+
.trim_end_matches(']')
78+
.parse::<f64>()
79+
.map_err(|_| {
80+
ParseError::InvalidFormat(format!("Invalid bootstrap value: {}", bootstrap_str))
81+
})?;
82+
let length = length_str.parse::<f64>().map_err(|_| {
83+
ParseError::InvalidFormat(format!("Invalid length value: {}", length_str))
84+
})?;
85+
return Ok((name.to_string(), bootstrap, length));
86+
}
87+
}
88+
5589
// Check if node_str contains single quotes and ensure they are together
5690
if node_str.matches('\'').count() % 2 != 0 {
5791
return Err(ParseError::InvalidFormat(format!(
@@ -102,12 +136,13 @@ pub fn parse_node(node_bytes: &[u8]) -> Result<(String, f64, f64), ParseError> {
102136
///
103137
/// ```
104138
/// use gtdb_tree::tree::parse_tree;
139+
/// use gtdb_tree::tree::NodeParser;
105140
///
106141
/// let newick_str = "((A:0.1,B:0.2):0.3,C:0.4);";
107-
/// let nodes = parse_tree(newick_str).unwrap();
142+
/// let nodes = parse_tree(newick_str, NodeParser::default()).unwrap();
108143
/// assert_eq!(nodes.len(), 5);
109144
/// ```
110-
pub fn parse_tree(newick_str: &str) -> Result<Vec<Node>, ParseError> {
145+
pub fn parse_tree(newick_str: &str, parser: NodeParser) -> Result<Vec<Node>, ParseError> {
111146
let mut nodes: Vec<Node> = Vec::new();
112147
let mut pos = 0;
113148

@@ -132,7 +167,16 @@ pub fn parse_tree(newick_str: &str) -> Result<Vec<Node>, ParseError> {
132167
let end_pos = memchr2(b',', b')', &bytes[pos..]).unwrap_or(bytes_len - pos);
133168
let node_end_pos = pos + end_pos;
134169
let node_bytes = &bytes[pos..node_end_pos];
135-
let (name, bootstrap, length) = parse_node(node_bytes)?;
170+
171+
let mut node_str = std::str::from_utf8(node_bytes).expect("UTF-8 sequence");
172+
if node_end_pos == bytes_len {
173+
node_str = node_str.trim_end_matches(';');
174+
}
175+
let (name, bootstrap, length) = match &parser {
176+
NodeParser::Default => parse_node_default(node_str)?,
177+
NodeParser::Custom(func) => func(node_str)?,
178+
};
179+
136180
let node_id = if &bytes[pos - 1] == &b')' {
137181
stack.pop().unwrap_or(0)
138182
} else {
@@ -161,8 +205,9 @@ mod tests {
161205
use super::*;
162206

163207
#[test]
164-
fn test_parse_tree() {
208+
fn test_parse_tree() -> Result<(), ParseError> {
165209
let test_cases = vec![
210+
"(A:0.1,B:0.2,(C:0.3,D:0.4)AD:0.03347[21.0]);",
166211
"((A:0.1,B:0.2)'56:F;H;':0.3,C:0.4);",
167212
"(,,(,));", // no nodes are named
168213
"(A,B,(C,D));", // leaf nodes are named
@@ -175,15 +220,15 @@ mod tests {
175220
];
176221

177222
for newick_str in test_cases {
178-
match parse_tree(newick_str) {
179-
Ok(nodes) => println!(
180-
"Parsed nodes for '{}': {:?}, len: {}",
181-
newick_str,
182-
nodes,
183-
nodes.len()
184-
),
185-
Err(e) => println!("Error parsing '{}': {:?}", newick_str, e),
186-
}
223+
let nodes = parse_tree(newick_str, NodeParser::default())?;
224+
println!(
225+
"Parsed nodes for '{}': {:?}, len: {}",
226+
newick_str,
227+
nodes,
228+
nodes.len()
229+
)
187230
}
231+
232+
Ok(())
188233
}
189234
}

0 commit comments

Comments
 (0)