Skip to content

Commit

Permalink
Merge pull request #20 from gsd-authors/soft
Browse files Browse the repository at this point in the history
Soft
krzysztofrusek authored Jan 22, 2024
2 parents 7ec8ae9 + 3111bc0 commit f64b539
Showing 8 changed files with 393 additions and 12 deletions.
69 changes: 69 additions & 0 deletions discussion/softvmin.wl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
(* ExecuteFile["softvmin.wl"] *)

Clear["Global`*"]

fa[x_] := (2-x)(x-1)
fb[x_] := (3-x)(x-2)

Plot[{fa[x],fb[x]},{x,1,3}] //Export["fafb.pdf", # ]&


pows = {0,2,4}
vars=Subscript[b,#]&/@pows

appf[x_] := Total[Subscript[b,#] (x-2)^# &/@pows]

eqs={
appf[2-d]==fa[2-d],
appf[2+d]==fb[2+d],
D[fa[x],x]==D[appf[x],x]/.{x->2-d},
D[fb[x],x]==D[appf[x],x]/.{x->2+d},
D[fa[x],{x,2}]==D[appf[x],{x,2}]/.{x->2-d},
D[fb[x],{x,2}]==D[appf[x],{x,2}]/.{x->2+d},
D[appf[x],x]==0/.{x->2}
}

sol=Solve[
eqs,
vars
]

(* sol = vars/.Solve[
eqs/.{d->0.1},
vars
] *)

sol=sol[[1]]


Plot[{(appf[x]/.sol)/.{d->1/50}, fa[x],fb[x]},{x,1.8,2.2}, PlotRange->{0,1/4}]//Export["appf.pdf", # ]&

Export["sol.txt",(appf[x]/.sol)]

(* Needs["CCodeGenerator`"]
CCodeGenerator[]
c = Compile[ {{x},{d}}, appf[x]/.sol];
file = CCodeStringGenerate[c, "fun"] *)

(* Test cases *)

(appf[x]/.sol)/.{d->1/50, x->1.99}

(appf[x]/.sol)/.{d->1/10, x->2.05}

p = (n+1)/(n+2)
ep = p x + (1-p)(x+1)

v = Simplify[p (x-ep)^2 + (1-p)(x+1-ep)^2]

ExportString[v,"tex"]

sol = Solve[((appf[x]/.sol)/.{x->2})==v,d]

ExportString[sol,"tex"]

N[(d/.sol[[1]])/.{n->24}]

202 changes: 202 additions & 0 deletions examples/softvmin.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ classifiers = [

# HATCH_PYTHON=python3.10
requires-python = ">=3.10"
dependencies=["jax>=0.4.6"]
dependencies=["jax>=0.4.23"]

[project.urls]
Homepage = "https://github.com/gsd-authors/gsd"
@@ -46,7 +46,7 @@ include = [
]

[tool.hatch.envs.default]
dependencies=["jaxlib>=0.4.6"]
dependencies=["jaxlib>=0.4.23"]

[project.optional-dependencies]
experimental = [
2 changes: 1 addition & 1 deletion src/gsd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.2.2dev'
__version__ = '0.2.2'
from gsd.fit import GSDParams as GSDParams
from gsd.fit import fit_moments as fit_moments
from gsd.gsd import (log_prob as log_prob,
10 changes: 5 additions & 5 deletions src/gsd/experimental/max_entropy.py
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@ def _explicit_log_probs(dist: 'MaxEntropyGSD'):

lgr = jax.tree_util.tree_map(jnp.asarray, (-0.01, -0.01, -0.01))
sol = optx.root_find(_implicit_log_probs, solver, lgr, args=dist,
max_steps=int(1e4), throw=False)
max_steps=int(1e4), throw=True)
return _lagrange_log_probs(sol.value, dist)


@@ -66,7 +66,6 @@ class MaxEntropyGSD(eqx.Module):
sigma: Float[Array, ""] # std
N: int = eqx.field(static=True)


def log_prob(self, x: Int[Array, ""]):
lp = _explicit_log_probs(self)
return lp[x - 1]
@@ -106,7 +105,7 @@ def sample(self, key: PRNGKeyArray, axis=-1, shape=None):
return jax.random.categorical(key, lp, axis, shape) + self.support[0]

@staticmethod
def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
def from_gsd(theta: GSDParams, N: int) -> 'MaxEntropyGSD':
"""Created maxentropy from GSD parameters.
:param theta: Parameters of a GSD distribution.
@@ -119,6 +118,7 @@ def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
N=N
)


MaxEntropyGSD.__init__.__doc__ = """Creates a MaxEntropyGSD
:param mean: Expectation value of the distribution.
@@ -127,6 +127,6 @@ def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
.. note::
An alternative way to construct this distribution is by use of
:ref:`from_gsd`
:meth:`from_gsd`
"""
"""
32 changes: 31 additions & 1 deletion src/gsd/gsd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from typing import Sequence, Callable

import jax
import jax.numpy as jnp
@@ -141,3 +141,33 @@ def sufficient_statistic(data: ArrayLike) -> Array:
bins = jnp.arange(0.5, N + 1.5, 1.)
c, _ = jnp.histogram(jnp.asarray(data), bins=bins)
return c


def softvmin_poly(x: Array, c: float, d: float) -> Array:
"""Smooths approximation to `vmin` function.
:param x: An argument, this would be psi
:param d: Cut point of approximation from `[0,0.5)`
:return: An approximated value `x` such that `abs(round(x)-x)<=d`
"""
sq1 = jnp.square(x - c)
sq2 = jnp.square(sq1)

return (3 * d) / 8 - ((-3 + 4 * d) * sq1) / (4 * d) - sq2 / (8 * d ** 3)


def make_softvmin(d: float) -> Callable[[Array], Array]:
"""Create a soft approximation to `vmin` function.
:param d: Cut point of approximation from `[0,0.5)`
:return: A callable returning n approximated value `vmin` for `x`
`abs(round(x)-x)<=d`
"""
def sofvmin(psi: ArrayLike):
psi = jnp.asarray(psi)
c = jax.lax.stop_gradient(jnp.round(psi))
return jnp.where(jnp.abs(psi - c) < d, softvmin_poly(psi, c, d),
vmin(psi)
)

return sofvmin
59 changes: 56 additions & 3 deletions tests/experimental_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from jax import config

from gsd.gsd import make_softvmin, vmax, vmin

config.update("jax_enable_x64", True)
from gsd.experimental.max_entropy import MaxEntropyGSD
import unittest # noqa: E402

import jax
import jax.numpy as jnp
import numpy as np

import gsd
@@ -14,6 +13,14 @@
from gsd.experimental.fit import GridEstimator
from gsd.fit import log_pmax, pairs, pmax, GSDParams, fit_moments

import equinox as eqx
import optimistix as optx

from gsd.experimental.max_entropy import MaxEntropyGSD, vmax

import jax
import jax.numpy as jnp


class FitTestCase(unittest.TestCase):
def test_pairs(self):
@@ -117,3 +124,49 @@ def test_probs(self):
lp = me.all_log_probs
p = np.exp(lp)
self.assertAlmostEqual(p.sum(), 1)


def test_fit(self):
def nll(d, x):
m, s = d
mean = 1.0 + 4.0 * jax.nn.sigmoid(m)
svmin = make_softvmin(0.1)
smin = jnp.sqrt(svmin(mean))
smax = jnp.sqrt(vmax(mean, N=5))
sigma = smin + (smax - smin) * jax.nn.sigmoid(s)
d = MaxEntropyGSD(mean, sigma, N=5)
return -jnp.mean(d.log_prob(x))

# x = jnp.asarray([2, 3, 2, 2, 3, 3, 4])
x = jnp.asarray([2, 2, 2, 2, 2, 2, 2])

eqx.tree_pprint(jax.grad(nll)((0.01, 2.0), x), short_arrays=False)

def fit(x):
solver = optx.BFGS(rtol=1e-2, atol=1e-4)

res = optx.minimise(nll, solver, (-0.0, .0),
args=x,
max_steps=int(1e6),
throw=True)
return res

res = jax.jit(fit)(x)
eqx.tree_pprint(res.value, short_arrays=False)

m, s = res.value
mean = 1.0 + 4.0 * jax.nn.sigmoid(m)
smin = jnp.sqrt(vmin(mean))
smax = jnp.sqrt(vmax(mean, N=5))
sigma = smin + (smax - smin) * jax.nn.sigmoid(s)
d = MaxEntropyGSD(mean, sigma, N=5)

self.assertAlmostEqual(d.mean,2., places=4)

eqx.tree_pprint(d, short_arrays=False)
eqx.tree_pprint(MaxEntropyGSD(jnp.mean(x), jnp.std(x), N=5),
short_arrays=False)




27 changes: 27 additions & 0 deletions tests/ref_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
from jax import config

from gsd.gsd import softvmin_poly, make_softvmin, vmin

config.update("jax_enable_x64", True)

import unittest
@@ -104,5 +107,29 @@ def test_sufficient_statistic4(self):
# 1, 2 3 4 5
self.assertTrue(np.allclose(ss,c))


class SoftTestCase(unittest.TestCase):
def test_poly(self):
v = softvmin_poly(x=1.99,c=2., d=1/50.)
self.assertAlmostEqual(v, 0.0109938)
v = softvmin_poly(x=2.05,c=2, d=1 / 10.)
self.assertAlmostEqual(v, 0.0529687)

def test_softvmin(self):
svmin = make_softvmin(0.1)
self.assertAlmostEqual(svmin(3.3), vmin(3.3))

for x in [1.5,1.9, 1.95, 2.05, 2.1, 2.2]:
gsvmin = jax.grad(svmin)
g = gsvmin(x)
print(g)
self.assertIsNotNone(g)

ggsvmin = jax.grad(gsvmin)
gg = ggsvmin(x)
print(gg)
self.assertIsNotNone(gg)


if __name__ == '__main__':
unittest.main()

0 comments on commit f64b539

Please sign in to comment.