Skip to content

Commit

Permalink
Nils nice idea of using max inside the pytorch rat functions
Browse files Browse the repository at this point in the history
  • Loading branch information
k4ntz committed Dec 8, 2020
1 parent ff443e3 commit 52bf68d
Showing 1 changed file with 21 additions and 38 deletions.
59 changes: 21 additions & 38 deletions rational/torch/rational_pytorch_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,34 @@ def Rational_PYTORCH_A_F(x, weight_numerator, weight_denominator, training):

z = x.view(-1)
len_num, len_deno = len(weight_numerator), len(weight_denominator)
if len_num > len_deno:
xps = torch.vander(z, len_num, increasing=True)
numerator = xps.mul(weight_numerator).sum(1)

expanded_dw = torch.cat([torch.tensor([1.]), weight_denominator, \
torch.zeros(len_num - len_deno - 1)])
denominator = xps.mul(expanded_dw).abs().sum(1)
return numerator.div(denominator).view(x.shape)
else:
print("Not implemented yet")
exit(1)
xps = torch.vander(z, max(len_num, len_deno), increasing=True)
numerator = xps.mul(weight_numerator).sum(1)
expanded_dw = torch.cat([torch.tensor([1.]), weight_denominator, \
torch.zeros(len_num - len_deno - 1)])
denominator = xps.mul(expanded_dw).abs().sum(1)
return numerator.div(denominator).view(x.shape)


def Rational_PYTORCH_B_F(x, weight_numerator, weight_denominator, training):
# P(X) / Q(X) = a_0 + a_1 * X + ... + a_n * X^n /
# 1 + |b_1 * X + b_1 * X^2 + ... + b_m * X^m|
z = x.view(-1)
len_num, len_deno = len(weight_numerator), len(weight_denominator)
if len_num > len_deno:
xps = torch.vander(z, len_num, increasing=True)
numerator = xps.mul(weight_numerator).sum(1)
denominator = xps[:, 1:len_deno+1].mul(weight_denominator).sum(1).abs()
return numerator.div(1 + denominator).view(x.shape)
else:
print("Not implemented yet")
exit(1)
xps = torch.vander(z, max(len_num, len_deno), increasing=True)
numerator = xps.mul(weight_numerator).sum(1)
denominator = xps[:, 1:len_deno+1].mul(weight_denominator).sum(1).abs()
return numerator.div(1 + denominator).view(x.shape)


def Rational_PYTORCH_C_F(x, weight_numerator, weight_denominator, training):
# P(X) / Q(X) = a_0 + a_1 * X + ... + a_n * X^n /
# eps + |b_0 + b1 * X + b_2 * X^2 + ... + b_m*X^m|
z = x.view(-1)
len_num, len_deno = len(weight_numerator), len(weight_denominator)
if len_num > len_deno:
xps = torch.vander(z, len_num, increasing=True)
numerator = xps.mul(weight_numerator).sum(1)
denominator = xps[:, :len_deno].mul(weight_denominator).sum(1).abs()
return numerator.div(0.1 + denominator).view(x.shape)
else:
print("Not implemented yet")
exit(1)
xps = torch.vander(z, max(len_num, len_deno), increasing=True)
numerator = xps.mul(weight_numerator).sum(1)
denominator = xps[:, :len_deno].mul(weight_denominator).sum(1).abs()
return numerator.div(0.1 + denominator).view(x.shape)


def Rational_PYTORCH_D_F(x, weight_numerator, weight_denominator, training, random_deviation=0.1):
Expand All @@ -58,14 +45,10 @@ def Rational_PYTORCH_D_F(x, weight_numerator, weight_denominator, training, rand
return Rational_PYTORCH_B_F(x, weight_numerator, weight_denominator, training)
z = x.view(-1)
len_num, len_deno = len(weight_numerator), len(weight_denominator)
if len_num > len_deno:
xps = torch.vander(z, len_num, increasing=True)
numerator = xps.mul(weight_numerator.mul(
torch.FloatTensor(len_num).uniform_(1-random_deviation,
1+random_deviation))
).sum(1)
denominator = xps[:, 1:len_deno+1].mul(weight_denominator).sum(1).abs()
return numerator.div(1 + denominator).view(x.shape)
else:
print("Not implemented yet")
exit(1)
xps = torch.vander(z, max(len_num, len_deno), increasing=True)
numerator = xps.mul(weight_numerator.mul(
torch.FloatTensor(len_num).uniform_(1-random_deviation,
1+random_deviation))
).sum(1)
denominator = xps[:, 1:len_deno+1].mul(weight_denominator).sum(1).abs()
return numerator.div(1 + denominator).view(x.shape)

0 comments on commit 52bf68d

Please sign in to comment.