Skip to content

Commit

Permalink
finished example, finished network
Browse files Browse the repository at this point in the history
  • Loading branch information
hzheng40 committed Oct 20, 2022
1 parent 2e4a69a commit ca297f5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 26 deletions.
Binary file added classification.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
62 changes: 36 additions & 26 deletions flax_rbf/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
# Classification example based on:
# https://github.com/JeremyLinux/PyTorch-Radial-Basis-Function-Layer/blob/master/Torch%20RBF/classification_demo.py

from flax_rbf import RBFNet
from flax_rbf import gaussian
# from flax.metrics import tensorboard
from flax.training import train_state
import optax
import jax.numpy as jnp
import flax.linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from flax.training import train_state

from flax_rbf import RBFNet, gaussian

# rng
key = jax.random.PRNGKey(0)
Expand All @@ -46,7 +46,7 @@
samples = 200
x1 = jax.random.uniform(k1, (samples, 1), minval=-1., maxval=1.)
x2_1 = jax.random.uniform(k2, (samples//2, 1), minval=-1., maxval=0.5*jnp.cos(jnp.pi*x1[:samples//2])+0.5*jnp.cos(4*jnp.pi*(x1[:samples//2]+1)))
x2_2 = jax.random.uniform(k3, (samples//2, 1), minval=0.5*jnp.cos(jnp.pi*x1[:samples//2])+0.5*jnp.cos(4*jnp.pi*(x1[:samples//2]+1)), maxval=1.)
x2_2 = jax.random.uniform(k3, (samples//2, 1), minval=0.5*jnp.cos(jnp.pi*x1[samples//2:])+0.5*jnp.cos(4*jnp.pi*(x1[samples//2:]+1)), maxval=1.)

# training set
tx = jnp.hstack((x1, jnp.vstack((x2_1, x2_2))))
Expand All @@ -72,37 +72,47 @@
rbf_net = RBFNet(in_features=2, out_features=1, num_kernels=40, basis_func=gaussian)
params = rbf_net.init(init_rng, jnp.ones((10, 2)))
optim = optax.adam(0.01)
tstate = train_state.TrainState.create(apply_fn=rbf_net.apply, params=params, tx=optim)


rbfnet = Network(layer_widths, layer_centres, basis_func)
rbfnet.fit(tx, ty, 5000, samples, 0.01, nn.BCEWithLogitsLoss())
rbfnet.eval()


state = train_state.TrainState.create(apply_fn=rbf_net.apply, params=params, tx=optim)

@jax.jit
def train_step(state, x, y):
def loss_fn(params):
logits = rbf_net.apply(params, x)
loss = optax.sigmoid_binary_cross_entropy(logits=logits, labels=y).mean()
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss_, grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state, loss_

# nans after more than 3100 epochs TODO
epochs = 3100
for e in range(epochs):
state, current_loss = train_step(state, tx, ty)
print('Epoch: ', e, 'Train loss: ', current_loss)

# Plotting the ideal and learned decision boundaries
pred_logits = rbf_net.apply(state.params, values)
preds = nn.sigmoid(pred_logits)

with torch.no_grad():
preds = (torch.sigmoid(rbfnet(torch.from_numpy(values).float()))).data.numpy()
ideal_0 = values[np.where(values[:,1] <= 0.5*np.cos(np.pi*values[:,0]) + 0.5*np.cos(4*np.pi*(values[:,0]+1)))[0]]
ideal_1 = values[np.where(values[:,1] > 0.5*np.cos(np.pi*values[:,0]) + 0.5*np.cos(4*np.pi*(values[:,0]+1)))[0]]
area_0 = values[np.where(preds[:, 0] <= 0.5)[0]]
area_1 = values[np.where(preds[:, 0] > 0.5)[0]]
ideal_0 = values[jnp.where(values[:,1] <= 0.5*jnp.cos(jnp.pi*values[:,0]) + 0.5*jnp.cos(4*jnp.pi*(values[:,0]+1)))[0]]
ideal_1 = values[jnp.where(values[:,1] > 0.5*jnp.cos(jnp.pi*values[:,0]) + 0.5*jnp.cos(4*jnp.pi*(values[:,0]+1)))[0]]
area_0 = values[jnp.where(preds[:, 0] <= 0.5)[0]]
area_1 = values[jnp.where(preds[:, 0] > 0.5)[0]]

fig, ax = plt.subplots(figsize=(16,8), nrows=1, ncols=2)
ax[0].scatter(x[:samples//2,0], x[:samples//2,1], c='dodgerblue')
ax[0].scatter(x[samples//2:,0], x[samples//2:,1], c='orange', marker='x')
ax[0].scatter(x1[:samples//2], x2_1, c='dodgerblue')
ax[0].scatter(x1[samples//2:], x2_2, c='orange', marker='x')
ax[0].scatter(ideal_0[:, 0], ideal_0[:, 1], alpha=0.1, c='dodgerblue')
ax[0].scatter(ideal_1[:, 0], ideal_1[:, 1], alpha=0.1, c='orange')
ax[0].set_xlim([-1,1])
ax[0].set_ylim([-1,1])
ax[0].set_title('Ideal Decision Boundary')
ax[1].scatter(x[:samples//2,0], x[:samples//2,1], c='dodgerblue')
ax[1].scatter(x[samples//2:,0], x[samples//2:,1], c='orange', marker='x')
ax[1].scatter(x1[:samples//2], x2_1, c='dodgerblue')
ax[1].scatter(x1[samples//2:], x2_2, c='orange', marker='x')
ax[1].scatter(area_0[:, 0], area_0[:, 1], alpha=0.1, c='dodgerblue')
ax[1].scatter(area_1[:, 0], area_1[:, 1], alpha=0.1, c='orange')
ax[1].set_xlim([-1,1])
ax[1].set_ylim([-1,1])
ax[1].set_title('RBF Decision Boundary')
plt.show()
plt.show()

0 comments on commit ca297f5

Please sign in to comment.