1919
2020from absl import logging
2121from dataclasses import dataclass
22- from typing import List , Dict , Tuple , Any
22+ from typing import Iterable , List , Dict , Tuple , Any
2323
2424import json
2525import os
@@ -45,7 +45,7 @@ def __init__(self):
4545 self ._ranges = {}
4646
4747 def __call__ (self ,
48- module_specs : List [ModuleSpec ],
48+ module_specs : Tuple [ModuleSpec ],
4949 k : int ,
5050 n : int = 20 ) -> List [ModuleSpec ]:
5151 """
@@ -86,20 +86,23 @@ def __init__(self,
8686 data_path : str ,
8787 additional_flags : Tuple [str , ...] = (),
8888 delete_flags : Tuple [str , ...] = ()):
89- self ._module_specs = _build_modulespecs_from_datapath (
90- data_path = data_path ,
91- additional_flags = additional_flags ,
92- delete_flags = delete_flags )
89+ self .module_specs = tuple (
90+ sorted (
91+ _build_modulespecs_from_datapath (
92+ data_path = data_path ,
93+ additional_flags = additional_flags ,
94+ delete_flags = delete_flags ),
95+ key = lambda m : m .size ,
96+ reverse = True ))
9397 self ._root_dir = data_path
94- self ._module_specs .sort (key = lambda m : m .size , reverse = True )
9598
9699 @classmethod
97- def from_module_specs (cls , module_specs : List [ModuleSpec ]):
100+ def from_module_specs (cls , module_specs : Iterable [ModuleSpec ]):
98101 """Construct a Corpus from module specs. Mostly for testing purposes."""
99102 cps = cls .__new__ (cls ) # Avoid calling __init__
100103 super (cls , cps ).__init__ ()
101- cps ._module_specs = list ( module_specs ) # Don't mutate the original list.
102- cps . _module_specs . sort ( key = lambda m : m .size , reverse = True )
104+ cps .module_specs = tuple (
105+ sorted ( module_specs , key = lambda m : m .size , reverse = True ) )
103106 cps .root_dir = None
104107 return cps
105108
@@ -110,23 +113,21 @@ def sample(self,
110113 """Samples `k` module_specs, optionally sorting by size descending."""
111114 # Note: sampler is intentionally defaulted to a mutable object, as the
112115 # only mutable attribute of SamplerBucketRoundRobin is its range cache.
113- k = min (len (self ._module_specs ), k )
116+ k = min (len (self .module_specs ), k )
114117 if k < 1 :
115118 raise ValueError ('Attempting to sample <1 module specs from corpus.' )
116- sampled_specs = sampler (self ._module_specs , k = k )
119+ sampled_specs = sampler (self .module_specs , k = k )
117120 if sort :
118121 sampled_specs .sort (key = lambda m : m .size , reverse = True )
119122 return sampled_specs
120123
121124 def filter (self , p : re .Pattern ):
122125 """Filters module specs, keeping those which match the provided pattern."""
123- self ._module_specs = [ms for ms in self ._module_specs if p .match (ms .name )]
124-
125- def get_modules_copy (self ):
126- return list (self ._module_specs )
126+ self .module_specs = tuple (
127+ ms for ms in self .module_specs if p .match (ms .name ))
127128
128129 def __len__ (self ):
129- return len (self ._module_specs )
130+ return len (self .module_specs )
130131
131132
132133def _build_modulespecs_from_datapath (
0 commit comments