Skip to content

Commit 76037f1

Browse files
committed
add heatmap
1 parent ad65264 commit 76037f1

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

examples/analysis.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from progpy.analysis import show_heatmap
2+
from progpy.datasets import nasa_cmapss
3+
4+
(training, testing, eol) = nasa_cmapss.load_data(1)
5+
6+
show_heatmap(training)
7+
8+
# Notice that some values have no color- this is because they are constant. Let's drop these
9+
for feature in ['setting3', 'sensor1', 'sensor5', 'sensor10', 'sensor16', 'sensor18', 'sensor19']:
10+
training.drop(feature, axis=1)
11+
show_heatmap(training)
12+
13+
# Here you can see high correlations between sensor 14 and 9

src/progpy/analysis/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright © 2021 United States Government as represented by the Administrator of the
2+
# National Aeronautics and Space Administration. All Rights Reserved.
3+
4+
from progpy.analysis.heatmap import show_heatmap

src/progpy/analysis/heatmap.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright © 2021 United States Government as represented by the Administrator of the
2+
# National Aeronautics and Space Administration. All Rights Reserved.
3+
4+
import matplotlib.pyplot as plt
5+
6+
def show_heatmap(data):
7+
"""
8+
Generate a heatmap showing correlation between parameters.
9+
10+
Code from: https://github.com/keras-team/keras-io/blob/13d513d7375656a14698ba4827ebbb4177efcf43/examples/timeseries/timeseries_weather_forecasting.py#L152
11+
12+
Args:
13+
data (np.ndarray): Array of data where each column is a variable.
14+
"""
15+
plt.matshow(data.corr())
16+
plt.xticks(range(data.shape[1]), data.columns, fontsize=14, rotation=90)
17+
plt.gca().xaxis.tick_bottom()
18+
plt.yticks(range(data.shape[1]), data.columns, fontsize=14)
19+
20+
cb = plt.colorbar()
21+
cb.ax.tick_params(labelsize=14)
22+
plt.title("Feature Correlation Heatmap", fontsize=14)
23+
plt.show()

0 commit comments

Comments
 (0)