9
9
from tensorlayer .layers .utils import (get_variable_with_initializer )
10
10
from tensorlayer import logging
11
11
12
- __all__ = ['Module' , 'SequentialLayer' ]
12
+ __all__ = ['Module' , 'SequentialLayer' , 'LayerList' ]
13
13
14
14
_global_layer_name_dict = {}
15
15
Parameter_ = tf .Variable
@@ -606,19 +606,19 @@ def __init__(self, *args):
606
606
def __getitem__ (self , index ):
607
607
if isinstance (index , slice ):
608
608
return self .__class__ (OrderedDict (list (self ._layers .items ())[index ]))
609
- index = self . _valid_index (len (self ), index )
609
+ index = _valid_index (len (self ), index )
610
610
return list (self ._layers .values ())[index ]
611
611
612
612
def __setitem__ (self , index , layer ):
613
- if self . _valid_module (layer ):
614
- index = self . _valid_index (len (self ), index )
613
+ if _valid_module (layer ):
614
+ index = _valid_index (len (self ), index )
615
615
key = list (self ._layers .keys ())[index ]
616
616
self ._layers [key ] = layer
617
617
self .layer_list = list (self ._layers .values ())
618
618
619
619
def __delitem__ (self , index ):
620
620
if isinstance (index , int ):
621
- index = self . _valid_index (len (self ), index )
621
+ index = _valid_index (len (self ), index )
622
622
key = list (self ._layers .keys ())[index ]
623
623
del self ._layers [key ]
624
624
elif isinstance (index , slice ):
@@ -633,7 +633,7 @@ def __len__(self):
633
633
return len (self ._layers )
634
634
635
635
def append (self , layer ):
636
- if self . _valid_module (layer ):
636
+ if _valid_module (layer ):
637
637
self ._layers [str (len (self ))] = layer
638
638
self .layer_list = list (self ._layers .values ())
639
639
return self
@@ -646,16 +646,133 @@ def forward(self, input_data):
646
646
input_data = layer (input_data )
647
647
return input_data
648
648
649
- def _valid_index (self , layer_num , index ):
650
- if not isinstance (index , int ):
651
- raise TypeError ("Index {} is not int type" )
652
- if not - layer_num <= index < layer_num :
653
- raise IndexError (
654
- "Index should be a number in range [{}, {}), but got {}" .format (- layer_num , layer_num , index )
655
- )
656
- return index % layer_num
657
-
658
- def _valid_module (self , layer ):
659
- if issubclass (layer .__class__ , Module ):
660
- return True
661
- raise TypeError ('Module {} is not subclass of Module' .format (layer ))
649
+
650
+ class LayerList (Module ):
651
+ """
652
+ Holds Modules in a list.
653
+
654
+ LayerList can be used like a regular Python list, support
655
+ '__getitem__', '__setitem__', '__delitem__', '__len__', '__iter__' and '__iadd__',
656
+ but module it contains are properly registered, and will be visible by all Modules methods.
657
+
658
+ Parameters
659
+ ----------
660
+ args : list
661
+ List of subclass of Module.
662
+ Methods
663
+ ---------
664
+ __init__()
665
+ Initializing the Layer.
666
+ insert()
667
+ Inserts a given layer before a given index in the list.
668
+ extend()
669
+ Appends layers from a Python iterable to the end of the list.
670
+ append()
671
+ Appends a given layer to the end of the list.
672
+
673
+ Examples
674
+ ---------
675
+ Args:
676
+ args (list, optional): List of subclass of Module.
677
+
678
+ Examples:
679
+
680
+ """
681
+ def __init__ (self , * args , ** kwargs ):
682
+ super (LayerList , self ).__init__ ()
683
+ if len (args ) == 1 :
684
+ self .extend (args [0 ])
685
+
686
+ def __getitem__ (self , index ):
687
+ if isinstance (index , slice ):
688
+ return self .__class__ (list (self ._layers .values ())[index ])
689
+ if isinstance (index , int ):
690
+ index = _valid_index (len (self ), index )
691
+ return self ._layers [str (index )]
692
+ raise TypeError ('Index {} is not int type or slice type' .format (index ))
693
+
694
+ def __setitem__ (self , index , layer ):
695
+ if not isinstance (index , int ) and _valid_module (layer ):
696
+ raise TypeError ('Index {} is not int type' .format (index ))
697
+ index = _valid_index (len (self ), index )
698
+ self ._layers [str (index )] = layer
699
+
700
+ def __delitem__ (self , index ):
701
+ if isinstance (index , int ):
702
+ index = _valid_index (len (self ), index )
703
+ del self ._layers [str (index )]
704
+ elif isinstance (index , slice ):
705
+ keys = list (self ._layers .keys ())[index ]
706
+ for key in keys :
707
+ del self ._layers [key ]
708
+ else :
709
+ raise TypeError ('Index {} is not int type or slice type' .format (index ))
710
+ temp_dict = OrderedDict ()
711
+ for idx , layer in enumerate (self ._layers .values ()):
712
+ temp_dict [str (idx )] = layer
713
+ self ._layers = temp_dict
714
+
715
+ def __len__ (self ):
716
+ return len (self ._layers )
717
+
718
+ def __iter__ (self ):
719
+ return iter (self ._layers .values ())
720
+
721
+ def __iadd__ (self , layers ):
722
+ self .extend (layers )
723
+ return self
724
+
725
+ def insert (self , index , layer ):
726
+ """
727
+ Inserts a given layer before a given index in the list.
728
+
729
+ """
730
+
731
+ idx = _valid_index (len (self ), index )
732
+ _valid_module (layer )
733
+ length = len (self )
734
+ while length > idx :
735
+ self ._layers [str (length )] = self ._layers [str (length - 1 )]
736
+ length -= 1
737
+ self ._layers [str (idx )] = layer
738
+
739
+ def extend (self , layers ):
740
+ """
741
+ Appends layers from a Python iterable to the end of the list.
742
+
743
+ """
744
+
745
+ if not isinstance (layers , list ):
746
+ raise TypeError ('Modules {} should be list of sublayers' .format (layers ))
747
+ for layer in layers :
748
+ if _valid_module (layer ):
749
+ self ._layers [str (len (self ))] = layer
750
+ return self
751
+
752
+ def append (self , layer ):
753
+ """
754
+ Appends a given layer to the end of the list.
755
+
756
+ """
757
+
758
+ if _valid_module (layer ):
759
+ self ._layers [str (len (self ))] = layer
760
+
761
+ def forward (self , * inputs ):
762
+ raise NotImplementedError
763
+
764
+
765
+ def _valid_index (layer_num , index ):
766
+ if not isinstance (index , int ):
767
+ raise TypeError ("Index {} is not int type" )
768
+ if not - layer_num <= index < layer_num :
769
+ raise IndexError (
770
+ "Index should be a number in range [{}, {}), but got {}" .format (- layer_num , layer_num , index )
771
+ )
772
+ return index % layer_num
773
+
774
+
775
+ def _valid_module (layer ):
776
+ if issubclass (layer .__class__ , Module ):
777
+ return True
778
+ raise TypeError ('Module {} is not subclass of Module' .format (layer ))
0 commit comments