diff --git a/src/Devito.jl b/src/Devito.jl index d1879fc..e0b008e 100644 --- a/src/Devito.jl +++ b/src/Devito.jl @@ -1873,6 +1873,17 @@ function SubDomain(name::String, instructions...) return SubDomain{N}(subdom(name,instructions)) end +struct Buffer + o::PyObject +end + +""" + Buffer(value::Int) +Construct a devito buffer. This may be used as a save= keyword argument in the construction of TimeFunctions. +""" +Buffer(value::Int) = Buffer(pycall(devito.Buffer, PyObject, value)) +PyCall.PyObject(x::Buffer) = x.o + """ nsimplify(expr::PyObject; constants=(), tolerance=none, full=false, rational=none, rational_conversion="base10") @@ -1938,6 +1949,6 @@ Base.isequal(x::Union{SubDomain, DiscreteFunction, Constant, AbstractDimension, Base.hash(x::Union{SubDomain, DiscreteFunction, Constant, AbstractDimension, Operator, Grid, Eq, Injection}) = hash(PyObject(x)) -export Constant, DiscreteFunction, Grid, Function, SparseFunction, SparseTimeFunction, SubDomain, TimeFunction, apply, backward, ccode, configuration, configuration!, coordinates, coordinates_data, data, data_allocated, data_with_halo, data_with_inhalo, dimension, dimensions, dx, dy, dz, evaluate, extent, forward, grid, halo, inject, interpolate, localindices, localindices_with_halo, localindices_with_inhalo, localsize, name, nsimplify, origin, size_with_halo, simplify, solve, spacing, spacing_map, step, subdomains, subs, thickness, value, value! +export Buffer, Constant, DiscreteFunction, Grid, Function, SparseFunction, SparseTimeFunction, SubDomain, TimeFunction, apply, backward, ccode, configuration, configuration!, coordinates, coordinates_data, data, data_allocated, data_with_halo, data_with_inhalo, dimension, dimensions, dx, dy, dz, evaluate, extent, forward, grid, halo, inject, interpolate, localindices, localindices_with_halo, localindices_with_inhalo, localsize, name, nsimplify, origin, size_with_halo, simplify, solve, spacing, spacing_map, step, subdomains, subs, thickness, value, value! end diff --git a/test/serialtests.jl b/test/serialtests.jl index 3931df5..bac71ef 100644 --- a/test/serialtests.jl +++ b/test/serialtests.jl @@ -1026,3 +1026,12 @@ end @test data_with_inhalo(sf) ≈ ones(Float32, npoint) @test data_with_inhalo(stf) ≈ ones(Float32, npoint, nt) end + +@testset "Buffer construction and use, buffer size = $value" for value in (1,2,4) + b = Buffer(value) + @test typeof(b) == Buffer + shp = (5,6) + grd = Grid(shape=shp) + u = TimeFunction(name="u", grid=grd, save=b) + @test size(u) == (shp...,value) +end