Skip to content

Commit 2e36f69

Browse files
authored
Merge pull request #3290 from shnizzedy/fix/gantt-chart
FIX: Restore generate_gantt_chart functionality
2 parents 5dc8701 + 7223914 commit 2e36f69

File tree

3 files changed

+111
-17
lines changed

3 files changed

+111
-17
lines changed

nipype/info.py

+1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def get_nipype_gitversion():
153153

154154
TESTS_REQUIRES = [
155155
"coverage >= 5.2.1",
156+
"pandas >= 1.5.0",
156157
"pytest >= 6",
157158
"pytest-cov >=2.11",
158159
"pytest-env",

nipype/pipeline/plugins/tests/test_callback.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
3-
"""Tests for workflow callbacks
4-
"""
3+
"""Tests for workflow callbacks."""
4+
from pathlib import Path
55
from time import sleep
6+
import json
67
import pytest
78
import nipype.interfaces.utility as niu
89
import nipype.pipeline.engine as pe
@@ -60,3 +61,51 @@ def test_callback_exception(tmpdir, plugin, stop_on_first_crash):
6061

6162
sleep(0.5) # Wait for callback to be called (python 2.7)
6263
assert so.statuses == [("f_node", "start"), ("f_node", "exception")]
64+
65+
66+
@pytest.mark.parametrize("plugin", ["Linear", "MultiProc", "LegacyMultiProc"])
67+
def test_callback_gantt(tmp_path: Path, plugin: str) -> None:
68+
import logging
69+
70+
from os import path
71+
72+
from nipype.utils.profiler import log_nodes_cb
73+
from nipype.utils.draw_gantt_chart import generate_gantt_chart
74+
75+
log_filename = tmp_path / "callback.log"
76+
logger = logging.getLogger("callback")
77+
logger.setLevel(logging.DEBUG)
78+
handler = logging.FileHandler(log_filename)
79+
logger.addHandler(handler)
80+
81+
# create workflow
82+
wf = pe.Workflow(name="test", base_dir=str(tmp_path))
83+
f_node = pe.Node(
84+
niu.Function(function=func, input_names=[], output_names=[]), name="f_node"
85+
)
86+
wf.add_nodes([f_node])
87+
wf.config["execution"] = {"crashdump_dir": wf.base_dir, "poll_sleep_duration": 2}
88+
89+
plugin_args = {"status_callback": log_nodes_cb}
90+
if plugin != "Linear":
91+
plugin_args["n_procs"] = 8
92+
wf.run(plugin=plugin, plugin_args=plugin_args)
93+
94+
with open(log_filename, "r") as _f:
95+
loglines = _f.readlines()
96+
97+
# test missing duration
98+
first_line = json.loads(loglines[0])
99+
if "duration" in first_line:
100+
del first_line["duration"]
101+
loglines[0] = f"{json.dumps(first_line)}\n"
102+
103+
# test duplicate timestamp warning
104+
loglines.append(loglines[-1])
105+
106+
with open(log_filename, "w") as _f:
107+
_f.write("".join(loglines))
108+
109+
with pytest.warns(Warning):
110+
generate_gantt_chart(str(log_filename), 1 if plugin == "Linear" else 8)
111+
assert (tmp_path / "callback.log.html").exists()

nipype/utils/draw_gantt_chart.py

+59-15
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import random
99
import datetime
1010
import simplejson as json
11+
from typing import Union
1112

1213
from collections import OrderedDict
14+
from warnings import warn
1315

1416
# Pandas
1517
try:
@@ -66,9 +68,9 @@ def create_event_dict(start_time, nodes_list):
6668
finish_delta = (node["finish"] - start_time).total_seconds()
6769

6870
# Populate dictionary
69-
if events.get(start_delta) or events.get(finish_delta):
71+
if events.get(start_delta):
7072
err_msg = "Event logged twice or events started at exact same time!"
71-
raise KeyError(err_msg)
73+
warn(err_msg, category=Warning)
7274
events[start_delta] = start_node
7375
events[finish_delta] = finish_node
7476

@@ -101,15 +103,25 @@ def log_to_dict(logfile):
101103

102104
nodes_list = [json.loads(l) for l in lines]
103105

104-
def _convert_string_to_datetime(datestring):
105-
try:
106+
def _convert_string_to_datetime(
107+
datestring: Union[str, datetime.datetime],
108+
) -> datetime.datetime:
109+
"""Convert a date string to a datetime object."""
110+
if isinstance(datestring, datetime.datetime):
111+
datetime_object = datestring
112+
elif isinstance(datestring, str):
113+
date_format = (
114+
"%Y-%m-%dT%H:%M:%S.%f%z"
115+
if "+" in datestring
116+
else "%Y-%m-%dT%H:%M:%S.%f"
117+
)
106118
datetime_object: datetime.datetime = datetime.datetime.strptime(
107-
datestring, "%Y-%m-%dT%H:%M:%S.%f"
119+
datestring, date_format
108120
)
109-
return datetime_object
110-
except Exception as _:
111-
pass
112-
return datestring
121+
else:
122+
msg = f"{datestring} is not a string or datetime object."
123+
raise TypeError(msg)
124+
return datetime_object
113125

114126
date_object_node_list: list = list()
115127
for n in nodes_list:
@@ -154,12 +166,18 @@ def calculate_resource_timeseries(events, resource):
154166
# Iterate through the events
155167
for _, event in sorted(events.items()):
156168
if event["event"] == "start":
157-
if resource in event and event[resource] != "Unknown":
158-
all_res += float(event[resource])
169+
if resource in event:
170+
try:
171+
all_res += float(event[resource])
172+
except ValueError:
173+
continue
159174
current_time = event["start"]
160175
elif event["event"] == "finish":
161-
if resource in event and event[resource] != "Unknown":
162-
all_res -= float(event[resource])
176+
if resource in event:
177+
try:
178+
all_res -= float(event[resource])
179+
except ValueError:
180+
continue
163181
current_time = event["finish"]
164182
res[current_time] = all_res
165183

@@ -284,7 +302,14 @@ def draw_nodes(start, nodes_list, cores, minute_scale, space_between_minutes, co
284302
# Left
285303
left = 60
286304
for core in range(len(end_times)):
287-
if end_times[core] < node_start:
305+
try:
306+
end_time_condition = end_times[core] < node_start
307+
except TypeError:
308+
# if one has a timezone and one does not
309+
end_time_condition = end_times[core].replace(
310+
tzinfo=None
311+
) < node_start.replace(tzinfo=None)
312+
if end_time_condition:
288313
left += core * 30
289314
end_times[core] = datetime.datetime(
290315
node_finish.year,
@@ -307,7 +332,7 @@ def draw_nodes(start, nodes_list, cores, minute_scale, space_between_minutes, co
307332
"offset": offset,
308333
"scale_duration": scale_duration,
309334
"color": color,
310-
"node_name": node["name"],
335+
"node_name": node.get("name", node.get("id", "")),
311336
"node_dur": node["duration"] / 60.0,
312337
"node_start": node_start.strftime("%Y-%m-%d %H:%M:%S"),
313338
"node_finish": node_finish.strftime("%Y-%m-%d %H:%M:%S"),
@@ -527,6 +552,25 @@ def generate_gantt_chart(
527552
# Read in json-log to get list of node dicts
528553
nodes_list = log_to_dict(logfile)
529554

555+
# Only include nodes with timing information, and convert timestamps
556+
# from strings to datetimes
557+
nodes_list = [
558+
{
559+
k: (
560+
datetime.datetime.strptime(i[k], "%Y-%m-%dT%H:%M:%S.%f")
561+
if k in {"start", "finish"} and isinstance(i[k], str)
562+
else i[k]
563+
)
564+
for k in i
565+
}
566+
for i in nodes_list
567+
if "start" in i and "finish" in i
568+
]
569+
570+
for node in nodes_list:
571+
if "duration" not in node:
572+
node["duration"] = (node["finish"] - node["start"]).total_seconds()
573+
530574
# Create the header of the report with useful information
531575
start_node = nodes_list[0]
532576
last_node = nodes_list[-1]

0 commit comments

Comments
 (0)