Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
806be1e
Added a run_coupled! function to implement checkpointers in coupled s…
taimoorsohail Mar 11, 2025
46b7bb2
Added a run_coupled! function to implement checkpointers in coupled s…
taimoorsohail Mar 11, 2025
75b8182
Simplified the near_global_ocean example to see why the checkpointing…
taimoorsohail Mar 12, 2025
55106ba
Merge remote-tracking branch 'origin/main' into ts/checkpointer-for-c…
taimoorsohail Mar 12, 2025
61f91f5
merge main;
taimoorsohail Mar 13, 2025
fb5d35d
Testing the set! function
navidcy Mar 13, 2025
e39c87a
merge main
navidcy Mar 13, 2025
d15439f
managed to pick up!
navidcy Mar 13, 2025
a0dca49
Changed checkointer_mwe.jl
taimoorsohail Mar 13, 2025
8ff3748
Cleaning up checkpointer_mwe.jl file
taimoorsohail Mar 13, 2025
3fdf785
extends methods to work with OSIM and OSIMSIM
navidcy Mar 13, 2025
2328bb5
mwe
navidcy Mar 13, 2025
9ab0df1
simplify
navidcy Mar 13, 2025
5bd34f8
tidying up
navidcy Mar 13, 2025
9e2af01
bit cleaner
navidcy Mar 13, 2025
4f5ff2e
set!(sim::OSIMSIM{PrescribedAtmosphere})
navidcy Mar 13, 2025
ba2521a
cleaner mwe
navidcy Mar 13, 2025
b75d6e0
Changed the function to set!
taimoorsohail Mar 13, 2025
f1cae4f
Merge NCC changes
taimoorsohail Mar 13, 2025
dfb57fa
Merge NCC changes
taimoorsohail Mar 14, 2025
76dd720
reverting near_global_ocean.jl example
taimoorsohail Mar 14, 2025
3177cf2
Update near_global_ocean_simulation.jl
taimoorsohail Mar 14, 2025
868c870
Update Project.toml
taimoorsohail Mar 14, 2025
5f02123
Added checkpointing test; integrated checkpointer into one_degree exa…
taimoorsohail Mar 14, 2025
f9087aa
set clock method for each simulation type
navidcy Mar 14, 2025
4b0ebeb
Apply suggestions from code review
navidcy Mar 14, 2025
3f0b9d5
Apply suggestions from code review
navidcy Mar 14, 2025
f4aa21e
don't pickup by default; add explanation
navidcy Mar 14, 2025
0e0a168
merge
navidcy Mar 14, 2025
f56ca8e
move set_clock! for PrescribedAtmosphere to where it belongs
navidcy Mar 14, 2025
f8c1444
move set_clock! for PrescribedAtmosphere to where it belongs
navidcy Mar 14, 2025
cfba574
only pickup in the second time
navidcy Mar 14, 2025
9f811df
Merge branch 'main' into ts/checkpointer-for-coupled-model
navidcy Mar 14, 2025
4110aca
extend set!(model::OSIM,...) instead of set!(sim:OSIMSIM,...)
navidcy Mar 16, 2025
e56515f
Merge branch 'main' into ts/checkpointer-for-coupled-model
navidcy Mar 16, 2025
a9be97c
use default radiation
navidcy Mar 16, 2025
afd7a66
add set!(::PrescribedAtmospher, checkopoint_file_path)
navidcy Mar 16, 2025
1fddbf2
use set!(::PrescribedAtmosphere,...)
navidcy Mar 16, 2025
8699207
check all clocks are aligned
navidcy Mar 16, 2025
9f9ca89
drop unused aliases
navidcy Mar 16, 2025
ba80b56
Merge branch 'ts/checkpointer-for-coupled-model' of github.com:CliMA/…
taimoorsohail Mar 19, 2025
90abea4
Removed bottom line
taimoorsohail Mar 19, 2025
9efd6b8
try to generalize
navidcy Mar 19, 2025
006574f
Merge branch 'ts/checkpointer-for-coupled-model' of github.com:CliMA/…
navidcy Mar 19, 2025
e3029cb
Merge branch 'main' into ts/checkpointer-for-coupled-model
navidcy Mar 19, 2025
1e8d638
don't assume ocean component is special
navidcy Mar 20, 2025
5a69a52
bump Oceanigans compat
navidcy Mar 20, 2025
a74499f
remove commented code
navidcy Mar 20, 2025
0a46067
properties in write_output! is kwarg
navidcy Mar 20, 2025
98ec18b
undo changes
navidcy Mar 20, 2025
f9e9e60
Delete examples/generate_atmos_dataset.jl
navidcy Mar 20, 2025
39c347a
Delete src/CoupledSimulation.jl
navidcy Mar 20, 2025
5281eab
Delete test/test_ocean_sea_ice_model_parameter_space.jl
navidcy Mar 20, 2025
368358e
Delete test/test_simulations.jl
navidcy Mar 20, 2025
021c306
Delete src/DataWrangling/JRA55.jl
navidcy Mar 20, 2025
48dd1b1
Delete src/DistributedUtils.jl
navidcy Mar 20, 2025
fdf45d7
don't import things we don't need
navidcy Mar 20, 2025
875d045
cleanup
navidcy Mar 20, 2025
1e34bb5
validate_properties -> validate_checkpointed_properties
navidcy Mar 20, 2025
1f3a46b
no need to duplicate validation
navidcy Mar 20, 2025
4f9425a
fix initialize! and update_state! + add set_clock!
navidcy Mar 25, 2025
bf96083
reorganize imports
navidcy Mar 25, 2025
cd7f0c5
wip
navidcy Apr 24, 2025
47d1121
merge main and resolve conflicts
navidcy Apr 24, 2025
8bd0d36
Update ClimaOcean.jl
navidcy Apr 25, 2025
ed8c6df
Merge branch 'main' into ts/checkpointer-for-coupled-model
navidcy Apr 28, 2025
72efbc4
Merge branch 'main' into ts/checkpointer-for-coupled-model
taimoorsohail Apr 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions examples/checkpointer_mwe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using ClimaOcean
using Oceananigans
using Oceananigans.Units
using CFTime
using Dates
using Printf

arch = CPU()

Nx = 144
Ny = 60
Nz = 40

grid = LatitudeLongitudeGrid(arch;
size = (Nx, Ny, Nz),
halo = (7, 7, 7),
z = (-6000, 0),
latitude = (-75, 75),
longitude = (0, 360))

ocean = ocean_simulation(grid)

# date = DateTimeProlepticGregorian(1993, 1, 1)
# set!(ocean.model, T=ECCOMetadata(:temperature; dates=date),
# S=ECCOMetadata(:salinity; dates=date))

radiation = Radiation(arch)

atmosphere = JRA55PrescribedAtmosphere(arch; backend=JRA55NetCDFBackend(41))

coupled_model = OceanSeaIceModel(ocean; atmosphere, radiation)

simulation = Simulation(coupled_model; Δt=10, stop_iteration=8)

wall_time = Ref(time_ns())

function progress(sim)
ocean = sim.model.ocean
atmosphere = sim.model.atmosphere

u, v, w = ocean.model.velocities
T = ocean.model.tracers.T

Tmax = maximum(interior(T))
Tmin = minimum(interior(T))

umax = (maximum(abs, interior(u)),
maximum(abs, interior(v)),
maximum(abs, interior(w)))

step_time = 1e-9 * (time_ns() - wall_time[])

msg = @sprintf("Iter: %d, sim time: %s, atmos time: %s, ocean time: %s", iteration(sim), sim.model.clock.time, atmosphere.clock.time, ocean.model.clock.time)
msg *= @sprintf(", max|u|: (%.2e, %.2e, %.2e) m s⁻¹, extrema(T): (%.2f, %.2f) ᵒC, wall time: %s",
umax..., Tmax, Tmin, prettytime(step_time))

@info msg

wall_time[] = time_ns()
end

simulation.callbacks[:progress] = Callback(progress, IterationInterval(1))

outputs = merge(ocean.model.tracers, ocean.model.velocities)

simulation.output_writers[:surface] = JLD2OutputWriter(ocean.model, outputs;
schedule = IterationInterval(2),
filename = "surface",
indices = (:, :, grid.Nz),
with_halos = true,
overwrite_existing = true,
array_type = Array{Float32})

output_dir = "."
prefix = "checkpointer_mwe"

simulation.output_writers[:checkpoint] = Checkpointer(ocean.model;
schedule = IterationInterval(3),
prefix = prefix,
dir = output_dir,
verbose = true,
overwrite_existing = true)

run!(simulation)

simulation.stop_iteration += 5

run!(simulation, pickup=true)
64 changes: 38 additions & 26 deletions examples/near_global_ocean_simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ using Printf
# The total depth of the domain is set to 6000 meters.
# Finally, we specify the architecture for the simulation, which in this case is a GPU.

arch = GPU()
arch = CPU()

Nx = 1440
Ny = 600
Nx = 144
Ny = 60
Nz = 40

depth = 6000meters
Expand All @@ -53,21 +53,21 @@ grid = LatitudeLongitudeGrid(arch;
# (i) all the minor enclosed basins except the 3 largest `major_basins`, as well as
# (ii) regions that are shallower than `minimum_depth`.

bottom_height = regrid_bathymetry(grid;
minimum_depth = 10meters,
interpolation_passes = 5,
major_basins = 3)
# bottom_height = regrid_bathymetry(grid;
# minimum_depth = 10meters,
# interpolation_passes = 5,
# major_basins = 3)

grid = ImmersedBoundaryGrid(grid, GridFittedBottom(bottom_height); active_cells_map=true)
# grid = ImmersedBoundaryGrid(grid, GridFittedBottom(bottom_height); active_cells_map=true)

# Let's see what the bathymetry looks like:

h = grid.immersed_boundary.bottom_height
# h = grid.immersed_boundary.bottom_height

fig, ax, hm = heatmap(h, colormap=:deep, colorrange=(-depth, 0))
Colorbar(fig[0, 1], hm, label="Bottom height (m)", vertical=false)
save("bathymetry.png", fig)
nothing #hide
# fig, ax, hm = heatmap(h, colormap=:deep, colorrange=(-depth, 0))
# Colorbar(fig[0, 1], hm, label="Bottom height (m)", vertical=false)
# save("bathymetry.png", fig)
# nothing #hide

# ![](bathymetry.png)

Expand All @@ -83,9 +83,9 @@ ocean.model

# We initialize the ocean model with ECCO2 temperature and salinity for January 1, 1993.

date = DateTimeProlepticGregorian(1993, 1, 1)
set!(ocean.model, T=Metadata(:temperature; dates=date, dataset=ECCO4Monthly()),
S=Metadata(:salinity; dates=date, dataset=ECCO4Monthly()))
# date = DateTimeProlepticGregorian(1993, 1, 1)
# set!(ocean.model, T=ECCOMetadata(:temperature; dates=date),
# S=ECCOMetadata(:salinity; dates=date))

# ### Prescribed atmosphere and radiation
#
Expand Down Expand Up @@ -117,7 +117,7 @@ coupled_model = OceanSeaIceModel(ocean; atmosphere, radiation)
# We then create a coupled simulation. We start with a small-ish time step of 90 seconds.
# We run the simulation for 10 days with this small-ish time step.

simulation = Simulation(coupled_model; Δt=90, stop_time=10days)
simulation = Simulation(coupled_model; Δt=90, stop_iteration=10)

# We define a callback function to monitor the simulation's progress,

Expand Down Expand Up @@ -146,7 +146,7 @@ function progress(sim)
wall_time[] = time_ns()
end

simulation.callbacks[:progress] = Callback(progress, TimeInterval(5days))
simulation.callbacks[:progress] = Callback(progress, IterationInterval(1))

# ### Set up output writers
#
Expand All @@ -156,13 +156,23 @@ simulation.callbacks[:progress] = Callback(progress, TimeInterval(5days))
# Below, we use `indices` to save only the values of the variables at the surface, which corresponds to `k = grid.Nz`

outputs = merge(ocean.model.tracers, ocean.model.velocities)
ocean.output_writers[:surface] = JLD2OutputWriter(ocean.model, outputs;
schedule = TimeInterval(1days),
filename = "near_global_surface_fields",
indices = (:, :, grid.Nz),
with_halos = true,
overwrite_existing = true,
array_type = Array{Float32})
# ocean.output_writers[:surface] = JLD2OutputWriter(ocean.model, outputs;
# schedule = IterationInterval(1),
# filename = "near_global_surface_fields",
# indices = (:, :, grid.Nz),
# with_halos = true,
# overwrite_existing = true,
# array_type = Array{Float32})

output_dir = "."
prefix = "near_global"

ocean.output_writers[:checkpoint] = Checkpointer(ocean.model;
schedule = TimeInterval(3minutes),
prefix = prefix,
dir = output_dir)#,
# verbose = true,
# overwrite_existing = true)

# ### Spinning up the simulation
#
Expand All @@ -175,11 +185,12 @@ run!(simulation)
# ### Running the simulation for real

# After the initial spin up of 10 days, we can increase the time-step and run for longer.

#=
simulation.stop_time = 60days
simulation.Δt = 10minutes
run!(simulation)


# ## A pretty movie
#
# It's time to make a pretty movie of the simulation. First we load the output we've been saving on
Expand Down Expand Up @@ -259,3 +270,4 @@ end
nothing #hide

# ![](near_global_ocean_surface.mp4)
=#
4 changes: 2 additions & 2 deletions src/ClimaOcean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export
Metadata,
all_dates,
JRA55FieldTimeSeries,
ECCO_field,
ECCO_field,
ECCORestoring,
LinearlyTaperedPolarMask,
ocean_simulation,
Expand All @@ -51,7 +51,7 @@ const SKOFTS = SomeKindOfFieldTimeSeries
@inline stateindex(a::SKOFTS, i, j, k, grid, time, args...) = @inbounds a[i, j, k, time]

@inline function stateindex(a::Function, i, j, k, grid, time, loc)
LX, LY, LZ = loc
LX, LY, LZ = loc
λ, φ, z = node(i, j, k, grid, LX(), LY(), LZ())
return a(λ, φ, z, time)
end
Expand Down
39 changes: 30 additions & 9 deletions src/OceanSeaIceModels/ocean_sea_ice_model.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using Oceananigans
using Oceananigans.OutputWriters: checkpoint_path
using Oceananigans.TimeSteppers: Clock
using Oceananigans: SeawaterBuoyancy
using ClimaSeaIce.SeaIceThermodynamics: melting_temperature
using KernelAbstractions: @kernel, @index

using SeawaterPolynomials: TEOS10EquationOfState

import Thermodynamics as AtmosphericThermodynamics
Expand All @@ -13,11 +13,11 @@ import Oceananigans: fields, prognostic_fields
import Oceananigans.Architectures: architecture
import Oceananigans.Fields: set!
import Oceananigans.Models: timestepper, NaNChecker, default_nan_checker
import Oceananigans.OutputWriters: default_included_properties
import Oceananigans.Simulations: reset!, initialize!, iteration
import Oceananigans.OutputWriters: default_included_properties, checkpointer_address,
write_output!, initialize_jld2_file!
import Oceananigans.Simulations: reset!, initialize!, iteration, run!
import Oceananigans.TimeSteppers: time_step!, update_state!, time
import Oceananigans.Utils: prettytime
import Oceananigans.Models: timestepper, NaNChecker, default_nan_checker

struct OceanSeaIceModel{I, A, O, F, C, Arch} <: AbstractModel{Nothing, Arch}
architecture :: Arch
Expand All @@ -29,6 +29,7 @@ struct OceanSeaIceModel{I, A, O, F, C, Arch} <: AbstractModel{Nothing, Arch}
end

const OSIM = OceanSeaIceModel
const OSIMSIM = Simulation{<:OceanSeaIceModel}

function Base.summary(model::OSIM)
A = nameof(typeof(architecture(model)))
Expand Down Expand Up @@ -59,9 +60,11 @@ prettytime(model::OSIM) = prettytime(model.clock.time)
iteration(model::OSIM) = model.clock.iteration
timestepper(::OSIM) = nothing
default_included_properties(::OSIM) = tuple()
prognostic_fields(cm::OSIM) = nothing
prognostic_fields(::OSIM) = nothing
fields(::OSIM) = NamedTuple()
default_clock(TT) = Oceananigans.TimeSteppers.Clock{TT}(0, 0, 1)
time(model::OSIM) = model.clock.time
checkpointer_address(::OSIM) = "HydrostaticFreeSurfaceModel"

function reset!(model::OSIM)
reset!(model.ocean)
Expand All @@ -73,6 +76,26 @@ function initialize!(model::OSIM)
return nothing
end

initialize_jld2_file!(filepath, init, jld2_kw, including, outputs, model::OSIM) =
initialize_jld2_file!(filepath, init, jld2_kw, including, outputs, model.ocean.model)

write_output!(c::Checkpointer, model::OSIM) = write_output!(c, model.ocean.model)

function set!(sim::OSIMSIM, pickup::Union{Bool, Integer, String})
checkpoint_file_path = checkpoint_path(pickup, sim.output_writers)

set!(sim.model.ocean.model, checkpoint_file_path)

sim.model.clock.iteration = sim.model.ocean.model.clock.iteration
sim.model.clock.time = sim.model.ocean.model.clock.time

# Setting the atmosphere time to the ocean time
sim.model.atmosphere.clock.iteration = sim.model.ocean.model.clock.iteration
sim.model.atmosphere.clock.time = sim.model.ocean.model.clock.time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than extending set! For every atmos model, it might be better to write a function for setting the clock and then extend that appropriately

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment for ocean and sea ice

Copy link
Member

@glwagner glwagner Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in:

set_clock!(atmos, time, iter)

Copy link
Member

@navidcy navidcy Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point regarding defining a set_clock! method.

Just slightly confused: this doesn't alleviate the need to extend set!, right? Given that

https://github.com/CliMA/Oceananigans.jl/blob/c310123be943467a6d7813052c36ef9cad1589e5/src/Simulations/run.jl#L96-L98

uses set! then the only way is to extend set! so that it "sets" for every component of the coupled simulation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so @navidcy - currently the set! function in Oceananigans does set!(sim.model, checkpoint_file_path) when we want it to do set!(sim.model.ocean.model, checkpoint_file_path) because of the coupled nature of the model.

Copy link
Collaborator Author

@taimoorsohail taimoorsohail Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you are saying that we can do this without extending set!?

No, I think we need to extend set! and create the function set_clock! as @glwagner suggests. See my latest push which has added that function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One problem with set!(::OSIMSIM, filepath) is that, when we eventually support checkpointing diagnostics / callbacks / output writer states, that support will have to be copy/pasted here. So it'd be better to use set!(::OceanSeaIceModel, filepath).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. But given then the output writers/checkpointers are included in the outer-outer simulations, then if we define set!(::OSIM, filepath) that won't have access to the checkpointer. Or? Am I missing something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, I like the suggestion!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set!(::OceanSeaIceModel, filepath)

done via 4110aca


return nothing
end

reference_density(unsupported) =
throw(ArgumentError("Cannot extract reference density from $(typeof(unsupported))"))

Expand Down Expand Up @@ -114,7 +137,7 @@ function OceanSeaIceModel(ocean, sea_ice=FreezingLimitedOceanTemperature(eltype(
pop!(ocean.callbacks, :wall_time_limit_exceeded, nothing)
pop!(ocean.callbacks, :nan_checker, nothing)
end

if sea_ice isa SeaIceSimulation
if !isnothing(sea_ice.callbacks)
pop!(sea_ice.callbacks, :stop_time_exceeded, nothing)
Expand Down Expand Up @@ -151,10 +174,8 @@ function OceanSeaIceModel(ocean, sea_ice=FreezingLimitedOceanTemperature(eltype(
return ocean_sea_ice_model
end

time(coupled_model::OceanSeaIceModel) = coupled_model.clock.time

# Check for NaNs in the first prognostic field (generalizes to prescribed velocities).
function default_nan_checker(model::OceanSeaIceModel)
function default_nan_checker(model::OSIM)
u_ocean = model.ocean.model.velocities.u
nan_checker = NaNChecker((; u_ocean))
return nan_checker
Expand Down