Skip to content

Commit cd5e853

Browse files
committed
enh: enabled OrCategory and kw filter for sub-classes
The OrCategory follows regular Python semantics, 1 or 2 == 1. Now one can do AtomCategory.kw(Z=5, neighbours={"min":2}) Thanks to Jonas for these suggestions.
1 parent 60fcf1f commit cd5e853

File tree

1 file changed

+52
-5
lines changed

1 file changed

+52
-5
lines changed

sisl/category/base.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,55 @@ def set_name(self, name):
8282
r""" Override the name of the categorization """
8383
self._name = name
8484

85+
@classmethod
86+
def kw(cls, **kwargs):
87+
""" Create categories based on keywords
88+
89+
This will search through the inherited classes and
90+
return and & category object for all keywords.
91+
92+
Since this is a class method one should use this
93+
on the base category class in the given section
94+
of the code.
95+
"""
96+
97+
subcls = set()
98+
work = [cls]
99+
while work:
100+
parent = work.pop()
101+
for child in parent.__subclasses__():
102+
if child not in subcls:
103+
subcls.add(child)
104+
work.append(child)
105+
106+
del work, parent, child
107+
108+
# create dictionary look-up
109+
subcls = {cl.__name__.lower(): cl for cl in subcls}
110+
111+
def get_cat(cl, args):
112+
if isinstance(args, dict):
113+
return cl(**args)
114+
return cl(args)
115+
116+
# Now search keywords and create category
117+
cat = None
118+
for key, args in kwargs.items():
119+
lkey = key.lower()
120+
found = ''
121+
for name, cl in subcls.items():
122+
if name.endswith(lkey):
123+
if found:
124+
raise ValueError(f"{cls.__name__}.kw got a non-unique argument for category name:\n"
125+
f" Searching for {name} and found matches {found} and {name}.")
126+
found = name
127+
if cat is None:
128+
cat = get_cat(cl, args)
129+
else:
130+
cat = cat & get_cat(cl, args)
131+
132+
return cat
133+
85134
@abstractmethod
86135
def categorize(self, *args, **kwargs):
87136
r""" Do categorization """
@@ -134,8 +183,8 @@ def __ne__(self, other):
134183
def __and__(self, other):
135184
return AndCategory(self, other)
136185

137-
#def __or__(self, other):
138-
# return OrCategory(self, other)
186+
def __or__(self, other):
187+
return OrCategory(self, other)
139188

140189
def __xor__(self, other):
141190
return XOrCategory(self, other)
@@ -285,9 +334,7 @@ def categorize(self, *args, **kwargs):
285334
def cmp(a, b):
286335
if isinstance(a, NullCategory):
287336
return b
288-
elif isinstance(b, NullCategory):
289-
return a
290-
return self
337+
return a
291338

292339
if isinstance(catA, list):
293340
return list(map(cmp, catA, catB))

0 commit comments

Comments
 (0)