Skip to content

Commit

Permalink
Changed the show function to return a dictionary in case display=Fals…
Browse files Browse the repository at this point in the history
…e, and implemented the .to function by changing _apply (I had to use a trick), to be able to send to a specific device (to be tested)
  • Loading branch information
k4ntz committed Dec 19, 2020
1 parent 52bf68d commit da7fee1
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 55 deletions.
16 changes: 8 additions & 8 deletions rational/_cuda/rational_cuda_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3028,7 +3028,7 @@ std::vector<torch::Tensor> rational_cuda_backward_A_7_6(torch::Tensor grad_outpu



// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n|
// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n|



Expand Down Expand Up @@ -3299,7 +3299,7 @@ std::vector<torch::Tensor> rational_cuda_backward_B_3_3(torch::Tensor grad_outpu



// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n|
// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n|



Expand Down Expand Up @@ -3607,7 +3607,7 @@ std::vector<torch::Tensor> rational_cuda_backward_B_4_4(torch::Tensor grad_outpu



// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n|
// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n|



Expand Down Expand Up @@ -3952,7 +3952,7 @@ std::vector<torch::Tensor> rational_cuda_backward_B_5_5(torch::Tensor grad_outpu



// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n|
// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n|



Expand Down Expand Up @@ -4334,7 +4334,7 @@ std::vector<torch::Tensor> rational_cuda_backward_B_6_6(torch::Tensor grad_outpu



// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n|
// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n|



Expand Down Expand Up @@ -4753,7 +4753,7 @@ std::vector<torch::Tensor> rational_cuda_backward_B_7_7(torch::Tensor grad_outpu



// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n|
// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n|



Expand Down Expand Up @@ -5209,7 +5209,7 @@ std::vector<torch::Tensor> rational_cuda_backward_B_8_8(torch::Tensor grad_outpu



// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n|
// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n|



Expand Down Expand Up @@ -5538,7 +5538,7 @@ std::vector<torch::Tensor> rational_cuda_backward_B_5_4(torch::Tensor grad_outpu



// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n|
// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n|



Expand Down
8 changes: 4 additions & 4 deletions rational/keras/rational_keras_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def get_xps(weight_denominator, weight_numerator, z):
return xps


def Rational_PYTORCH_A_F(z, weight_numerator, weight_denominator, training):
def Rational_KERAS_A_F(z, weight_numerator, weight_denominator, training):
# P(X) / Q(X) = a_0 + a_1 * X + ... + a_n * X ^ n /
# 1 + | b_0 | | X | + | b_1 | | X | ^ 2 + ... + | b_i | | X | ^ {i + 1}

Expand All @@ -29,7 +29,7 @@ def Rational_PYTORCH_A_F(z, weight_numerator, weight_denominator, training):
return numerator / denominator


def Rational_PYTORCH_B_F(z, weight_numerator, weight_denominator, training):
def Rational_KERAS_B_F(z, weight_numerator, weight_denominator, training):
# P(X) / Q(X) = a_0 + a_1 * X + ... + a_n * X ^ n /
# 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n|

Expand All @@ -48,7 +48,7 @@ def Rational_PYTORCH_B_F(z, weight_numerator, weight_denominator, training):
return numerator / (1 + tf.abs(denominator))


def Rational_PYTORCH_C_F(z, weight_numerator, weight_denominator, training):
def Rational_KERAS_C_F(z, weight_numerator, weight_denominator, training):
# P(X) / Q(X) = a_0 + a_1 * X + ... + a_n * X ^ n /
# eps + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n|

Expand All @@ -67,5 +67,5 @@ def Rational_PYTORCH_C_F(z, weight_numerator, weight_denominator, training):
return numerator / (0.1 + tf.abs(denominator))


def Rational_PYTORCH_D_F(x, weight_numerator, weight_denominator, training):
def Rational_KERAS_D_F(x, weight_numerator, weight_denominator, training):
raise NotImplementedError()
8 changes: 4 additions & 4 deletions rational/keras/rationals.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ def __init__(self, approx_func="leaky_relu", degrees=(5, 4), cuda=False,
self.denominator = tf.Variable(initial_value=w_denominator, trainable=trainable and train_denominator)

if version == "A":
rational_func = Rational_PYTORCH_A_F
rational_func = Rational_KERAS_A_F
elif version == "B":
rational_func = Rational_PYTORCH_B_F
rational_func = Rational_KERAS_B_F
elif version == "C":
rational_func = Rational_PYTORCH_C_F
rational_func = Rational_KERAS_C_F
elif version == "D":
rational_func = Rational_PYTORCH_D_F
rational_func = Rational_KERAS_D_F
else:
raise ValueError("version %s not implemented" % version)

Expand Down
14 changes: 1 addition & 13 deletions rational/numpy/rationals.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def torch(self, cuda=None, trainable=True, train_numerator=True,
requires_grad=trainable and train_denominator)
return rtorch

def fit(self, function, x_range=np.arange(-3., 3., 0.1), show=False):
def fit(self, function, x_range=np.arange(-3., 3., 0.1)):
"""
Compute the parameters a, b, c, and d to have the neurally equivalent \
function of the provided one as close as possible to this rational \
Expand Down Expand Up @@ -109,18 +109,6 @@ def fit(self, function, x_range=np.arange(-3., 3., 0.1), show=False):
from rational.utils import find_closest_equivalent
(a, b, c, d), distance = find_closest_equivalent(self, function,
x_range)
if show:
import matplotlib.pyplot as plt
import torch
plt.plot(x_range, self(x_range), label="Rational (self)")
if '__name__' in dir(function):
func_label = function.__name__
else:
func_label = str(function)
result = a * function(c * torch.tensor(x_range) + d) + b
plt.plot(x_range, result, label=f"Fitted {func_label}")
plt.legend()
plt.show()
return (a, b, c, d), distance

def __repr__(self):
Expand Down
90 changes: 72 additions & 18 deletions rational/torch/rationals.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,12 @@ def __init__(self, approx_func="leaky_relu", degrees=(5, 4), cuda=None,

if cuda is None:
cuda = torch_cuda_available()
device = "cuda" if cuda else "cpu"
if cuda is True:
device = "cuda"
elif cuda is False:
device = "cpu"
else:
device = cuda

w_numerator, w_denominator = get_parameters(version, degrees,
approx_func)
Expand All @@ -190,7 +195,7 @@ def __init__(self, approx_func="leaky_relu", degrees=(5, 4), cuda=None,

self.init_approximation = approx_func

if cuda:
if "cuda" in device:
if version == "A":
rational_func = Rational_CUDA_A_F
elif version == "B":
Expand Down Expand Up @@ -247,7 +252,7 @@ def cpu(self):
self.numerator = nn.Parameter(self.numerator.cpu())
self.denominator = nn.Parameter(self.denominator.cpu())

def cuda(self):
def cuda(self, device="0"):
if self.version == "A":
rational_func = Rational_CUDA_A_F
elif self.version == "B":
Expand All @@ -258,30 +263,37 @@ def cuda(self):
rational_func = Rational_CUDA_D_F
else:
raise ValueError("version %s not implemented" % self.version)
self.device = "cuda"
if "cuda" in device:
self.device = f"device"
else:
self.device = f"cuda:{device}"
self.activation_function = rational_func.apply
self.numerator = nn.Parameter(self.numerator.cuda())
self.denominator = nn.Parameter(self.denominator.cuda())
self.numerator = nn.Parameter(self.numerator.cuda(self.device))
self.denominator = nn.Parameter(self.denominator.cuda(self.device))

def to(self, device):
if "cpu" in str(device):
self.cpu()
elif "cuda" in str(device):
self.cuda()
self.cuda(device)

def _apply(self, fn):
if "Module.cpu" in str(fn):
self.cpu()
elif "Module.cuda" in str(fn):
self.cuda()
elif "Module.to" in str(fn):
device = fn.__closure__[1].cell_contents
assert type(device) == torch.device # otherwise loop on __closure__
self.to(device)
else:
return super._apply(fn)

def numpy(self):
"""
Returns a numpy version of this activation function.
"""
from rational import Rational as Rational_numpy
from rational.numpy import Rational as Rational_numpy
rational_n = Rational_numpy(self.init_approximation, self.degrees,
self.version)
rational_n.numerator = self.numerator.tolist()
Expand Down Expand Up @@ -312,27 +324,56 @@ def fit(self, function, x=None, show=False):
dist: The final distance between the rational function and the \
fitted one
"""
used_dist = False
rational_numpy = self.numpy()
if x is not None:
return rational_numpy.fit(function, x, show)
(a, b, c, d), distance = rational_numpy.fit(function, x)
else:
return rational_numpy.fit(function, show=show)
if self.distribution is not None:
freq, bins = _cleared_arrays(self.distribution)
x = bins
used_dist = True
else:
import numpy as np
x = np.arange(-3., 3., 0.1)
(a, b, c, d), distance = rational_numpy.fit(function, x)
if show:
import matplotlib.pyplot as plt
import torch
plt.plot(x, rational_numpy(x), label="Rational (self)")
if '__name__' in dir(function):
func_label = function.__name__
else:
func_label = str(function)
result = a * function(c * torch.tensor(x) + d) + b
plt.plot(x, result, label=f"Fitted {func_label}")
if used_dist:
ax = plt.gca()
ax2 = ax.twinx()
ax2.set_yticks([])
grey_color = (0.5, 0.5, 0.5, 0.6)
ax2.bar(bins, freq, width=bins[1] - bins[0],
color=grey_color, edgecolor=grey_color)
plt.legend()
plt.show()
return (a, b, c, d), distance

def best_fit(self, functions_list, x=None):
def best_fit(self, functions_list, x=None, shows=False):
if self.distribution is not None:
freq, bins = _cleared_arrays(self.distribution)
x = bins
(a, b, c, d), distance = self.fit(functions_list[0], x=x)
(a, b, c, d), distance = self.fit(functions_list[0], x=x, show=shows)
min_dist = distance
params = (a, b, c, d)
final_function = functions_list[0]
for func in functions_dict[1:]:
(a, b, c, d), distance = self.fit(functions_list[0], x=x)
for func in functions_list[1:]:
(a, b, c, d), distance = self.fit(functions_list[0], x=x, show=shows)
print(f"{func}: {distance}")
if min_dist > distance:
min_dist = distance
params = (a, b, c, d)
final_func = func
print(f"{func} is the new best fitted function")
self.best_fitted_function = final_func
self.best_fitted_function_params = params
return final_func, (a, b, c, d)
Expand Down Expand Up @@ -429,6 +470,9 @@ def input_retrieve_mode(self, auto_stop=True, max_saves=1000, bin_width=0.1):
together.\n
Default ``1000``
"""
if self._handle_retrieve_mode is not None:
print("Already in retrieve mode")
return
from physt import h1 as hist1
self.distribution = hist1(None, "fixed_width", bin_width=bin_width,
adaptive=True)
Expand All @@ -447,17 +491,21 @@ def training_mode(self):
print("Training mode, no longer retrieving the input.")
self._handle_retrieve_mode.remove()

def show(self, input_range=None, display=True):
def show(self, input_range=None, fitted_function=True, display=True):
"""
Show the function using `matplotlib`.
Arguments:
input_range (range):
The range to print the function on.\n
Default ``None``
fitted_function (bool):
If ``True``, displays the best fitted function if searched.
Otherwise, returns it. \n
Default ``True``
display (bool):
If ``True``, displays the graph.
Otherwise, returns it. \n
Otherwise, returns a dictionary with functions informations. \n
Default ``True``
"""
freq = None
Expand Down Expand Up @@ -504,8 +552,14 @@ def show(self, input_range=None, display=True):
else:
hist_dict = {"bins": bins, "freq": freq,
"width": bins[1] - bins[0]}
if "best_fitted_function" not in vars(self) or self.best_fitted_function is None:
fitted_function = None
else:
fitted_function = {"function": self.best_fitted_function,
"params": (a, b, c, d)}
return {"hist": hist_dict,
"line": {"x": inputs_np, "y": outputs_np}}
"line": {"x": inputs_np, "y": outputs_np},
"fitted_function": fitted_function}


def _save_input(self, input, output):
Expand All @@ -518,6 +572,7 @@ def _save_input_auto_stop(self, input, output):
if self.inputs_saved > self._max_saves:
self.training_mode()


def _cleared_arrays(hist, tolerance=0.001):
hist = hist.normalize()
freq, bins = hist.numpy_like
Expand All @@ -529,7 +584,6 @@ def _cleared_arrays(hist, tolerance=0.001):
return freq[first:last], bins[first:last - 1]



class AugmentedRational(nn.Module):
"""
Augmented Rational activation function inherited from ``Rational``
Expand Down
12 changes: 6 additions & 6 deletions rational/utils/find_init_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ def plot_result(x_array, rational_array, target_array,
plt.show()


def append_to_config_file(params, approx_name, w_params, d_params, overright=None):
def append_to_config_file(params, approx_name, w_params, d_params, overwrite=None):
rational_full_name = f'Rational_version_{params["version"]}{params["nd"]}/{params["dd"]}'
cfd = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
with open(f'{cfd}/rationals_config.json') as json_file:
rationals_dict = json.load(json_file) # rational_version -> approx_func
approx_name = approx_name.lower()
if rational_full_name in rationals_dict:
if approx_name in rationals_dict[rational_full_name]:
if overright is None:
overright = input(f'Rational_{params["version"]} approximation of {approx_name} already exist. \
if overwrite is None:
overwrite = input(f'Rational_{params["version"]} approximation of {approx_name} already exist. \
\nDo you want to replace it ? (y/n)') in ["y", "yes"]
if not overright:
if not overwrite:
print("Parameters not stored")
return
else:
Expand Down Expand Up @@ -79,7 +79,7 @@ def typed_input(text, type, choice_list=None):


def find_weights(function, function_name=None, degrees=None, bounds=None,
version=None, plot=None, save=None, overright=None):
version=None, plot=None, save=None, overwrite=None):
# To be changed by the function you want to approximate
if function_name is None:
function_name = input("approximated function name: ")
Expand Down Expand Up @@ -130,7 +130,7 @@ def function_to_approx(x):
if save is None:
save = input("Do you want to store them in the json file ? (y/n)") in ["y", "yes"]
if save:
append_to_config_file(params, function_name, w_params, d_params, overright)
append_to_config_file(params, function_name, w_params, d_params, overwrite)
else:
print("Parameters not stored")
return w_params, d_params
Loading

0 comments on commit da7fee1

Please sign in to comment.