diff --git a/layers.py b/layers.py index e9df9c3..44d3664 100644 --- a/layers.py +++ b/layers.py @@ -198,7 +198,7 @@ def forward(self, vxin, hxin, h): hx = hx + self.tohori(vx) if self.gates: - vx = self.gate(vx, h, (self.vvf, self.vvg)) + vx = self.gate(vx, h, (self.vvf, self.vvg)) hx = self.gate(hx, h, (self.vhf, self.vhg)) if self.res_connection: diff --git a/models.py b/models.py index 9e7a368..24efd3c 100644 --- a/models.py +++ b/models.py @@ -3,6 +3,10 @@ from layers import * class Gated(nn.Module): + """ + Model combining several gated pixelCNN layers with a conditional input (usually the class) + + """ def __init__(self, input_size, conditional_size, channels, num_layers, k=7, padding=3): super().__init__()