Skip to content

Commit

Permalink
docformatter
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed Oct 17, 2024
1 parent dad0a02 commit 6aa5773
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
15 changes: 10 additions & 5 deletions src/gluonts/torch/model/mq_cnn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ def __init__(
self.feat_static_real_dim = 0

def create_transformation(self) -> Chain:
"""Creates transformation to be applied to input dataset
"""
Creates transformation to be applied to input dataset.
Returns:
Chain:
Expand Down Expand Up @@ -577,7 +578,8 @@ def create_transformation(self) -> Chain:
return Chain(transforms)

def _create_instance_splitter(self, mode: str) -> Chain:
"""Creates instance splitter to be applied to the dataset
"""
Creates instance splitter to be applied to the dataset.
Args:
mode (str): `training`, `validation` or `test`
Expand Down Expand Up @@ -676,7 +678,8 @@ def create_training_data_loader(
shuffle_buffer_length: Optional[int] = None,
**kwargs,
) -> Iterable:
"""Creates data loader for the training dataset
"""
Creates data loader for the training dataset.
Args:
data (Dataset): training dataset
Expand Down Expand Up @@ -706,7 +709,8 @@ def create_validation_data_loader(
module: MQCNNLightningModule,
**kwargs,
) -> Iterable:
"""Creates data loader for the validation dataset
"""
Creates data loader for the validation dataset.
Args:
data (Dataset): validation dataset
Expand Down Expand Up @@ -769,7 +773,8 @@ def create_predictor(
transformation: Transformation,
module: MQCNNLightningModule,
) -> PyTorchPredictor:
"""Creates predictor for inference
"""
Creates predictor for inference.
Args:
transformation (Transformation): transformation to be applied to data input to predictor
Expand Down
8 changes: 5 additions & 3 deletions src/gluonts/torch/model/mq_cnn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,9 @@ def forward(

class Enc2Dec(nn.Module):
"""
Integrates the encoder_output_static, encoder_output_dynamic and future_features_dynamic
and passes them through as the dynamic input to the decoder.
Integrates the encoder_output_static, encoder_output_dynamic and
future_features_dynamic and passes them through as the dynamic input to the
decoder.
Parameters:
------------
Expand Down Expand Up @@ -439,7 +440,8 @@ def _get_local_mlp(self, init_dim, final_dim, hidden_dimension_seq):
return local_mlp

def forward(self, encoded_input: Tensor, future_input: Tensor) -> Tensor:
"""Forward pass for MQCNN decoder
"""
Forward pass for MQCNN decoder.
Args:
encoded_input (Tensor):
Expand Down
7 changes: 5 additions & 2 deletions src/gluonts/transform/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,9 @@ def flatmap_transform(


class ForkingSequenceSplitter(FlatMapTransformation):
"""Forking sequence splitter used by MQ-CNN Model"""
"""
Forking sequence splitter used by MQ-CNN Model.
"""

@validated()
def __init__(
Expand All @@ -597,7 +599,8 @@ def __init__(
start_input_field: str = FieldName.TARGET,
lead_time: int = 0,
) -> None:
"""Creates forking sequences
"""
Creates forking sequences.
Args:
instance_sampler ([type]):
Expand Down

0 comments on commit 6aa5773

Please sign in to comment.