diff --git a/tests/PinnedH2O_3DOF/compare_energy_and_forces.py b/tests/PinnedH2O_3DOF/compare_energy_and_forces.py index b11dbe91..d9d5d1dc 100644 --- a/tests/PinnedH2O_3DOF/compare_energy_and_forces.py +++ b/tests/PinnedH2O_3DOF/compare_energy_and_forces.py @@ -28,7 +28,7 @@ def process_data(s_l1, s_l2, s_theta, N_l, N_theta, rdim): try: - data_dir = f'data/{s_l1}_{s_l2}_{s_theta}' + data_dir = f'/usr/workspace/nlrom/MGmol/PinnedH2O_3DOF/data_8/{s_l1}_{s_l2}_{s_theta}' offline_file = os.path.join(data_dir, 'offline_PinnedH2O.out') log_file = f'{output_dir}/{s_l1}_{s_l2}_{s_theta}.log' @@ -185,42 +185,125 @@ def calculate_differences(fom, rom, name): else: print("Error occurred during data processing. Differences cannot be calculated.", file=outfile) +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker +# (Assuming the rest of your imports and script setup remain the same) + def plot_histogram(data, quantity, test_case): - plt.figure(figsize=(8, 6)) - plt.hist(data, bins=20, color='skyblue', edgecolor='black') + # Increased figure size for better spacing (Retained) + plt.figure(figsize=(10, 7.5)) + + # Calculate the histogram data first + counts, bin_edges, _ = plt.hist(data, bins=20, color='skyblue', edgecolor='black') if quantity == "Eks": quantity_name = "absolute difference in total energy" + elif quantity.startswith("f_H1"): + quantity_name = r"$\| \mathbf{F}_1 - \widetilde{\mathbf{F}}_1 \|_2$" + elif quantity.startswith("f_H2"): + quantity_name = r"$\| \mathbf{F}_2 - \widetilde{\mathbf{F}}_2 \|_2$" elif quantity.startswith("f_"): quantity_name = f"magnitude of difference in force on {quantity[2:]}" else: raise ValueError("Invalid input quantity") - plt.title(f'Histogram of {quantity_name}') - plt.xlabel('Difference') - plt.ylabel('Frequency') - + plt.title(f'Histogram of {quantity_name}', fontsize=22) + plt.xlabel('Difference', fontsize=22) + plt.ylabel('Frequency', fontsize=22) + + # 1. Y-axis: Max 5 labels and starting at 0 + max_frequency = np.max(counts) + + # Calculate the ideal step size to yield at most 5 labels + # The ceiling function ensures we cover the max value. + # The number of steps will be up to 5 (including 0). + max_y = np.ceil(max_frequency / 5.0) * 5.0 # Round up to the nearest multiple of 5 + if max_y == 0: + max_y = 1 # Handle case where all counts are 0 + + # Calculate the step size to get at most 5 ticks (excluding 0) + # The target number of intervals is 4-5. + num_intervals = 4 + y_step = np.ceil(max_frequency / num_intervals) + + # We round the step size to a "nice" number (e.g., 1, 2, 5, 10, 20, 50, 100...) + # This is a common plotting requirement. We'll use a simple heuristic for large numbers. + if y_step <= 5: + y_step = max(1, y_step) + elif y_step <= 10: + y_step = 10 + elif y_step <= 25: + y_step = 25 + elif y_step <= 50: + y_step = 50 + elif y_step <= 100: + y_step = 100 + else: + # For very high numbers, round to nearest 50 or 100 + y_step = np.ceil(y_step / 100.0) * 100.0 + + y_step = int(y_step) # Ensure it's an integer step + + # Generate y-axis ticks + # Use arange to get ticks starting at 0 up to max_frequency plus the step + y_ticks = np.arange(0, np.ceil(max_frequency) + y_step, y_step) + + # Adjust to ensure the highest tick is not far above the max bar height + y_ticks = y_ticks[y_ticks <= np.ceil(max_frequency) + y_step * 0.5] + if y_ticks[-1] < np.ceil(max_frequency): + y_ticks = np.append(y_ticks, y_ticks[-1] + y_step) + + # Ensure there is a 0 tick if it was somehow missed + if y_ticks[0] != 0: + y_ticks = np.insert(y_ticks, 0, 0) + + # Use unique ticks and convert to int for cleaner labels + y_ticks = np.unique(y_ticks.astype(int)) + + plt.yticks(y_ticks, fontsize=22) + # Set y-limits from 0 up to the final highest tick + plt.ylim(0, y_ticks[-1] + 0.5) + + # 2. X-axis tick and scientific notation font control min_val, max_val = np.min(data), np.max(data) - plt.xlim(min_val, max_val) - num_ticks = 8 + plt.xlim(min_val - 0.05 * (max_val - min_val), max_val + 0.05 * (max_val - min_val)) + + # Use 6 ticks for less dense x-axis + num_ticks = 6 xticks = np.linspace(min_val, max_val, num_ticks) - plt.xticks(xticks) + plt.xticks(xticks, fontsize=22) + formatter = ticker.ScalarFormatter(useMathText=True) formatter.set_scientific(True) - formatter.set_powerlimits((0, 0)) + formatter.set_powerlimits((0, 0)) plt.gca().xaxis.set_major_formatter(formatter) + # Adjusting the scientific notation font size (the 'x10^-4' part) + plt.gca().ticklabel_format(axis='x', style='sci', scilimits=(0,0)) + ax = plt.gca() + + # Explicitly set the font size of the scientific notation exponent + try: + # Matplotlib's way to find and set the exponent text's font size + # This is a more robust way to increase the font size of the exponent + ax.xaxis.get_offset_text().set_fontsize(22) + except Exception as e: + print(f"Error setting scientific notation exponent font size: {e}") + + # Stats box positioning (Retained) total_count = len(data) mean_val = np.mean(data) - max_val = np.max(data) stats_text = (f"Total {test_case} cases: {total_count}\n" f"Mean: {mean_val:.3e}") + + # Placed in the top right corner plt.text(0.95, 0.95, stats_text, transform=plt.gca().transAxes, - fontsize=12, verticalalignment='top', horizontalalignment='right', + fontsize=22, verticalalignment='top', horizontalalignment='right', bbox=dict(facecolor='white', alpha=0.7, edgecolor='black')) - plt.tight_layout() - plt.savefig(f"{output_dir}/{quantity}_difference_histogram_{test_case}.png") + plt.tight_layout() + plt.savefig(f"{output_dir}/{quantity}_difference_histogram_{test_case}.png") # Uncomment in final script plot_histogram(Eks_diff_reproductive, "Eks", "reproductive") plot_histogram(f_O1_diff_reproductive, "f_O1", "reproductive")