Skip to content

Commit 339fb03

Browse files
committed
add LayerList
1 parent fd6b260 commit 339fb03

File tree

4 files changed

+183
-29
lines changed

4 files changed

+183
-29
lines changed

tensorlayer/initializers/paddle_initializers.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,56 @@
99
import paddle
1010

1111
__all__ = [
12-
'Zeros', 'Ones', 'Constant', 'RandomUniform', 'RandomNormal', 'TruncatedNormal',
12+
'Initializer', 'Zeros', 'Ones', 'Constant', 'RandomUniform', 'RandomNormal', 'TruncatedNormal',
1313
'deconv2d_bilinear_upsampling_initializer', 'HeNormal'
1414
]
1515

16+
class Initializer(object):
17+
"""Initializer base class: all initializers inherit from this class.
18+
"""
19+
20+
def __call__(self, shape, dtype=None):
21+
"""Returns a tensor object initialized as specified by the initializer.
22+
23+
Parameters
24+
----------
25+
shape : tuple of int.
26+
The shape of the tensor.
27+
dtype : Optional dtype of the tensor.
28+
If not provided will return tensor of `tl.float32`.
29+
30+
Returns
31+
-------
32+
33+
"""
34+
raise NotImplementedError
35+
36+
def get_config(self):
37+
"""Returns the configuration of the initializer as a JSON-serializable dict.
38+
39+
Returns
40+
-------
41+
A JSON-serializable Python dict.
42+
"""
43+
return {}
44+
45+
@classmethod
46+
def from_config(cls, config):
47+
"""Instantiates an initializer from a configuration dictionary.
48+
49+
Parameters
50+
----------
51+
config : A python dictionary.
52+
It will typically be the output of `get_config`.
53+
54+
Returns
55+
-------
56+
An Initializer instance.
57+
"""
58+
if 'dtype' in config:
59+
config.pop('dtype')
60+
return cls(**config)
61+
1662

1763
class Zeros(ConstantInitializer):
1864
"""Initializer that generates tensors initialized to 0.

tensorlayer/layers/core/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,5 @@
99
from .core_tensorflow import *
1010
elif BACKEND == 'paddle':
1111
from .core_paddle import *
12-
elif BACKEND == 'dragon':
13-
from .core_dragon import *
1412
else:
1513
raise ("Unsupported backend:", BACKEND)

tensorlayer/layers/core/core_tensorflow.py

Lines changed: 136 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tensorlayer.layers.utils import (get_variable_with_initializer)
1010
from tensorlayer import logging
1111

12-
__all__ = ['Module', 'SequentialLayer']
12+
__all__ = ['Module', 'SequentialLayer', 'LayerList']
1313

1414
_global_layer_name_dict = {}
1515
Parameter_ = tf.Variable
@@ -606,19 +606,19 @@ def __init__(self, *args):
606606
def __getitem__(self, index):
607607
if isinstance(index, slice):
608608
return self.__class__(OrderedDict(list(self._layers.items())[index]))
609-
index = self._valid_index(len(self), index)
609+
index = _valid_index(len(self), index)
610610
return list(self._layers.values())[index]
611611

612612
def __setitem__(self, index, layer):
613-
if self._valid_module(layer):
614-
index = self._valid_index(len(self), index)
613+
if _valid_module(layer):
614+
index = _valid_index(len(self), index)
615615
key = list(self._layers.keys())[index]
616616
self._layers[key] = layer
617617
self.layer_list = list(self._layers.values())
618618

619619
def __delitem__(self, index):
620620
if isinstance(index, int):
621-
index = self._valid_index(len(self), index)
621+
index = _valid_index(len(self), index)
622622
key = list(self._layers.keys())[index]
623623
del self._layers[key]
624624
elif isinstance(index, slice):
@@ -633,7 +633,7 @@ def __len__(self):
633633
return len(self._layers)
634634

635635
def append(self, layer):
636-
if self._valid_module(layer):
636+
if _valid_module(layer):
637637
self._layers[str(len(self))] = layer
638638
self.layer_list = list(self._layers.values())
639639
return self
@@ -646,16 +646,133 @@ def forward(self, input_data):
646646
input_data = layer(input_data)
647647
return input_data
648648

649-
def _valid_index(self, layer_num, index):
650-
if not isinstance(index, int):
651-
raise TypeError("Index {} is not int type")
652-
if not -layer_num <= index < layer_num:
653-
raise IndexError(
654-
"Index should be a number in range [{}, {}), but got {}".format(-layer_num, layer_num, index)
655-
)
656-
return index % layer_num
657-
658-
def _valid_module(self, layer):
659-
if issubclass(layer.__class__, Module):
660-
return True
661-
raise TypeError('Module {} is not subclass of Module'.format(layer))
649+
650+
class LayerList(Module):
651+
"""
652+
Holds Modules in a list.
653+
654+
LayerList can be used like a regular Python list, support
655+
'__getitem__', '__setitem__', '__delitem__', '__len__', '__iter__' and '__iadd__',
656+
but module it contains are properly registered, and will be visible by all Modules methods.
657+
658+
Parameters
659+
----------
660+
args : list
661+
List of subclass of Module.
662+
Methods
663+
---------
664+
__init__()
665+
Initializing the Layer.
666+
insert()
667+
Inserts a given layer before a given index in the list.
668+
extend()
669+
Appends layers from a Python iterable to the end of the list.
670+
append()
671+
Appends a given layer to the end of the list.
672+
673+
Examples
674+
---------
675+
Args:
676+
args (list, optional): List of subclass of Module.
677+
678+
Examples:
679+
680+
"""
681+
def __init__(self, *args, **kwargs):
682+
super(LayerList, self).__init__()
683+
if len(args) == 1:
684+
self.extend(args[0])
685+
686+
def __getitem__(self, index):
687+
if isinstance(index, slice):
688+
return self.__class__(list(self._layers.values())[index])
689+
if isinstance(index, int):
690+
index = _valid_index(len(self), index)
691+
return self._layers[str(index)]
692+
raise TypeError('Index {} is not int type or slice type'.format(index))
693+
694+
def __setitem__(self, index, layer):
695+
if not isinstance(index, int) and _valid_module(layer):
696+
raise TypeError('Index {} is not int type'.format(index))
697+
index = _valid_index(len(self), index)
698+
self._layers[str(index)] = layer
699+
700+
def __delitem__(self, index):
701+
if isinstance(index, int):
702+
index = _valid_index(len(self), index)
703+
del self._layers[str(index)]
704+
elif isinstance(index, slice):
705+
keys = list(self._layers.keys())[index]
706+
for key in keys:
707+
del self._layers[key]
708+
else:
709+
raise TypeError('Index {} is not int type or slice type'.format(index))
710+
temp_dict = OrderedDict()
711+
for idx, layer in enumerate(self._layers.values()):
712+
temp_dict[str(idx)] = layer
713+
self._layers = temp_dict
714+
715+
def __len__(self):
716+
return len(self._layers)
717+
718+
def __iter__(self):
719+
return iter(self._layers.values())
720+
721+
def __iadd__(self, layers):
722+
self.extend(layers)
723+
return self
724+
725+
def insert(self, index, layer):
726+
"""
727+
Inserts a given layer before a given index in the list.
728+
729+
"""
730+
731+
idx = _valid_index(len(self), index)
732+
_valid_module(layer)
733+
length = len(self)
734+
while length > idx:
735+
self._layers[str(length)] = self._layers[str(length - 1)]
736+
length -= 1
737+
self._layers[str(idx)] = layer
738+
739+
def extend(self, layers):
740+
"""
741+
Appends layers from a Python iterable to the end of the list.
742+
743+
"""
744+
745+
if not isinstance(layers, list):
746+
raise TypeError('Modules {} should be list of sublayers'.format(layers))
747+
for layer in layers:
748+
if _valid_module(layer):
749+
self._layers[str(len(self))] = layer
750+
return self
751+
752+
def append(self, layer):
753+
"""
754+
Appends a given layer to the end of the list.
755+
756+
"""
757+
758+
if _valid_module(layer):
759+
self._layers[str(len(self))] = layer
760+
761+
def forward(self, *inputs):
762+
raise NotImplementedError
763+
764+
765+
def _valid_index(layer_num, index):
766+
if not isinstance(index, int):
767+
raise TypeError("Index {} is not int type")
768+
if not -layer_num <= index < layer_num:
769+
raise IndexError(
770+
"Index should be a number in range [{}, {}), but got {}".format(-layer_num, layer_num, index)
771+
)
772+
return index % layer_num
773+
774+
775+
def _valid_module(layer):
776+
if issubclass(layer.__class__, Module):
777+
return True
778+
raise TypeError('Module {} is not subclass of Module'.format(layer))

tensorlayer/layers/deprecated.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -416,13 +416,6 @@ def TimeDistributedLayer(*args, **kwargs):
416416
raise NonExistingLayerError("TimeDistributedLayer is removed for TF 2.0, please use eager mode instead." + __log__)
417417

418418

419-
__all__ += ['LayerList']
420-
421-
422-
def LayerList(*args, **kwargs):
423-
raise NonExistingLayerError("LayerList(list)(input_data) --> SequentialLayer(list)(input_data)" + __log__)
424-
425-
426419
__all__ += ['ModelLayer']
427420

428421

0 commit comments

Comments
 (0)