-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathVisualization.py
More file actions
138 lines (113 loc) · 5.48 KB
/
Visualization.py
File metadata and controls
138 lines (113 loc) · 5.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Based on reading the book "Machine Learning for the Quantified Self: On the Art of Learning from Sensory Data"
import matplotlib.colors as cl
import matplotlib.pyplot as plt
import matplotlib.dates as md
import numpy as np
import pandas as pd
from pathlib import Path
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.patches as mpatches
import matplotlib.cm as cm
from scipy.cluster.hierarchy import dendrogram
import itertools
from scipy.optimize import curve_fit
import re
import math
import sys
from pathlib import Path
import dateutil
class Visualization:
point_displays = ['+', 'x'] # '*', 'd', 'o', 's', '<', '>']
line_displays = ['-'] # , '--', ':', '-.']
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
# Set some initial attributes to define and create a save location for the images.
def __init__(self, module_path='.py'):
subdir = Path(module_path).name.split('.')[0]
self.plot_number = 1
self.figures_dir = Path('figures') / subdir
self.figures_dir.mkdir(exist_ok=True, parents=True)
def save(self, plot_obj, file_name, formats=('png',)): # 'svg'
# fig_name = f'figure_{self.plot_number}'
fig_name = file_name
for format in formats:
save_path = self.figures_dir / \
f'{self.plot_number}__{fig_name}.{format}'
plot_obj.savefig(save_path)
print(f'Figure saved to {save_path}')
self.plot_number += 1
def plot_dataset(self, data_table, columns, match, display, file_name):
names = list(data_table.columns)
# Create subplots if more columns are specified.
if len(columns) > 1:
f, xar = plt.subplots(len(columns), sharex=True,
sharey=False, figsize=(40, 40))
else:
f, xar = plt.subplots()
xar = [xar]
f.subplots_adjust(hspace=0.4)
xfmt = md.DateFormatter('%H:%M')
# Pass through the columns specified.
for i in range(0, len(columns)):
xar[i].xaxis.set_major_formatter(xfmt)
xar[i].set_prop_cycle(color=['b', 'g', 'r', 'c', 'm', 'y', 'k'])
# if a column match is specified as 'exact', select the column name(s) with an exact match.
# If it's specified as 'like', select columns containing the name.
# We can match exact (i.e. a columns name is an exact name of a columns or 'like' for
# which we need to find columns names in the dataset that contain the name.
if match[i] == 'exact':
relevant_cols = [columns[i]]
elif match[i] == 'like':
relevant_cols = [
name for name in names if columns[i] == name[0:len(columns[i])]]
else:
raise ValueError(
"Match should be 'exact' or 'like' for " + str(i) + ".")
max_values = []
min_values = []
# Pass through the relevant columns.
for j in range(0, len(relevant_cols)):
# Create a mask to ignore the NaN and Inf values when plotting:
mask = data_table[relevant_cols[j]].replace(
[np.inf, -np.inf], np.nan).notnull()
max_values.append(data_table[relevant_cols[j]][mask].max())
min_values.append(data_table[relevant_cols[j]][mask].min())
# Display point, or as a line
if display[i] == 'points':
xar[i].plot(data_table.index[mask], data_table[relevant_cols[j]][mask],
self.point_displays[j % len(self.point_displays)])
else:
xar[i].plot(data_table.index[mask], data_table[relevant_cols[j]][mask],
self.line_displays[j % len(self.line_displays)])
xar[i].tick_params(axis='y', labelsize=10)
xar[i].legend(relevant_cols, fontsize='xx-small', numpoints=1, loc='upper center',
bbox_to_anchor=(0.5, 1.3), ncol=len(relevant_cols), fancybox=True, shadow=True, prop={'size': 25})
xar[i].set_ylim([min(min_values) - 0.1*(max(max_values) - min(min_values)),
max(max_values) + 0.1*(max(max_values) - min(min_values))])
# Make sure we get a nice figure with only a single x-axis and labels there.
plt.setp([a.get_xticklabels() for a in f.axes[:-1]], visible=False)
plt.xlabel('Time')
self.save(plt, file_name)
plt.show()
def plot_xy(self, x, y, method='plot', xlabel=None, ylabel=None, xlim=None, ylim=None, names=None, line_styles=None, loc=None, title=None, file_name='test'):
for input in x, y:
if not hasattr(input[0], '__iter__'):
raise TypeError(
'x/y should be given as a list of lists of coordinates')
plot_method = getattr(plt, method)
for i, (x_line, y_line) in enumerate(zip(x, y)):
plot_method(x_line, y_line, line_styles[i]) if line_styles is not None else plt.plot(
x_line, y_line)
if xlabel is not None:
plt.xlabel(xlabel)
if ylabel is not None:
plt.ylabel(ylabel)
if xlim is not None:
plt.xlim(xlim)
if ylim is not None:
plt.ylim(ylim)
if title is not None:
plt.title(title)
if names is not None:
plt.legend(names)
self.save(plt, file_name)
plt.show()