Skip to content

Commit 4b487a8

Browse files
IsaH57mmschlk
andauthored
Add Game.compute function and related test to the Game class (mmschlk#397)
* add compute() function and related test to the Game class * adjust compute() * add function and test to generate interaction lookup from coalition * fix lookup generation from coalitions * adjusted return type * checks that compute is not normalized in test * add to CHANGELOG.md --------- Co-authored-by: Maximilian <[email protected]>
1 parent 74c31bc commit 4b487a8

File tree

6 files changed

+113
-0
lines changed

6 files changed

+113
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- adds ``verbose`` parameter to the ``Explainer``'s ``explain_X()`` method to control weather a progress bar is shown or not which is defaulted to ``False``. [#391](https://github.com/mmschlk/shapiq/pull/391)
1111
- made `InteractionValues.get_n_order()` and `InteractionValues.get_n_order_values()` function more efficient by iterating over the stored interactions and not over the powerset of all potential interactions, which made the function not usable for higher player counts (models with many features, and results obtained from `TreeExplainer`). Note, this change does not really help `get_n_order_values()` as it still needs to create a numpy array of shape `n_players` times `order` [#372](https://github.com/mmschlk/shapiq/pull/372)
1212
- streamlined the ``network_plot()`` plot function to use the ``si_graph_plot()`` as its backend function. This allows for more flexibility in the plot function and makes it easier to use the same code for different purposes. In addition, the ``si_graph_plot`` was modified to make plotting more easy and allow for more flexibility with new parameters. [#349](https://github.com/mmschlk/shapiq/pull/349)
13+
- adds `Game.compute()` method to the `shapiq.Game` class to compute game values without changing the state of the game object. The compute method also introduces a `shapiq.utils.sets.generate_interaction_lookup_from_coalitions()` utility method which creates an interaction lookup dict from an array of coalitions. [#397](https://github.com/mmschlk/shapiq/pull/397)
1314

1415
#### Testing, Code-Quality and Documentation
1516
- activates ``"ALL"`` rules in ``ruff-format`` configuration to enforce stricter code quality checks and addressed around 500 (not automatically solvable) issues in the code base. [#391](https://github.com/mmschlk/shapiq/pull/391)

shapiq/games/base.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,35 @@ def precompute(self, coalitions: np.ndarray | None = None) -> None:
419419
self.coalition_lookup = coalitions_dict
420420
self.precompute_flag = True
421421

422+
def compute(
423+
self, coalitions: np.ndarray | None = None
424+
) -> tuple[np.ndarray, dict[tuple[int, ...], int], float]:
425+
"""Compute the game values for all or a given set of coalitions.
426+
427+
Args:
428+
coalitions: The coalitions to evaluate.
429+
430+
Returns:
431+
A tuple containing:
432+
- The computed game values in the same order of the coalitions.
433+
- A lookup dictionary mapping from coalitions to the indices in the array.
434+
- The normalization value used to center/normalize the game values.
435+
436+
Note:
437+
This method does not change the state of the game and does not normalize the values.
438+
439+
Examples:
440+
>>> from shapiq.games.benchmark import DummyGame
441+
>>> game = DummyGame(4, interaction=(1, 2))
442+
>>> game.compute(np.array([[0, 1, 0, 0], [0, 1, 1, 0]], dtype=bool))
443+
(array([0.25, 1.5]), {(1): 0, (1, 2): 1.5}, 0.0)
444+
445+
"""
446+
coalitions: np.ndarray = self._check_coalitions(coalitions)
447+
game_values = self.value_function(coalitions)
448+
449+
return game_values, self.coalition_lookup, self.normalization_value
450+
422451
def save_values(self, path: Path | str) -> None:
423452
"""Saves the game values to the given path.
424453

shapiq/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .sets import (
77
count_interactions,
88
generate_interaction_lookup,
9+
generate_interaction_lookup_from_coalitions,
910
get_explicit_subsets,
1011
pair_subset_sizes,
1112
powerset,
@@ -23,6 +24,7 @@
2324
"split_subsets_budget",
2425
"get_explicit_subsets",
2526
"generate_interaction_lookup",
27+
"generate_interaction_lookup_from_coalitions",
2628
"transform_coalitions_to_array",
2729
"transform_array_to_coalitions",
2830
"count_interactions",

shapiq/utils/sets.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,31 @@ def generate_interaction_lookup(
232232
}
233233

234234

235+
def generate_interaction_lookup_from_coalitions(
236+
coalitions: np.ndarray,
237+
) -> dict[tuple[Any, ...], int]:
238+
"""Generates a lookup dictionary for interactions based on an array of coalitions.
239+
240+
Args:
241+
coalitions: An array of player coalitions.
242+
243+
Returns:
244+
A dictionary that maps interactions to their index in the values vector
245+
246+
Example:
247+
>>> coalitions = np.array([
248+
... [1, 0, 1],
249+
... [0, 1, 1],
250+
... [1, 1, 0],
251+
... [0, 0, 1]
252+
... ])
253+
>>> generate_interaction_lookup_from_coalitions(coalitions)
254+
{(0, 2): 0, (1, 2): 1, (0, 1): 2, (2,): 3}
255+
256+
"""
257+
return {tuple(np.where(coalition)[0]): idx for idx, coalition in enumerate(coalitions)}
258+
259+
235260
def transform_coalitions_to_array(
236261
coalitions: Collection[tuple[int, ...]],
237262
n_players: int | None = None,

tests/tests_games/test_base_game.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,27 @@ def test_exact_computer_call():
326326
sv = game.exact_values(index=index, order=order)
327327
assert sv.index == index
328328
assert sv.max_order == order
329+
330+
331+
def test_compute():
332+
"""Tests the compute function with and without returned normalization."""
333+
normalization_value = 1.0 # not zero
334+
335+
n_players = 3
336+
game = DummyGame(n=n_players, interaction=(0, 1))
337+
338+
coalitions = np.array([[1, 0, 0], [0, 1, 1]])
339+
340+
# Make sure normalization value is added
341+
game.normalization_value = normalization_value
342+
assert game.normalize
343+
344+
result = game.compute(coalitions=coalitions)
345+
assert len(result[0]) == len(coalitions) # number of coalitions is correct
346+
assert result[2] == normalization_value
347+
assert len(result) == 3 # game_values, normalization_value and coalition_lookup
348+
349+
# check if the game values are correct and that they are not normalized from compute
350+
game_values = result[0]
351+
assert game(coalitions[0]) + normalization_value == pytest.approx(game_values[0])
352+
assert game(coalitions[1]) + normalization_value == pytest.approx(game_values[1])

tests/tests_utils/test_utils_sets.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from shapiq.utils import (
99
count_interactions,
1010
generate_interaction_lookup,
11+
generate_interaction_lookup_from_coalitions,
1112
get_explicit_subsets,
1213
pair_subset_sizes,
1314
powerset,
@@ -110,6 +111,37 @@ def test_generate_interaction_lookup(n, min_order, max_order, expected):
110111
assert generate_interaction_lookup(n, min_order, max_order) == expected
111112

112113

114+
@pytest.mark.parametrize(
115+
("coalitions", "expected"),
116+
[
117+
(
118+
np.array([[1, 0, 1], [0, 1, 1], [1, 1, 0], [0, 0, 1]]),
119+
{(0, 2): 0, (1, 2): 1, (0, 1): 2, (2,): 3},
120+
),
121+
(
122+
np.array([[1, 1, 1], [0, 1, 0], [1, 0, 0], [0, 0, 1]]),
123+
{(0, 1, 2): 0, (1,): 1, (0,): 2, (2,): 3},
124+
),
125+
(
126+
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
127+
{(0,): 0, (1,): 1, (2,): 2},
128+
),
129+
(
130+
np.array([[1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 1, 0]]),
131+
{(0, 1, 3): 0, (2, 3): 1, (0, 2): 2},
132+
),
133+
(
134+
np.array([[0, 0, 0], [1, 1, 1]]),
135+
{(): 0, (0, 1, 2): 1},
136+
),
137+
],
138+
)
139+
def test_generate_interaction_lookup_from_coalitions(coalitions, expected):
140+
"""Tests the generate_interaction_lookup_from_coalitions function."""
141+
result = generate_interaction_lookup_from_coalitions(coalitions)
142+
assert result == expected
143+
144+
113145
@pytest.mark.parametrize(
114146
("coalitions", "n_player", "expected"),
115147
[

0 commit comments

Comments
 (0)