Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions FWCore/ParameterSet/python/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,7 +1990,9 @@ def __init__(self,*arg,**args):

def testProcessDumpPython(self):
self.assertEqual(Process("test").dumpPython(),
"""import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")
"""import FWCore.ParameterSet.Config as cms

process = cms.Process("test")

process.maxEvents = cms.untracked.PSet(
input = cms.optional.untracked.int32,
Expand All @@ -2011,7 +2013,7 @@ def testProcessDumpPython(self):
emptyRunLumiMode = cms.obsolete.untracked.string,
eventSetup = cms.untracked.PSet(
forceNumberOfConcurrentIOVs = cms.untracked.PSet(

allowAnyLabel_=cms.required.untracked.uint32
),
numberOfConcurrentIOVs = cms.untracked.uint32(1)
),
Expand Down
4 changes: 4 additions & 0 deletions FWCore/ParameterSet/python/Mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def dumpPython(self, options=PrintOptions()):
# usings need to go first
resultList = usings
resultList.extend(others)
if self.__validator is not None:
options.indent()
resultList.append(options.indentation()+"allowAnyLabel_="+self.__validator.dumpPython(options))
options.unindent()
return ',\n'.join(resultList)+'\n'
def __repr__(self):
return self.dumpPython()
Expand Down
73 changes: 71 additions & 2 deletions FWCore/ParameterSet/python/Types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __setattr__(self,name, value):
if v is not None:
return setattr(v,name,value)
else:
if not name.startswith('_'):
raise AttributeError("%r object has no attribute %r" % (self.__class__.__name__, name))
return object.__setattr__(self, name, value)
def __bool__(self):
v = self.__dict__.get('_ProxyParameter__value',None)
Expand All @@ -71,7 +73,9 @@ def dumpPython(self, options=PrintOptions()):
v = "cms."+self._dumpPythonName()
if not _ParameterTypeBase.isTracked(self):
v+=".untracked"
return v+'.'+self.__type.__name__
if hasattr(self.__type, "__name__"):
return v+'.'+self.__type.__name__
return v+'.'+self.__type.dumpPython(options)
def validate_(self,value):
return isinstance(value,self.__type)
def convert_(self,value):
Expand Down Expand Up @@ -138,6 +142,19 @@ def __call__(self,value):
raise RuntimeError("Cannot convert "+str(value)+" to 'allowed' type")
return chosenType(value)

class _PSetTemplate(object):
def __init__(self, *args, **kargs):
self._pset = PSet(*args,**kargs)
self.__dict__['_PSetTemplate__value'] = None
def __call__(self, value):
self.__dict__
return self._pset.clone(**value)
def dumpPython(self, options=PrintOptions()):
v =self.__dict__.get('_ProxyParameter__value',None)
if v is not None:
return v.dumpPython(options)
return "PSetTemplate(\n"+_Parameterizable.dumpPython(self._pset, options)+options.indentation()+")"


class _ProxyParameterFactory(object):
"""Class type for ProxyParameter types to allow nice syntax"""
Expand All @@ -160,7 +177,17 @@ def __call__(self, *args):
return self.type(_AllowedParameterTypes(*args))

return _AllowedWrapper(self.__isUntracked, self.__type)

if name == 'PSetTemplate':
class _PSetTemplateWrapper(object):
def __init__(self, untracked, type):
self.untracked = untracked
self.type = type
def __call__(self,*args,**kargs):
if self.untracked:
return untracked(self.type(_PSetTemplate(*args,**kargs)))
return self.type(_PSetTemplate(*args,**kargs))
return _PSetTemplateWrapper(self.__isUntracked, self.__type)

type = globals()[name]
if not issubclass(type, _ParameterTypeBase):
raise AttributeError
Expand Down Expand Up @@ -1859,6 +1886,27 @@ def testRequired(self):
self.assertEqual(p1.foo.value(),3)
self.failIf(p1.foo.isTracked())
self.assertRaises(ValueError,setattr,p1, 'bar', 'bad')
#PSetTemplate use
p1 = PSet(aPSet = required.PSetTemplate())
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.required.PSetTemplate(\n\n )\n)')
p1.aPSet = dict()
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.PSet(\n\n )\n)')
p1 = PSet(aPSet=required.PSetTemplate(a=required.int32))
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.required.PSetTemplate(\n a = cms.required.int32\n )\n)')
p1.aPSet = dict(a=5)
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.PSet(\n a = cms.int32(5)\n )\n)')
self.assertEqual(p1.aPSet.a.value(), 5)
p1 = PSet(aPSet=required.untracked.PSetTemplate(a=required.int32))
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.required.untracked.PSetTemplate(\n a = cms.required.int32\n )\n)')
p1.aPSet = dict(a=5)
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.untracked.PSet(\n a = cms.int32(5)\n )\n)')
self.assertEqual(p1.aPSet.a.value(), 5)
p1 = PSet(allowAnyLabel_=required.PSetTemplate(a=required.int32))
self.assertEqual(p1.dumpPython(), 'cms.PSet(\n allowAnyLabel_=cms.required.PSetTemplate(\n a = cms.required.int32\n )\n)')
p1.foo = dict(a=5)
self.assertEqual(p1.dumpPython(),'cms.PSet(\n foo = cms.PSet(\n a = cms.int32(5)\n ),\n allowAnyLabel_=cms.required.PSetTemplate(\n a = cms.required.int32\n )\n)')
self.assertEqual(p1.foo.a.value(), 5)

def testOptional(self):
p1 = PSet(anInt = optional.int32)
self.assert_(hasattr(p1,"anInt"))
Expand Down Expand Up @@ -1887,6 +1935,27 @@ def testOptional(self):
self.failIf(p1.f)
p1.f.append(3)
self.assert_(p1.f)
#PSetTemplate use
p1 = PSet(aPSet = optional.PSetTemplate())
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.optional.PSetTemplate(\n\n )\n)')
p1.aPSet = dict()
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.PSet(\n\n )\n)')
p1 = PSet(aPSet=optional.PSetTemplate(a=optional.int32))
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.optional.PSetTemplate(\n a = cms.optional.int32\n )\n)')
p1.aPSet = dict(a=5)
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.PSet(\n a = cms.int32(5)\n )\n)')
self.assertEqual(p1.aPSet.a.value(), 5)
p1 = PSet(aPSet=optional.untracked.PSetTemplate(a=optional.int32))
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.optional.untracked.PSetTemplate(\n a = cms.optional.int32\n )\n)')
p1.aPSet = dict(a=5)
self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.untracked.PSet(\n a = cms.int32(5)\n )\n)')
self.assertEqual(p1.aPSet.a.value(), 5)
p1 = PSet(allowAnyLabel_=optional.PSetTemplate(a=optional.int32))
self.assertEqual(p1.dumpPython(), 'cms.PSet(\n allowAnyLabel_=cms.optional.PSetTemplate(\n a = cms.optional.int32\n )\n)')
p1.foo = dict(a=5)
self.assertEqual(p1.dumpPython(),'cms.PSet(\n foo = cms.PSet(\n a = cms.int32(5)\n ),\n allowAnyLabel_=cms.optional.PSetTemplate(\n a = cms.optional.int32\n )\n)')
self.assertEqual(p1.foo.a.value(), 5)


def testAllowed(self):
p1 = PSet(aValue = required.allowed(int32, string))
Expand Down