Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
icedoom888 committed Nov 28, 2024
1 parent 0779032 commit 3df353f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 18 deletions.
23 changes: 16 additions & 7 deletions notebooks/dynamic_vis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"outputs": [],
"source": [
"import torch\n",
"from anemoi.graphs.plotting.interactive_html import plot_downscale, plot_upscale, plot_level"
"from anemoi.graphs.plotting.interactive.graph_3d import plot_downscale, plot_upscale, plot_level"
]
},
{
Expand Down Expand Up @@ -64,12 +64,21 @@
" hidden_to_data_edges = hetero_data[('hidden_1', 'to', 'data')].edge_index\n",
"\n",
"else:\n",
" data_to_hidden_edges = hetero_data[('data', 'to', 'hidden')].edge_index\n",
" hidden_nodes.append(hetero_data['hidden'].x)\n",
" hidden_edges.append(hetero_data[('hidden', 'to', 'hidden')].edge_index)\n",
" downscale_edges.append(hetero_data[('data', 'to', 'hidden')].edge_index)\n",
" upscale_edges.append(hetero_data[('hidden', 'to', 'data')].edge_index)\n",
" hidden_to_data_edges = hetero_data[('hidden', 'to', 'data')].edge_index\n",
" try:\n",
" data_to_hidden_edges = hetero_data[('data', 'to', 'hidden_1')].edge_index\n",
" hidden_nodes.append(hetero_data['hidden_1'].x)\n",
" hidden_edges.append(hetero_data[('hidden_1', 'to', 'hidden_1')].edge_index)\n",
" downscale_edges.append(hetero_data[('data', 'to', 'hidden_1')].edge_index)\n",
" upscale_edges.append(hetero_data[('hidden_1', 'to', 'data')].edge_index)\n",
" hidden_to_data_edges = hetero_data[('hidden_1', 'to', 'data')].edge_index\n",
" \n",
" except Exception:\n",
" data_to_hidden_edges = hetero_data[('data', 'to', 'hidden')].edge_index\n",
" hidden_nodes.append(hetero_data['hidden'].x)\n",
" hidden_edges.append(hetero_data[('hidden', 'to', 'hidden')].edge_index)\n",
" downscale_edges.append(hetero_data[('data', 'to', 'hidden')].edge_index)\n",
" upscale_edges.append(hetero_data[('hidden', 'to', 'data')].edge_index)\n",
" hidden_to_data_edges = hetero_data[('hidden', 'to', 'data')].edge_index\n",
"\n",
"print(f'Lat Lon grid has: {len(data_nodes)} points.')\n",
"for i in range(num_hidden):\n",
Expand Down
13 changes: 7 additions & 6 deletions src/anemoi/graphs/plotting/interactive/graph_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def plot_downscale(
)

# Node trace
node_trace_data, graph_data, coords_data = convert_and_plot_nodes(
node_trace_data, _, coords_data = convert_and_plot_nodes(
g_data, data_nodes, x_range, y_range, z_range, scale=1.0, color="darkgrey"
)
node_trace_hidden = [node_trace_data]
Expand Down Expand Up @@ -103,7 +103,7 @@ def plot_downscale(
coords_hidden[0],
1.0,
1.0 - scale_increment,
"yellowgreen",
colorscale[i],
x_range,
y_range,
z_range,
Expand Down Expand Up @@ -192,7 +192,7 @@ def plot_upscale(
)

# Node trace
node_trace_data, graph_data, coords_data = convert_and_plot_nodes(
node_trace_data, _, coords_data = convert_and_plot_nodes(
g_data, data_nodes, x_range, y_range, z_range, scale=1.0, color="darkgrey"
)
node_trace_hidden = [node_trace_data]
Expand Down Expand Up @@ -239,7 +239,7 @@ def plot_upscale(
coords_data,
1 - scale_increment,
1.0,
"yellowgreen",
colorscale[-1 - i],
x_range,
y_range,
z_range,
Expand Down Expand Up @@ -312,7 +312,7 @@ def plot_level(
)

# Node trace
node_trace_data, graph_data, coords_data = convert_and_plot_nodes(
node_trace_data, _, _ = convert_and_plot_nodes(
g_data, data_nodes, x_range, y_range, z_range, scale=1.0, color="darkgrey"
)
node_trace_hidden = [node_trace_data]
Expand Down Expand Up @@ -354,9 +354,10 @@ def plot_level(
edge_traces = sum(edge_traces, [])
# Combine traces and layout into a figure
fig = go.Figure(data=node_trace_hidden + edge_traces, layout=layout)

return fig


def plot_3d_graph(
graph: HeteroData, nodes_coord: Tuple[List[float], List[float]], title: str = None, show_edges: bool = True
):
Expand Down
5 changes: 0 additions & 5 deletions src/anemoi/graphs/plotting/interactive/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.utils.convert import to_networkx

from anemoi.graphs.plotting.prepare import compute_isolated_nodes
from anemoi.graphs.plotting.style import *
Expand Down Expand Up @@ -66,6 +64,3 @@ def plot_isolated_nodes(graph: HeteroData, out_file: Optional[Union[str, Path]]
fig.write_html(out_file)
else:
fig.show()


def plot_interactive_nodes(graph: HeteroData, nodes_name: str, out_file: Optional[str] = None) -> None:

0 comments on commit 3df353f

Please sign in to comment.