diff --git a/rational/_cuda/rational_cuda_kernels.cu b/rational/_cuda/rational_cuda_kernels.cu index 5b6c4d3..920e090 100644 --- a/rational/_cuda/rational_cuda_kernels.cu +++ b/rational/_cuda/rational_cuda_kernels.cu @@ -3028,7 +3028,7 @@ std::vector rational_cuda_backward_A_7_6(torch::Tensor grad_outpu -// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n| +// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n| @@ -3299,7 +3299,7 @@ std::vector rational_cuda_backward_B_3_3(torch::Tensor grad_outpu -// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n| +// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n| @@ -3607,7 +3607,7 @@ std::vector rational_cuda_backward_B_4_4(torch::Tensor grad_outpu -// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n| +// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n| @@ -3952,7 +3952,7 @@ std::vector rational_cuda_backward_B_5_5(torch::Tensor grad_outpu -// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n| +// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n| @@ -4334,7 +4334,7 @@ std::vector rational_cuda_backward_B_6_6(torch::Tensor grad_outpu -// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n| +// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n| @@ -4753,7 +4753,7 @@ std::vector rational_cuda_backward_B_7_7(torch::Tensor grad_outpu -// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n| +// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n| @@ -5209,7 +5209,7 @@ std::vector rational_cuda_backward_B_8_8(torch::Tensor grad_outpu -// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n| +// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n| @@ -5538,7 +5538,7 @@ std::vector rational_cuda_backward_B_5_4(torch::Tensor grad_outpu -// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n| +// P(X)/Q(X) = a_0 + a_1*X + a_2*X^2 + ... + a_n*X^n / 1 + |b_1*X + b_2*X^2 + ... + b_n*X^n| diff --git a/rational/keras/rational_keras_functions.py b/rational/keras/rational_keras_functions.py index 15b9709..3ffebd5 100644 --- a/rational/keras/rational_keras_functions.py +++ b/rational/keras/rational_keras_functions.py @@ -10,7 +10,7 @@ def get_xps(weight_denominator, weight_numerator, z): return xps -def Rational_PYTORCH_A_F(z, weight_numerator, weight_denominator, training): +def Rational_KERAS_A_F(z, weight_numerator, weight_denominator, training): # P(X) / Q(X) = a_0 + a_1 * X + ... + a_n * X ^ n / # 1 + | b_0 | | X | + | b_1 | | X | ^ 2 + ... + | b_i | | X | ^ {i + 1} @@ -29,7 +29,7 @@ def Rational_PYTORCH_A_F(z, weight_numerator, weight_denominator, training): return numerator / denominator -def Rational_PYTORCH_B_F(z, weight_numerator, weight_denominator, training): +def Rational_KERAS_B_F(z, weight_numerator, weight_denominator, training): # P(X) / Q(X) = a_0 + a_1 * X + ... + a_n * X ^ n / # 1 + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n| @@ -48,7 +48,7 @@ def Rational_PYTORCH_B_F(z, weight_numerator, weight_denominator, training): return numerator / (1 + tf.abs(denominator)) -def Rational_PYTORCH_C_F(z, weight_numerator, weight_denominator, training): +def Rational_KERAS_C_F(z, weight_numerator, weight_denominator, training): # P(X) / Q(X) = a_0 + a_1 * X + ... + a_n * X ^ n / # eps + |b_0*X + b_1*X^2 + ... + b_{n-1}*X^n| @@ -67,5 +67,5 @@ def Rational_PYTORCH_C_F(z, weight_numerator, weight_denominator, training): return numerator / (0.1 + tf.abs(denominator)) -def Rational_PYTORCH_D_F(x, weight_numerator, weight_denominator, training): +def Rational_KERAS_D_F(x, weight_numerator, weight_denominator, training): raise NotImplementedError() diff --git a/rational/keras/rationals.py b/rational/keras/rationals.py index d7b2ea0..b844938 100644 --- a/rational/keras/rationals.py +++ b/rational/keras/rationals.py @@ -46,13 +46,13 @@ def __init__(self, approx_func="leaky_relu", degrees=(5, 4), cuda=False, self.denominator = tf.Variable(initial_value=w_denominator, trainable=trainable and train_denominator) if version == "A": - rational_func = Rational_PYTORCH_A_F + rational_func = Rational_KERAS_A_F elif version == "B": - rational_func = Rational_PYTORCH_B_F + rational_func = Rational_KERAS_B_F elif version == "C": - rational_func = Rational_PYTORCH_C_F + rational_func = Rational_KERAS_C_F elif version == "D": - rational_func = Rational_PYTORCH_D_F + rational_func = Rational_KERAS_D_F else: raise ValueError("version %s not implemented" % version) diff --git a/rational/numpy/rationals.py b/rational/numpy/rationals.py index f082d81..837e3d5 100644 --- a/rational/numpy/rationals.py +++ b/rational/numpy/rationals.py @@ -81,7 +81,7 @@ def torch(self, cuda=None, trainable=True, train_numerator=True, requires_grad=trainable and train_denominator) return rtorch - def fit(self, function, x_range=np.arange(-3., 3., 0.1), show=False): + def fit(self, function, x_range=np.arange(-3., 3., 0.1)): """ Compute the parameters a, b, c, and d to have the neurally equivalent \ function of the provided one as close as possible to this rational \ @@ -109,18 +109,6 @@ def fit(self, function, x_range=np.arange(-3., 3., 0.1), show=False): from rational.utils import find_closest_equivalent (a, b, c, d), distance = find_closest_equivalent(self, function, x_range) - if show: - import matplotlib.pyplot as plt - import torch - plt.plot(x_range, self(x_range), label="Rational (self)") - if '__name__' in dir(function): - func_label = function.__name__ - else: - func_label = str(function) - result = a * function(c * torch.tensor(x_range) + d) + b - plt.plot(x_range, result, label=f"Fitted {func_label}") - plt.legend() - plt.show() return (a, b, c, d), distance def __repr__(self): diff --git a/rational/torch/rationals.py b/rational/torch/rationals.py index 9c6e537..76b03df 100644 --- a/rational/torch/rationals.py +++ b/rational/torch/rationals.py @@ -172,7 +172,12 @@ def __init__(self, approx_func="leaky_relu", degrees=(5, 4), cuda=None, if cuda is None: cuda = torch_cuda_available() - device = "cuda" if cuda else "cpu" + if cuda is True: + device = "cuda" + elif cuda is False: + device = "cpu" + else: + device = cuda w_numerator, w_denominator = get_parameters(version, degrees, approx_func) @@ -190,7 +195,7 @@ def __init__(self, approx_func="leaky_relu", degrees=(5, 4), cuda=None, self.init_approximation = approx_func - if cuda: + if "cuda" in device: if version == "A": rational_func = Rational_CUDA_A_F elif version == "B": @@ -247,7 +252,7 @@ def cpu(self): self.numerator = nn.Parameter(self.numerator.cpu()) self.denominator = nn.Parameter(self.denominator.cpu()) - def cuda(self): + def cuda(self, device="0"): if self.version == "A": rational_func = Rational_CUDA_A_F elif self.version == "B": @@ -258,22 +263,29 @@ def cuda(self): rational_func = Rational_CUDA_D_F else: raise ValueError("version %s not implemented" % self.version) - self.device = "cuda" + if "cuda" in device: + self.device = f"device" + else: + self.device = f"cuda:{device}" self.activation_function = rational_func.apply - self.numerator = nn.Parameter(self.numerator.cuda()) - self.denominator = nn.Parameter(self.denominator.cuda()) + self.numerator = nn.Parameter(self.numerator.cuda(self.device)) + self.denominator = nn.Parameter(self.denominator.cuda(self.device)) def to(self, device): if "cpu" in str(device): self.cpu() elif "cuda" in str(device): - self.cuda() + self.cuda(device) def _apply(self, fn): if "Module.cpu" in str(fn): self.cpu() elif "Module.cuda" in str(fn): self.cuda() + elif "Module.to" in str(fn): + device = fn.__closure__[1].cell_contents + assert type(device) == torch.device # otherwise loop on __closure__ + self.to(device) else: return super._apply(fn) @@ -281,7 +293,7 @@ def numpy(self): """ Returns a numpy version of this activation function. """ - from rational import Rational as Rational_numpy + from rational.numpy import Rational as Rational_numpy rational_n = Rational_numpy(self.init_approximation, self.degrees, self.version) rational_n.numerator = self.numerator.tolist() @@ -312,27 +324,56 @@ def fit(self, function, x=None, show=False): dist: The final distance between the rational function and the \ fitted one """ + used_dist = False rational_numpy = self.numpy() if x is not None: - return rational_numpy.fit(function, x, show) + (a, b, c, d), distance = rational_numpy.fit(function, x) else: - return rational_numpy.fit(function, show=show) + if self.distribution is not None: + freq, bins = _cleared_arrays(self.distribution) + x = bins + used_dist = True + else: + import numpy as np + x = np.arange(-3., 3., 0.1) + (a, b, c, d), distance = rational_numpy.fit(function, x) + if show: + import matplotlib.pyplot as plt + import torch + plt.plot(x, rational_numpy(x), label="Rational (self)") + if '__name__' in dir(function): + func_label = function.__name__ + else: + func_label = str(function) + result = a * function(c * torch.tensor(x) + d) + b + plt.plot(x, result, label=f"Fitted {func_label}") + if used_dist: + ax = plt.gca() + ax2 = ax.twinx() + ax2.set_yticks([]) + grey_color = (0.5, 0.5, 0.5, 0.6) + ax2.bar(bins, freq, width=bins[1] - bins[0], + color=grey_color, edgecolor=grey_color) + plt.legend() + plt.show() + return (a, b, c, d), distance - def best_fit(self, functions_list, x=None): + def best_fit(self, functions_list, x=None, shows=False): if self.distribution is not None: freq, bins = _cleared_arrays(self.distribution) x = bins - (a, b, c, d), distance = self.fit(functions_list[0], x=x) + (a, b, c, d), distance = self.fit(functions_list[0], x=x, show=shows) min_dist = distance params = (a, b, c, d) final_function = functions_list[0] - for func in functions_dict[1:]: - (a, b, c, d), distance = self.fit(functions_list[0], x=x) + for func in functions_list[1:]: + (a, b, c, d), distance = self.fit(functions_list[0], x=x, show=shows) print(f"{func}: {distance}") if min_dist > distance: min_dist = distance params = (a, b, c, d) final_func = func + print(f"{func} is the new best fitted function") self.best_fitted_function = final_func self.best_fitted_function_params = params return final_func, (a, b, c, d) @@ -429,6 +470,9 @@ def input_retrieve_mode(self, auto_stop=True, max_saves=1000, bin_width=0.1): together.\n Default ``1000`` """ + if self._handle_retrieve_mode is not None: + print("Already in retrieve mode") + return from physt import h1 as hist1 self.distribution = hist1(None, "fixed_width", bin_width=bin_width, adaptive=True) @@ -447,7 +491,7 @@ def training_mode(self): print("Training mode, no longer retrieving the input.") self._handle_retrieve_mode.remove() - def show(self, input_range=None, display=True): + def show(self, input_range=None, fitted_function=True, display=True): """ Show the function using `matplotlib`. @@ -455,9 +499,13 @@ def show(self, input_range=None, display=True): input_range (range): The range to print the function on.\n Default ``None`` + fitted_function (bool): + If ``True``, displays the best fitted function if searched. + Otherwise, returns it. \n + Default ``True`` display (bool): If ``True``, displays the graph. - Otherwise, returns it. \n + Otherwise, returns a dictionary with functions informations. \n Default ``True`` """ freq = None @@ -504,8 +552,14 @@ def show(self, input_range=None, display=True): else: hist_dict = {"bins": bins, "freq": freq, "width": bins[1] - bins[0]} + if "best_fitted_function" not in vars(self) or self.best_fitted_function is None: + fitted_function = None + else: + fitted_function = {"function": self.best_fitted_function, + "params": (a, b, c, d)} return {"hist": hist_dict, - "line": {"x": inputs_np, "y": outputs_np}} + "line": {"x": inputs_np, "y": outputs_np}, + "fitted_function": fitted_function} def _save_input(self, input, output): @@ -518,6 +572,7 @@ def _save_input_auto_stop(self, input, output): if self.inputs_saved > self._max_saves: self.training_mode() + def _cleared_arrays(hist, tolerance=0.001): hist = hist.normalize() freq, bins = hist.numpy_like @@ -529,7 +584,6 @@ def _cleared_arrays(hist, tolerance=0.001): return freq[first:last], bins[first:last - 1] - class AugmentedRational(nn.Module): """ Augmented Rational activation function inherited from ``Rational`` diff --git a/rational/utils/find_init_weights.py b/rational/utils/find_init_weights.py index b4fbac8..6363a12 100644 --- a/rational/utils/find_init_weights.py +++ b/rational/utils/find_init_weights.py @@ -23,7 +23,7 @@ def plot_result(x_array, rational_array, target_array, plt.show() -def append_to_config_file(params, approx_name, w_params, d_params, overright=None): +def append_to_config_file(params, approx_name, w_params, d_params, overwrite=None): rational_full_name = f'Rational_version_{params["version"]}{params["nd"]}/{params["dd"]}' cfd = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) with open(f'{cfd}/rationals_config.json') as json_file: @@ -31,10 +31,10 @@ def append_to_config_file(params, approx_name, w_params, d_params, overright=Non approx_name = approx_name.lower() if rational_full_name in rationals_dict: if approx_name in rationals_dict[rational_full_name]: - if overright is None: - overright = input(f'Rational_{params["version"]} approximation of {approx_name} already exist. \ + if overwrite is None: + overwrite = input(f'Rational_{params["version"]} approximation of {approx_name} already exist. \ \nDo you want to replace it ? (y/n)') in ["y", "yes"] - if not overright: + if not overwrite: print("Parameters not stored") return else: @@ -79,7 +79,7 @@ def typed_input(text, type, choice_list=None): def find_weights(function, function_name=None, degrees=None, bounds=None, - version=None, plot=None, save=None, overright=None): + version=None, plot=None, save=None, overwrite=None): # To be changed by the function you want to approximate if function_name is None: function_name = input("approximated function name: ") @@ -130,7 +130,7 @@ def function_to_approx(x): if save is None: save = input("Do you want to store them in the json file ? (y/n)") in ["y", "yes"] if save: - append_to_config_file(params, function_name, w_params, d_params, overright) + append_to_config_file(params, function_name, w_params, d_params, overwrite) else: print("Parameters not stored") return w_params, d_params diff --git a/rational/utils/utils.py b/rational/utils/utils.py index d7a021c..8c57eec 100644 --- a/rational/utils/utils.py +++ b/rational/utils/utils.py @@ -91,7 +91,7 @@ def find_closest_equivalent(rational_func, new_func, x): def equivalent_func(x_array, a, b, c, d): return a * new_func(c * torch.tensor(x_array) + d) + b - params = curve_fit(equivalent_func, x, y, initials, bounds=(x.min(), x.max())) + params = curve_fit(equivalent_func, x, y, initials) a, b, c, d = params[0] final_func_output = np.array(equivalent_func(x, a, b, c, d)) final_distance = ((y - final_func_output)**2).sum() diff --git a/scripts/compute_all_weights.py b/scripts/compute_all_weights.py index 24e140b..e8be589 100644 --- a/scripts/compute_all_weights.py +++ b/scripts/compute_all_weights.py @@ -13,5 +13,9 @@ def swish(x): print("-" * 30) print(f"Computing weights for {act_n}") print("-" * 30) + show_plot = False + save_in_file = True + overwrite = True for version in versions: - find_weights(act_f, act_n, (5, 4), (-3, 3), version, True, True, True) + find_weights(act_f, act_n, (5, 4), (-3, 3), version, show_plot, + save_in_file, overwrite)