1111import logging
1212from theano .compile import Mode
1313import theano
14+ import theano .tensor as T
15+ import theano .sandbox .cuda as cuda
1416import numpy as np
1517from pylearn2 .models .dbm import flatten
1618from pylearn2 .utils import contains_nan , contains_inf
@@ -36,6 +38,23 @@ class NanGuardMode(Mode):
3638 If True, raise an error when a value greater than 1e10 is encountered.
3739 """
3840 def __init__ (self , nan_is_error , inf_is_error , big_is_error = True ):
41+ if cuda .cuda_available :
42+ self .guard_input = cuda .fvector ('nan_guard' )
43+ if nan_is_error or inf_is_error :
44+ self .gpumin = theano .function (
45+ [self .guard_input ], T .min (self .guard_input ),
46+ mode = 'FAST_RUN'
47+ )
48+ if inf_is_error :
49+ self .gpumax = theano .function (
50+ [self .guard_input ], T .max (self .guard_input ),
51+ mode = 'FAST_RUN'
52+ )
53+ if big_is_error :
54+ self .gpuabsmax = theano .function (
55+ [self .guard_input ], T .max (T .abs_ (self .guard_input )),
56+ mode = 'FAST_RUN'
57+ )
3958 def do_check_on (var , nd , f , is_input ):
4059 """
4160 Checks `var` for NaNs / Infs. If detected, raises an exception
@@ -56,15 +75,31 @@ def do_check_on(var, nd, f, is_input):
5675 """
5776 error = False
5877 if nan_is_error :
59- if contains_nan (var ):
78+ err = False
79+ if cuda .cuda_available and isinstance (var , cuda .CudaNdarray ):
80+ err = np .isnan (self .gpumin (var .reshape (var .size )))
81+ else :
82+ err = contains_nan (var )
83+ if err :
6084 logger .error ('NaN detected' )
6185 error = True
6286 if inf_is_error :
63- if contains_inf (var ):
87+ err = False
88+ if cuda .cuda_available and isinstance (var , cuda .CudaNdarray ):
89+ err = (np .isinf (self .gpumin (var .reshape (var .size ))) or \
90+ np .isinf (self .gpumax (var .reshape (var .size ))))
91+ else :
92+ err = contains_inf (var )
93+ if err :
6494 logger .error ('Inf detected' )
6595 error = True
6696 if big_is_error :
67- if np .abs (var ).max () > 1e10 :
97+ err = False
98+ if cuda .cuda_available and isinstance (var , cuda .CudaNdarray ):
99+ err = (self .gpuabsmax (var .reshape (var .size )) > 1e10 )
100+ else :
101+ err = (np .abs (var ).max () > 1e10 )
102+ if err :
68103 logger .error ('Big value detected' )
69104 error = True
70105 if error :
0 commit comments