Skip to content

Update to Static.jl v0.8 #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 22 additions & 38 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ on:
push:
branches:
- master
tags: '*'
pull_request:


concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
Expand All @@ -19,59 +19,43 @@ jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
strategy:
fail-fast: false
matrix:
version:
- '1.6'
- '1.7'
- '1.8'
- '1'
- 'nightly'
os:
- ubuntu-latest
arch:
- x64
include:
- version: 1
os: ubuntu-latest
arch: x86
- version: 1
os: macOS-latest
arch: x64
- version: 1
os: windows-latest
arch: x64
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
with:
coverage: ${{ matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.arch == 'x64' }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.arch == 'x64'
- uses: codecov/codecov-action@v3
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.arch == 'x64'
with:
file: lcov.info
# docs:
# name: Documentation
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v2
# - uses: julia-actions/setup-julia@v1
# with:
# version: '1'
# - run: |
# julia --project=docs -e '
# using Pkg
# Pkg.develop(PackageSpec(path=pwd()))
# Pkg.instantiate()'
# - run: |
# julia --project=docs -e '
# using Documenter: DocMeta, doctest
# using MeasureBase
# DocMeta.setdocmeta!(MeasureBase, :DocTestSetup, :(using MeasureBase); recursive=true)
# doctest(MeasureBase)'
# - run: julia --project=docs docs/make.jl
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}

4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MeasureBase"
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
authors = ["Chad Scherrer <[email protected]> and contributors"]
version = "0.14.5"
version = "0.14.6"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down Expand Up @@ -48,6 +48,6 @@ NaNMath = "0.3, 1"
PrettyPrinting = "0.3, 0.4"
Reexport = "1"
SpecialFunctions = "2"
Static = "0.5, 0.6"
Static = "0.8"
Tricks = "0.1"
julia = "1.3"
3 changes: 2 additions & 1 deletion src/MeasureBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ using PrettyPrinting
const Pretty = PrettyPrinting

using ChainRulesCore
using FillArrays
import FillArrays
using Static
using FunctionChains

Expand Down Expand Up @@ -106,6 +106,7 @@ using Compat

using IrrationalConstants

include("static.jl")
include("smf.jl")
include("getdof.jl")
include("transport.jl")
Expand Down
53 changes: 30 additions & 23 deletions src/combinators/power.jl
Original file line number Diff line number Diff line change
@@ -1,58 +1,68 @@
import Base
using FillArrays: Fill
# """
# A power measure is a product of a measure with itself. The number of elements in
# the product determines the dimensionality of the resulting support.

# Note that power measures are only well-defined for integer powers.
export PowerMeasure

# The nth power of a measure μ can be written μ^x.
# """
# PowerMeasure{M,N,D} = ProductMeasure{Fill{M,N,D}}
"""
struct PowerMeasure{M,...} <: AbstractProductMeasure

export PowerMeasure
A power measure is a product of a measure with itself. The number of elements in
the product determines the dimensionality of the resulting support.

Note that power measures are only well-defined for integer powers.

The nth power of a measure μ can be written μ^n.
"""
struct PowerMeasure{M,A} <: AbstractProductMeasure
parent::M
axes::A
end

maybestatic_length(μ::PowerMeasure) = prod(maybestatic_size(μ))
maybestatic_size(μ::PowerMeasure) = map(maybestatic_length, μ.axes)

function Pretty.tile(μ::PowerMeasure)
sz = length.(μ.axes)
arg1 = Pretty.tile(μ.parent)
arg2 = Pretty.tile(length(sz) == 1 ? only(sz) : sz)
return Pretty.pair_layout(arg1, arg2; sep = " ^ ")
end

# ToDo: Make rand return static arrays for statically-sized power measures.

function _cartidxs(axs::Tuple{Vararg{<:AbstractUnitRange,N}}) where {N}
CartesianIndices(map(_dynamic, axs))
end

function Base.rand(
rng::AbstractRNG,
::Type{T},
d::PowerMeasure{M},
) where {T,M<:AbstractMeasure}
map(CartesianIndices(d.axes)) do _
map(_cartidxs(d.axes)) do _
rand(rng, T, d.parent)
end
end

function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T}
map(CartesianIndices(d.axes)) do _
map(_cartidxs(d.axes)) do _
rand(rng, d.parent)
end
end

@inline _pm_axes(sz::Tuple{Vararg{<:IntegerLike,N}}) where {N} = map(one_to, sz)
@inline _pm_axes(axs::Tuple{Vararg{<:AbstractUnitRange,N}}) where {N} = axs

@inline function powermeasure(x::T, sz::Tuple{Vararg{<:Any,N}}) where {T,N}
a = axes(Fill{T,N}(x, sz))
A = typeof(a)
PowerMeasure{T,A}(x, a)
PowerMeasure(x, _pm_axes(sz))
end

marginals(d::PowerMeasure) = Fill(d.parent, d.axes)
marginals(d::PowerMeasure) = fill_with(d.parent, d.axes)

function Base.:^(μ::AbstractMeasure, dims::Tuple{Vararg{<:AbstractArray,N}}) where {N}
powermeasure(μ, dims)
end

Base.:^(μ::AbstractMeasure, dims::Tuple) = powermeasure(μ, Base.OneTo.(dims))
Base.:^(μ::AbstractMeasure, dims::Tuple) = powermeasure(μ, one_to.(dims))
Base.:^(μ::AbstractMeasure, n) = powermeasure(μ, (n,))

# Base.show(io::IO, d::PowerMeasure) = print(io, d.parent, " ^ ", size(d.xs))
Expand All @@ -75,18 +85,15 @@ end
end
end

@inline function logdensity_def(
d::PowerMeasure{M,Tuple{Base.OneTo{StaticInt{N}}}},
x,
) where {M,N}
@inline function logdensity_def(d::PowerMeasure{M,Tuple{Static.SOneTo{N}}}, x) where {M,N}
parent = d.parent
sum(1:N) do j
@inbounds logdensity_def(parent, x[j])
end
end

@inline function logdensity_def(
d::PowerMeasure{M,NTuple{N,Base.OneTo{StaticInt{0}}}},
d::PowerMeasure{M,NTuple{N,Static.SOneTo{0}}},
x,
) where {M,N}
static(0.0)
Expand All @@ -110,7 +117,7 @@ end

@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(map(length, μ.axes))

@inline function getdof(::PowerMeasure{<:Any,NTuple{N,Base.OneTo{StaticInt{0}}}}) where {N}
@inline function getdof(::PowerMeasure{<:Any,NTuple{N,Static.SOneTo{0}}}) where {N}
static(0)
end

Expand All @@ -135,7 +142,7 @@ logdensity_def(::PowerMeasure{P}, x) where {P<:PrimitiveMeasure} = static(0.0)

# To avoid ambiguities
function logdensity_def(
::PowerMeasure{P,Tuple{Vararg{Base.OneTo{Static.StaticInt{0}},N}}},
::PowerMeasure{P,Tuple{Vararg{Static.SOneTo{0},N}}},
x,
) where {P<:PrimitiveMeasure,N}
static(0.0)
Expand Down
2 changes: 1 addition & 1 deletion src/combinators/smart-constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end
###############################################################################
# ProductMeasure

productmeasure(mar::Fill) = powermeasure(mar.value, mar.axes)
productmeasure(mar::FillArrays.Fill) = powermeasure(mar.value, mar.axes)

function productmeasure(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M}
return powermeasure(mar.f.value, axes(mar.data))
Expand Down
2 changes: 1 addition & 1 deletion src/domains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ struct Simplex <: CodimOne end

function zeroset(::Simplex)
f(x::AbstractArray{T}) where {T} = sum(x) - one(T)
∇f(x::AbstractArray{T}) where {T} = Fill(one(T), size(x))
∇f(x::AbstractArray{T}) where {T} = fill_with(one(T), size(x))
ZeroSet(f, ∇f)
end

Expand Down
6 changes: 3 additions & 3 deletions src/standard/stdmeasure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x)
end

function transport_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x)
return Fill(transport_def(ν.parent, μ, only(x)), map(length, ν.axes)...)
return fill_with(transport_def(ν.parent, μ, only(x)), map(length, ν.axes))
end

function transport_def(
Expand All @@ -35,7 +35,7 @@ end
# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}):

_std_measure(::Type{M}, ::StaticInt{1}) where {M<:StdMeasure} = M()
_std_measure(::Type{M}, dof::Integer) where {M<:StdMeasure} = M()^dof
_std_measure(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof
_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ))

function transport_to(::Type{NU}, μ) where {NU<:StdMeasure}
Expand Down Expand Up @@ -90,7 +90,7 @@ end
@inline _offset_cumsum(s, x) = (s,)
@inline _offset_cumsum(s) = ()

function _stdvar_viewranges(μs::Tuple, startidx::Integer)
function _stdvar_viewranges(μs::Tuple, startidx::IntegerLike)
N = map(getdof, μs)
offs = _offset_cumsum(startidx, N...)
map((o, n) -> o:o+n-1, offs, N)
Expand Down
61 changes: 61 additions & 0 deletions src/static.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
MeasureBase.IntegerLike

Equivalent to `Union{Integer,Static.StaticInt}`.
"""
const IntegerLike = Union{Integer,Static.StaticInt}

"""
MeasureBase.one_to(n::IntegerLike)

Creates a range from one to n.

Returns an instance of `Base.OneTo` or `Static.SOneTo`, depending
on the type of `n`.
"""
@inline one_to(n::Integer) = Base.OneTo(n)
@inline one_to(::Static.StaticInt{N}) where {N} = Static.SOneTo{N}()

_dynamic(x::Number) = dynamic(x)
_dynamic(::Static.SOneTo{N}) where {N} = Base.OneTo(N)
_dynamic(r::AbstractUnitRange) = minimum(r):maximum(r)

"""
MeasureBase.fill_with(x, sz::NTuple{N,<:IntegerLike}) where N

Creates an array of size `sz` filled with `x`.

Returns an instance of `FillArrays.Fill`.
"""
function fill_with end

@inline function fill_with(x::T, sz::Tuple{Vararg{<:IntegerLike,N}}) where {T,N}
fill_with(x, map(one_to, sz))
end

@inline function fill_with(x::T, axs::Tuple{Vararg{<:AbstractUnitRange,N}}) where {T,N}
# While `FillArrays.Fill` (mostly?) works with axes that are static unit
# ranges, some operations that automatic differentiation requires do fail
# on such instances of `Fill` (e.g. `reshape` from dynamic to static size).
# So need to use standard ranges for the axes for now:
dyn_axs = map(_dynamic, axs)
FillArrays.Fill(x, dyn_axs)
end

"""
MeasureBase.maybestatic_length(x)::IntegerLike

Returns the length of `x` as a dynamic or static integer.
"""
maybestatic_length(x) = length(x)
maybestatic_length(x::AbstractUnitRange) = length(x)
function maybestatic_length(::Static.OptionallyStaticUnitRange{StaticInt{A},StaticInt{B}}) where {A,B}
StaticInt{B - A + 1}()
end

"""
MeasureBase.maybestatic_size(x)::Tuple{Vararg{IntegerLike}}

Returns the size of `x` as a tuple of dynamic or static integers.
"""
maybestatic_size(x) = size(x)
6 changes: 5 additions & 1 deletion test/transport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ using LogExpFunctions: logit
using ChainRulesTestUtils

@testset "transport_to" begin
test_rrule(MeasureBase._origin_depth, pushfwd(exp, StdUniform()))
test_rrule(
MeasureBase._origin_depth,
pushfwd(exp, StdUniform()),
output_tangent = static(0),
)

for (f, μ) in [
(logit, StdUniform())
Expand Down