Skip to content

Commit

Permalink
Data preparation (#35)
Browse files Browse the repository at this point in the history
* style: pep 8 module name convention
* feat: 2d convolution transform
* feat: crop mask
* feat: custom dataset for location
* fix: custom dataset
* feat: unsqueeze transform
* feat: DoG and Gabor filters
* doc: string doc for helpers
* chore: updating authors
* feat: silent input at the end of each instance
  • Loading branch information
saeedark authored Jun 17, 2023
1 parent 3e7cce3 commit c4f499a
Show file tree
Hide file tree
Showing 22 changed files with 534 additions and 134 deletions.
2 changes: 2 additions & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ Contributors
------------

* Amir Hosein Ebrahimi <[email protected]>

* Saman Arzaghi <[email protected]>
5 changes: 3 additions & 2 deletions Example/test/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from conex import *
from conex.nn.priorities import NEURON_PRIORITIES

from conex.helpers.encoders import Poisson
from conex.helpers.transforms import *
from conex.helpers.transforms.encoders import Poisson
from conex.helpers.transforms.misc import *

from torchvision import transforms
from torchvision.datasets import MNIST
Expand Down Expand Up @@ -67,6 +67,7 @@
depth=1, height=SENSORY_SIZE_HEIGHT, width=SENSORY_SIZE_WIDTH
),
sensory_trace=SENSORY_TRACE_TAU_S,
instance_duration=POISSON_TIME,
)

#############################
Expand Down
4 changes: 2 additions & 2 deletions conex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@

from conex.behaviors.layer.dataset import *

from conex.nn.Structure import *
from conex.nn.Modules import *
from conex.nn.structure import *
from conex.nn.modules import *
from conex.nn.config import *
38 changes: 36 additions & 2 deletions conex/behaviors/layer/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

"""
Behaviors to load datasets
"""
Expand All @@ -17,6 +19,8 @@ class SpikeNdDataset(Behavior):
have_location (bool): Whether dataloader returns location input.
have_sensory (bool): Whether dataloader returns sensory input.
have_label (bool): Whether dataloader returns label of input.
silent_interval (int): The interval of silent activity between two different input.
instance_duration (int): The duration of each instance of input with same target value.
loop (bool): If True, dataloader repeats.
"""

Expand All @@ -27,10 +31,14 @@ def initialize(self, layer):
self.have_location = self.parameter("have_location", False)
self.have_sensory = self.parameter("have_sensory", True)
self.have_label = self.parameter("have_label", True)
self.silent_interval = self.parameter("silent_interval", 0)
self.each_instance = self.parameter("instance_duration", 0, required=True)
self.loop = self.parameter("loop", True)

self.data_generator = self._get_data()
self.device = layer.device
self.new_data = False
self.silent_iteration = 0

def _get_data(self):
while self.loop:
Expand Down Expand Up @@ -59,13 +67,39 @@ def _get_data(self):

if batch_y is not None:
batch_y = batch_y.to(self.device)
each_instance = num_instance // torch.numel(batch_y)
self.each_instance = num_instance // torch.numel(batch_y)

for i in range(num_instance):
x = batch_x[i].view((-1,)) if batch_x is not None else None
loc = batch_loc[i].view((-1,)) if batch_loc is not None else None
y = batch_y[i // each_instance] if batch_y is not None else None
y = (
batch_y[i // self.each_instance]
if batch_y is not None
else None
)
if i % self.each_instance == self.each_instance - 1:
self.new_data = True
yield x, loc, y

def forward(self, layer):
if self.silent_interval and self.new_data:
if self.silent_iteration == 0:
layer.x = (
layer.tensor(mode="zeros", dtype=torch.bool, dim=layer.x.shape)
if layer.x is not None
else None
)
layer.loc = (
layer.tensor(mode="zeros", dtype=torch.bool, dim=layer.loc.shape)
if layer.loc is not None
else None
)

self.silent_iteration += 1

if self.silent_iteration == self.silent_interval:
self.new_data = False
self.silent_iteration = 0
return

layer.x, layer.loc, layer.y = next(self.data_generator)
4 changes: 4 additions & 0 deletions conex/behaviors/synapses/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class WeightInitializer(Behavior):

def initialize(self, synapse):
init_mode = self.parameter("mode", None)
scale = self.parameter("scale", 1)
offset = self.parameter("offset", 0)
synapse.weights = self.parameter("weights", None)
synapse.weight_shape = self.parameter("weight_shape", None)
synapse.kernel_shape = self.parameter("kernel_shape", None)
Expand All @@ -110,6 +112,8 @@ def initialize(self, synapse):
mode=init_mode, dim=synapse.weight_shape
)

synapse.weights = synapse.weights * scale + offset


class WeightNormalization(Behavior):
"""
Expand Down
44 changes: 44 additions & 0 deletions conex/helpers/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from torch.utils.data import Dataset


class LocationDataset(Dataset):
"""
A custom dataset class for data with triple image, location, label nature.
Args:
dataset (Dataset): An Instance of a dataset
pre_transform (Transform): A transformation to apply on images. If given, transformation should return a `(image, location)` tuple.
post_transform (Transform): A Transformation that applies on images. Suitable for encodings.
location_transform (Transform): A Transformation applies on location data. Suitable for encodings.
target_transform (Transform): A Transformation applies on labels.
"""

def __init__(
self,
dataset,
pre_transform=None,
post_transform=None,
location_transform=None,
target_transform=None,
):
self.dataset = dataset
self.pre_transform = pre_transform
self.post_transform = post_transform
self.location_transform = location_transform
self.target_transform = target_transform

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
image, label = self.dataset[idx]
location = None
if self.pre_transform:
image, location = self.pre_transform(image)
if self.post_transform:
image = self.post_transform(image)
if self.location_transform:
location = self.location_transform(location)
if self.target_transform:
label = self.target_transform(label)
return image, location, label
56 changes: 0 additions & 56 deletions conex/helpers/encoders.py

This file was deleted.

110 changes: 110 additions & 0 deletions conex/helpers/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch


def DoGFilter(
size,
sigma_1,
sigma_2,
step=1.0,
zero_mean=False,
one_sum=False,
device=None,
dtype=None,
):
"""Difference of Gaussians.
Makes a square mono-colored DoG filter.
Args:
size (int): Filter size.
sigma_1 (float): First standard deviation.
sigma_2 (float): Second standard deviation.
step (float, optional): Scaling factor for axes. Defaults to 1.0.
zero_mean (bool, optional): Whether to scale negative values in order to have zero mean. Defaults to False.
one_sum (bool, optional): Whether to divide values in order to have maximum possible dot product equal to one. Defaults to False.
device (str, optional): Device to locate filter on. Defaults to None.
dtype (dtype, optional): Datatype of desired filter. Defaults to None.
Returns:
tensor: the desired DoG filter
"""
scale = (size - 1) / 2

v_range = torch.arange(-scale, scale + step, step, dtype=dtype, device=device)
x, y = torch.meshgrid(v_range, v_range, indexing="ij")

g_values = -(x**2 + y**2) / 2

dog_1 = torch.exp(g_values / (sigma_1**2)) / sigma_1
dog_2 = torch.exp(g_values / (sigma_2**2)) / sigma_2

dog_filter = (dog_1 - dog_2) / torch.sqrt(
torch.tensor(2 * torch.pi, device=device, dtype=dtype)
)

if zero_mean:
p_sum = torch.sum(dog_filter[dog_filter > 0])
n_sum = torch.sum(dog_filter[dog_filter < 0])
dog_filter[dog_filter < 0] *= -p_sum / n_sum

if one_sum:
dog_filter /= torch.sum(torch.abs(dog_filter))

return dog_filter


def GaborFilter(
size,
labda,
theta,
sigma,
gamma,
step=1.0,
zero_mean=False,
one_sum=False,
device=None,
dtype=None,
):
"""Gabor filter
Makes a square mono-colored Gabor filter.
Args:
size (int): Filter size.
labda (float): The wavelength of the filter.
theta (float): The orientation of the filter.
sigma (float): The standard deviation of the filter.
gamma (float): The aspect ratio for the filter.
step (float, optional): Scaling factor for axes. Defaults to 1.0.
zero_mean (bool, optional): Whether to scale negative values in order to have zero mean. Defaults to False.
one_sum (bool, optional): Whether to divide values in order to have maximum possible dot product equal to one. Defaults to False.
device (str, optional): Device to locate filter on. Defaults to None.
dtype (dtype, optional): Datatype of desired filter. Defaults to None.
Returns:
tensor: the desired Gabor filter
"""

scale = (size - 1) / 2

v_range = torch.arange(-scale, scale + step, step, dtype=dtype, device=device)
x, y = torch.meshgrid(v_range, v_range, indexing="ij")

x_rotated = x * torch.cos(
torch.tensor(theta, device=device, dtype=dtype)
) + y * torch.sin(torch.tensor(theta, device=device, dtype=dtype))
y_rotated = -x * torch.sin(
torch.tensor(theta, device=device, dtype=dtype)
) + y * torch.cos(torch.tensor(theta, device=device, dtype=dtype))

gabor_filter = torch.exp(
-(x_rotated**2 + (gamma**2 * y_rotated**2)) / (2 * sigma**2)
) * torch.cos(2 * torch.pi * x_rotated / labda)

if zero_mean:
p_sum = torch.sum(gabor_filter[gabor_filter > 0])
n_sum = torch.sum(gabor_filter[gabor_filter < 0])
gabor_filter[gabor_filter < 0] *= -p_sum / n_sum

if one_sum:
gabor_filter /= torch.sum(torch.abs(gabor_filter))

return gabor_filter
55 changes: 0 additions & 55 deletions conex/helpers/mask_transforms.py

This file was deleted.

Loading

0 comments on commit c4f499a

Please sign in to comment.