Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 98 additions & 15 deletions tests/PinnedH2O_3DOF/compare_energy_and_forces.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

def process_data(s_l1, s_l2, s_theta, N_l, N_theta, rdim):
try:
data_dir = f'data/{s_l1}_{s_l2}_{s_theta}'
data_dir = f'/usr/workspace/nlrom/MGmol/PinnedH2O_3DOF/data_8/{s_l1}_{s_l2}_{s_theta}'
offline_file = os.path.join(data_dir, 'offline_PinnedH2O.out')
log_file = f'{output_dir}/{s_l1}_{s_l2}_{s_theta}.log'

Expand Down Expand Up @@ -185,42 +185,125 @@ def calculate_differences(fom, rom, name):
else:
print("Error occurred during data processing. Differences cannot be calculated.", file=outfile)

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
# (Assuming the rest of your imports and script setup remain the same)

def plot_histogram(data, quantity, test_case):
plt.figure(figsize=(8, 6))
plt.hist(data, bins=20, color='skyblue', edgecolor='black')
# Increased figure size for better spacing (Retained)
plt.figure(figsize=(10, 7.5))

# Calculate the histogram data first
counts, bin_edges, _ = plt.hist(data, bins=20, color='skyblue', edgecolor='black')

if quantity == "Eks":
quantity_name = "absolute difference in total energy"
elif quantity.startswith("f_H1"):
quantity_name = r"$\| \mathbf{F}_1 - \widetilde{\mathbf{F}}_1 \|_2$"
elif quantity.startswith("f_H2"):
quantity_name = r"$\| \mathbf{F}_2 - \widetilde{\mathbf{F}}_2 \|_2$"
elif quantity.startswith("f_"):
quantity_name = f"magnitude of difference in force on {quantity[2:]}"
else:
raise ValueError("Invalid input quantity")

plt.title(f'Histogram of {quantity_name}')
plt.xlabel('Difference')
plt.ylabel('Frequency')

plt.title(f'Histogram of {quantity_name}', fontsize=22)
plt.xlabel('Difference', fontsize=22)
plt.ylabel('Frequency', fontsize=22)

# 1. Y-axis: Max 5 labels and starting at 0
max_frequency = np.max(counts)

# Calculate the ideal step size to yield at most 5 labels
# The ceiling function ensures we cover the max value.
# The number of steps will be up to 5 (including 0).
max_y = np.ceil(max_frequency / 5.0) * 5.0 # Round up to the nearest multiple of 5
if max_y == 0:
max_y = 1 # Handle case where all counts are 0

# Calculate the step size to get at most 5 ticks (excluding 0)
# The target number of intervals is 4-5.
num_intervals = 4
y_step = np.ceil(max_frequency / num_intervals)

# We round the step size to a "nice" number (e.g., 1, 2, 5, 10, 20, 50, 100...)
# This is a common plotting requirement. We'll use a simple heuristic for large numbers.
if y_step <= 5:
y_step = max(1, y_step)
elif y_step <= 10:
y_step = 10
elif y_step <= 25:
y_step = 25
elif y_step <= 50:
y_step = 50
elif y_step <= 100:
y_step = 100
else:
# For very high numbers, round to nearest 50 or 100
y_step = np.ceil(y_step / 100.0) * 100.0

y_step = int(y_step) # Ensure it's an integer step

# Generate y-axis ticks
# Use arange to get ticks starting at 0 up to max_frequency plus the step
y_ticks = np.arange(0, np.ceil(max_frequency) + y_step, y_step)

# Adjust to ensure the highest tick is not far above the max bar height
y_ticks = y_ticks[y_ticks <= np.ceil(max_frequency) + y_step * 0.5]
if y_ticks[-1] < np.ceil(max_frequency):
y_ticks = np.append(y_ticks, y_ticks[-1] + y_step)

# Ensure there is a 0 tick if it was somehow missed
if y_ticks[0] != 0:
y_ticks = np.insert(y_ticks, 0, 0)

# Use unique ticks and convert to int for cleaner labels
y_ticks = np.unique(y_ticks.astype(int))

plt.yticks(y_ticks, fontsize=22)
# Set y-limits from 0 up to the final highest tick
plt.ylim(0, y_ticks[-1] + 0.5)

# 2. X-axis tick and scientific notation font control
min_val, max_val = np.min(data), np.max(data)
plt.xlim(min_val, max_val)
num_ticks = 8
plt.xlim(min_val - 0.05 * (max_val - min_val), max_val + 0.05 * (max_val - min_val))

# Use 6 ticks for less dense x-axis
num_ticks = 6
xticks = np.linspace(min_val, max_val, num_ticks)
plt.xticks(xticks)
plt.xticks(xticks, fontsize=22)

formatter = ticker.ScalarFormatter(useMathText=True)
formatter.set_scientific(True)
formatter.set_powerlimits((0, 0))
formatter.set_powerlimits((0, 0))
plt.gca().xaxis.set_major_formatter(formatter)

# Adjusting the scientific notation font size (the 'x10^-4' part)
plt.gca().ticklabel_format(axis='x', style='sci', scilimits=(0,0))
ax = plt.gca()

# Explicitly set the font size of the scientific notation exponent
try:
# Matplotlib's way to find and set the exponent text's font size
# This is a more robust way to increase the font size of the exponent
ax.xaxis.get_offset_text().set_fontsize(22)
except Exception as e:
print(f"Error setting scientific notation exponent font size: {e}")

# Stats box positioning (Retained)
total_count = len(data)
mean_val = np.mean(data)
max_val = np.max(data)
stats_text = (f"Total {test_case} cases: {total_count}\n"
f"Mean: {mean_val:.3e}")

# Placed in the top right corner
plt.text(0.95, 0.95, stats_text, transform=plt.gca().transAxes,
fontsize=12, verticalalignment='top', horizontalalignment='right',
fontsize=22, verticalalignment='top', horizontalalignment='right',
bbox=dict(facecolor='white', alpha=0.7, edgecolor='black'))

plt.tight_layout()
plt.savefig(f"{output_dir}/{quantity}_difference_histogram_{test_case}.png")
plt.tight_layout()
plt.savefig(f"{output_dir}/{quantity}_difference_histogram_{test_case}.png") # Uncomment in final script

plot_histogram(Eks_diff_reproductive, "Eks", "reproductive")
plot_histogram(f_O1_diff_reproductive, "f_O1", "reproductive")
Expand Down