Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions pylearn2/sandbox/cuda_convnet/base_acts.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,16 @@


class BaseActs(GpuOp):
"""Shared code for wrapping various convnet operations.

:param conv: if conv is True, we do "normal"
cross-correlation. Otherwise, we do a Locally-connected layer
with unshared weights.
"""
Shared code for wrapping various convnet operations.
"""
def __init__(self, pad=0, partial_sum=None, stride=1):
__props__ = ('pad', 'partial_sum', 'stride', 'conv',
'dense_connectivity', 'copy_non_contiguous')

def __init__(self, pad=0, partial_sum=None, stride=1, conv=True):

if not isinstance(pad, py_integer_types):
raise TypeError("pad must be an int")
Expand All @@ -75,6 +81,8 @@ def __init__(self, pad=0, partial_sum=None, stride=1):
self.copy_non_contiguous = 0
# TODO: support sparse connectivity pattern
self.dense_connectivity = True
assert conv in [False, True]
self.conv = conv

def c_header_dirs(self):
if config.pthreads.inc_dir:
Expand Down Expand Up @@ -129,23 +137,6 @@ def _argument_dimension_check(self, arg_name, ndim):
}
""" % locals()

def __eq__(self, other):
return (type(self) == type(other) and
self.partial_sum == other.partial_sum and
self.pad == other.pad and
self.dense_connectivity == other.dense_connectivity and
self.stride == other.stride and
self.copy_non_contiguous == other.copy_non_contiguous)

def __hash__(self):
msg = []
msg.append(self.__class__.__name__)
for val in (self.partial_sum, self.pad, self.dense_connectivity,
self.stride, self.copy_non_contiguous):
msg.append(str(val))

return hash(tuple(msg))

# Make sure the cuda_convnet library is compiled and up-to-date
def make_thunk(self, node, storage_map, compute_map, no_recycling):
if not convnet_available():
Expand Down
73 changes: 43 additions & 30 deletions pylearn2/sandbox/cuda_convnet/filter_acts.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def c_code(self, node, name, inputs, outputs, sub):
basic_setup = """
#define scaleTargets 0
#define scaleOutput 1
"""
int conv = %d;
""" % self.conv

if self.dense_connectivity:
basic_setup += """
Expand Down Expand Up @@ -200,6 +201,16 @@ def c_code(self, node, name, inputs, outputs, sub):
"""
num_braces += 1

# p + (m_x - 1) * s + f >= i_x
# p + (m_x - 1) * s >= i_x - f
# m_x = (i_x - f - p) / s + 1
div_ms_y = "((imgSizeY - 2*paddingStart - filter_rows) / moduleStride)"
div_ms_x = "((imgSizeX - 2*paddingStart - filter_cols) / moduleStride)"
mod_ms_y = "((imgSizeY - 2*paddingStart - filter_rows) % moduleStride)"
mod_ms_x = "((imgSizeX - 2*paddingStart - filter_cols) % moduleStride)"
target_rows = "%s + ((%s > 0) ? 1 : 0) + 1" % (div_ms_y, mod_ms_y)
target_cols = "%s + ((%s > 0) ? 1 : 0) + 1" % (div_ms_x, mod_ms_x)

# Convert filters into nv_filters, an NVMatrix, for compatibility
# with the cuda-convnet functions
setup_nv_filters = self._argument_contiguity_check("filters") + """
Expand All @@ -217,11 +228,16 @@ def c_code(self, node, name, inputs, outputs, sub):
const int filter_cols = filters_dims[2];
const int num_filters = filters_dims[3];

if (numGroups * filter_channels != img_channels)
const int target_rows = %(target_rows)s;
const int target_cols = %(target_cols)s;
int filterModuleMult = conv ? 1 : target_rows * target_cols;
printf("target row, col %%d %%d filterModuleMult %%d\\n", target_rows, target_cols, filterModuleMult);

if (((numGroups * filter_channels) / filterModuleMult) != img_channels)
{
PyErr_Format(PyExc_ValueError,
"# input channels mismatch. images have %%d but filters have %%d groups of %%d for a total of %%d.",
img_channels, numGroups, filter_channels, numGroups * filter_channels);
"# input channels mismatch. images have %%d but filters have %%d groups of %%d for a total of %%d. filterModuleMultfil=%%d",
img_channels, numGroups, filter_channels, numGroups * filter_channels, filterModuleMult);
%(fail)s;
}

Expand Down Expand Up @@ -249,29 +265,18 @@ def c_code(self, node, name, inputs, outputs, sub):

{ // setup_nv_filters brace 2


NVMatrix nv_filters(%(filters)s, filter_channels * filter_rows *
filter_cols, num_filters, "filter_acts:nv_filters");
NVMatrix nv_filters(%(filters)s,
filter_channels * filter_rows * filter_cols,
num_filters, "filter_acts:nv_filters");
"""
num_braces += 2

# p + (m_x - 1) * s + f >= i_x
# p + (m_x - 1) * s >= i_x - f
# m_x = (i_x - f - p) / s + 1
div_ms_y = "((imgSizeY - 2*paddingStart - filter_rows) / moduleStride)"
div_ms_x = "((imgSizeX - 2*paddingStart - filter_cols) / moduleStride)"
mod_ms_y = "((imgSizeY - 2*paddingStart - filter_rows) % moduleStride)"
mod_ms_x = "((imgSizeX - 2*paddingStart - filter_cols) % moduleStride)"
target_rows = "%s + ((%s > 0) ? 1 : 0) + 1" % (div_ms_y, mod_ms_y)
target_cols = "%s + ((%s > 0) ? 1 : 0) + 1" % (div_ms_x, mod_ms_x)

setup_nv_targets = """


int target_dims [] = {
num_filters,
%(target_rows)s,
%(target_cols)s,
target_rows,
target_cols,
batch_size };

#define numModulesY target_dims[1]
Expand Down Expand Up @@ -303,10 +308,16 @@ def c_code(self, node, name, inputs, outputs, sub):
# nv_filters.getNumRows() by numFilterColors
#
do_convolution = """
convFilterActs(nv_images, nv_filters, nv_targets,
imgSizeY, numModulesY, numModulesX,
paddingStart, moduleStride, img_channels,
numGroups, scaleTargets, scaleOutput);
if (conv)
convFilterActs(nv_images, nv_filters, nv_targets,
imgSizeY, numModulesY, numModulesX,
paddingStart, moduleStride, img_channels,
numGroups, scaleTargets, scaleOutput);
else
localFilterActs(nv_images, nv_filters, nv_targets,
imgSizeY, numModulesY, numModulesX,
paddingStart, moduleStride, img_channels,
numGroups, scaleTargets, scaleOutput);
"""

braces = '}' * num_braces
Expand All @@ -322,13 +333,13 @@ def c_code(self, node, name, inputs, outputs, sub):

return rval

def c_code_cache_version(self):
# def c_code_cache_version(self):
"""
.. todo::

WRITEME
"""
return (10,)
# return (10,)

def R_op(self, inputs, evals):
"""
Expand Down Expand Up @@ -375,8 +386,10 @@ def grad(self, inputs, dout):

ishape = images.shape[1:3]
fshape = filters.shape[1:3]
d_images = ImageActs(self.pad, self.partial_sum, self.stride)(
dout, filters, ishape)
d_filters = WeightActs(self.pad, self.partial_sum, self.stride)(
images, dout, fshape)[0]
d_images = ImageActs(
self.pad, self.partial_sum, self.stride, self.conv)(
dout, filters, ishape)
d_filters = WeightActs(
self.pad, self.partial_sum, self.stride, self.conv)(
images, dout, fshape)[0]
return d_images, d_filters
31 changes: 21 additions & 10 deletions pylearn2/sandbox/cuda_convnet/img_acts.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,15 @@ def grad(self, inputs, g_outputs):
from pylearn2.sandbox.cuda_convnet.filter_acts import FilterActs
from pylearn2.sandbox.cuda_convnet.weight_acts import WeightActs

g_filters = WeightActs(stride=self.stride,
partial_sum=self.partial_sum, pad=self.pad)(
g_images, hid_acts, filters.shape[1:3])[0]
g_filters = WeightActs(
stride=self.stride, partial_sum=self.partial_sum,
pad=self.pad, conv=self.conv)(
g_images, hid_acts, filters.shape[1:3])[0]
assert not isinstance(g_filters, list)
g_hid_acts = FilterActs(stride=self.stride, pad=self.pad,
partial_sum=self.partial_sum)(g_images, filters)
g_hid_acts = FilterActs(
stride=self.stride, pad=self.pad,
partial_sum=self.partial_sum, conv=self.conv)(
g_images, filters)

return [g_hid_acts, g_filters, DisconnectedType()()]

Expand All @@ -202,7 +205,8 @@ def c_code(self, node, name, inputs, outputs, sub):
basic_setup = """
#define scaleTargets 0
#define scaleOutput 1
"""
int conv = %d;
""" % self.conv

if self.dense_connectivity:
basic_setup += """
Expand Down Expand Up @@ -332,6 +336,7 @@ def c_code(self, node, name, inputs, outputs, sub):
intp_dtype, 0);
target_rows = *((npy_intp *)PyArray_GETPTR1(casted_shape, 0));
target_cols = *((npy_intp *)PyArray_GETPTR1(casted_shape, 1));
int filterModuleMult = conv ? 1 : target_rows * target_cols;
{
int target_dims [] = {
filter_channels,
Expand Down Expand Up @@ -372,10 +377,16 @@ def c_code(self, node, name, inputs, outputs, sub):
# nv_filters.getNumRows() by numFilterColors
#
do_convolution = """
convImgActs(nv_hid_acts, nv_filters, nv_targets,
if (conv)
convImgActs(nv_hid_acts, nv_filters, nv_targets,
imgSizeY, imgSizeX, numModulesY,
paddingStart, moduleStride, filter_channels,
numGroups, scaleTargets, scaleOutput);
else
localImgActs(nv_hid_acts, nv_filters, nv_targets,
imgSizeY, imgSizeX, numModulesY,
paddingStart, moduleStride, filter_channels,
numGroups);
numGroups, scaleTargets, scaleOutput);
"""

braces = '}' * num_braces
Expand All @@ -391,10 +402,10 @@ def c_code(self, node, name, inputs, outputs, sub):

return rval

def c_code_cache_version(self):
# def c_code_cache_version(self):
"""
.. todo::

WRITEME
"""
return (9,)
# return (9,)
68 changes: 67 additions & 1 deletion pylearn2/sandbox/cuda_convnet/tests/test_filter_acts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import print_function
import warnings

__authors__ = "Ian Goodfellow"
__copyright__ = "Copyright 2010-2012, Universite de Montreal"
Expand All @@ -11,6 +12,8 @@
skip_if_no_gpu()

import numpy as np

import theano
from theano import shared
from theano.tensor import grad, constant
from pylearn2.sandbox.cuda_convnet.filter_acts import FilterActs
Expand All @@ -20,7 +23,6 @@
from theano.tensor.nnet.conv import conv2d
from theano import function
from theano import tensor as T
import warnings


def test_match_valid_conv():
Expand Down Expand Up @@ -80,6 +82,70 @@ def test_match_valid_conv():
assert False


def test_match_valid_local():

# Tests that running FilterActs with no padding is the same as running
# theano's conv2D in valid mode

rng = np.random.RandomState([2012, 10, 9])

batch_size = 5
rows = 10
cols = 9
filter_rows = 4
filter_cols = filter_rows
out_rows = rows - filter_rows + 1
out_cols = cols - filter_cols + 1
num_filters = 16
channels = 3

images = shared(rng.uniform(
-1., 1., (channels, rows, cols, batch_size)).astype('float32'),
name='images')
filters = shared(rng.uniform(
-1., 1., (channels * out_rows * out_cols,
filter_rows, filter_cols, num_filters)).astype('float32'),
name='filters')

gpu_images = gpu_from_host(images)
gpu_filters = gpu_from_host(filters)

output = FilterActs(conv=False)(gpu_images, gpu_filters)
output = host_from_gpu(output)

images_bc01 = images.dimshuffle(3, 0, 1, 2)
filters_bc01 = filters.dimshuffle(3, 0, 1, 2)
filters_bc01 = filters_bc01[:, :, ::-1, ::-1]

output_conv2d = conv2d(images_bc01, filters_bc01,
border_mode='valid')

output_conv2d = output_conv2d.dimshuffle(1, 2, 3, 0)

f = function([], [output, output_conv2d])
theano.printing.debugprint(f)

output, output_conv2d = f()

warnings.warn("""test_match_valid_conv success criterion is not very strict. Can we verify that this is OK?
One possibility is that theano is numerically unstable and Alex's code is better.
Probably theano CPU 64 bit is OK but it's worth checking the others.""")
if np.abs(output - output_conv2d).max() > 2.4e-6:
assert type(output) == type(output_conv2d)
assert output.dtype == output_conv2d.dtype
if output.shape != output_conv2d.shape:
print('cuda-convnet shape: ', output.shape)
print('theano shape: ', output_conv2d.shape)
assert False
err = np.abs(output - output_conv2d)
print('absolute error range: ', (err.min(), err.max()))
print('mean absolute error: ', err.mean())
print('cuda-convnet value range: ', (output.min(), output.max()))
print('theano value range: ', (output_conv2d.min(),
output_conv2d.max()))
assert False


def test_match_valid_conv_strided():

# Tests that running FilterActs with stride is the same as running
Expand Down
Loading