Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data preparation #35

Merged
merged 10 commits into from
Jun 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ Contributors
------------

* Amir Hosein Ebrahimi <[email protected]>

* Saman Arzaghi <[email protected]>
5 changes: 3 additions & 2 deletions Example/test/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from conex import *
from conex.nn.priorities import NEURON_PRIORITIES

from conex.helpers.encoders import Poisson
from conex.helpers.transforms import *
from conex.helpers.transforms.encoders import Poisson
from conex.helpers.transforms.misc import *

from torchvision import transforms
from torchvision.datasets import MNIST
Expand Down Expand Up @@ -67,6 +67,7 @@
depth=1, height=SENSORY_SIZE_HEIGHT, width=SENSORY_SIZE_WIDTH
),
sensory_trace=SENSORY_TRACE_TAU_S,
instance_duration=POISSON_TIME,
)

#############################
Expand Down
4 changes: 2 additions & 2 deletions conex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@

from conex.behaviors.layer.dataset import *

from conex.nn.Structure import *
from conex.nn.Modules import *
from conex.nn.structure import *
from conex.nn.modules import *
from conex.nn.config import *
38 changes: 36 additions & 2 deletions conex/behaviors/layer/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

"""
Behaviors to load datasets
"""
Expand All @@ -17,6 +19,8 @@ class SpikeNdDataset(Behavior):
have_location (bool): Whether dataloader returns location input.
have_sensory (bool): Whether dataloader returns sensory input.
have_label (bool): Whether dataloader returns label of input.
silent_interval (int): The interval of silent activity between two different input.
instance_duration (int): The duration of each instance of input with same target value.
loop (bool): If True, dataloader repeats.
"""

Expand All @@ -27,10 +31,14 @@ def initialize(self, layer):
self.have_location = self.parameter("have_location", False)
self.have_sensory = self.parameter("have_sensory", True)
self.have_label = self.parameter("have_label", True)
self.silent_interval = self.parameter("silent_interval", 0)
self.each_instance = self.parameter("instance_duration", 0, required=True)
self.loop = self.parameter("loop", True)

self.data_generator = self._get_data()
self.device = layer.device
self.new_data = False
self.silent_iteration = 0

def _get_data(self):
while self.loop:
Expand Down Expand Up @@ -59,13 +67,39 @@ def _get_data(self):

if batch_y is not None:
batch_y = batch_y.to(self.device)
each_instance = num_instance // torch.numel(batch_y)
self.each_instance = num_instance // torch.numel(batch_y)

for i in range(num_instance):
x = batch_x[i].view((-1,)) if batch_x is not None else None
loc = batch_loc[i].view((-1,)) if batch_loc is not None else None
y = batch_y[i // each_instance] if batch_y is not None else None
y = (
batch_y[i // self.each_instance]
if batch_y is not None
else None
)
if i % self.each_instance == self.each_instance - 1:
self.new_data = True
yield x, loc, y

def forward(self, layer):
if self.silent_interval and self.new_data:
if self.silent_iteration == 0:
layer.x = (
layer.tensor(mode="zeros", dtype=torch.bool, dim=layer.x.shape)
if layer.x is not None
else None
)
layer.loc = (
layer.tensor(mode="zeros", dtype=torch.bool, dim=layer.loc.shape)
if layer.loc is not None
else None
)

self.silent_iteration += 1

if self.silent_iteration == self.silent_interval:
self.new_data = False
self.silent_iteration = 0
return

layer.x, layer.loc, layer.y = next(self.data_generator)
4 changes: 4 additions & 0 deletions conex/behaviors/synapses/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class WeightInitializer(Behavior):

def initialize(self, synapse):
init_mode = self.parameter("mode", None)
scale = self.parameter("scale", 1)
offset = self.parameter("offset", 0)
synapse.weights = self.parameter("weights", None)
synapse.weight_shape = self.parameter("weight_shape", None)
synapse.kernel_shape = self.parameter("kernel_shape", None)
Expand All @@ -110,6 +112,8 @@ def initialize(self, synapse):
mode=init_mode, dim=synapse.weight_shape
)

synapse.weights = synapse.weights * scale + offset


class WeightNormalization(Behavior):
"""
Expand Down
44 changes: 44 additions & 0 deletions conex/helpers/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from torch.utils.data import Dataset


class LocationDataset(Dataset):
"""
A custom dataset class for data with triple image, location, label nature.

Args:
dataset (Dataset): An Instance of a dataset
pre_transform (Transform): A transformation to apply on images. If given, transformation should return a `(image, location)` tuple.
post_transform (Transform): A Transformation that applies on images. Suitable for encodings.
location_transform (Transform): A Transformation applies on location data. Suitable for encodings.
target_transform (Transform): A Transformation applies on labels.
"""

def __init__(
self,
dataset,
pre_transform=None,
post_transform=None,
location_transform=None,
target_transform=None,
):
self.dataset = dataset
self.pre_transform = pre_transform
self.post_transform = post_transform
self.location_transform = location_transform
self.target_transform = target_transform

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
image, label = self.dataset[idx]
location = None
if self.pre_transform:
image, location = self.pre_transform(image)
if self.post_transform:
image = self.post_transform(image)
if self.location_transform:
location = self.location_transform(location)
if self.target_transform:
label = self.target_transform(label)
return image, location, label
56 changes: 0 additions & 56 deletions conex/helpers/encoders.py

This file was deleted.

110 changes: 110 additions & 0 deletions conex/helpers/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch


def DoGFilter(
size,
sigma_1,
sigma_2,
step=1.0,
zero_mean=False,
one_sum=False,
device=None,
dtype=None,
):
"""Difference of Gaussians.
Makes a square mono-colored DoG filter.

Args:
size (int): Filter size.
sigma_1 (float): First standard deviation.
sigma_2 (float): Second standard deviation.
step (float, optional): Scaling factor for axes. Defaults to 1.0.
zero_mean (bool, optional): Whether to scale negative values in order to have zero mean. Defaults to False.
one_sum (bool, optional): Whether to divide values in order to have maximum possible dot product equal to one. Defaults to False.
device (str, optional): Device to locate filter on. Defaults to None.
dtype (dtype, optional): Datatype of desired filter. Defaults to None.

Returns:
tensor: the desired DoG filter
"""
scale = (size - 1) / 2

v_range = torch.arange(-scale, scale + step, step, dtype=dtype, device=device)
x, y = torch.meshgrid(v_range, v_range, indexing="ij")

g_values = -(x**2 + y**2) / 2

dog_1 = torch.exp(g_values / (sigma_1**2)) / sigma_1
dog_2 = torch.exp(g_values / (sigma_2**2)) / sigma_2

dog_filter = (dog_1 - dog_2) / torch.sqrt(
torch.tensor(2 * torch.pi, device=device, dtype=dtype)
)

if zero_mean:
p_sum = torch.sum(dog_filter[dog_filter > 0])
n_sum = torch.sum(dog_filter[dog_filter < 0])
dog_filter[dog_filter < 0] *= -p_sum / n_sum

if one_sum:
dog_filter /= torch.sum(torch.abs(dog_filter))

return dog_filter


def GaborFilter(
size,
labda,
theta,
sigma,
gamma,
step=1.0,
zero_mean=False,
one_sum=False,
device=None,
dtype=None,
):
"""Gabor filter
Makes a square mono-colored Gabor filter.

Args:
size (int): Filter size.
labda (float): The wavelength of the filter.
theta (float): The orientation of the filter.
sigma (float): The standard deviation of the filter.
gamma (float): The aspect ratio for the filter.
step (float, optional): Scaling factor for axes. Defaults to 1.0.
zero_mean (bool, optional): Whether to scale negative values in order to have zero mean. Defaults to False.
one_sum (bool, optional): Whether to divide values in order to have maximum possible dot product equal to one. Defaults to False.
device (str, optional): Device to locate filter on. Defaults to None.
dtype (dtype, optional): Datatype of desired filter. Defaults to None.

Returns:
tensor: the desired Gabor filter
"""

scale = (size - 1) / 2

v_range = torch.arange(-scale, scale + step, step, dtype=dtype, device=device)
x, y = torch.meshgrid(v_range, v_range, indexing="ij")

x_rotated = x * torch.cos(
torch.tensor(theta, device=device, dtype=dtype)
) + y * torch.sin(torch.tensor(theta, device=device, dtype=dtype))
y_rotated = -x * torch.sin(
torch.tensor(theta, device=device, dtype=dtype)
) + y * torch.cos(torch.tensor(theta, device=device, dtype=dtype))

gabor_filter = torch.exp(
-(x_rotated**2 + (gamma**2 * y_rotated**2)) / (2 * sigma**2)
) * torch.cos(2 * torch.pi * x_rotated / labda)

if zero_mean:
p_sum = torch.sum(gabor_filter[gabor_filter > 0])
n_sum = torch.sum(gabor_filter[gabor_filter < 0])
gabor_filter[gabor_filter < 0] *= -p_sum / n_sum

if one_sum:
gabor_filter /= torch.sum(torch.abs(gabor_filter))

return gabor_filter
55 changes: 0 additions & 55 deletions conex/helpers/mask_transforms.py

This file was deleted.

Loading