diff --git a/Project.toml b/Project.toml index 8f7cadd..70a1ce9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,23 +1,34 @@ -authors = ["Gregory Wagner ", "Xin Kai Lee "] name = "ClimaOceanCalibration" uuid = "b5b0db0a-afc0-4eaa-813c-9b0f9c9ed209" +authors = ["Gregory Wagner ", "Xin Kai Lee "] version = "0.1.0" [deps] +ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" CFTime = "179af706-886a-5703-950a-314cd64e0468" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2" ClimaOcean = "0376089a-ecfe-4b0e-a64f-9c555d74d754" ClimaSeaIce = "6ba0ff68-24e6-4315-936c-2e99227c95a4" +ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d" +Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" +NaNStatistics = "b946abbf-3ea7-4610-9019-9858bfdeaf2d" Oceananigans = "9e8cae18-63c1-5223-a75c-80ca9d6e9a09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PkgDev = "149e707d-584d-56d3-88ec-740c18e106ff" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +SeawaterPolynomials = "d496a93d-167e-4197-9f49-d3af4ff8fe40" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" XESMF = "2e0b0046-e7a1-486f-88de-807ee8ffabe5" [compat] @@ -35,4 +46,4 @@ julia = "1.10" [extras] CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" \ No newline at end of file +MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" diff --git a/examples/GM_calibration/average_ECCO_data.jl b/examples/GM_calibration/average_ECCO_data.jl new file mode 100644 index 0000000..c701a7e --- /dev/null +++ b/examples/GM_calibration/average_ECCO_data.jl @@ -0,0 +1,62 @@ +using ClimaOcean +using Oceananigans +using Oceananigans.Units +using Oceananigans.Architectures: on_architecture +using SeawaterPolynomials.TEOS10 +using ClimaOcean.DataWrangling +using Printf +using Dates +using CUDA +using ClimaOceanCalibration.DataWrangling +using JLD2 +using XESMF + +arch = GPU() + +grid = jldopen(joinpath(pwd(), "examples", "GM_calibration", "grids_and_regridder.jld2"), "r") do file + return on_architecture(arch, file["target_grid"]) +end + +dataset = ECCO4Monthly() + +dir = joinpath(homedir(), "ECCO_data") +mkpath(dir) +start_dates = [DateTime(year) for year in 1992:2017] +sampling_length = 1 # year +buoyancy_model = SeawaterBuoyancy(equation_of_state=TEOS10EquationOfState()) + +for start_date in start_dates + @info "Processing data starting from $(start_date)..." + end_date = start_date + Year(sampling_length) - Month(1) + + T = Metadata(:temperature; dataset, dir, start_date, end_date) + S = Metadata(:salinity; dataset, dir, start_date, end_date) + + T_data = FieldTimeSeries(T, grid, time_indices_in_memory=12) + S_data = FieldTimeSeries(S, grid, time_indices_in_memory=12) + + T_averaging = TimeAverageOperator(T_data) + T_averaged_fts = AveragedFieldTimeSeries(T_averaging(T_data), T_averaging, nothing) + + S_averaging = TimeAverageOperator(S_data) + S_averaged_fts = AveragedFieldTimeSeries(S_averaging(S_data), S_averaging, nothing) + + b_averaging = TimeAverageBuoyancyOperator(T_data) + b_averaged_fts = AveragedFieldTimeSeries(b_averaging(T_data, S_data, buoyancy_model), b_averaging, nothing) + + prefix = "$(sampling_length)yearaverage_2degree" + date_str = replace(string(start_date), ":" => "-") + + dirname = prefix * date_str + + SAVE_PATH = joinpath(pwd(), "calibration_data", "ECCO4Monthly", dirname) + mkpath(SAVE_PATH) + + T_filepath = joinpath(SAVE_PATH, "T.jld2") + S_filepath = joinpath(SAVE_PATH, "S.jld2") + b_filepath = joinpath(SAVE_PATH, "b.jld2") + + save_averaged_fieldtimeseries(T_averaged_fts, T, filename=T_filepath, overwrite_existing=true) + save_averaged_fieldtimeseries(S_averaged_fts, S, filename=S_filepath, overwrite_existing=true) + save_averaged_fieldtimeseries(b_averaged_fts, nothing, filename=b_filepath, overwrite_existing=true) +end \ No newline at end of file diff --git a/examples/GM_calibration/average_EN4_data.jl b/examples/GM_calibration/average_EN4_data.jl new file mode 100644 index 0000000..aaf2557 --- /dev/null +++ b/examples/GM_calibration/average_EN4_data.jl @@ -0,0 +1,111 @@ +using ClimaOcean +using Oceananigans +using Oceananigans.Units +using Oceananigans.Architectures: on_architecture +using SeawaterPolynomials.TEOS10 +using ClimaOcean.DataWrangling +using Printf +using Dates +using CUDA +using ClimaOceanCalibration.DataWrangling +using JLD2 +using XESMF + +arch = GPU() + +grid = jldopen(joinpath(pwd(), "examples", "GM_calibration", "grids_and_regridder.jld2"), "r") do file + return on_architecture(arch, file["target_grid"]) +end + +dataset = EN4Monthly() + +dir = joinpath(homedir(), "EN4_data") +mkpath(dir) +start_dates = [DateTime(1902), DateTime(1912), DateTime(1922), DateTime(1942), + DateTime(1952), DateTime(1972), + DateTime(1992), DateTime(2002), DateTime(2012)] + +# seems that T fields for 1939 1971 1985 is problematic + +buoyancy_model = SeawaterBuoyancy(equation_of_state=TEOS10EquationOfState()) + +for start_date in start_dates + @info "Processing data starting from $(start_date)..." + end_date = start_date + Year(10) - Month(1) + + T = Metadata(:temperature; dataset, dir, start_date, end_date) + S = Metadata(:salinity; dataset, dir, start_date, end_date) + + T_data = FieldTimeSeries(T, grid, time_indices_in_memory=20) + S_data = FieldTimeSeries(S, grid, time_indices_in_memory=20) + + T_averaging = TimeAverageOperator(T_data) + T_averaged_fts = AveragedFieldTimeSeries(T_averaging(T_data), T_averaging, nothing) + + S_averaging = TimeAverageOperator(S_data) + S_averaged_fts = AveragedFieldTimeSeries(S_averaging(S_data), S_averaging, nothing) + + b_averaging = TimeAverageBuoyancyOperator(T_data) + b_averaged_fts = AveragedFieldTimeSeries(b_averaging(T_data, S_data, buoyancy_model), b_averaging, nothing) + + prefix = "10yearaverage_2degree" + date_str = replace(string(start_date), ":" => "-") + + dirname = prefix * date_str + + SAVE_PATH = joinpath(pwd(), "calibration_data", "EN4Monthly", dirname) + mkpath(SAVE_PATH) + + T_filepath = joinpath(SAVE_PATH, "T.jld2") + S_filepath = joinpath(SAVE_PATH, "S.jld2") + b_filepath = joinpath(SAVE_PATH, "b.jld2") + + save_averaged_fieldtimeseries(T_averaged_fts, T, filename=T_filepath, overwrite_existing=true) + save_averaged_fieldtimeseries(S_averaged_fts, S, filename=S_filepath, overwrite_existing=true) + save_averaged_fieldtimeseries(b_averaged_fts, nothing, filename=b_filepath, overwrite_existing=true) +end + +# seems that T fields for 1939 1971 1985 is problematic +start_dates = [DateTime(1902), DateTime(1907), DateTime(1912), DateTime(1917), DateTime(1922), + DateTime(1927), DateTime(1932), DateTime(1942), + DateTime(1947), DateTime(1952), DateTime(1957), DateTime(1962), + DateTime(1972), DateTime(1977), + DateTime(1987), DateTime(1992), DateTime(1997), DateTime(2002), + DateTime(2007), DateTime(2012)] + +for start_date in start_dates + @info "Processing data starting from $(start_date)..." + end_date = start_date + Year(5) - Month(1) + + T = Metadata(:temperature; dataset, dir, start_date, end_date) + S = Metadata(:salinity; dataset, dir, start_date, end_date) + + T_data = FieldTimeSeries(T, grid, time_indices_in_memory=20) + S_data = FieldTimeSeries(S, grid, time_indices_in_memory=20) + + T_averaging = TimeAverageOperator(T_data) + T_averaged_fts = AveragedFieldTimeSeries(T_averaging(T_data), T_averaging, nothing) + + S_averaging = TimeAverageOperator(S_data) + S_averaged_fts = AveragedFieldTimeSeries(S_averaging(S_data), S_averaging, nothing) + + b_averaging = TimeAverageBuoyancyOperator(T_data) + b_averaged_fts = AveragedFieldTimeSeries(b_averaging(T_data, S_data, buoyancy_model), b_averaging, nothing) + + prefix = "5yearaverage_2degree" + date_str = replace(string(start_date), ":" => "-") + + dirname = prefix * date_str + + SAVE_PATH = joinpath(pwd(), "calibration_data", "EN4Monthly", dirname) + mkpath(SAVE_PATH) + + T_filepath = joinpath(SAVE_PATH, "T.jld2") + S_filepath = joinpath(SAVE_PATH, "S.jld2") + b_filepath = joinpath(SAVE_PATH, "b.jld2") + + save_averaged_fieldtimeseries(T_averaged_fts, T, filename=T_filepath, overwrite_existing=true) + save_averaged_fieldtimeseries(S_averaged_fts, S, filename=S_filepath, overwrite_existing=true) + save_averaged_fieldtimeseries(b_averaged_fts, nothing, filename=b_filepath, overwrite_existing=true) +end + diff --git a/examples/GM_calibration/calibrate_gm_distributed.jl b/examples/GM_calibration/calibrate_gm_distributed.jl new file mode 100644 index 0000000..ea87362 --- /dev/null +++ b/examples/GM_calibration/calibrate_gm_distributed.jl @@ -0,0 +1,134 @@ +const ensemble_size = 5 +using Distributed +using ArgParse + +function parse_commandline() + s = ArgParseSettings() + + @add_arg_table! s begin + "--simulation_length" + help = "Length of calibration simulation in years" + arg_type = Int + default = 6 + "--sampling_length" + help = "Length of sampling period in years" + arg_type = Int + default = 1 + "--zonal_average" + help = "Whether to perform zonal averaging in loss function" + arg_type = Bool + default = false + "--observation_covariance" + help = "Type of covariance to use (observations vs predetermined)" + arg_type = Bool + default = true + "--pickup" + help = "Pickup files for simulation spinup" + arg_type = Bool + default = false + end + return parse_args(s) +end + +args = parse_commandline() + +# Add workers with pre-set environment variables +nprocs = ensemble_size +addprocs(nprocs) +@everywhere @info "Worker $(myid())" +@everywhere ENV["CUDA_VISIBLE_DEVICES"] = myid() - 1 + +# Now load CUDA on all workers +@everywhere using CUDA +# Verify each worker sees exactly one GPU +@everywhere println("Worker $(myid()) sees GPU: $(CUDA.NVML.index(CUDA.NVML.Device(CUDA.uuid(CUDA.device()))))") + +@everywhere begin + using ClimaCalibrate + using ClimaOcean + using ClimaOceanCalibration.DataWrangling + using Oceananigans + using EnsembleKalmanProcesses + using EnsembleKalmanProcesses.ParameterDistributions + using LinearAlgebra + using JLD2 + using Glob + using Statistics + import ClimaCalibrate: generate_sbatch_script + include(joinpath(pwd(), "examples", "GM_calibration", "data_processing.jl")) + include(joinpath(pwd(), "examples", "GM_calibration", "model_interface.jl")) + + args = $args + + const simulation_length = args["simulation_length"] + const sampling_length = args["sampling_length"] + const zonal_average = args["zonal_average"] + const observation_covariance = args["observation_covariance"] + const pickup = args["pickup"] ? Dict("ocean" => joinpath(pwd(), "pickups", "ocean_pickup.jld2"), "sea_ice" => joinpath(pwd(), "pickups", "seaice_pickup.jld2")) : nothing + + obl_closure = ClimaOcean.OceanSimulations.default_ocean_closure() + + if obl_closure isa RiBasedVerticalDiffusivity + obl_str = "RiBased" + else + obl_str = "CATKE" + end + + if observation_covariance + cov_str = "obscov" + else + cov_str = "diagcov" + end + + const output_dir = joinpath(pwd(), "calibration_runs", "gm_$(simulation_length)yr_$(sampling_length)yravg_ecco_$(obl_str)_$(cov_str)$(zonal_average ? "_zonalavg" : "")$(pickup !== nothing ? "_pickup" : "")") + ClimaCalibrate.forward_model(iteration, member) = gm_forward_model(iteration, member; simulation_length, sampling_length, obl_closure, pickup) + ClimaCalibrate.observation_map(iteration) = gm_construct_g_ensemble(iteration, zonal_average) +end + +n_iterations = 10 + +κ_skew_prior = constrained_gaussian("κ_skew", 5e2, 3e2, 0, Inf) +κ_symmetric_prior = constrained_gaussian("κ_symmetric", 5e2, 3e2, 0, Inf) + +priors = combine_distributions([κ_skew_prior, κ_symmetric_prior]) + +obs_paths = abspath.(glob("$(sampling_length)yearaverage_2degree*", joinpath("calibration_data", "ECCO4Monthly"))) + +calibration_target_obs_path = abspath(joinpath("calibration_data", "ECCO4Monthly", "$(sampling_length)yearaverage_2degree2007-01-01T00-00-00")) + +Y = hcat(process_observation.(obs_paths, no_tapering, zonal_average)...) + +const output_dim = size(Y, 1) + +n_trials = size(Y, 2) + +if observation_covariance + # the noise estimated from the samples (will have rank n_trials-1) + internal_cov = tsvd_cov_from_samples(Y) # SVD object + + # the "5%" model error (diagonal) + model_error_frac = 0.05 + data_mean = vec(mean(Y,dims=2)) + model_error_cov = Diagonal((model_error_frac*data_mean).^2) + + # regularize the model error diagonal (in case of zero entries) + model_error_cov += 1e-6*I + + # Combine... + covariance = SVDplusD(internal_cov, model_error_cov) +else + T_variance = 0.7^2 + S_variance = 0.1^2 + N_data = output_dim ÷ 2 + covariance = Diagonal(vcat(fill(T_variance, N_data), fill(S_variance, N_data))) +end + +Y_obs = Observation(Dict("samples" => process_observation(calibration_target_obs_path, taper_interior_ocean, zonal_average), + "covariances" => covariance, + "names" => basename(calibration_target_obs_path))) + +utki = EnsembleKalmanProcess(Y_obs, TransformUnscented(priors)) + +backend = ClimaCalibrate.WorkerBackend + +ClimaCalibrate.calibrate(ClimaCalibrate.WorkerBackend, utki, n_iterations, priors, output_dir) \ No newline at end of file diff --git a/examples/GM_calibration/calibrate_gm_gcloud.jl b/examples/GM_calibration/calibrate_gm_gcloud.jl new file mode 100644 index 0000000..4e739b7 --- /dev/null +++ b/examples/GM_calibration/calibrate_gm_gcloud.jl @@ -0,0 +1,63 @@ +using ClimaCalibrate +using ClimaOceanCalibration.DataWrangling +using Oceananigans +using EnsembleKalmanProcesses +using EnsembleKalmanProcesses.ParameterDistributions +using LinearAlgebra +using JLD2 +using Glob +using Statistics +import ClimaCalibrate: generate_sbatch_script +include(joinpath(pwd(), "examples", "GM_calibration", "data_processing.jl")) +include(joinpath(pwd(), "examples", "GM_calibration", "gcloud_configuration.jl")) +include(joinpath(pwd(), "examples", "GM_calibration", "model_interface.jl")) + +const output_dir = joinpath(pwd(), "calibration_runs", "gm_20year_ecco") +const zonal_average = false + +n_iterations = 3 +κ_skew_prior = constrained_gaussian("κ_skew", 5e2, 3e2, 0, Inf) +κ_symmetric_prior = constrained_gaussian("κ_symmetric", 5e2, 3e2, 0, Inf) + +priors = combine_distributions([κ_skew_prior, κ_symmetric_prior]) + +obs_paths = abspath.(glob("10yearaverage_2degree*", joinpath("calibration_data", "ECCO4Monthly"))) +calibration_target_obs_path = obs_paths[findfirst(x -> occursin("2002", x), obs_paths)] + +synthetic_obs_paths = abspath.(glob("*500.0_500.0*20year*", joinpath("calibration_data", "synthetic_observations"))) + +Y = hcat(process_observation.(obs_paths, no_tapering, zonal_average)..., process_member_data.(synthetic_obs_paths, no_tapering, zonal_average)...) + +n_trials = size(Y, 2) +# the noise estimated from the samples (will have rank n_trials-1) +internal_cov = tsvd_cov_from_samples(Y) # SVD object + +# the "5%" model error (diagonal) +model_error_frac = 0.05 +data_mean = vec(mean(Y,dims=2)) +model_error_cov = Diagonal((model_error_frac*data_mean).^2) + +# regularize the model error diagonal (in case of zero entries) +model_error_cov += 1e-6*I + +# Combine... +covariance = SVDplusD(internal_cov, model_error_cov) + +Y_obs = Observation(Dict("samples" => process_observation(calibration_target_obs_path, taper_interior_ocean, zonal_average), + "covariances" => covariance, + "names" => basename(calibration_target_obs_path))) + +utki = EnsembleKalmanProcess(Y_obs, TransformUnscented(priors)) + +backend = ClimaOceanSingleGPUGCPBackend + +hpc_kwargs = hpc_kwargs = Dict(:ntasks => 1, + :cpus_per_task => 4, + :gpus_per_task => 1, + :mem => "128G", + :time => 120, + :partition => "a3mega") + +model_interface = abspath("./examples/GM_calibration/model_interface.jl") + +ClimaCalibrate.calibrate(backend, utki, n_iterations, priors, output_dir; hpc_kwargs, verbose=true, model_interface) \ No newline at end of file diff --git a/examples/GM_calibration/calibrate_gm_perfectmodel_distributed.jl b/examples/GM_calibration/calibrate_gm_perfectmodel_distributed.jl new file mode 100644 index 0000000..cd5ca71 --- /dev/null +++ b/examples/GM_calibration/calibrate_gm_perfectmodel_distributed.jl @@ -0,0 +1,99 @@ +const ensemble_size = 5 +using Distributed +using ArgParse + +function parse_commandline() + s = ArgParseSettings() + + @add_arg_table! s begin + "--simulation_length" + help = "Length of calibration simulation in years" + arg_type = Int + default = 15 + "--zonal_average" + help = "Whether to perform zonal averaging in loss function" + arg_type = Bool + default = false + end + return parse_args(s) +end + +args = parse_commandline() + +# Add workers with pre-set environment variables +nprocs = ensemble_size +addprocs(nprocs) +@everywhere @info "Worker $(myid())" +@everywhere ENV["CUDA_VISIBLE_DEVICES"] = myid() - 1 + +# Now load CUDA on all workers +@everywhere using CUDA +# Verify each worker sees exactly one GPU +@everywhere println("Worker $(myid()) sees GPU: $(CUDA.NVML.index(CUDA.NVML.Device(CUDA.uuid(CUDA.device()))))") + +@everywhere begin + using ClimaCalibrate + using Distributed + using ClimaOceanCalibration.DataWrangling + using Oceananigans + using EnsembleKalmanProcesses + using EnsembleKalmanProcesses.ParameterDistributions + using LinearAlgebra + using JLD2 + using Glob + using Statistics + import ClimaCalibrate: generate_sbatch_script + include(joinpath(pwd(), "examples", "GM_calibration", "data_processing.jl")) + include(joinpath(pwd(), "examples", "GM_calibration", "model_interface.jl")) + + args = $args + + const simulation_length = args["simulation_length"] + const sampling_length = simulation_length - 10 + const zonal_average = args["zonal_average"] + + const output_dir = joinpath(pwd(), "calibration_runs", "gm_$(simulation_length)year_perfectmodel_distributed$(zonal_average ? "_zonalavg" : "")") + ClimaCalibrate.forward_model(iteration, member) = gm_forward_model(iteration, member; simulation_length, sampling_length) + ClimaCalibrate.observation_map(iteration) = gm_construct_g_ensemble(iteration, zonal_average) +end + +n_iterations = 5 + +κ_skew_prior = constrained_gaussian("κ_skew", 5e2, 3e2, 0, Inf) +κ_symmetric_prior = constrained_gaussian("κ_symmetric", 5e2, 3e2, 0, Inf) + +priors = combine_distributions([κ_skew_prior, κ_symmetric_prior]) + +obs_paths = abspath.(glob("*750.0_750.0*$(sampling_length)yearsample*", joinpath("calibration_data", "synthetic_observations"))) + +calibration_target_obs_path = abspath(joinpath("calibration_data", "synthetic_observations", "halfdegree_RiBased_750.0_750.0_1992_15year_5yearsample_advectiveGM_multiyearjra55_calibrationsamples")) + +Y = hcat(process_member_data.(obs_paths, no_tapering, zonal_average)...) + +const output_dim = size(Y, 1) + +n_trials = size(Y, 2) + +# the noise estimated from the samples (will have rank n_trials-1) +internal_cov = tsvd_cov_from_samples(Y) # SVD object + +# the "5%" model error (diagonal) +model_error_frac = 0.05 +data_mean = vec(mean(Y,dims=2)) +model_error_cov = Diagonal((model_error_frac*data_mean).^2) + +# regularize the model error diagonal (in case of zero entries) +model_error_cov += 1e-6*I + +# Combine... +covariance = SVDplusD(internal_cov, model_error_cov) + +Y_obs = Observation(Dict("samples" => process_member_data(calibration_target_obs_path, taper_interior_ocean, zonal_average), + "covariances" => covariance, + "names" => basename(calibration_target_obs_path))) + +utki = EnsembleKalmanProcess(Y_obs, TransformUnscented(priors)) + +backend = ClimaCalibrate.WorkerBackend + +ClimaCalibrate.calibrate(ClimaCalibrate.WorkerBackend, utki, n_iterations, priors, output_dir) \ No newline at end of file diff --git a/examples/GM_calibration/data_plotting.jl b/examples/GM_calibration/data_plotting.jl new file mode 100644 index 0000000..e3a0ef0 --- /dev/null +++ b/examples/GM_calibration/data_plotting.jl @@ -0,0 +1,91 @@ +using CairoMakie +using Oceananigans +using Oceananigans.Grids: znodes, φnodes +using NaNStatistics +using ColorSchemes + +function plot_parameter_distribution(κs, error) + fig = Figure() + ax = Axis(fig[1, 1]; xlabel="κ skew (m²/s)", ylabel="κ symmetric (m²/s)", title="Parameter Distribution, Covariance-weighted loss at mean = $error") + scatter!(ax, κs[1, :], κs[2, :]) + return fig +end + +function exp_levels(min_val, max_val, n; factor=2) + uniform = range(0, 1, length=n) + transformed = (exp.(factor .* uniform) .- 1) ./ (exp(factor) - 1) + return min_val .+ transformed .* (max_val - min_val) +end + +function inverted_exp_levels(min_val, max_val, n; factor=2) + uniform = range(0, 1, length=n) + # This inverts the concentration from low to high values + transformed = 1 .- (exp.(factor .* (1 .- uniform)) .- 1) ./ (exp(factor) - 1) + return min_val .+ transformed .* (max_val - min_val) +end + +function plot_zonal_average(truth_data, model_data, field_name, κ_skew, κ_symmetric) + fig = Figure(size=(1920, 1080)) + axtruth = Axis(fig[1, 1]; xlabel="Latitude (°)", ylabel="Depth (m)", title="Target: ECCO4") + axmodel = Axis(fig[2, 1]; xlabel="Latitude (°)", ylabel="Depth (m)", title="Model Output") + axdiff = Axis(fig[3, 1]; xlabel="Latitude (°)", ylabel="Depth (m)", title="Anomaly (Model - Target)") + + LX, LY, LZ = location(truth_data) + zCs = znodes(truth_data.grid, LX(), LY(), LZ()) + + truth_φCs = φnodes(truth_data.grid, LX(), LY(), LZ()) + model_φCs = φnodes(model_data.grid, LX(), LY(), LZ()) + + truth_field = nanmean(interior(truth_data), dims=1)[1, :, :] + model_field = nanmean(interior(model_data), dims=1)[1, :, :] + diff_field = model_field .- truth_field + + fieldlim = (nanminimum([truth_field; model_field]), nanmaximum([truth_field; model_field])) + difflim = (-nanmaximum(abs.(diff_field)) / 2, nanmaximum(abs.(diff_field)) / 2) + + if field_name == "T" + @info "Plotting temperature zonal average..." + field_levels = exp_levels(fieldlim[1], fieldlim[2], 15) + field_colorbar_kwargs = (label="Temperature (°C)",) + diff_colorbar_kwargs = (label="Temperature Anomaly (°C)",) + difflim = (-0.7, 0.7) + elseif field_name == "S" + @info "Plotting salinity zonal average..." + field_colorbar_kwargs = (label="Salinity (psu)",) + diff_colorbar_kwargs = (label="Salinity Anomaly (psu)",) + # fieldlim = (nanminimum([truth_field; model_field[2:end, :]]), nanmaximum([truth_field; model_field[2:end, :]])) + # difflim = (-nanmaximum(abs.(diff_field[2:end, :])) / 2, nanmaximum(abs.(diff_field[2:end, :])) / 2) + fieldlim = (34, 36) + difflim = (-0.1, 0.1) + # field_levels = inverted_exp_levels(fieldlim[1], fieldlim[2], 15) + # field_levels = range(fieldlim[1], fieldlim[2], length=15) + field_levels = exp_levels(fieldlim[1], fieldlim[2], 15) + elseif field_name == "b" + @info "Plotting buoyancy zonal average..." + field_levels = exp_levels(fieldlim[1], fieldlim[2], 15) + difflim = (-0.001, 0.001) + field_colorbar_kwargs = (label="Buoyancy (m/s²)",) + diff_colorbar_kwargs = (label="Buoyancy Anomaly (m/s²)",) + else + error("Unsupported field name: $field_name") + end + + diff_levels = range(difflim[1], difflim[2], length=15) + + cf_f = contourf!(axtruth, truth_φCs, zCs, truth_field, colormap=:turbo, levels = field_levels, extendhigh=:auto, extendlow=:auto) + contourf!(axmodel, model_φCs, zCs, model_field, colormap=:turbo, levels = field_levels, extendhigh=:auto, extendlow=:auto) + cf_d = contourf!(axdiff, model_φCs, zCs, diff_field, colormap=:balance, levels = diff_levels, extendhigh=:auto, extendlow=:auto) + + Colorbar(fig[1:2, 2], cf_f; field_colorbar_kwargs...) + Colorbar(fig[3, 2], cf_d; diff_colorbar_kwargs...) + + xlims!(axtruth, -84, 84) + xlims!(axmodel, -84, 84) + xlims!(axdiff, -84, 84) + ylims!(axtruth, -6000, 0) + ylims!(axmodel, -6000, 0) + ylims!(axdiff, -6000, 0) + + Label(fig[0, 1:2], "Zonal Average of $field_name (κ_skew=$(round(κ_skew, digits=1)), κ_symmetric=$(round(κ_symmetric)))", fontsize=25, font=:bold) + return fig +end \ No newline at end of file diff --git a/examples/GM_calibration/data_processing.jl b/examples/GM_calibration/data_processing.jl new file mode 100644 index 0000000..b9e7b42 --- /dev/null +++ b/examples/GM_calibration/data_processing.jl @@ -0,0 +1,112 @@ +using Oceananigans +using Oceananigans.Grids: znodes, φnodes +using Oceananigans.Fields: location +using Oceananigans.ImmersedBoundaries: mask_immersed_field! +using Oceananigans.Architectures: on_architecture +using XESMF +using JLD2 +using NaNStatistics +using Glob + +function regrid_model_data(simdir) + @info "Regridding model data in $(simdir)..." + filepath = first(glob("*calibrationsample*", simdir)) + T_data = FieldTimeSeries(filepath, "T", backend=OnDisk()) + S_data = FieldTimeSeries(filepath, "S", backend=OnDisk()) + + source_grid = T_data.grid + LX, LY, LZ = location(T_data) + boundary_conditions = T_data.boundary_conditions + times = T_data.times + + Nx, Ny, Nz = (180, 84, 100) + z_faces = ExponentialDiscretization(Nz, -6000, 0; scale=1800) + + target_grid, regridder = jldopen(joinpath(pwd(), "examples", "GM_calibration", "grids_and_regridder.jld2"), "r") do file + return file["target_grid"], file["regridder"] + end + + T_target = FieldTimeSeries{LX, LY, LZ}(target_grid, times; boundary_conditions) + S_target = FieldTimeSeries{LX, LY, LZ}(target_grid, times; boundary_conditions) + + for t in 1:length(times) + regrid!(T_target[t], regridder, T_data[t]) + regrid!(S_target[t], regridder, S_data[t]) + end + return T_target, S_target +end + +taper_interior_ocean(z, z_scale=3500, width=1000) = 0.5 * (1 + tanh((z + z_scale) / width)) +no_tapering(z) = 1 + +function extract_field_section(fts::FieldTimeSeries, latitude_range; vertical_weighting=no_tapering) + fts = on_architecture(CPU(), fts) + LX, LY, LZ = location(fts) + grid = fts.grid + + φᶜ = φnodes(grid, LX(), LY(), LZ()) + zᶜ = znodes(grid, LX(), LY(), LZ()) + + φmin, φmax = latitude_range + + lat_indices = findfirst(x -> x >= φmin, φᶜ):findlast(x -> x <= φmax, φᶜ) + z_weights = vertical_weighting.(zᶜ) + + times = fts.times + + Nt = length(times) + for t in 1:length(times) + mask_immersed_field!(fts[t], NaN) + end + + field_section = reshape(z_weights, 1, 1, :) .* interior(fts[Nt], :, lat_indices, :) + + return field_section +end + +extract_southern_ocean_section(fts, vertical_weighting=no_tapering) = extract_field_section(fts, (-80, -50); vertical_weighting) +extract_ocean_section(fts, vertical_weighting=no_tapering) = extract_field_section(fts, (-80, 0); vertical_weighting) + +function process_observation(obs_path, vertical_weighting, zonal_average) + T_filepath = joinpath(obs_path, "T.jld2") + S_filepath = joinpath(obs_path, "S.jld2") + + T_afts = jldopen(T_filepath, "r") do file + return file["averaged_fieldtimeseries"] + end + + S_afts = jldopen(S_filepath, "r") do file + return file["averaged_fieldtimeseries"] + end + + T_data = T_afts.data + S_data = S_afts.data + + # T_section = extract_southern_ocean_section(T_data, vertical_weighting) + # S_section = extract_southern_ocean_section(S_data, vertical_weighting) + T_section = extract_ocean_section(T_data, vertical_weighting) + S_section = extract_ocean_section(S_data, vertical_weighting) + + if zonal_average + T_section = nanmean(T_section, dims=1) + S_section = nanmean(S_section, dims=1) + end + + return vcat(T_section[.!isnan.(T_section)], S_section[.!isnan.(S_section)]) +end + +function process_member_data(simdir, vertical_weighting, zonal_average) + T_target, S_target = regrid_model_data(simdir) + + # T_section = extract_southern_ocean_section(T_target, vertical_weighting) + # S_section = extract_southern_ocean_section(S_target, vertical_weighting) + T_section = extract_ocean_section(T_target, vertical_weighting) + S_section = extract_ocean_section(S_target, vertical_weighting) + + if zonal_average + T_section = nanmean(T_section, dims=1) + S_section = nanmean(S_section, dims=1) + end + + return vcat(T_section[.!isnan.(T_section)], S_section[.!isnan.(S_section)]) +end \ No newline at end of file diff --git a/examples/GM_calibration/gcloud_configuration.jl b/examples/GM_calibration/gcloud_configuration.jl new file mode 100644 index 0000000..4722202 --- /dev/null +++ b/examples/GM_calibration/gcloud_configuration.jl @@ -0,0 +1,101 @@ +using ClimaCalibrate +using ClimaCalibrate: generate_sbatch_directives +import ClimaCalibrate: generate_sbatch_script + +struct ClimaOceanSingleGPUGCPBackend <: ClimaCalibrate.SlurmBackend end + +function ClimaCalibrate.module_load_string(::Type{ClimaOceanSingleGPUGCPBackend}) + return """ +unset CUDA_HOME CUDA_PATH CUDA_ROOT NVHPC_CUDA_HOME CUDA_INC_DIR CPATH NVHPC_ROOT OPAL_PREFIX +export LD_LIBRARY_PATH=\$(echo \$LD_LIBRARY_PATH | tr ':' '\n' | grep -v cuda | grep -v ucx | tr '\n' ':' | sed 's/:\$//') +export PATH=/usr/bin:/bin:/usr/sbin:/sbin:\$HOME/cmake-3.28.1-linux-x86_64/bin:\$HOME/julia-1.10.10/bin + +export JULIA_CUDA_MEMORY_POOL=binned +export JULIA_NUM_THREADS=1 +export CUDA_VISIBLE_DEVICES=0 + +cd \$HOME/CES_oceananigans/ClimaOceanCalibration.jl + +# ============================================ +# LOAD API KEYS +# ============================================ +if [ -f ~/API_keys.sh ]; then + source ~/API_keys.sh +else + echo "Warning: API_keys.sh file not found in home directory" +fi + +# ============================================ +# CONFIGURE FOR SINGLE-GPU (NO UCX) +# ============================================ +echo "=== Checking existing configuration ===" +if ~/julia-1.10.10/bin/julia --project=. -e ' +using MPI, CUDA +ok = false +rt = try CUDA.runtime_version() catch; nothing end +if rt == v"12.4" && occursin("openmpi", lowercase(MPI.MPI_LIBRARY)) + println("✓ Already configured: OpenMPI + CUDA 12.4") + exit(0) +else + exit(1) +end +'; then + echo "Configuration looks correct; skipping reconfiguration." +else + echo "=== Configuring for single-GPU ===" + ~/julia-1.10.10/bin/julia --project=. -e ' + using MPIPreferences + MPIPreferences.use_jll_binary("OpenMPI_jll") + + using CUDA + CUDA.set_runtime_version!(v"12.4", local_toolkit=false) + + println("✓ Configured: OpenMPI_jll + CUDA artifacts") + ' +fi + +echo "=== Verify Configuration ===" +~/julia-1.10.10/bin/julia --project -e ' +using MPI, CUDA, Libdl, Oceananigans, ClimaOcean, ClimaSeaIce + +println("MPI: ", MPI.MPI_LIBRARY) +println("CUDA runtime: ", CUDA.runtime_version()) + +ucx_libs = filter(lib -> occursin("ucx", lowercase(lib)), Libdl.dllist()) +if isempty(ucx_libs) + println("✓ No UCX - safe to run!") +else + println("⚠️ WARNING: UCX detected:") + foreach(println, ucx_libs) + exit(1) +end +'""" +end + +ClimaCalibrate.backend_worker_kwargs(::Type{ClimaOceanSingleGPUGCPBackend}) = (; partition = "a3mega") + +function generate_sbatch_script(iter::Int, member::Int, output_dir, experiment_dir, model_interface, module_load_str, hpc_kwargs, exeflags = "") + member_log = path_to_model_log(output_dir, iter, member) + slurm_directives = generate_sbatch_directives(hpc_kwargs) + + sbatch_contents = """ + #!/bin/bash + #SBATCH --job-name=run_$(iter)_$(member) + #SBATCH --output=$member_log + $slurm_directives + + $module_load_str + + julia $exeflags --project=$experiment_dir -e ' + + import ClimaCalibrate as CAL + iteration = $iter; member = $member + model_interface = "$model_interface"; include(model_interface) + experiment_dir = "$experiment_dir" + CAL.forward_model(iteration, member) + CAL.write_model_completed("$output_dir", iteration, member) + ' + exit 0 + """ + return sbatch_contents +end \ No newline at end of file diff --git a/examples/GM_calibration/half_degree_omip.jl b/examples/GM_calibration/half_degree_omip.jl new file mode 100644 index 0000000..8793c54 --- /dev/null +++ b/examples/GM_calibration/half_degree_omip.jl @@ -0,0 +1,281 @@ +using ClimaOcean +using ClimaSeaIce +using Oceananigans +using Oceananigans.Grids +using Oceananigans.Units +using Oceananigans.OrthogonalSphericalShellGrids +using Oceananigans.BuoyancyFormulations: buoyancy, buoyancy_frequency +using ClimaOcean.OceanSimulations +using ClimaOcean.ECCO +using ClimaOcean.JRA55 +using ClimaOcean.DataWrangling +using ClimaSeaIce.SeaIceThermodynamics: IceWaterThermalEquilibrium +using Printf +using Dates +using CUDA +using JLD2 +using ArgParse +using Oceananigans.TurbulenceClosures: ExplicitTimeDiscretization, AdvectiveFormulation, IsopycnalSkewSymmetricDiffusivity +using Oceananigans.TurbulenceClosures.TKEBasedVerticalDiffusivities: CATKEVerticalDiffusivity, CATKEMixingLength, CATKEEquation +using Oceananigans.Operators: Δx, Δy +using Statistics + +import Oceananigans.OutputWriters: checkpointer_address + +using Libdl +ucx_libs = filter(lib -> occursin("ucx", lowercase(lib)), Libdl.dllist()) +if isempty(ucx_libs) + @info "✓ No UCX - safe to run!" +else + @warn "✗ UCX libraries detected! This can cause issues with MPI+CUDA. Detected libs:\n$(join(ucx_libs, "\n"))" +end + +function run_gm_calibration_omip(κ_skew, κ_symmetric, config_dict, obl_closure) + output_dir = config_dict["output_dir"] + pickup = config_dict["pickup"] + + if !(pickup isa Nothing) + ocean_pickup_file = pickup["ocean"] + seaice_pickup_file = pickup["sea_ice"] + end + + logfile_path = joinpath(output_dir, "output.log") + + logfile = open(logfile_path, "w") + original_stdout = stdout + original_stderr = stderr + + redirect_stdout(logfile) + redirect_stderr(logfile) + + flusher = @async while isopen(logfile); flush(logfile); sleep(1); end + + try + start_year = 2002 + simulation_length = config_dict["simulation_length"] + sampling_length = config_dict["sampling_length"] + + @info "Using κ_skew = $(κ_skew) m²/s and κ_symmetric = $(κ_symmetric) m²/s, starting in $(start_year) for $(simulation_length) years with $(sampling_length)-year sampling window." + @info "Saving output to $(config_dict["output_dir"])" + + arch = GPU() + + Nx = 720 # longitudinal direction + Ny = 360 # meridional direction + Nz = 100 + + z_faces = ExponentialDiscretization(Nz, -6000, 0; scale=1800) + z_surf = z_faces(Nz) + + grid = TripolarGrid(arch; + size = (Nx, Ny, Nz), + z = z_faces, + halo = (7, 7, 7)) + + bottom_height = regrid_bathymetry(grid; minimum_depth=15, major_basins=1, interpolation_passes=55) + grid = ImmersedBoundaryGrid(grid, GridFittedBottom(bottom_height); active_cells_map=true) + + momentum_advection = WENOVectorInvariant(order=5) + tracer_advection = WENO(order=7) + free_surface = SplitExplicitFreeSurface(grid; cfl=0.8, fixed_Δt=50minutes) + + @inline Δ²ᵃᵃᵃ(i, j, k, grid, lx, ly, lz) = 2 * (1 / (1 / Δx(i, j, k, grid, lx, ly, lz)^2 + 1 / Δy(i, j, k, grid, lx, ly, lz)^2)) + @inline geometric_νhb(i, j, k, grid, lx, ly, lz, clock, fields, λ) = Δ²ᵃᵃᵃ(i, j, k, grid, lx, ly, lz)^2 / λ + + eddy_closure = IsopycnalSkewSymmetricDiffusivity(; κ_skew, κ_symmetric, skew_flux_formulation=AdvectiveFormulation()) + visc_closure = HorizontalScalarBiharmonicDiffusivity(ν=geometric_νhb, discrete_form=true, parameters=25days) + + closure = (obl_closure, VerticalScalarDiffusivity(κ=1e-5, ν=3e-4), visc_closure, eddy_closure) + + dir = joinpath(homedir(), "ECCO_data") + mkpath(dir) + + start_date = DateTime(start_year, 1, 1) + end_date = start_date + Year(simulation_length) + simulation_period = Dates.value(Second(end_date - start_date)) + sampling_start_date = end_date - Year(sampling_length) + sampling_window = Dates.value(Second(end_date - sampling_start_date)) + + @info "Settting up salinity restoring..." + @inline mask(x, y, z, t) = z ≥ z_surf - 1 + Smetadata = Metadata(:salinity; dataset=ECCO4Monthly(), dir, start_date, end_date) + FS = DatasetRestoring(Smetadata, grid; rate = 1/30days, mask, time_indices_in_memory = 10) + + ocean = ocean_simulation(grid; Δt=1minutes, + momentum_advection, + tracer_advection, + timestepper = :SplitRungeKutta3, + free_surface, + forcing = (; S = FS), + closure) + + @info "Built ocean model $(ocean)" + + if pickup isa Nothing + set!(ocean.model, T=Metadatum(:temperature; dataset=ECCO4Monthly(), date=start_date, dir), + S=Metadatum(:salinity; dataset=ECCO4Monthly(), date=start_date, dir)) + @info "Initialized T and S with ECCO data" + else + set!(ocean.model, pickup=ocean_pickup_file) + @info "Initialized T and S from pickup file $(ocean_pickup_file)" + end + + # Default sea-ice dynamics and salinity coupling are included in the defaults + # sea_ice = sea_ice_simulation(grid, ocean; advection=WENO(order=7)) + sea_ice = sea_ice_simulation(grid, ocean; dynamics=nothing) + @info "Built sea ice model $(sea_ice)" + + if pickup isa Nothing + set!(sea_ice.model, h=Metadatum(:sea_ice_thickness; dataset=ECCO4Monthly(), date=start_date, dir), + ℵ=Metadatum(:sea_ice_concentration; dataset=ECCO4Monthly(), date=start_date, dir)) + @info "Initialized sea ice fields with ECCO data" + else + set!(sea_ice.model, pickup=seaice_pickup_file) + @info "Initialized sea ice fields from pickup file $(seaice_pickup_file)" + end + + jra55_dir = joinpath(homedir(), "JRA55_data") + mkpath(jra55_dir) + dataset = MultiYearJRA55() + backend = JRA55NetCDFBackend(100) + + @info "Setting up presctibed atmosphere $(dataset)" + atmosphere = JRA55PrescribedAtmosphere(arch; dir=jra55_dir, dataset, backend, include_rivers_and_icebergs=true, start_date, end_date) + radiation = Radiation() + + @info "Built atmosphere model $(atmosphere)" + + omip = OceanSeaIceModel(ocean, sea_ice; atmosphere, radiation) + + @info "Built coupled model $(omip)" + + omip = Simulation(omip, Δt=30minutes, stop_time=simulation_period) + @info "Built simulation $(omip)" + + FILE_DIR = config_dict["output_dir"] + mkpath(FILE_DIR) + + b = Field(buoyancy(ocean.model)) + N² = Field(buoyancy_frequency(ocean.model)) + + ocean_outputs = merge(ocean.model.tracers, ocean.model.velocities, (; b, N²)) + sea_ice_outputs = merge((h = sea_ice.model.ice_thickness, + ℵ = sea_ice.model.ice_concentration, + T = sea_ice.model.ice_thermodynamics.top_surface_temperature), + sea_ice.model.velocities) + + ocean.output_writers[:surface] = JLD2Writer(ocean.model, ocean_outputs; + schedule = TimeInterval(180days), + filename = "$(FILE_DIR)/ocean_surface_fields", + indices = (:, :, grid.Nz), + overwrite_existing = true) + + sea_ice.output_writers[:surface] = JLD2Writer(ocean.model, sea_ice_outputs; + schedule = TimeInterval(180days), + filename = "$(FILE_DIR)/sea_ice_surface_fields", + overwrite_existing = true) + + ocean.output_writers[:sample_average] = JLD2Writer(ocean.model, ocean_outputs; + schedule = AveragedTimeInterval(simulation_period, window=sampling_window), + filename = "$(FILE_DIR)/ocean_complete_fields_$(sampling_length)year_average_calibrationsample", + overwrite_existing = true) + + wall_time = Ref(time_ns()) + + function progress(sim) + sea_ice = sim.model.sea_ice + ocean = sim.model.ocean + hmax = maximum(sea_ice.model.ice_thickness) + ℵmax = maximum(sea_ice.model.ice_concentration) + Tmax = maximum(sim.model.interfaces.atmosphere_sea_ice_interface.temperature) + Tmin = minimum(sim.model.interfaces.atmosphere_sea_ice_interface.temperature) + umax = maximum(ocean.model.velocities.u) + vmax = maximum(ocean.model.velocities.v) + wmax = maximum(ocean.model.velocities.w) + + step_time = 1e-9 * (time_ns() - wall_time[]) + + msg1 = @sprintf("time: %s, iteration: %d, Δt: %s, ", prettytime(sim), iteration(sim), prettytime(sim.Δt)) + msg2 = @sprintf("max(h): %.2e m, max(ℵ): %.2e ", hmax, ℵmax) + msg4 = @sprintf("extrema(T): (%.2f, %.2f) ᵒC, ", Tmax, Tmin) + msg5 = @sprintf("maximum(u): (%.2f, %.2f, %.2f) m/s, ", umax, vmax, wmax) + msg6 = @sprintf("wall time: %s \n", prettytime(step_time)) + + @info msg1 * msg2 * msg4 * msg5 * msg6 + + wall_time[] = time_ns() + + return nothing + end + + # And add it as a callback to the simulation. + add_callback!(omip, progress, IterationInterval(100)) + + run!(omip) + return nothing + catch e + # Handle errors + if e isa InterruptException + println(stderr, "Interrupted by user") + else + println(stderr, "Error occurred: $e") + # Optionally rethrow to propagate the error + rethrow(e) + end + finally + # Cleanup - ALWAYS runs + redirect_stdout(original_stdout) + redirect_stderr(original_stderr) + close(logfile) + println("Log file closed") # This prints to console, not log + end + +end + +function run_gm_calibration_omip_dry_run(κ_skew, κ_symmetric, config_dict) + output_dir = config_dict["output_dir"] + logfile_path = joinpath(output_dir, "output.log") + + logfile = open(logfile_path, "w") + original_stdout = stdout + original_stderr = stderr + + redirect_stdout(logfile) + redirect_stderr(logfile) + + flusher = @async while isopen(logfile); flush(logfile); sleep(1); end + + try + # ALL your main code goes here + println("Starting work...") + + start_year = rand(1992:2011) + member = config_dict["member"] + iteration = config_dict["iteration"] + @info "Member $member, iter $iteration dry run: Using κ_skew = $(κ_skew) m²/s and κ_symmetric = $(κ_symmetric) m²/s, starting in year $(start_year)" + @info "Saving output to $(config_dict["output_dir"])" + FILE_DIR = config_dict["output_dir"] + mkpath(FILE_DIR) + + cp(joinpath(homedir(), "ocean_complete_fields_10year_average_calibrationsample.jld2"), "$(FILE_DIR)/ocean_complete_fields_1year_average_calibrationsample.jld2") + + println("Finished successfully") + + return nothing + catch e + # Handle errors + if e isa InterruptException + println(stderr, "Interrupted by user") + else + println(stderr, "Error occurred: $e") + # Optionally rethrow to propagate the error + rethrow(e) + end + finally + # Cleanup - ALWAYS runs + redirect_stdout(original_stdout) + redirect_stderr(original_stderr) + close(logfile) + println("Log file closed") # This prints to console, not log + end +end \ No newline at end of file diff --git a/examples/GM_calibration/model_interface.jl b/examples/GM_calibration/model_interface.jl new file mode 100644 index 0000000..12382cf --- /dev/null +++ b/examples/GM_calibration/model_interface.jl @@ -0,0 +1,181 @@ +using ClimaCalibrate +using TOML +using ClimaOceanCalibration.DataWrangling +using EnsembleKalmanProcesses +using Oceananigans +using Oceananigans.Architectures: on_architecture +using Oceananigans.ImmersedBoundaries: mask_immersed_field! +using Oceananigans.Fields: location +using JLD2 +include("half_degree_omip.jl") +include("data_processing.jl") +include("data_plotting.jl") + +function gm_forward_model(iteration, member; simulation_length, sampling_length, obl_closure, pickup) + config_dict = Dict() + + # Set the output path for the current member + member_path = ClimaCalibrate.path_to_ensemble_member(output_dir, iteration, member) + config_dict["output_dir"] = member_path + + # Set the parameters for the current member + parameter_path = ClimaCalibrate.parameter_path(output_dir, iteration, member) + if haskey(config_dict, "toml") + push!(config_dict["toml"], parameter_path) + else + config_dict["toml"] = [parameter_path] + end + + config_dict["iteration"] = iteration + config_dict["member"] = member + config_dict["simulation_length"] = simulation_length + config_dict["sampling_length"] = sampling_length + config_dict["pickup"] = pickup + + params = TOML.parsefile(parameter_path) + κ_skew = params["κ_skew"] + κ_symmetric = params["κ_symmetric"] + + try + # run_gm_calibration_omip_dry_run(κ_skew["value"], κ_symmetric["value"], config_dict) + run_gm_calibration_omip(κ_skew["value"], κ_symmetric["value"], config_dict, obl_closure) + catch e + # Create a failure indicator file with error information + error_file = joinpath(member_path, "RUN_FAILED.err") + open(error_file, "w") do io + println(io, "Run failed at $(now())") + println(io, "Parameters: κ_skew = $(κ_skew), κ_symmetric = $(κ_symmetric)") + println(io, "Error: $(e)") + println(io, "Backtrace:") + for (exc, bt) in Base.catch_stack() + showerror(io, exc, bt) + println(io) + end + end + + @error "GM calibration failed (κ_skew = $(κ_skew), κ_symmetric = $(κ_symmetric))" exception=(e, catch_backtrace()) + end + + return nothing +end + +function gm_construct_g_ensemble(iteration, zonal_average) + G_ensemble = zeros(output_dim, ensemble_size) + + for m in 1:ensemble_size + member_path = ClimaCalibrate.path_to_ensemble_member(output_dir, iteration, m) + + if isfile(joinpath(member_path, "RUN_FAILED.err")) + @warn "Skipping member $m for iteration $iteration due to failed run." + G_ensemble[:, m] .= NaN + else + G_ensemble[:, m] .= process_member_data(member_path, taper_interior_ocean, zonal_average) + end + end + + return G_ensemble +end + +function ClimaCalibrate.analyze_iteration(ekp, g_ensemble, prior, output_dir, iteration) + @info "Mean constrained parameter(s): $(get_ϕ_mean_final(prior, ekp))" + @info "Covariance-weighted error: $(last(get_error(ekp)))" + + ϕs = get_ϕ(prior, ekp) + model_error = get_error(ekp) + + jldopen(joinpath(output_dir, "ekp_diagnostics_iteration$(iteration).jld2"), "w") do file + file["ϕs"] = ϕs + file["g_ensemble"] = g_ensemble + file["prior"] = prior + file["ekp"] = ekp + end + + ϕ = ϕs[iteration + 1] + + plots_filepath = abspath(joinpath(output_dir, "diagnostics_output")) + mkpath(plots_filepath) + + try + fig = plot_parameter_distribution(ϕ, last(model_error)) + save(joinpath(plots_filepath, "iteration_$(iteration)_parameter_distribution.png"), fig) + catch e + @error "Failed to plot parameter distribution for iteration $(iteration)" exception=(e, catch_backtrace()) + end + + obs_path = joinpath(pwd(), "calibration_data", "ECCO4Monthly", "1yearaverage_2degree2007-01-01T00-00-00") + + T_truth_filepath = joinpath(obs_path, "T.jld2") + S_truth_filepath = joinpath(obs_path, "S.jld2") + b_truth_filepath = joinpath(obs_path, "b.jld2") + + T_truth_afts = jldopen(T_truth_filepath, "r") do file + return file["averaged_fieldtimeseries"] + end + + S_truth_afts = jldopen(S_truth_filepath, "r") do file + return file["averaged_fieldtimeseries"] + end + + b_truth_afts = jldopen(b_truth_filepath, "r") do file + return file["averaged_fieldtimeseries"] + end + + T_truth = on_architecture(CPU(), T_truth_afts.data) + S_truth = on_architecture(CPU(), S_truth_afts.data) + b_truth = on_architecture(CPU(), b_truth_afts.data) + + Nt_truth = length(T_truth.times) + + for i in 1:Nt_truth + mask_immersed_field!(T_truth[i], NaN) + mask_immersed_field!(S_truth[i], NaN) + mask_immersed_field!(b_truth[i], NaN) + end + + target_grid, regridder = jldopen(joinpath(pwd(), "examples", "GM_calibration", "grids_and_regridder.jld2"), "r") do file + return file["target_grid"], file["regridder"] + end + + for m in 1:ensemble_size + try + @info "Plotting zonal averages for member $m" + + κ_skew, κ_symmetric = ϕ[:, m] + member_path = ClimaCalibrate.path_to_ensemble_member(output_dir, iteration, m) + model_filepath = joinpath(member_path, "ocean_complete_fields_1year_average_calibrationsample.jld2") + + T_model = FieldTimeSeries(model_filepath, "T", backend=InMemory()) + S_model = FieldTimeSeries(model_filepath, "S", backend=InMemory()) + b_model = FieldTimeSeries(model_filepath, "b", backend=InMemory()) + + T_model_field = CenterField(target_grid) + S_model_field = CenterField(target_grid) + b_model_field = CenterField(target_grid) + + Nt_model = length(T_model.times) + + mask_immersed_field!(T_model[Nt_model], NaN) + mask_immersed_field!(S_model[Nt_model], NaN) + mask_immersed_field!(b_model[Nt_model], NaN) + + regrid!(T_model_field, regridder, T_model[Nt_model]) + regrid!(S_model_field, regridder, S_model[Nt_model]) + regrid!(b_model_field, regridder, b_model[Nt_model]) + + mask_immersed_field!(T_model_field, NaN) + mask_immersed_field!(S_model_field, NaN) + mask_immersed_field!(b_model_field, NaN) + + T_fig = plot_zonal_average(T_truth[Nt_truth], T_model_field, "T", κ_skew, κ_symmetric) + S_fig = plot_zonal_average(S_truth[Nt_truth], S_model_field, "S", κ_skew, κ_symmetric) + b_fig = plot_zonal_average(b_truth[Nt_truth], b_model_field, "b", κ_skew, κ_symmetric) + + save(joinpath(plots_filepath, "iter$(iteration)_member$(m)_skew_$(κ_skew)_sym_$(κ_symmetric)_T_zonal_average.png"), T_fig) + save(joinpath(plots_filepath, "iter$(iteration)_member$(m)_skew_$(κ_skew)_sym_$(κ_symmetric)_S_zonal_average.png"), S_fig) + save(joinpath(plots_filepath, "iter$(iteration)_member$(m)_skew_$(κ_skew)_sym_$(κ_symmetric)_b_zonal_average.png"), b_fig) + catch e + @error "Failed to plot zonal averages for member $m in iteration $iteration" exception=(e, catch_backtrace()) + end + end + +end \ No newline at end of file diff --git a/examples/GM_calibration/precompute_regridder.jl b/examples/GM_calibration/precompute_regridder.jl new file mode 100644 index 0000000..cfc3ccf --- /dev/null +++ b/examples/GM_calibration/precompute_regridder.jl @@ -0,0 +1,81 @@ +using ClimaOcean +using Oceananigans +using Oceananigans.Architectures: on_architecture, architecture +using Oceananigans.Utils: launch! +using Oceananigans.Grids: znodes +using Oceananigans.ImmersedBoundaries: mask_immersed_field! +using CUDA +using XESMF +using JLD2 +using KernelAbstractions: @index, @kernel + +import Oceananigans.Architectures: on_architecture + +Nz = 100 +z_faces = ExponentialDiscretization(Nz, -6000, 0; scale=1800) +Nx_target, Ny_target = (180, 84) + +minimum_depth = 15 +major_basins = 1 +interpolation_passes = 55 + +arch = GPU() +Nx_source, Ny_source = (720, 360) +source_grid = TripolarGrid(arch; + size = (Nx_source, Ny_source, Nz), + z = z_faces, + halo = (7, 7, 7)) + +bottom_height_source = regrid_bathymetry(source_grid; minimum_depth, major_basins, interpolation_passes) +source_grid = ImmersedBoundaryGrid(source_grid, GridFittedBottom(bottom_height_source); active_cells_map=true) + +target_grid = LatitudeLongitudeGrid(arch; size=(Nx_target, Ny_target, Nz), z = z_faces, + longitude=(0, 360), latitude=(-84, 84)) + +@kernel function _find_immersed_height!(bottom_height, grid, field) + i, j = @index(Global, NTuple) + Nz = grid.Nz + + kmax = 0 + @inbounds for k in 1:Nz + kmax = ifelse(isnan(field[i, j, k]), k, kmax) + end + + @inbounds bottom_height[i, j, 1] = ifelse(kmax == 0, grid.z.cᵃᵃᶠ[1], grid.z.cᵃᵃᶜ[kmax]) +end + +function find_immersed_height!(bottom_height, grid, field) + arch = architecture(grid) + launch!(arch, grid, :xy, _find_immersed_height!, bottom_height, grid, field) + return nothing +end + +src_field = CenterField(source_grid) +mask_immersed_field!(src_field, NaN) + +dst_field = CenterField(target_grid) + +regridder = XESMF.Regridder(dst_field, src_field, method="conservative") + +on_architecture(on, r::XESMF.Regridder) = XESMF.Regridder(on_architecture(on, r.method), + on_architecture(on, r.weights), + on_architecture(on, r.src_temp), + on_architecture(on, r.dst_temp)) + +regrid!(dst_field, regridder, src_field) + +bottom_height_target = Field{Center, Center, Nothing}(target_grid) +find_immersed_height!(bottom_height_target, target_grid, dst_field) + +target_grid = ImmersedBoundaryGrid(target_grid, GridFittedBottom(bottom_height_target); active_cells_map=true) + +new_field = CenterField(target_grid) +mask_immersed_field!(new_field, NaN) +@assert sum(isnan.(interior(new_field))) == sum(isnan.(interior(dst_field))) + +SAVE_PATH = joinpath(pwd(), "examples", "GM_calibration", "grids_and_regridder.jld2") +jldopen(SAVE_PATH, "w") do file + file["source_grid"] = on_architecture(CPU(), source_grid) + file["target_grid"] = on_architecture(CPU(), target_grid) + file["regridder"] = on_architecture(CPU(), regridder) +end \ No newline at end of file diff --git a/examples/GM_calibration_1degforwardmodel/average_ECCO_data.jl b/examples/GM_calibration_1degforwardmodel/average_ECCO_data.jl new file mode 100644 index 0000000..0f46319 --- /dev/null +++ b/examples/GM_calibration_1degforwardmodel/average_ECCO_data.jl @@ -0,0 +1,61 @@ +using ClimaOcean +using Oceananigans +using Oceananigans.Units +using Oceananigans.Architectures: on_architecture +using SeawaterPolynomials.TEOS10 +using ClimaOcean.DataWrangling +using Printf +using Dates +using CUDA +using ClimaOceanCalibration.DataWrangling +using JLD2 +using XESMF + +arch = GPU() + +grid = jldopen(joinpath(pwd(), "examples", "GM_calibration_1degforwardmodel", "grids_and_regridder.jld2"), "r") do file + return on_architecture(arch, file["target_grid"]) +end + +dataset = ECCO4Monthly() + +dir = joinpath(homedir(), "ECCO_data") +mkpath(dir) +start_dates = [DateTime(1997, 1, 1) DateTime(2007, 1, 1)] + +buoyancy_model = SeawaterBuoyancy(equation_of_state=TEOS10EquationOfState()) + +for start_date in start_dates + end_date = start_date + Year(10) - Month(1) + + T = Metadata(:temperature; dataset, dir, start_date, end_date) + S = Metadata(:salinity; dataset, dir, start_date, end_date) + + T_data = FieldTimeSeries(T, grid, time_indices_in_memory=20) + S_data = FieldTimeSeries(S, grid, time_indices_in_memory=20) + + T_averaging = TimeAverageOperator(T_data) + T_averaged_fts = AveragedFieldTimeSeries(T_averaging(T_data), T_averaging, nothing) + + S_averaging = TimeAverageOperator(S_data) + S_averaged_fts = AveragedFieldTimeSeries(S_averaging(S_data), S_averaging, nothing) + + b_averaging = TimeAverageBuoyancyOperator(T_data) + b_averaged_fts = AveragedFieldTimeSeries(b_averaging(T_data, S_data, buoyancy_model), b_averaging, nothing) + + prefix = "10yearaverage_1deggrid_2degree" + date_str = replace(string(start_date), ":" => "-") + + dirname = prefix * date_str + + SAVE_PATH = joinpath(pwd(), "calibration_data", "ECCO4Monthly", dirname) + mkpath(SAVE_PATH) + + T_filepath = joinpath(SAVE_PATH, "T.jld2") + S_filepath = joinpath(SAVE_PATH, "S.jld2") + b_filepath = joinpath(SAVE_PATH, "b.jld2") + + save_averaged_fieldtimeseries(T_averaged_fts, T, filename=T_filepath, overwrite_existing=true) + save_averaged_fieldtimeseries(S_averaged_fts, S, filename=S_filepath, overwrite_existing=true) + save_averaged_fieldtimeseries(b_averaged_fts, nothing, filename=b_filepath, overwrite_existing=true) +end \ No newline at end of file diff --git a/examples/GM_calibration_1degforwardmodel/average_EN4_data.jl b/examples/GM_calibration_1degforwardmodel/average_EN4_data.jl new file mode 100644 index 0000000..ef0b9a7 --- /dev/null +++ b/examples/GM_calibration_1degforwardmodel/average_EN4_data.jl @@ -0,0 +1,66 @@ +using ClimaOcean +using Oceananigans +using Oceananigans.Units +using Oceananigans.Architectures: on_architecture +using SeawaterPolynomials.TEOS10 +using ClimaOcean.DataWrangling +using Printf +using Dates +using CUDA +using ClimaOceanCalibration.DataWrangling +using JLD2 +using XESMF + +arch = GPU() + +grid = jldopen(joinpath(pwd(), "examples", "GM_calibration_1degforwardmodel", "grids_and_regridder.jld2"), "r") do file + return on_architecture(arch, file["target_grid"]) +end + +dataset = EN4Monthly() + +dir = joinpath(homedir(), "EN4_data") +mkpath(dir) +start_dates = [DateTime(1902), DateTime(1912), DateTime(1922), DateTime(1942), + DateTime(1952), DateTime(1972), + DateTime(1992), DateTime(2002), DateTime(2012)] + +# seems that T fields for 1939 1971 1985 is problematic + +buoyancy_model = SeawaterBuoyancy(equation_of_state=TEOS10EquationOfState()) + +for start_date in start_dates + @info "Processing data starting from $(start_date)..." + end_date = start_date + Year(10) - Month(1) + + T = Metadata(:temperature; dataset, dir, start_date, end_date) + S = Metadata(:salinity; dataset, dir, start_date, end_date) + + T_data = FieldTimeSeries(T, grid, time_indices_in_memory=20) + S_data = FieldTimeSeries(S, grid, time_indices_in_memory=20) + + T_averaging = TimeAverageOperator(T_data) + T_averaged_fts = AveragedFieldTimeSeries(T_averaging(T_data), T_averaging, nothing) + + S_averaging = TimeAverageOperator(S_data) + S_averaged_fts = AveragedFieldTimeSeries(S_averaging(S_data), S_averaging, nothing) + + b_averaging = TimeAverageBuoyancyOperator(T_data) + b_averaged_fts = AveragedFieldTimeSeries(b_averaging(T_data, S_data, buoyancy_model), b_averaging, nothing) + + prefix = "10yearaverage_1deggrid_2degree" + date_str = replace(string(start_date), ":" => "-") + + dirname = prefix * date_str + + SAVE_PATH = joinpath(pwd(), "calibration_data", "EN4Monthly", dirname) + mkpath(SAVE_PATH) + + T_filepath = joinpath(SAVE_PATH, "T.jld2") + S_filepath = joinpath(SAVE_PATH, "S.jld2") + b_filepath = joinpath(SAVE_PATH, "b.jld2") + + save_averaged_fieldtimeseries(T_averaged_fts, T, filename=T_filepath, overwrite_existing=true) + save_averaged_fieldtimeseries(S_averaged_fts, S, filename=S_filepath, overwrite_existing=true) + save_averaged_fieldtimeseries(b_averaged_fts, nothing, filename=b_filepath, overwrite_existing=true) +end \ No newline at end of file diff --git a/examples/GM_calibration_1degforwardmodel/calibrate_gm_distributed.jl b/examples/GM_calibration_1degforwardmodel/calibrate_gm_distributed.jl new file mode 100644 index 0000000..ff442f3 --- /dev/null +++ b/examples/GM_calibration_1degforwardmodel/calibrate_gm_distributed.jl @@ -0,0 +1,96 @@ +const ensemble_size = 5 +using Distributed +using ArgParse + +function parse_commandline() + s = ArgParseSettings() + + @add_arg_table! s begin + "--zonal_average" + help = "Whether to perform zonal averaging in loss function" + arg_type = Bool + default = false + end + return parse_args(s) +end + +args = parse_commandline() + +# Add workers with pre-set environment variables +nprocs = ensemble_size +addprocs(nprocs) +@everywhere @info "Worker $(myid())" +@everywhere ENV["CUDA_VISIBLE_DEVICES"] = myid() - 1 + +# Now load CUDA on all workers +@everywhere using CUDA +# Verify each worker sees exactly one GPU +@everywhere println("Worker $(myid()) sees GPU: $(CUDA.NVML.index(CUDA.NVML.Device(CUDA.uuid(CUDA.device()))))") + +@everywhere begin + using ClimaCalibrate + using Distributed + using ClimaOceanCalibration.DataWrangling + using Oceananigans + using EnsembleKalmanProcesses + using EnsembleKalmanProcesses.ParameterDistributions + using LinearAlgebra + using JLD2 + using Glob + using Statistics + import ClimaCalibrate: generate_sbatch_script + include(joinpath(pwd(), "examples", "GM_calibration_1degforwardmodel", "data_processing.jl")) + include(joinpath(pwd(), "examples", "GM_calibration_1degforwardmodel", "model_interface.jl")) + + args = $args + + const simulation_length = 25 + const sampling_length = 10 + const zonal_average = args["zonal_average"] + + const output_dir = joinpath(pwd(), "calibration_runs", "gm_1degforwardmodel_$(simulation_length)year_ecco_eccoinitial_distributed_obscov$(zonal_average ? "_zonalavg" : "")") + ClimaCalibrate.forward_model(iteration, member) = gm_forward_model(iteration, member; simulation_length, sampling_length) + ClimaCalibrate.observation_map(iteration) = gm_construct_g_ensemble(iteration, zonal_average) +end + +n_iterations = 10 + +κ_skew_prior = constrained_gaussian("κ_skew", 1e3, 5e2, 0, Inf) +κ_symmetric_prior = constrained_gaussian("κ_symmetric", 1e2, 5e2, 0, Inf) + +priors = combine_distributions([κ_skew_prior, κ_symmetric_prior]) + +obs_paths = abspath.(vcat(glob("$(sampling_length)yearaverage_1deggrid_2degree*", joinpath("calibration_data", "ECCO4Monthly")), + glob("$(sampling_length)yearaverage_1deggrid_2degree*", joinpath("calibration_data", "EN4Monthly")))) + +calibration_target_obs_path = abspath(joinpath("calibration_data", "ECCO4Monthly", "$(sampling_length)yearaverage_1deggrid_2degree2007-01-01T00-00-00")) + +Y = hcat(process_observation.(obs_paths, no_tapering, zonal_average)...) + +const output_dim = size(Y, 1) + +n_trials = size(Y, 2) + +# the noise estimated from the samples (will have rank n_trials-1) +internal_cov = tsvd_cov_from_samples(Y) # SVD object + +# the "5%" model error (diagonal) +model_error_frac = 0.05 +data_mean = vec(mean(Y,dims=2)) +model_error_cov = Diagonal((model_error_frac*data_mean).^2) + +# regularize the model error diagonal (in case of zero entries) +model_error_cov += 1e-6*I + +# Combine... +covariance = SVDplusD(internal_cov, model_error_cov) + +Y_obs = Observation(Dict("samples" => process_observation(calibration_target_obs_path, taper_interior_ocean, zonal_average), + "covariances" => covariance, + "names" => basename(calibration_target_obs_path))) + +utki = EnsembleKalmanProcess(Y_obs, TransformUnscented(priors)) + +backend = ClimaCalibrate.WorkerBackend + +ClimaCalibrate.calibrate(ClimaCalibrate.WorkerBackend, utki, n_iterations, priors, output_dir) \ No newline at end of file diff --git a/examples/GM_calibration_1degforwardmodel/data_plotting.jl b/examples/GM_calibration_1degforwardmodel/data_plotting.jl new file mode 100644 index 0000000..f47d35f --- /dev/null +++ b/examples/GM_calibration_1degforwardmodel/data_plotting.jl @@ -0,0 +1,86 @@ +using CairoMakie +using Oceananigans +using Oceananigans.Grids: znodes, φnodes +using NaNStatistics +using ColorSchemes + +function plot_parameter_distribution(κs, error) + fig = Figure() + ax = Axis(fig[1, 1]; xlabel="κ skew (m²/s)", ylabel="κ symmetric (m²/s)", title="Parameter Distribution, RMSE at mean = $error") + scatter!(ax, κs[1, :], κs[2, :]) + return fig +end + +function exp_levels(min_val, max_val, n; factor=2) + uniform = range(0, 1, length=n) + transformed = (exp.(factor .* uniform) .- 1) ./ (exp(factor) - 1) + return min_val .+ transformed .* (max_val - min_val) +end + +function inverted_exp_levels(min_val, max_val, n; factor=2) + uniform = range(0, 1, length=n) + # This inverts the concentration from low to high values + transformed = 1 .- (exp.(factor .* (1 .- uniform)) .- 1) ./ (exp(factor) - 1) + return min_val .+ transformed .* (max_val - min_val) +end + +function plot_zonal_average(truth_data, model_data, field_name, κ_skew, κ_symmetric) + fig = Figure(size=(1920, 1080)) + axtruth = Axis(fig[1, 1]; xlabel="Latitude (°)", ylabel="Depth (m)", title="Target: ECCO4") + axmodel = Axis(fig[2, 1]; xlabel="Latitude (°)", ylabel="Depth (m)", title="Model Output") + axdiff = Axis(fig[3, 1]; xlabel="Latitude (°)", ylabel="Depth (m)", title="Anomaly (Model - Target)") + + LX, LY, LZ = location(truth_data) + zCs = znodes(truth_data.grid, LX(), LY(), LZ()) + + truth_φCs = φnodes(truth_data.grid, LX(), LY(), LZ()) + model_φCs = φnodes(model_data.grid, LX(), LY(), LZ()) + + truth_field = nanmean(interior(truth_data), dims=1)[1, :, :] + model_field = nanmean(interior(model_data), dims=1)[1, :, :] + diff_field = model_field .- truth_field + + fieldlim = (nanminimum([truth_field; model_field]), nanmaximum([truth_field; model_field])) + difflim = (-nanmaximum(abs.(diff_field)) / 2, nanmaximum(abs.(diff_field)) / 2) + + if field_name == "T" + @info "Plotting temperature zonal average..." + field_levels = exp_levels(fieldlim[1], fieldlim[2], 15) + field_colorbar_kwargs = (label="Temperature (°C)",) + diff_colorbar_kwargs = (label="Temperature Anomaly (°C)",) + elseif field_name == "S" + @info "Plotting salinity zonal average..." + field_colorbar_kwargs = (label="Salinity (psu)",) + diff_colorbar_kwargs = (label="Salinity Anomaly (psu)",) + fieldlim = (nanminimum([truth_field; model_field[2:end, :]]), nanmaximum([truth_field; model_field[2:end, :]])) + difflim = (-nanmaximum(abs.(diff_field[2:end, :])) / 2, nanmaximum(abs.(diff_field[2:end, :])) / 2) + # field_levels = inverted_exp_levels(fieldlim[1], fieldlim[2], 15) + field_levels = range(fieldlim[1], fieldlim[2], length=15) + elseif field_name == "b" + @info "Plotting buoyancy zonal average..." + field_levels = exp_levels(fieldlim[1], fieldlim[2], 15) + field_colorbar_kwargs = (label="Buoyancy (m/s²)",) + diff_colorbar_kwargs = (label="Buoyancy Anomaly (m/s²)",) + else + error("Unsupported field name: $field_name") + end + + diff_levels = range(difflim[1], difflim[2], length=15) + + cf_f = contourf!(axtruth, truth_φCs, zCs, truth_field, colormap=:turbo, levels = field_levels, extendhigh=:auto, extendlow=:auto) + contourf!(axmodel, model_φCs, zCs, model_field, colormap=:turbo, levels = field_levels, extendhigh=:auto, extendlow=:auto) + cf_d = contourf!(axdiff, model_φCs, zCs, diff_field, colormap=:balance, levels = diff_levels, extendhigh=:auto, extendlow=:auto) + + Colorbar(fig[1:2, 2], cf_f; field_colorbar_kwargs...) + Colorbar(fig[3, 2], cf_d; diff_colorbar_kwargs...) + + xlims!(axtruth, -84, 84) + xlims!(axmodel, -84, 84) + xlims!(axdiff, -84, 84) + ylims!(axtruth, -6000, 0) + ylims!(axmodel, -6000, 0) + ylims!(axdiff, -6000, 0) + + Label(fig[0, 1:2], "Zonal Average of $field_name (κ_skew=$(round(κ_skew, digits=1)), κ_symmetric=$(round(κ_symmetric))", fontsize=25, font=:bold) + return fig +end \ No newline at end of file diff --git a/examples/GM_calibration_1degforwardmodel/data_processing.jl b/examples/GM_calibration_1degforwardmodel/data_processing.jl new file mode 100644 index 0000000..65c8e6b --- /dev/null +++ b/examples/GM_calibration_1degforwardmodel/data_processing.jl @@ -0,0 +1,107 @@ +using Oceananigans +using Oceananigans.Grids: znodes, φnodes +using Oceananigans.Fields: location +using Oceananigans.ImmersedBoundaries: mask_immersed_field! +using Oceananigans.Architectures: on_architecture +using XESMF +using JLD2 +using NaNStatistics +using Glob + +function regrid_model_data(simdir) + @info "Regridding model data in $(simdir)..." + filepath = first(glob("*calibrationsample*", simdir)) + T_data = FieldTimeSeries(filepath, "T", backend=OnDisk()) + S_data = FieldTimeSeries(filepath, "S", backend=OnDisk()) + + source_grid = T_data.grid + LX, LY, LZ = location(T_data) + boundary_conditions = T_data.boundary_conditions + times = T_data.times + + Nx, Ny, Nz = (180, 84, 100) + z_faces = ExponentialDiscretization(Nz, -6000, 0; scale=1800) + + target_grid, regridder = jldopen(joinpath(pwd(), "examples", "GM_calibration_1degforwardmodel", "grids_and_regridder.jld2"), "r") do file + return file["target_grid"], file["regridder"] + end + + T_target = FieldTimeSeries{LX, LY, LZ}(target_grid, times; boundary_conditions) + S_target = FieldTimeSeries{LX, LY, LZ}(target_grid, times; boundary_conditions) + + for t in 1:length(times) + regrid!(T_target[t], regridder, T_data[t]) + regrid!(S_target[t], regridder, S_data[t]) + end + return T_target, S_target +end + +taper_interior_ocean(z, z_scale=3500, width=1000) = 0.5 * (1 + tanh((z + z_scale) / width)) +no_tapering(z) = 1 + +function extract_field_section(fts::FieldTimeSeries, latitude_range; vertical_weighting=no_tapering) + fts = on_architecture(CPU(), fts) + LX, LY, LZ = location(fts) + grid = fts.grid + + φᶜ = φnodes(grid, LX(), LY(), LZ()) + zᶜ = znodes(grid, LX(), LY(), LZ()) + + φmin, φmax = latitude_range + + lat_indices = findfirst(x -> x >= φmin, φᶜ):findlast(x -> x <= φmax, φᶜ) + z_weights = vertical_weighting.(zᶜ) + + times = fts.times + + Nt = length(times) + for t in 1:length(times) + mask_immersed_field!(fts[t], NaN) + end + + field_section = reshape(z_weights, 1, 1, :) .* interior(fts[Nt], :, lat_indices, :) + + return field_section +end + +extract_southern_ocean_section(fts, vertical_weighting=no_tapering) = extract_field_section(fts, (-80, -50); vertical_weighting) + +function process_observation(obs_path, vertical_weighting, zonal_average) + T_filepath = joinpath(obs_path, "T.jld2") + S_filepath = joinpath(obs_path, "S.jld2") + + T_afts = jldopen(T_filepath, "r") do file + return file["averaged_fieldtimeseries"] + end + + S_afts = jldopen(S_filepath, "r") do file + return file["averaged_fieldtimeseries"] + end + + T_data = T_afts.data + S_data = S_afts.data + + T_section = extract_southern_ocean_section(T_data, vertical_weighting) + S_section = extract_southern_ocean_section(S_data, vertical_weighting) + + if zonal_average + T_section = nanmean(T_section, dims=1) + S_section = nanmean(S_section, dims=1) + end + + return vcat(T_section[.!isnan.(T_section)], S_section[.!isnan.(S_section)]) +end + +function process_member_data(simdir, vertical_weighting, zonal_average) + T_target, S_target = regrid_model_data(simdir) + + T_section = extract_southern_ocean_section(T_target, vertical_weighting) + S_section = extract_southern_ocean_section(S_target, vertical_weighting) + + if zonal_average + T_section = nanmean(T_section, dims=1) + S_section = nanmean(S_section, dims=1) + end + + return vcat(T_section[.!isnan.(T_section)], S_section[.!isnan.(S_section)]) +end \ No newline at end of file diff --git a/examples/GM_calibration_1degforwardmodel/model_interface.jl b/examples/GM_calibration_1degforwardmodel/model_interface.jl new file mode 100644 index 0000000..59b9a1d --- /dev/null +++ b/examples/GM_calibration_1degforwardmodel/model_interface.jl @@ -0,0 +1,183 @@ +using ClimaCalibrate +using TOML +using ClimaOceanCalibration.DataWrangling +using EnsembleKalmanProcesses +using Oceananigans +using Oceananigans.Architectures: on_architecture +using Oceananigans.ImmersedBoundaries: mask_immersed_field! +using JLD2 +include("one_degree_omip.jl") +include("data_processing.jl") +include("data_plotting.jl") + +function gm_forward_model(iteration, member; simulation_length, sampling_length) + config_dict = Dict() + + # Set the output path for the current member + member_path = ClimaCalibrate.path_to_ensemble_member(output_dir, iteration, member) + config_dict["output_dir"] = member_path + + # Set the parameters for the current member + parameter_path = ClimaCalibrate.parameter_path(output_dir, iteration, member) + if haskey(config_dict, "toml") + push!(config_dict["toml"], parameter_path) + else + config_dict["toml"] = [parameter_path] + end + + config_dict["iteration"] = iteration + config_dict["member"] = member + config_dict["simulation_length"] = simulation_length + config_dict["sampling_length"] = sampling_length + + params = TOML.parsefile(parameter_path) + κ_skew = params["κ_skew"] + κ_symmetric = params["κ_symmetric"] + + try + # run_gm_calibration_omip_dry_run(κ_skew["value"], κ_symmetric["value"], config_dict) + run_gm_calibration_omip(κ_skew["value"], κ_symmetric["value"], config_dict) + catch e + # Create a failure indicator file with error information + error_file = joinpath(member_path, "RUN_FAILED.err") + open(error_file, "w") do io + println(io, "Run failed at $(now())") + println(io, "Parameters: κ_skew = $(κ_skew), κ_symmetric = $(κ_symmetric)") + println(io, "Error: $(e)") + println(io, "Backtrace:") + for (exc, bt) in Base.catch_stack() + showerror(io, exc, bt) + println(io) + end + end + + @error "GM calibration failed (κ_skew = $(κ_skew), κ_symmetric = $(κ_symmetric))" exception=(e, catch_backtrace()) + end + + return nothing +end + +function gm_construct_g_ensemble(iteration, zonal_average) + G_ensemble = zeros(output_dim, ensemble_size) + + for m in 1:ensemble_size + member_path = ClimaCalibrate.path_to_ensemble_member(output_dir, iteration, m) + + if isfile(joinpath(member_path, "RUN_FAILED.err")) + @warn "Skipping member $m for iteration $iteration due to failed run." + G_ensemble[:, m] .= NaN + else + G_ensemble[:, m] .= process_member_data(member_path, taper_interior_ocean, zonal_average) + end + end + + return G_ensemble +end + +function ClimaCalibrate.analyze_iteration(ekp, g_ensemble, prior, output_dir, iteration) + @info "Mean constrained parameter(s): $(get_ϕ_mean_final(prior, ekp))" + @info "Covariance-weighted error: $(last(get_error(ekp)))" + + ϕs = get_ϕ(prior, ekp) + + jldopen(joinpath(output_dir, "ekp_diagnostics_iteration$(iteration).jld2"), "w") do file + file["ϕs"] = ϕs + file["g_ensemble"] = g_ensemble + file["prior"] = prior + file["ekp"] = ekp + end + + try + ϕ = ϕs[iteration + 1] + catch + @info "iteration $(iteration) is the last iteration in ϕs." + ϕ = ϕs[iteration] + end + + plots_filepath = abspath(joinpath(output_dir, "diagnostics_output")) + mkpath(plots_filepath) + + try + fig = plot_parameter_distribution(ϕ, avg_rmse) + save(joinpath(plots_filepath, "iteration_$(iteration)_parameter_distribution.png"), fig) + catch e + @error "Failed to plot parameter distribution for iteration $iteration" exception=(e, catch_backtrace()) + end + + obs_path = joinpath(pwd(), "calibration_data", "ECCO4Monthly", "10yearaverage_2degree2002-01-01T00-00-00") + + T_truth_filepath = joinpath(obs_path, "T.jld2") + S_truth_filepath = joinpath(obs_path, "S.jld2") + b_truth_filepath = joinpath(obs_path, "b.jld2") + + T_truth_afts = jldopen(T_truth_filepath, "r") do file + return file["averaged_fieldtimeseries"] + end + + S_truth_afts = jldopen(S_truth_filepath, "r") do file + return file["averaged_fieldtimeseries"] + end + + b_truth_afts = jldopen(b_truth_filepath, "r") do file + return file["averaged_fieldtimeseries"] + end + + T_truth = on_architecture(CPU(), T_truth_afts.data) + S_truth = on_architecture(CPU(), S_truth_afts.data) + b_truth = on_architecture(CPU(), b_truth_afts.data) + + Nt_truth = length(T_truth.times) + + for i in 1:Nt_truth + mask_immersed_field!(T_truth[i], NaN) + mask_immersed_field!(S_truth[i], NaN) + mask_immersed_field!(b_truth[i], NaN) + end + + target_grid, regridder = jldopen(joinpath(pwd(), "examples", "GM_calibration", "grids_and_regridder.jld2"), "r") do file + return file["target_grid"], file["regridder"] + end + + for m in 1:ensemble_size + try + @info "Plotting zonal averages for member $m" + + κ_skew, κ_symmetric = ϕ[:, m] + member_path = ClimaCalibrate.path_to_ensemble_member(output_dir, iteration, m) + model_filepath = joinpath(member_path, "ocean_complete_fields_10year_average_calibrationsample.jld2") + + T_model = FieldTimeSeries(model_filepath, "T", backend=InMemory()) + S_model = FieldTimeSeries(model_filepath, "S", backend=InMemory()) + b_model = FieldTimeSeries(model_filepath, "b", backend=InMemory()) + + T_model_field = CenterField(target_grid) + S_model_field = CenterField(target_grid) + b_model_field = CenterField(target_grid) + + Nt_model = length(T_model.times) + + mask_immersed_field!(T_model[Nt_model], NaN) + mask_immersed_field!(S_model[Nt_model], NaN) + mask_immersed_field!(b_model[Nt_model], NaN) + + regrid!(T_model_field, regridder, T_model[Nt_model]) + regrid!(S_model_field, regridder, S_model[Nt_model]) + regrid!(b_model_field, regridder, b_model[Nt_model]) + + mask_immersed_field!(T_model_field, NaN) + mask_immersed_field!(S_model_field, NaN) + mask_immersed_field!(b_model_field, NaN) + + T_fig = plot_zonal_average(T_truth[Nt_truth], T_model_field, "T", κ_skew, κ_symmetric) + S_fig = plot_zonal_average(S_truth[Nt_truth], S_model_field, "S", κ_skew, κ_symmetric) + b_fig = plot_zonal_average(b_truth[Nt_truth], b_model_field, "b", κ_skew, κ_symmetric) + + save(joinpath(plots_filepath, "iter$(iteration)_member$(m)_skew_$(κ_skew)_sym_$(κ_symmetric)_T_zonal_average.png"), T_fig) + save(joinpath(plots_filepath, "iter$(iteration)_member$(m)_skew_$(κ_skew)_sym_$(κ_symmetric)_S_zonal_average.png"), S_fig) + save(joinpath(plots_filepath, "iter$(iteration)_member$(m)_skew_$(κ_skew)_sym_$(κ_symmetric)_b_zonal_average.png"), b_fig) + catch + @error "Failed to plot zonal averages for member $m in iteration $iteration" exception=(e, catch_backtrace()) + end + end + +end \ No newline at end of file diff --git a/examples/GM_calibration_1degforwardmodel/one_degree_omip.jl b/examples/GM_calibration_1degforwardmodel/one_degree_omip.jl new file mode 100644 index 0000000..3a89778 --- /dev/null +++ b/examples/GM_calibration_1degforwardmodel/one_degree_omip.jl @@ -0,0 +1,437 @@ +using ClimaOcean +using ClimaSeaIce +using Oceananigans +using Oceananigans.Grids +using Oceananigans.Units +using Oceananigans.OrthogonalSphericalShellGrids +using Oceananigans.BuoyancyFormulations: buoyancy, buoyancy_frequency +using ClimaOcean.OceanSimulations +using ClimaOcean.ECCO +using ClimaOcean.JRA55 +using ClimaOcean.DataWrangling +using ClimaSeaIce.SeaIceThermodynamics: IceWaterThermalEquilibrium +using Printf +using Dates +using CUDA +using JLD2 +using ArgParse +using Oceananigans.TurbulenceClosures: ExplicitTimeDiscretization, AdvectiveFormulation, IsopycnalSkewSymmetricDiffusivity +using Oceananigans.TurbulenceClosures.TKEBasedVerticalDiffusivities: CATKEVerticalDiffusivity, CATKEMixingLength, CATKEEquation +using Oceananigans.Operators: Δx, Δy +using Statistics + +import Oceananigans.OutputWriters: checkpointer_address + +using Libdl +ucx_libs = filter(lib -> occursin("ucx", lowercase(lib)), Libdl.dllist()) +if isempty(ucx_libs) + @info "✓ No UCX - safe to run!" +else + @warn "✗ UCX libraries detected! This can cause issues with MPI+CUDA. Detected libs:\n$(join(ucx_libs, "\n"))" +end + +function run_gm_calibration_omip(κ_skew, κ_symmetric, config_dict) + output_dir = config_dict["output_dir"] + logfile_path = joinpath(output_dir, "output.log") + + logfile = open(logfile_path, "w") + original_stdout = stdout + original_stderr = stderr + + redirect_stdout(logfile) + redirect_stderr(logfile) + + flusher = @async while isopen(logfile); flush(logfile); sleep(1); end + + try + start_year = 1992 + simulation_length = config_dict["simulation_length"] + sampling_length = config_dict["sampling_length"] + + @info "1-degree simulation" + @info "Using κ_skew = $(κ_skew) m²/s and κ_symmetric = $(κ_symmetric) m²/s, starting in $(start_year) for $(simulation_length) years with $(sampling_length)-year sampling window." + @info "Saving output to $(config_dict["output_dir"])" + + arch = GPU() + + Nx = 360 # longitudinal direction + Ny = 180 # meridional direction + Nz = 100 + + z_faces = ExponentialDiscretization(Nz, -6000, 0; scale=1800) + z_surf = z_faces(Nz) + + grid = TripolarGrid(arch; + size = (Nx, Ny, Nz), + z = z_faces, + halo = (7, 7, 7)) + + bottom_height = regrid_bathymetry(grid; minimum_depth=15, major_basins=1, interpolation_passes=75) + grid = ImmersedBoundaryGrid(grid, GridFittedBottom(bottom_height); active_cells_map=true) + + momentum_advection = WENOVectorInvariant(order=5) + tracer_advection = WENO(order=5) + free_surface = SplitExplicitFreeSurface(grid; cfl=0.8, fixed_Δt=45minutes) + + eddy_closure = IsopycnalSkewSymmetricDiffusivity(; κ_skew, κ_symmetric, skew_flux_formulation=AdvectiveFormulation()) + # obl_closure = ClimaOcean.OceanSimulations.default_ocean_closure() + obl_closure = RiBasedVerticalDiffusivity() + + closure = (obl_closure, VerticalScalarDiffusivity(κ=1e-5, ν=3e-4), eddy_closure) + + dir = joinpath(homedir(), "ECCO_data") + mkpath(dir) + + start_date = DateTime(start_year, 1, 1) + end_date = start_date + Year(simulation_length) + simulation_period = Dates.value(Second(end_date - start_date)) + sampling_start_date = end_date - Year(sampling_length) + sampling_window = Dates.value(Second(end_date - sampling_start_date)) + + @info "Settting up salinity restoring..." + @inline mask(x, y, z, t) = z ≥ z_surf - 1 + Smetadata = Metadata(:salinity; dataset=ECCO4Monthly(), dir, start_date, end_date) + FS = DatasetRestoring(Smetadata, grid; rate = 1/18days, mask, time_indices_in_memory = 10) + + ocean = ocean_simulation(grid; Δt=1minutes, + momentum_advection, + tracer_advection, + timestepper = :SplitRungeKutta3, + free_surface, + forcing = (; S = FS), + closure) + + @info "Built ocean model $(ocean)" + + set!(ocean.model, T=Metadatum(:temperature; dataset=ECCO4Monthly(), date=start_date, dir), + S=Metadatum(:salinity; dataset=ECCO4Monthly(), date=start_date, dir)) + @info "Initialized T and S" + + # Default sea-ice dynamics and salinity coupling are included in the defaults + # sea_ice = sea_ice_simulation(grid, ocean; advection=WENO(order=7)) + sea_ice = sea_ice_simulation(grid, ocean; dynamics=nothing) + @info "Built sea ice model $(sea_ice)" + + set!(sea_ice.model, h=Metadatum(:sea_ice_thickness; dataset=ECCO4Monthly(), date=start_date, dir), + ℵ=Metadatum(:sea_ice_concentration; dataset=ECCO4Monthly(), date=start_date, dir)) + + @info "Initialized sea ice fields" + + jra55_dir = joinpath(homedir(), "JRA55_data") + mkpath(jra55_dir) + dataset = MultiYearJRA55() + backend = JRA55NetCDFBackend(100) + + @info "Setting up presctibed atmosphere $(dataset)" + atmosphere = JRA55PrescribedAtmosphere(arch; dir=jra55_dir, dataset, backend, include_rivers_and_icebergs=true, start_date, end_date) + radiation = Radiation() + + @info "Built atmosphere model $(atmosphere)" + + omip = OceanSeaIceModel(ocean, sea_ice; atmosphere, radiation) + + @info "Built coupled model $(omip)" + + omip = Simulation(omip, Δt=40minutes, stop_time=simulation_period) + @info "Built simulation $(omip)" + + FILE_DIR = config_dict["output_dir"] + mkpath(FILE_DIR) + + b = Field(buoyancy(ocean.model)) + N² = Field(buoyancy_frequency(ocean.model)) + + ocean_outputs = merge(ocean.model.tracers, ocean.model.velocities, (; b, N²)) + sea_ice_outputs = merge((h = sea_ice.model.ice_thickness, + ℵ = sea_ice.model.ice_concentration, + T = sea_ice.model.ice_thermodynamics.top_surface_temperature), + sea_ice.model.velocities) + + ocean.output_writers[:surface] = JLD2Writer(ocean.model, ocean_outputs; + schedule = TimeInterval(180days), + filename = "$(FILE_DIR)/ocean_surface_fields", + indices = (:, :, grid.Nz), + overwrite_existing = true) + + sea_ice.output_writers[:surface] = JLD2Writer(ocean.model, sea_ice_outputs; + schedule = TimeInterval(180days), + filename = "$(FILE_DIR)/sea_ice_surface_fields", + overwrite_existing = true) + + ocean.output_writers[:time_average] = JLD2Writer(ocean.model, ocean_outputs; + schedule = AveragedTimeInterval(3650days, window=3650days), + filename = "$(FILE_DIR)/ocean_complete_fields_10year_average", + overwrite_existing = true) + + sea_ice.output_writers[:time_average] = JLD2Writer(sea_ice.model, sea_ice_outputs; + schedule = AveragedTimeInterval(3650days, window=3650days), + filename = "$(FILE_DIR)/sea_ice_complete_fields_10year_average", + overwrite_existing = true) + + ocean.output_writers[:sample_decadal_average] = JLD2Writer(ocean.model, ocean_outputs; + schedule = AveragedTimeInterval(simulation_period, window=sampling_window), + filename = "$(FILE_DIR)/ocean_complete_fields_10year_average_calibrationsample", + overwrite_existing = true) + + wall_time = Ref(time_ns()) + + function progress(sim) + sea_ice = sim.model.sea_ice + ocean = sim.model.ocean + hmax = maximum(sea_ice.model.ice_thickness) + ℵmax = maximum(sea_ice.model.ice_concentration) + Tmax = maximum(sim.model.interfaces.atmosphere_sea_ice_interface.temperature) + Tmin = minimum(sim.model.interfaces.atmosphere_sea_ice_interface.temperature) + umax = maximum(ocean.model.velocities.u) + vmax = maximum(ocean.model.velocities.v) + wmax = maximum(ocean.model.velocities.w) + + step_time = 1e-9 * (time_ns() - wall_time[]) + + msg1 = @sprintf("time: %s, iteration: %d, Δt: %s, ", prettytime(sim), iteration(sim), prettytime(sim.Δt)) + msg2 = @sprintf("max(h): %.2e m, max(ℵ): %.2e ", hmax, ℵmax) + msg4 = @sprintf("extrema(T): (%.2f, %.2f) ᵒC, ", Tmax, Tmin) + msg5 = @sprintf("maximum(u): (%.2f, %.2f, %.2f) m/s, ", umax, vmax, wmax) + msg6 = @sprintf("wall time: %s \n", prettytime(step_time)) + + @info msg1 * msg2 * msg4 * msg5 * msg6 + + wall_time[] = time_ns() + + return nothing + end + + # And add it as a callback to the simulation. + add_callback!(omip, progress, IterationInterval(100)) + + run!(omip) + return nothing + catch e + # Handle errors + if e isa InterruptException + println(stderr, "Interrupted by user") + else + println(stderr, "Error occurred: $e") + # Optionally rethrow to propagate the error + rethrow(e) + end + finally + # Cleanup - ALWAYS runs + redirect_stdout(original_stdout) + redirect_stderr(original_stderr) + close(logfile) + println("Log file closed") # This prints to console, not log + end + +end + +function run_gm_calibration_omip_dry_run(κ_skew, κ_symmetric, config_dict) + output_dir = config_dict["output_dir"] + logfile_path = joinpath(output_dir, "output.log") + + logfile = open(logfile_path, "w") + original_stdout = stdout + original_stderr = stderr + + redirect_stdout(logfile) + redirect_stderr(logfile) + + flusher = @async while isopen(logfile); flush(logfile); sleep(1); end + + try + start_year = rand(1992:2011) + @info "Dry run: Using κ_skew = $(κ_skew) m²/s and κ_symmetric = $(κ_symmetric) m²/s, starting in year $(start_year)" + @info "Saving output to $(config_dict["output_dir"])" + + arch = GPU() + + Nx = 720 # longitudinal direction + Ny = 360 # meridional direction + Nz = 100 + + z_faces = ExponentialDiscretization(Nz, -6000, 0; scale=1800) + z_surf = z_faces(Nz) + + grid = TripolarGrid(arch; + size = (Nx, Ny, Nz), + z = z_faces, + halo = (7, 7, 7)) + + bottom_height = regrid_bathymetry(grid; minimum_depth=15, major_basins=1, interpolation_passes=55) + grid = ImmersedBoundaryGrid(grid, GridFittedBottom(bottom_height); active_cells_map=true) + + momentum_advection = WENOVectorInvariant(order=5) + tracer_advection = WENO(order=7) + free_surface = SplitExplicitFreeSurface(grid; cfl=0.8, fixed_Δt=50minutes) + + @inline Δ²ᵃᵃᵃ(i, j, k, grid, lx, ly, lz) = 2 * (1 / (1 / Δx(i, j, k, grid, lx, ly, lz)^2 + 1 / Δy(i, j, k, grid, lx, ly, lz)^2)) + @inline geometric_νhb(i, j, k, grid, lx, ly, lz, clock, fields, λ) = Δ²ᵃᵃᵃ(i, j, k, grid, lx, ly, lz)^2 / λ + + eddy_closure = IsopycnalSkewSymmetricDiffusivity(; κ_skew, κ_symmetric, skew_flux_formulation=AdvectiveFormulation()) + obl_closure = RiBasedVerticalDiffusivity() + visc_closure = HorizontalScalarBiharmonicDiffusivity(ν=geometric_νhb, discrete_form=true, parameters=25days) + + closure = (obl_closure, VerticalScalarDiffusivity(κ=1e-5, ν=3e-4), visc_closure, eddy_closure) + + dir = joinpath(homedir(), "forcing_data_half_degree") + mkpath(dir) + + start_date = DateTime(start_year, 1, 1) + end_date = start_date + Month(2) + simulation_period = Dates.value(Second(end_date - start_date)) + + @info "Settting up salinity restoring..." + @inline mask(x, y, z, t) = z ≥ z_surf - 1 + Smetadata = Metadata(:salinity; dataset=ECCO4Monthly(), dir, start_date, end_date) + FS = DatasetRestoring(Smetadata, grid; rate = 1/30days, mask, time_indices_in_memory = 2) + + ocean = ocean_simulation(grid; Δt=1minutes, + momentum_advection, + tracer_advection, + timestepper = :SplitRungeKutta3, + free_surface, + forcing = (; S = FS), + closure) + + @info "Built ocean model $(ocean)" + + set!(ocean.model, T=Metadatum(:temperature; dataset=ECCO4Monthly(), date=start_date, dir), + S=Metadatum(:salinity; dataset=ECCO4Monthly(), date=start_date, dir)) + @info "Initialized T and S" + + sea_ice = sea_ice_simulation(grid, ocean; dynamics=nothing) + @info "Built sea ice model $(sea_ice)" + + set!(sea_ice.model, h=Metadatum(:sea_ice_thickness; dataset=ECCO4Monthly(), dir), + ℵ=Metadatum(:sea_ice_concentration; dataset=ECCO4Monthly(), dir)) + + @info "Initialized sea ice fields" + + jra55_dir = joinpath(homedir(), "JRA55_data") + mkpath(jra55_dir) + dataset = MultiYearJRA55() + backend = JRA55NetCDFBackend(100) + + @info "Setting up presctibed atmosphere $(dataset)" + atmosphere = JRA55PrescribedAtmosphere(arch; dir=jra55_dir, dataset, backend, include_rivers_and_icebergs=true, start_date, end_date) + radiation = Radiation() + + @info "Built atmosphere model $(atmosphere)" + + omip = OceanSeaIceModel(ocean, sea_ice; atmosphere, radiation) + + @info "Built coupled model $(omip)" + + omip = Simulation(omip, Δt=30minutes, stop_time=simulation_period) + @info "Built simulation $(omip)" + + FILE_DIR = config_dict["output_dir"] + mkpath(FILE_DIR) + + b = Field(buoyancy(ocean.model)) + N² = Field(buoyancy_frequency(ocean.model)) + + ocean_outputs = merge(ocean.model.tracers, ocean.model.velocities, (; b, N²)) + + ocean.output_writers[:sample_decadal_average] = JLD2Writer(ocean.model, ocean_outputs; + schedule = AveragedTimeInterval(simulation_period, window=30days), + filename = "$(FILE_DIR)/ocean_complete_fields_10year_average_calibrationsample", + overwrite_existing = true) + + wall_time = Ref(time_ns()) + + function progress(sim) + sea_ice = sim.model.sea_ice + ocean = sim.model.ocean + hmax = maximum(sea_ice.model.ice_thickness) + ℵmax = maximum(sea_ice.model.ice_concentration) + Tmax = maximum(sim.model.interfaces.atmosphere_sea_ice_interface.temperature) + Tmin = minimum(sim.model.interfaces.atmosphere_sea_ice_interface.temperature) + umax = maximum(ocean.model.velocities.u) + vmax = maximum(ocean.model.velocities.v) + wmax = maximum(ocean.model.velocities.w) + + step_time = 1e-9 * (time_ns() - wall_time[]) + + msg1 = @sprintf("time: %s, iteration: %d, Δt: %s, ", prettytime(sim), Oceananigans.iteration(sim), prettytime(sim.Δt)) + msg2 = @sprintf("max(h): %.2e m, max(ℵ): %.2e ", hmax, ℵmax) + msg4 = @sprintf("extrema(T): (%.2f, %.2f) ᵒC, ", Tmax, Tmin) + msg5 = @sprintf("maximum(u): (%.2f, %.2f, %.2f) m/s, ", umax, vmax, wmax) + msg6 = @sprintf("wall time: %s \n", prettytime(step_time)) + + @info msg1 * msg2 * msg4 * msg5 * msg6 + + wall_time[] = time_ns() + + return nothing + end + + add_callback!(omip, progress, IterationInterval(10)) + + run!(omip) + return nothing + catch e + # Handle errors + if e isa InterruptException + println(stderr, "Interrupted by user") + else + println(stderr, "Error occurred: $e") + # Optionally rethrow to propagate the error + rethrow(e) + end + finally + # Cleanup - ALWAYS runs + redirect_stdout(original_stdout) + redirect_stderr(original_stderr) + close(logfile) + println("Log file closed") # This prints to console, not log + end +end + +# function run_gm_calibration_omip_dry_run(κ_skew, κ_symmetric, config_dict) +# output_dir = config_dict["output_dir"] +# logfile_path = joinpath(output_dir, "output.log") + +# logfile = open(logfile_path, "w") +# original_stdout = stdout +# original_stderr = stderr + +# redirect_stdout(logfile) +# redirect_stderr(logfile) + +# flusher = @async while isopen(logfile); flush(logfile); sleep(1); end + +# try +# # ALL your main code goes here +# println("Starting work...") + +# start_year = rand(1992:2011) +# member = config_dict["member"] +# iteration = config_dict["iteration"] +# @info "Member $member, iter $iteration dry run: Using κ_skew = $(κ_skew) m²/s and κ_symmetric = $(κ_symmetric) m²/s, starting in year $(start_year)" +# @info "Saving output to $(config_dict["output_dir"])" +# FILE_DIR = config_dict["output_dir"] +# mkpath(FILE_DIR) + +# cp(joinpath(homedir(), "ocean_complete_fields_10year_average_calibrationsample.jld2"), "$(FILE_DIR)/ocean_complete_fields_10year_average_calibrationsample.jld2") + +# println("Finished successfully") + +# return nothing +# catch e +# # Handle errors +# if e isa InterruptException +# println(stderr, "Interrupted by user") +# else +# println(stderr, "Error occurred: $e") +# # Optionally rethrow to propagate the error +# rethrow(e) +# end +# finally +# # Cleanup - ALWAYS runs +# redirect_stdout(original_stdout) +# redirect_stderr(original_stderr) +# close(logfile) +# println("Log file closed") # This prints to console, not log +# end +# end \ No newline at end of file diff --git a/examples/GM_calibration_1degforwardmodel/precompute_regridder.jl b/examples/GM_calibration_1degforwardmodel/precompute_regridder.jl new file mode 100644 index 0000000..15213cd --- /dev/null +++ b/examples/GM_calibration_1degforwardmodel/precompute_regridder.jl @@ -0,0 +1,81 @@ +using ClimaOcean +using Oceananigans +using Oceananigans.Architectures: on_architecture, architecture +using Oceananigans.Utils: launch! +using Oceananigans.Grids: znodes +using Oceananigans.ImmersedBoundaries: mask_immersed_field! +using CUDA +using XESMF +using JLD2 +using KernelAbstractions: @index, @kernel + +import Oceananigans.Architectures: on_architecture + +Nz = 100 +z_faces = ExponentialDiscretization(Nz, -6000, 0; scale=1800) +Nx_target, Ny_target = (180, 84) + +minimum_depth = 15 +major_basins = 1 +interpolation_passes = 75 + +arch = GPU() +Nx_source, Ny_source = (360, 180) +source_grid = TripolarGrid(arch; + size = (Nx_source, Ny_source, Nz), + z = z_faces, + halo = (7, 7, 7)) + +bottom_height_source = regrid_bathymetry(source_grid; minimum_depth, major_basins, interpolation_passes) +source_grid = ImmersedBoundaryGrid(source_grid, GridFittedBottom(bottom_height_source); active_cells_map=true) + +target_grid = LatitudeLongitudeGrid(arch; size=(Nx_target, Ny_target, Nz), z = z_faces, + longitude=(0, 360), latitude=(-84, 84)) + +@kernel function _find_immersed_height!(bottom_height, grid, field) + i, j = @index(Global, NTuple) + Nz = grid.Nz + + kmax = 0 + @inbounds for k in 1:Nz + kmax = ifelse(isnan(field[i, j, k]), k, kmax) + end + + @inbounds bottom_height[i, j, 1] = ifelse(kmax == 0, grid.z.cᵃᵃᶠ[1], grid.z.cᵃᵃᶜ[kmax]) +end + +function find_immersed_height!(bottom_height, grid, field) + arch = architecture(grid) + launch!(arch, grid, :xy, _find_immersed_height!, bottom_height, grid, field) + return nothing +end + +src_field = CenterField(source_grid) +mask_immersed_field!(src_field, NaN) + +dst_field = CenterField(target_grid) + +regridder = XESMF.Regridder(dst_field, src_field, method="conservative") + +on_architecture(on, r::XESMF.Regridder) = XESMF.Regridder(on_architecture(on, r.method), + on_architecture(on, r.weights), + on_architecture(on, r.src_temp), + on_architecture(on, r.dst_temp)) + +regrid!(dst_field, regridder, src_field) + +bottom_height_target = Field{Center, Center, Nothing}(target_grid) +find_immersed_height!(bottom_height_target, target_grid, dst_field) + +target_grid = ImmersedBoundaryGrid(target_grid, GridFittedBottom(bottom_height_target); active_cells_map=true) + +new_field = CenterField(target_grid) +mask_immersed_field!(new_field, NaN) +@assert sum(isnan.(interior(new_field))) == sum(isnan.(interior(dst_field))) + +SAVE_PATH = joinpath(pwd(), "examples", "GM_calibration_1degforwardmodel", "grids_and_regridder.jld2") +jldopen(SAVE_PATH, "w") do file + file["source_grid"] = on_architecture(CPU(), source_grid) + file["target_grid"] = on_architecture(CPU(), target_grid) + file["regridder"] = on_architecture(CPU(), regridder) +end \ No newline at end of file diff --git a/src/DataWrangling/DataWrangling.jl b/src/DataWrangling/DataWrangling.jl index 9ddc6ac..7e78bcd 100644 --- a/src/DataWrangling/DataWrangling.jl +++ b/src/DataWrangling/DataWrangling.jl @@ -1,5 +1,7 @@ module DataWrangling -include("time_averaging.jl") +export TimeAverageOperator, TimeAverageBuoyancyOperator, AveragedFieldTimeSeries, spatial_averaging, save_averaged_fieldtimeseries + +include("fieldtimeseries_averaging.jl") end \ No newline at end of file diff --git a/src/DataWrangling/fieldtimeseries_averaging.jl b/src/DataWrangling/fieldtimeseries_averaging.jl new file mode 100644 index 0000000..a055a64 --- /dev/null +++ b/src/DataWrangling/fieldtimeseries_averaging.jl @@ -0,0 +1,343 @@ +using Oceananigans +using Oceananigans.Fields: location +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.BuoyancyFormulations: buoyancy_perturbationᶜᶜᶜ +using Oceananigans.Architectures: architecture +using Oceananigans.Utils: launch! +using Oceananigans.Fields: interpolate! +using ClimaOcean +using ClimaOcean.DataWrangling: DatasetFieldTimeSeries, native_grid +using Dates +using JLD2 +using KernelAbstractions +using XESMF + +""" + AveragedFieldTimeSeries{D, T, S} + +A container for field data that has been averaged in time and/or space. + +# Fields +- `data`: The averaged field time series data +- `time_averaging`: Information about the time averaging operation applied +- `space_averaging`: Information about the space averaging operation applied + +This struct provides a way to track both the averaged data and the operations used to produce it. +""" +struct AveragedFieldTimeSeries{D, T, S} + data :: D + time_averaging :: T + space_averaging :: S +end + +""" + AbstractTimeAverageOperator + +An abstract type representing time averaging operators for field time series. +All concrete implementations should provide: + +- Storage for source and target times +- An implementation of the call method that transforms a field time series +""" +abstract type AbstractTimeAverageOperator end + +""" + TimeAverageOperator{N, ST, TT, SDT, TDT} + +An operator that performs time averaging on field time series data. + +# Fields +- `nsteps`: Number of time steps to combine in each averaging window +- `source_times`: Original times from the source data +- `target_times`: Times for the averaged data (subset of source times) +- `source_Δt`: Time intervals in the source data +- `target_Δt`: Time intervals in the averaged data + +This operator is used to reduce temporal resolution by averaging multiple time steps together. +""" +struct TimeAverageOperator{N, ST, TT, SDT, TDT} <: AbstractTimeAverageOperator + nsteps :: N + source_times :: ST + target_times :: TT + source_Δt :: SDT + target_Δt :: TDT +end + +floor_multiple(a, b) = a - rem(a, b) + +""" + TimeAverageOperator(fts::DatasetFieldTimeSeries, nsteps) + +Create a time averaging operator that averages every `nsteps` time steps in the field time series. +Note that the assumption is that fts[i] is the average field value over the interval [times[i], times[i+1]]. +For the last timestep, we assume it is averaged over the interval [times[end], times[end] + Δt], where Δt is the date step (which depends on the actual dates given by the metadata). + +# Arguments +- `fts`: A `DatasetFieldTimeSeries` containing the time data to be averaged +- `nsteps`: Number of consecutive time steps to average together + +# Returns +- `TimeAverageOperator` that can be applied to a compatible field time series + +# Notes +- If `nsteps` is 1, no averaging will be performed +- The operator computes target times and appropriate time intervals for weighted averaging +- For dataset time series with dates, proper date-based time intervals are calculated +""" +function TimeAverageOperator(fts::DatasetFieldTimeSeries, nsteps) + fts.times isa Number && return TimeAverageOperator(1, nothing) + + source_dates = fts.backend.metadata.dates + source_datestep = source_dates |> step + source_enddate = last(source_dates) + source_datestep + + fts_times = Array(fts.times) + last_timestep = Dates.value(source_enddate - first(source_dates)) / 1000 + + times_inclusive = vcat(fts_times, last_timestep) + source_Δt = diff(times_inclusive) + + truncated_length = floor_multiple(length(fts_times), nsteps) + target_times = fts_times[1:truncated_length][1:nsteps:end] + target_Δt = diff(times_inclusive[1:truncated_length+1][1:nsteps:end]) + + return TimeAverageOperator(nsteps, fts_times, target_times, source_Δt, target_Δt) +end + +""" + TimeAverageOperator(fts::FieldTimeSeries, nsteps) + +Create a time averaging operator that averages every `nsteps` time steps in a regular field time series. +The assumption is that fts[i] is the average field value over the interval [times[i], times[i+1]]. +For the last timestep, we assume it extends one timestep beyond the final recorded time. + +# Arguments +- `fts`: A `FieldTimeSeries` containing the time data to be averaged +- `nsteps`: Number of consecutive time steps to average together + +# Returns +- `TimeAverageOperator` that can be applied to a compatible field time series + +# Notes +- If `nsteps` is 1, no averaging will be performed +- The operator requires uniform time spacing in the input field time series +- The operator truncates the data to ensure complete averaging windows +- The returned operator contains both source and target times and time intervals needed for weighted averaging + +# Throws +- Assertion error if non-uniform time steps are detected in the input field time series +""" +function TimeAverageOperator(fts::FieldTimeSeries, nsteps) + fts.times isa Number && return TimeAverageOperator(1, nothing) + + fts_times = Array(fts.times) + timestep = fts_times[2] - fts_times[1] # assume uniform spacing!! + if length(fts_times) > 2 + all_timesteps = diff(fts_times) + @assert all(isapprox.(all_timesteps, timestep)) "Non-uniform time steps detected in FieldTimeSeries. This implementation requires uniform time spacing." + end + + last_timestep = fts_times[end] + timestep + + times_inclusive = vcat(fts_times, last_timestep) + source_Δt = diff(times_inclusive) + + truncated_length = floor_multiple(length(fts_times), nsteps) + target_times = fts_times[1:truncated_length][1:nsteps:end] + target_Δt = diff(times_inclusive[1:truncated_length+1][1:nsteps:end]) + + return TimeAverageOperator(nsteps, fts_times, target_times, source_Δt, target_Δt) +end + +TimeAverageOperator(fts) = TimeAverageOperator(fts, length(fts)) + +""" + (𝒯::TimeAverageOperator)(fts::FieldTimeSeries) + +Apply time averaging to a field time series using the specified operator. + +# Arguments +- `𝒯`: The time averaging operator +- `fts`: The field time series to which the operator is applied + +# Returns +A new field time series with reduced temporal resolution, where each time step is an average of `nsteps` original time steps. + +# Example + +```julia +# Create a time averaging operator for 3 time steps +operator = TimeAverageOperator(fts, 3) + +# Apply the operator to a field time series +averaged_fts = operator(fts) +``` + +""" +function (𝒯::TimeAverageOperator)(fts::FieldTimeSeries) + nsteps = 𝒯.nsteps + nsteps == 1 && return fts + + LX, LY, LZ = location(fts) + grid = fts.grid + boundary_conditions = fts.boundary_conditions + target_fts = FieldTimeSeries{LX, LY, LZ}(grid, 𝒯.target_times; boundary_conditions) + + for i in eachindex(𝒯.target_times) + target_field = target_fts[i] + for j in 1:nsteps + target_field .+= fts[nsteps * (i-1) + j] * 𝒯.source_Δt[nsteps * (i-1) + j] + end + target_field ./= 𝒯.target_Δt[i] + end + return target_fts +end + +""" + TimeAverageBuoyancyOperator{N, ST, TT, SDT, TDT} <: AbstractTimeAverageOperator + +A concrete operator that performs time averaging specifically for buoyancy-related quantities. + +# Fields +- `nsteps`: Number of time steps to combine in each averaging window +- `source_times`: Original times from the source data +- `target_times`: Times for the averaged data (subset of source times) +- `source_Δt`: Time intervals in the source data +- `target_Δt`: Time intervals in the averaged data +""" +struct TimeAverageBuoyancyOperator{N, ST, TT, SDT, TDT} <: AbstractTimeAverageOperator + nsteps :: N + source_times :: ST + target_times :: TT + source_Δt :: SDT + target_Δt :: TDT +end + +# Constructor for DatasetFieldTimeSeries +function TimeAverageBuoyancyOperator(fts::DatasetFieldTimeSeries, nsteps) + # Reuse the TimeAverageOperator constructor logic + operator = TimeAverageOperator(fts, nsteps) + return TimeAverageBuoyancyOperator(operator.nsteps, operator.source_times, + operator.target_times, operator.source_Δt, + operator.target_Δt) +end + +# Constructor for FieldTimeSeries +function TimeAverageBuoyancyOperator(fts::FieldTimeSeries, nsteps) + # Reuse the TimeAverageOperator constructor logic + operator = TimeAverageOperator(fts, nsteps) + return TimeAverageBuoyancyOperator(operator.nsteps, operator.source_times, + operator.target_times, operator.source_Δt, + operator.target_Δt) +end + +TimeAverageBuoyancyOperator(fts) = TimeAverageBuoyancyOperator(fts, length(fts)) + +@kernel function _compute_buoyancy!(b, grid, buoyancy_model, C) + i, j, k = @index(Global, NTuple) + @inbounds b[i, j, k] = buoyancy_perturbationᶜᶜᶜ(i, j, k, grid, buoyancy_model, C) +end + +""" + (𝒯::TimeAverageBuoyancyOperator)(fts::FieldTimeSeries) + +Apply time averaging specialized for buoyancy quantities. +""" +function (𝒯::TimeAverageBuoyancyOperator)(T_fts::FieldTimeSeries, S_fts::FieldTimeSeries, buoyancy_model::SeawaterBuoyancy) + nsteps = 𝒯.nsteps + + LX, LY, LZ = location(T_fts) + grid = T_fts.grid + arch = architecture(grid) + + boundary_conditions = T_fts.boundary_conditions + b_field = CenterField(grid) + target_buoyancy_fts = FieldTimeSeries{LX, LY, LZ}(grid, 𝒯.target_times; boundary_conditions) + + for i in eachindex(𝒯.target_times) + target_field = target_buoyancy_fts[i] + for j in 1:nsteps + T_field = T_fts[nsteps * (i-1) + j] + S_field = S_fts[nsteps * (i-1) + j] + + C = (T=T_field, S=S_field) + launch!(arch, grid, :xyz, _compute_buoyancy!, b_field, grid, buoyancy_model, C) + + target_field .+= b_field * 𝒯.source_Δt[nsteps * (i-1) + j] + end + target_field ./= 𝒯.target_Δt[i] + end + return target_buoyancy_fts +end + +""" + (𝒯::TimeAverageBuoyancyOperator)(T_metadata, S_metadata, grid, buoyancy_model::SeawaterBuoyancy, meta_indices_in_memory=20) + +Apply time averaging specialized for buoyancy quantities using metadata to create field time series. +Buoyancy is computed from temperature and salinity fields at the native grid of the dataset before being interpolated to the target grid. +""" +function (𝒯::TimeAverageBuoyancyOperator)(T_metadata::Metadata, S_metadata::Metadata, grid, buoyancy_model::SeawaterBuoyancy, meta_indices_in_memory=20) + nsteps = 𝒯.nsteps + + arch = architecture(grid) + meta_grid = native_grid(T_metadata, arch) + b_native_field = CenterField(meta_grid) + + T_fts = FieldTimeSeries(T_metadata, meta_grid, time_indices_in_memory=meta_indices_in_memory) + S_fts = FieldTimeSeries(S_metadata, meta_grid, time_indices_in_memory=meta_indices_in_memory) + + boundary_conditions = T_fts.boundary_conditions + LX, LY, LZ = location(T_fts) + b_field = CenterField(grid) + target_buoyancy_fts = FieldTimeSeries{LX, LY, LZ}(grid, 𝒯.target_times; boundary_conditions) + + for i in eachindex(𝒯.target_times) + target_field = target_buoyancy_fts[i] + for j in 1:nsteps + T_field = T_fts[nsteps * (i-1) + j] + S_field = S_fts[nsteps * (i-1) + j] + + C = (T=T_field, S=S_field) + + launch!(arch, meta_grid, :xyz, _compute_buoyancy!, b_native_field, meta_grid, buoyancy_model, C) + mask_immersed_field!(b_native_field, NaN) + + interpolate!(b_field, b_native_field) + + target_field .+= b_field * 𝒯.source_Δt[nsteps * (i-1) + j] + end + target_field ./= 𝒯.target_Δt[i] + end + return target_buoyancy_fts +end + +function spatial_averaging(fts::FieldTimeSeries, target_grid, spatial_average_operator::XESMF.Regridder) + times = fts.times + ntime = length(times) + LX, LY, LZ = location(fts) + boundary_conditions = fts.boundary_conditions + + averaged_fts = FieldTimeSeries{LX, LY, LZ}(target_grid, times; boundary_conditions) + + for t in 1:ntime + regrid!(averaged_fts[t], spatial_average_operator, fts[t]) + end + + return averaged_fts +end + +function save_averaged_fieldtimeseries(afts::AveragedFieldTimeSeries, metadata; filename::String="averaged_fieldtimeseries", overwrite_existing::Bool=false) + # add .jld2 to filename if not present + if !endswith(filename, ".jld2") + filename *= ".jld2" + end + + # only save if file doesn't exist or if overwrite_existing is true + if overwrite_existing || !isfile(filename) + jldopen(filename, "w+") do file + file["averaged_fieldtimeseries"] = afts + file["metadata"] = metadata + end + end + return nothing +end \ No newline at end of file diff --git a/src/DataWrangling/time_averaging.jl b/src/DataWrangling/time_averaging.jl deleted file mode 100644 index af6ea15..0000000 --- a/src/DataWrangling/time_averaging.jl +++ /dev/null @@ -1,174 +0,0 @@ -using Oceananigans -using Oceananigans.Fields: location -using Oceananigans.OutputReaders: FieldTimeSeries -using ClimaOcean -using ClimaOcean.DataWrangling: DatasetFieldTimeSeries -using Dates - -""" - AveragedFieldTimeSeries{D, T, S} - -A container for field data that has been averaged in time and/or space. - -# Fields -- `data`: The averaged field time series data -- `time_averaging`: Information about the time averaging operation applied -- `space_averaging`: Information about the space averaging operation applied - -This struct provides a way to track both the averaged data and the operations used to produce it. -""" -struct AveragedFieldTimeSeries{D, T, S} - data :: D - time_averaging :: T - space_averaging :: S -end - -""" - TimeAverageOperator{N, ST, TT, SDT, TDT} - -An operator that performs time averaging on field time series data. - -# Fields -- `nsteps`: Number of time steps to combine in each averaging window -- `source_times`: Original times from the source data -- `target_times`: Times for the averaged data (subset of source times) -- `source_Δt`: Time intervals in the source data -- `target_Δt`: Time intervals in the averaged data - -This operator is used to reduce temporal resolution by averaging multiple time steps together. -""" -struct TimeAverageOperator{N, ST, TT, SDT, TDT} - nsteps :: N - source_times :: ST - target_times :: TT - source_Δt :: SDT - target_Δt :: TDT -end - -floor_multiple(a, b) = a - rem(a, b) - -""" - TimeAverageOperator(fts::DatasetFieldTimeSeries, nsteps) - -Create a time averaging operator that averages every `nsteps` time steps in the field time series. -Note that the assumption is that fts[i] is the average field value over the interval [times[i], times[i+1]]. -For the last timestep, we assume it is averaged over the interval [times[end], times[end] + Δt], where Δt is the date step (which depends on the actual dates given by the metadata). - -# Arguments -- `fts`: A `DatasetFieldTimeSeries` containing the time data to be averaged -- `nsteps`: Number of consecutive time steps to average together - -# Returns -- `TimeAverageOperator` that can be applied to a compatible field time series - -# Notes -- If `nsteps` is 1, no averaging will be performed -- The operator computes target times and appropriate time intervals for weighted averaging -- For dataset time series with dates, proper date-based time intervals are calculated -""" -function TimeAverageOperator(fts::DatasetFieldTimeSeries, nsteps) - fts.times isa Number && return TimeAverageOperator(1, nothing) - - source_dates = fts.backend.metadata.dates - source_datestep = source_dates |> step - source_enddate = last(source_dates) + source_datestep - - fts_times = Array(fts.times) - last_timestep = Dates.value(source_enddate - first(source_dates)) / 1000 - - times_inclusive = vcat(fts_times, last_timestep) - source_Δt = diff(times_inclusive) - - truncated_length = floor_multiple(length(fts_times), nsteps) - target_times = fts_times[1:truncated_length][1:nsteps:end] - target_Δt = diff(times_inclusive[1:truncated_length+1][1:nsteps:end]) - - return TimeAverageOperator(nsteps, fts_times, target_times, source_Δt, target_Δt) -end - -""" - TimeAverageOperator(fts::FieldTimeSeries, nsteps) - -Create a time averaging operator that averages every `nsteps` time steps in a regular field time series. -The assumption is that fts[i] is the average field value over the interval [times[i], times[i+1]]. -For the last timestep, we assume it extends one timestep beyond the final recorded time. - -# Arguments -- `fts`: A `FieldTimeSeries` containing the time data to be averaged -- `nsteps`: Number of consecutive time steps to average together - -# Returns -- `TimeAverageOperator` that can be applied to a compatible field time series - -# Notes -- If `nsteps` is 1, no averaging will be performed -- The operator requires uniform time spacing in the input field time series -- The operator truncates the data to ensure complete averaging windows -- The returned operator contains both source and target times and time intervals needed for weighted averaging - -# Throws -- Assertion error if non-uniform time steps are detected in the input field time series -""" -function TimeAverageOperator(fts::FieldTimeSeries, nsteps) - fts.times isa Number && return TimeAverageOperator(1, nothing) - - fts_times = Array(fts.times) - timestep = fts_times[2] - fts_times[1] # assume uniform spacing!! - if length(fts_times) > 2 - all_timesteps = diff(fts_times) - @assert all(isapprox.(all_timesteps, timestep)) "Non-uniform time steps detected in FieldTimeSeries. This implementation requires uniform time spacing." - end - - last_timestep = fts_times[end] + timestep - - times_inclusive = vcat(fts_times, last_timestep) - source_Δt = diff(times_inclusive) - - truncated_length = floor_multiple(length(fts_times), nsteps) - target_times = fts_times[1:truncated_length][1:nsteps:end] - target_Δt = diff(times_inclusive[1:truncated_length+1][1:nsteps:end]) - - return TimeAverageOperator(nsteps, fts_times, target_times, source_Δt, target_Δt) -end - -""" - (𝒯::TimeAverageOperator)(fts::FieldTimeSeries) - -Apply time averaging to a field time series using the specified operator. - -# Arguments -- `𝒯`: The time averaging operator -- `fts`: The field time series to which the operator is applied - -# Returns -A new field time series with reduced temporal resolution, where each time step is an average of `nsteps` original time steps. - -# Example - -```julia -# Create a time averaging operator for 3 time steps -operator = TimeAverageOperator(fts, 3) - -# Apply the operator to a field time series -averaged_fts = operator(fts) -``` - -""" -function (𝒯::TimeAverageOperator)(fts::FieldTimeSeries) - nsteps = 𝒯.nsteps - nsteps == 1 && return fts - - LX, LY, LZ = location(fts) - grid = fts.grid - boundary_conditions = fts.boundary_conditions - target_fts = FieldTimeSeries{LX, LY, LZ}(grid, 𝒯.target_times; boundary_conditions) - - for i in eachindex(𝒯.target_times) - target_field = target_fts[i] - for j in 1:nsteps - target_field .+= fts[nsteps * (i-1) + j] * 𝒯.source_Δt[nsteps * (i-1) + j] - end - target_field ./= 𝒯.target_Δt[i] - end - return target_fts -end \ No newline at end of file