Skip to content

Commit

Permalink
names
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory committed Sep 12, 2024
1 parent 5357f8e commit f209d97
Showing 1 changed file with 14 additions and 26 deletions.
40 changes: 14 additions & 26 deletions neuralprophet/utils_time_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,13 @@ def stack_lagged_regressors(self, df_tensors, feature_list, current_idx, config)
return current_idx + num_features
return current_idx

def stack_additive_events(
self,
df_tensors,
feature_list,
current_idx,
additive_event_and_holiday_names,
):
def stack_additive_events(self, df_tensors, feature_list, current_idx, names):
"""
Stack the additive event and holiday features.
"""
if additive_event_and_holiday_names:
if names:
additive_events_tensor = torch.cat(
[df_tensors[name].unsqueeze(-1) for name in additive_event_and_holiday_names],
[df_tensors[name].unsqueeze(-1) for name in names],
dim=1,
)
feature_list.append(additive_events_tensor)
Expand All @@ -264,16 +258,12 @@ def stack_additive_events(
return current_idx + additive_events_tensor.size(1)
return current_idx

def stack_multiplicative_events(
self, df_tensors, feature_list, current_idx, multiplicative_event_and_holiday_names
):
def stack_multiplicative_events(self, df_tensors, feature_list, current_idx, names):
"""
Stack the multiplicative event and holiday features.
"""
if multiplicative_event_and_holiday_names:
multiplicative_events_tensor = torch.cat(
[df_tensors[name].unsqueeze(-1) for name in multiplicative_event_and_holiday_names], dim=1
)
if names:
multiplicative_events_tensor = torch.cat([df_tensors[name].unsqueeze(-1) for name in names], dim=1)
feature_list.append(multiplicative_events_tensor)
self.feature_indices["multiplicative_events"] = (
current_idx,
Expand All @@ -282,14 +272,12 @@ def stack_multiplicative_events(
return current_idx + multiplicative_events_tensor.size(1)
return current_idx

def stack_additive_regressors(self, df_tensors, feature_list, current_idx, additive_regressors_names):
def stack_additive_regressors(self, df_tensors, feature_list, current_idx, names):
"""
Stack the additive regressor features.
"""
if additive_regressors_names:
additive_regressors_tensor = torch.cat(
[df_tensors[name].unsqueeze(-1) for name in additive_regressors_names], dim=1
)
if names:
additive_regressors_tensor = torch.cat([df_tensors[name].unsqueeze(-1) for name in names], dim=1)
feature_list.append(additive_regressors_tensor)
self.feature_indices["additive_regressors"] = (
current_idx,
Expand All @@ -298,20 +286,20 @@ def stack_additive_regressors(self, df_tensors, feature_list, current_idx, addit
return current_idx + additive_regressors_tensor.size(1)
return current_idx

def stack_multiplicative_regressors(self, df_tensors, feature_list, current_idx, multiplicative_regressors_names):
def stack_multiplicative_regressors(self, df_tensors, feature_list, current_idx, names):
"""
Stack the multiplicative regressor features.
"""
if multiplicative_regressors_names:
if names:
multiplicative_regressors_tensor = torch.cat(
[df_tensors[name].unsqueeze(-1) for name in multiplicative_regressors_names], dim=1
[df_tensors[name].unsqueeze(-1) for name in names], dim=1
) # Shape: [batch_size, num_multiplicative_regressors, 1]
feature_list.append(multiplicative_regressors_tensor)
self.feature_indices["multiplicative_regressors"] = (
current_idx,
current_idx + len(multiplicative_regressors_names) - 1,
current_idx + len(names) - 1,
)
return current_idx + len(multiplicative_regressors_names)
return current_idx + len(names)
return current_idx

def stack_seasonalities(self, df_tensors, feature_list, current_idx, config, seasonalities):
Expand Down

0 comments on commit f209d97

Please sign in to comment.