diff --git a/simple.json b/simple.json index 5f9a5be..68c21de 100644 --- a/simple.json +++ b/simple.json @@ -1 +1 @@ -{"training_data":50.0,"evaluation_data":0.0} \ No newline at end of file +{"training_data":75.0,"evaluation_data":60.0} \ No newline at end of file diff --git a/src/grid_types.py b/src/grid_types.py index 8592432..f636ad8 100644 --- a/src/grid_types.py +++ b/src/grid_types.py @@ -14,6 +14,13 @@ # Direction vectors for 8 directions (N, NE, E, SE, S, SW, W, NW) DIRECTIONS8 = [(0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0), (-1, -1)] +# Symmetries: horizontal, vertical, diagonal, anti-diagonal w.r.t. the origin +class Symmetry(Enum): + HORIZONTAL = auto() + VERTICAL = auto() + DIAGONAL = auto() # x == y + ANTI_DIAGONAL = auto() # x == -y + from enum import Enum, auto from typing import NamedTuple @@ -104,11 +111,6 @@ class Rotation(str, Enum): CLOCKWISE = "Clockwise" COUNTERCLOCKWISE = "CounterClockwise" - -class Axis(str, Enum): - HORIZONTAL = "Horizontal" - VERTICAL = "Vertical" - Color = NewType("Color", int) # Define the custom color scheme as a list of colors diff --git a/src/objects.py b/src/objects.py index 1e605ff..c18df5d 100644 --- a/src/objects.py +++ b/src/objects.py @@ -8,7 +8,7 @@ Cell, GridData, Rotation, - Axis, + Symmetry, BLACK, Color, RigidTransformation, @@ -22,7 +22,6 @@ GREEN, RED, YELLOW, - Axis, Rotation, color_scheme, ) @@ -192,7 +191,10 @@ def create_object( return Object(origin=(min_x, min_y), data=data._data) connected_components = find_connected_components( - grid=self, diagonals=diagonals, allow_black=allow_black, multicolor=multicolor + grid=self, + diagonals=diagonals, + allow_black=allow_black, + multicolor=multicolor, ) detected_objects = [ create_object(self, component, background_color) @@ -212,6 +214,14 @@ def flipud(self) -> "Object": x: np.ndarray = np.flipud(self._data.copy()) return Object(origin=self.origin, data=x) + def flip_diagonal(self) -> "Object": + x: np.ndarray = self._data.copy().T + return Object(origin=self.origin, data=x) + + def flip_anti_diagonal(self) -> "Object": + x: np.ndarray = np.flipud(np.fliplr(self._data.copy())) + return Object(origin=self.origin, data=x) + def invert(self) -> "Object": x: np.ndarray = 1 - self._data.copy() return Object(origin=self.origin, data=x) @@ -227,11 +237,17 @@ def translate_in_place(self, dx: int, dy: int) -> None: new_origin = (self.origin[0] + dx, self.origin[1] + dy) self.origin = new_origin - def flip(self, axis: Axis) -> "Object": - if axis == Axis.HORIZONTAL: + def flip(self, symmetry: Symmetry) -> "Object": + if symmetry == Symmetry.HORIZONTAL: return self.fliplr() - else: + elif symmetry == Symmetry.VERTICAL: return self.flipud() + elif symmetry == Symmetry.DIAGONAL: + return self.flip_diagonal() + elif symmetry == Symmetry.ANTI_DIAGONAL: + return self.flip_anti_diagonal() + else: + raise ValueError(f"Unknown symmetry type: {symmetry}") def num_cells(self, color: Optional[int]) -> int: if color is None: @@ -418,6 +434,55 @@ def detect_colored_objects( """ return visual_cortex.find_colored_objects(self, background_color) + def is_symmetric(self, symmetry: Symmetry) -> bool: + """ + Check if the object is symmetric with respect to the given symmetry type. + + Args: + symmetry (Symmetry): The type of symmetry to check for. + + Returns: + bool: True if the object is symmetric, False otherwise. + """ + data = self._data + width, height = self.size + + if symmetry == Symmetry.HORIZONTAL: # (x, y) == (w-x-1, y) + return np.array_equal(data, np.fliplr(data)) + + elif symmetry == Symmetry.VERTICAL: # (x, y) == (x, h-y-1) + return np.array_equal(data, np.flipud(data)) + + elif symmetry == Symmetry.DIAGONAL: # (x, y) == (y, x) + if height != width: + return False + return np.array_equal(data, data.T) # Transpose for diagonal symmetry + + elif symmetry == Symmetry.ANTI_DIAGONAL: # (x, y) == (w-x-1, h-y-1) + if height != width: + return False + # fliplr: + # then fliplr: + return np.array_equal(data, np.fliplr(data.T)) + + else: + raise ValueError(f"Unknown symmetry type: {symmetry}") + + def find_symmetries(self) -> List[Symmetry]: + """ + Find all symmetries of the object. + """ + symmetries = [] + if self.is_symmetric(Symmetry.HORIZONTAL): + symmetries.append(Symmetry.HORIZONTAL) + if self.is_symmetric(Symmetry.VERTICAL): + symmetries.append(Symmetry.VERTICAL) + if self.is_symmetric(Symmetry.DIAGONAL): + symmetries.append(Symmetry.DIAGONAL) + if self.is_symmetric(Symmetry.ANTI_DIAGONAL): + symmetries.append(Symmetry.ANTI_DIAGONAL) + return symmetries + def display( input: Object, output: Object = Object(np.array([[0]])), title: Optional[str] = None @@ -522,12 +587,12 @@ def test_rotate(): def test_flip(): grid = Object(np.array([[1, 2], [3, 4]])) - flipped_grid = grid.flip(Axis.HORIZONTAL) + flipped_grid = grid.flip(Symmetry.HORIZONTAL) assert flipped_grid == Object( np.array([[2, 1], [4, 3]]) ), f"Expected [[2, 1], [4, 3]], but got {flipped_grid}" - flipped_grid = grid.flip(Axis.VERTICAL) + flipped_grid = grid.flip(Symmetry.VERTICAL) assert flipped_grid == Object( np.array([[3, 4], [1, 2]]) ), f"Expected [[3, 4], [1, 2]], but got {flipped_grid}" diff --git a/src/simple.py b/src/simple.py new file mode 100644 index 0000000..89e5ac0 --- /dev/null +++ b/src/simple.py @@ -0,0 +1,1487 @@ +from typing import ( + Any, + Callable, + List, + Optional, + Tuple, + TypeVar, + Generic, + Union, +) + +from color_features import detect_color_features +from logger import logger +from objects import Object, display, display_multiple +from load_data import Example, Task, Tasks, training_data, evaluation_data +from rule_based_selector import DecisionRule, select_object_minimal +from shape_features import detect_shape_features +from symmetry_features import detect_symmetry_features +from symmetry import ( + find_periodic_symmetry_with_unknowns, + find_non_periodic_symmetry, + fill_grid, + PeriodicGridSymmetry, + NonPeriodicGridSymmetry, +) +from visual_cortex import find_rectangular_objects, regularity_score +import numpy as np +from dataclasses import dataclass +from grid_normalization import ClockwiseRotation, XReflection, RigidTransformation + +# returns the index of the object to pick +ObjectPicker = Callable[[List[Object]], int] + + +class Config: + find_xform = True + find_matched_objects = False + try_remove_main_color = False + difficulty = 1000 + task_name: str | None = None + # task_name = "e9afcf9a.json" # map 2 colored objects + # task_name = "0dfd9992.json" + # task_name = "05269061.json" + # task_name = "47996f11.json" + # task_name = "f9d67f8b.json" + # task_name = "47996f11.json" + task_fractal = "8f2ea7aa.json" # fractal expansion + task_puzzle = "97a05b5b.json" # puzzle-like, longest in DSL (59 lines) + whitelisted_tasks: List[str] = [] + whitelisted_tasks.append(task_puzzle) + # find_xform_color = True + display_not_found = False + display_this_task = False + only_simple_examples = False + only_inpainting_puzzles = True + max_size = 9 + max_colors = 4 + + +def filter_simple_xforms(task: Task, task_name: str): + examples = task.train + tests = task.test + for example in examples: + input = example[0] + output = example[1] + if ( + input.width > Config.max_size + or input.height > Config.max_size + or input.size != output.size + or input.get_colors(allow_black=True) != output.get_colors(allow_black=True) + or len(input.get_colors(allow_black=True)) > Config.max_colors + ): + return False + return True + + +GridAndObjects = Tuple[Object, List[Object]] + +T = TypeVar("T", bound=Union[Object, GridAndObjects]) +State = str + +Primitive = Callable[[Object, str, int], Object] +Match = Tuple[State, Callable[[T], T]] +Xform = Callable[[List[Example[T]], str, int], Optional[Match[T]]] + + +@dataclass +class XformEntry(Generic[T]): + xform: Xform[T] + difficulty: int + + +def check_primitive_on_examples( + prim: Callable[[Object, str, int], Object], + examples: List[Example[Object]], + task_name: str, + nesting_level: int, +) -> Optional[Match[Object]]: + logger.debug(f"{' ' * nesting_level}Checking primitive {prim.__name__}") + for i, example in enumerate(examples): + logger.debug(f"{' ' * nesting_level} Example {i+1}/{len(examples)}") + input = example[0] + output = example[1] + new_output = prim(input, task_name, nesting_level) + if new_output != output: + logger.debug(f"{' ' * nesting_level} Example {i+1} failed") + return None + state = "prim" + return (state, lambda i: prim(i, task_name, nesting_level)) + + +def primitive_to_xform(primitive: Primitive) -> Xform[Object]: + def xform( + examples: List[Example], + task_name: str, + nesting_level: int, + ) -> Optional[Match]: + result = check_primitive_on_examples( + primitive, examples, task_name, nesting_level + ) + if result is None: + return None + state, apply_state = result + return (state, apply_state) + + xform.__name__ = primitive.__name__ + return xform + + +def translate_down_1(input: Object, task_name: str, nesting_level: int): + obj = input.copy() + obj.translate_in_place(dx=0, dy=1) + result = Object.empty(obj.size) + result.add_object_in_place(obj) + return result + + +primitives: List[Primitive] = [ + translate_down_1, +] + + +def xform_identity( + examples: List[Example], task_name: str, nesting_level: int +) -> Optional[Match]: + def identity(input: Object, task_name: str, nesting_level: int): + return input + + return check_primitive_on_examples(identity, examples, task_name, nesting_level) + + +# TODO: This is currently not used but it illustrates how to compose primitives +def xform_two_primitives_in_sequence( + examples: List[Example], task_name: str, nesting_level: int +) -> Optional[Match]: + # try to apply two primitives in sequence, and return the first pair that works + for p1 in primitives: + for p2 in primitives: + + def composed_primitive(input: Object, task_name: str, nesting_level: int): + r1 = p1(input, task_name, nesting_level + 1) + return p2(r1, task_name, nesting_level + 1) + + if check_primitive_on_examples( + composed_primitive, examples, task_name, nesting_level + ): + state = f"({p1.__name__}, {p2.__name__})" + solve = lambda input: p2( + p1(input, task_name, nesting_level + 1), + task_name, + nesting_level + 1, + ) + return (state, solve) + return None + + +def check_matching_colored_objects_count_and_color(examples: List[Example]) -> bool: + for input, output in examples: + input_objects = input.detect_colored_objects(background_color=0) + output_objects = output.detect_colored_objects(background_color=0) + if len(input_objects) != len(output_objects): + return False + + different_color = any( + input_object.first_color != output_object.first_color + for input_object, output_object in zip(input_objects, output_objects) + ) + + if different_color: + return False + return True + + +def match_colored_objects( + examples: List[Example[Object]], + task_name: str, + nesting_level: int, +) -> Optional[Match[Object]]: + + logger.info( + f"{' ' * nesting_level}match_colored_objects examples:{len(examples)} task_name:{task_name} nesting_level:{nesting_level}" + ) + + color_match = check_matching_colored_objects_count_and_color(examples) + if color_match is None: + return None + # now the colored input + + # each example has the same number of input and output objects + # so we can turn those lists into and ObjectListExample + object_list_examples: List[Example[GridAndObjects]] = [] + + def get_background_color(input: Object) -> int: + background_color = 0 # TODO: determine background color + return background_color + + def get_grid_and_objects(input: Object) -> GridAndObjects: + background_color = get_background_color(input) + input_objects: List[Object] = input.detect_colored_objects(background_color) + return (input, input_objects) + + input_grid_and_objects: GridAndObjects + output_grid_and_objects: GridAndObjects + for input, output in examples: + input_grid_and_objects = get_grid_and_objects(input) + output_grid_and_objects = get_grid_and_objects(output) + input_objects = input_grid_and_objects[1] + output_objects = output_grid_and_objects[1] + + if len(input_objects) == 0 or len(output_objects) == 0: + return None + + if False: + display_multiple( + [ + (input_object, output_object) + for input_object, output_object in zip( + input_objects, output_objects + ) + ], + title=f"Colored Objects [Exam]", + ) + + object_list_example: Example[GridAndObjects] = ( + input_grid_and_objects, + output_grid_and_objects, + ) + object_list_examples.append(object_list_example) + + for list_xform in list_xforms: + match: Optional[Match[GridAndObjects]] = list_xform.xform( + object_list_examples, task_name, nesting_level + 1 + ) + if match is not None: + apply_state_list_xform: Callable[[GridAndObjects], GridAndObjects] + state_list_xform, apply_state_list_xform = match + + def solve(input: Object) -> Object: + background_color = get_background_color(input) + input_objects = input.detect_colored_objects(background_color) + grid_and_objects: GridAndObjects = (input, input_objects) + s, output_objects = apply_state_list_xform(grid_and_objects) + output_grid = None + if False: + display_multiple( + list(zip(input_objects, output_objects)), + title=f"Output Objects", + ) + for output in output_objects: + if output_grid is None: + output_grid = output.copy() + else: + output_grid.add_object_in_place(output) + assert output_grid is not None + if False: + display(output_grid, title=f"Output Grid") + return output_grid + + return ( + f"{list_xform.xform.__name__}({state_list_xform})", + solve, + ) + else: + logger.info( + f"{' ' * nesting_level}Xform {list_xform.xform.__name__} is not applicable" + ) + + return None + + +class CanvasGridMatch: + @staticmethod + def find_canvas_objects( + inputs: List[Object], outputs: Optional[List[Object]] + ) -> Optional[List[Object]]: + """Finds the largest objects in each example's input. Check the size is the same as the output's size, if provided.""" + canvas_objects = [] + for i, input in enumerate(inputs): + objects = input.detect_objects() + canvas = max(objects, key=lambda obj: obj.area, default=None) + if canvas is None: + return None + if outputs is not None: + output = outputs[i] + if canvas.size != output.size: + return None + canvas_objects.append(canvas) + + return canvas_objects + + @staticmethod + def solve_puzzle( + input: Object, task_name: str, nesting_level: int, canvas: Object + ) -> Optional[Tuple[Object, List[Object]]]: + + canvas_color = canvas.main_color() + compound_objects = input.detect_objects( + background_color=canvas_color, multicolor=True + ) + # remove canvas and input grid from grid_objects + compound_objects = [obj for obj in compound_objects if obj.area < canvas.area] + + def compound_object_get_handle(cobj: Object) -> Optional[Object]: + colors = set(cobj.get_colors(allow_black=True)) + if canvas_color not in colors: + logger.debug(f"Canvas color: {canvas_color} not in colors: {colors}") + return None + colors.remove(canvas_color) + if len(colors) != 1: + logger.debug(f"Canvas color: {canvas_color} colors: {colors}") + return None + other_color = colors.pop() + handle = cobj.detect_colored_objects(background_color=other_color)[0] + logger.debug(f"Canvas color: {canvas_color} Handle color: {other_color}") + + return handle + + handles = [] + handles_shapes = [] + for cobj in compound_objects: + handle = compound_object_get_handle(cobj) + if handle is None: + return None + handles.append(handle) + handles_shapes.append(handle.get_shape()) + holes = canvas.detect_objects(allow_black=True, background_color=canvas_color) + # remove object of size the entire max_area_object + holes = [obj for obj in holes if obj.area < canvas.area] + holes_shapes = [obj.get_shape(background_color=canvas_color) for obj in holes] + + logger.debug(f"compound_objects:{len(compound_objects)} holes:{len(holes)}") + + if len(compound_objects) != len(holes): + return None + if len(compound_objects) == 0: + return None + + # display grid_objects alongsize holes + if False: + display_multiple( + [(o1, o2) for o1, o2 in zip(compound_objects, holes)], + title=f"Compound Objects and Canvas Holes", + ) + + # display objects_holes alonside canvas_holes + if False: + display_multiple( + [ + (o1, o2) + for o1, o2 in zip( + handles, + holes, + ) + ], + title=f"Handles and Holes", + ) + + if False: + display_multiple( + [(o1, o2) for o1, o2 in zip(handles_shapes, holes_shapes)], + title=f"Handles Shapes and Holes Shapes", + ) + + matches = [] + num_cases = len(handles_shapes) + xform = equal_modulo_rigid_transformation + matched_handles = set() + matched_holes = set() + transformed_compound_objects = compound_objects + holes_origins = [(0, 0)] * num_cases + for i in range(num_cases): + match = None + for j in range(num_cases): + if i in matched_handles or j in matched_holes: + continue + handle = handles[i] + hole = holes[j] + handle_shape = handles_shapes[i] + hole_shape = holes_shapes[j] + match = xform( + [ + ( + handle_shape, + hole_shape, + ) + ], + task_name, + nesting_level + 1, + ) + if match is None: + continue + matched_handles.add(i) + matched_holes.add(j) + state, solve = match + + cobj = compound_objects[i] + + transformed_compound_objects[i] = solve(compound_objects[i]) + holes_origins[i] = hole.origin + + matches.append(match) + if match is None: + logger.info(f"No match found for handle {i}") + return None + logger.debug(f"matches found:{len(matches)}") + + if False: + display_multiple( + [(obj, obj) for obj in transformed_compound_objects], + title=f"Transformed Compound Objects", + ) + + new_objects = [] + for i, obj in enumerate(transformed_compound_objects): + new_handle = compound_object_get_handle(obj) + if new_handle is None: + return None + ox, oy = holes_origins[i] + ox -= new_handle.origin[0] + oy -= new_handle.origin[1] + obj.origin = (ox, oy) + logger.debug(f"new origin: {obj.origin}") + new_objects.append(obj) + + return canvas, new_objects + + @staticmethod + def canvas_grid_xform( + examples: List[Example[Object]], + task_name: str, + nesting_level: int, + ) -> Optional[Match[Object]]: + # every example has a canvas + canvas_objects = CanvasGridMatch.find_canvas_objects( + inputs=[input for input, _ in examples], + outputs=[output for _, output in examples], + ) + if canvas_objects is None: + return None + # Config.display_this_task = True + + for i, (input, output) in enumerate(examples): + canvas = canvas_objects[i] + solution = CanvasGridMatch.solve_puzzle( + input, task_name, nesting_level, canvas + ) + if solution is None: + return None + canvas, new_objects = solution + new_input = input.copy() + for obj in new_objects: + ox, oy = obj.origin + ox += canvas.origin[0] + oy += canvas.origin[1] + obj.origin = (ox, oy) + logger.debug(f"new origin: {obj.origin}") + new_input.add_object_in_place(obj) + + state = "canvas_grid_xform" + + def solve(input: Object) -> Object: + canvas_objects = CanvasGridMatch.find_canvas_objects([input], None) + if canvas_objects is None: + return input + canvas = canvas_objects[0] + solution = CanvasGridMatch.solve_puzzle( + input, task_name, nesting_level, canvas + ) + if solution is None: + logger.info(f"No solution found for input") + return input + canvas, new_objects = solution + output = canvas.copy() + output.origin = (0, 0) + for obj in new_objects: + output.add_object_in_place(obj) + if False: + display(output, title=f"Output") + return output + + match = (state, solve) + return match + + +def equal_modulo_rigid_transformation( + examples: List[Example], task_name: str, nesting_level: int +) -> Optional[Match]: + for x_reflection in XReflection: + for rotation in ClockwiseRotation: + rigid_transformation = RigidTransformation(rotation, x_reflection) + all_examples_correct = True + for input, output in examples: + trasformed_input = input.apply_rigid_xform(rigid_transformation) + if trasformed_input != output: + all_examples_correct = False + break + if all_examples_correct: + state = f"({rigid_transformation})" + solve = lambda input: input.apply_rigid_xform(rigid_transformation) + return (state, solve) + return None + + +class InpaintingMatch: + @staticmethod + def is_inpainting_puzzle(examples: List[Example[Object]]) -> bool: + # check the inpainting conditions on all examples + for input, output in examples: + if InpaintingMatch.check_inpainting_conditions(input, output) is None: + return False + return True + + @staticmethod + def check_inpainting_conditions(input: Object, output: Object) -> Optional[int]: + # Check if input and output are the same size + if input.size != output.size: + return None + + # check if input has one more color than output + if len(input.get_colors(allow_black=True)) - 1 != len( + output.get_colors(allow_black=True) + ): + return None + colors_only_in_input = set(input.get_colors(allow_black=True)) - set( + output.get_colors(allow_black=True) + ) + if len(colors_only_in_input) != 1: + return None + color = colors_only_in_input.pop() + + # check if input and output are the same except for the color + for x in range(input.width): + for y in range(input.height): + if input[x, y] == color: + continue + if input[x, y] != output[x, y]: + return None + + # check if output has high regularity score + # if regularity_score(output) >= 0.5: + # return None + # Config.display_this_task = True + + return color + + @staticmethod + def compute_shared_symmetries( + examples, + ) -> Optional[Tuple[NonPeriodicGridSymmetry, PeriodicGridSymmetry, int]]: + non_periodic_shared = None + periodic_shared = None + + color_only_in_input = None + + for i, (input, output) in enumerate(examples): + color = InpaintingMatch.check_inpainting_conditions(input, output) + if color is None: + return None + if color_only_in_input is None: + color_only_in_input = color + else: + if color != color_only_in_input: + logger.info(f"Color mismatch: {color} != {color_only_in_input}") + return None + + non_periodic_symmetry_output = find_non_periodic_symmetry(output, color) + if non_periodic_shared is None: + non_periodic_shared = non_periodic_symmetry_output + else: + non_periodic_shared = non_periodic_shared.intersection( + non_periodic_symmetry_output + ) + + periodic_symmetry_output = find_periodic_symmetry_with_unknowns( + output, color + ) + if periodic_shared is None: + periodic_shared = periodic_symmetry_output + else: + periodic_shared = periodic_shared.intersection(periodic_symmetry_output) + + logger.info( + f"#{i} From Output {non_periodic_symmetry_output} {periodic_symmetry_output}" + ) + if ( + periodic_shared is None + or non_periodic_shared is None + or color_only_in_input is None + ): + return None + + return non_periodic_shared, periodic_shared, color_only_in_input + + @staticmethod + def apply_shared( + input: Object, + non_periodic_shared: NonPeriodicGridSymmetry, + periodic_shared: PeriodicGridSymmetry, + color: int, + ) -> Object: + filled_grid = fill_grid( + input, + non_periodic_symmetry=non_periodic_shared, + periodic_symmetry=periodic_shared, + unknown=color, + ) + return filled_grid + + @staticmethod + def inpainting_xform( + examples: List[Example[Object]], + task_name: str, + nesting_level: int, + ) -> Optional[Match[Object]]: + + # Web view: open -a /Applications/Safari.app "https://arcprize.org/play?task=484b58aa" + # Tasks with average_regularity_score < 0.5: + + # bd4472b8 + # 8e5a5113 + # 62b74c02 + # 4aab4007 # solve two halves (top left and bottom right) separately + # ef26cbf6 + # 1e97544e # solve two halves (top left and bottom right) separately + # 4cd1b7b2 # sudoku + + # f9d67f8b # solve two halves (top left and bottom right) separately + + shared_symmetries = InpaintingMatch.compute_shared_symmetries(examples) + if shared_symmetries is None: + return None + non_periodic_shared, periodic_shared, color_only_in_input = shared_symmetries + + logger.info( + f"inpainting_xform examples:{len(examples)} task_name:{task_name} nesting_level:{nesting_level} non_periodic_symmetries:{non_periodic_shared}" + ) + + for i, (input, output) in enumerate(examples): + filled_grid = fill_grid( + input, + periodic_shared, + non_periodic_shared, + color_only_in_input, + ) + is_correct = filled_grid == output + logger.info( + f"#{i} Shared {non_periodic_shared} {periodic_shared} is_correct: {is_correct}" + ) + if not is_correct: + display(input, filled_grid, title=f"{is_correct} Shared Symm") + if is_correct: + logger.info(f"#{i} Found correct solution using shared symmetries") + pass + else: + break + else: + state = f"symmetry({non_periodic_shared}, {periodic_shared})" + + def solve_shared(input: Object) -> Object: + filled_grid = fill_grid( + input, + periodic_shared, + non_periodic_shared, + color_only_in_input, + ) + if False: + logger.info( + f"Test Shared {non_periodic_shared} {periodic_shared} is_correct: {is_correct}" + ) + display(input, filled_grid, title=f"Test Shared") + return filled_grid + + return (state, solve_shared) + + def solve_find_symmetry(input: Object) -> Object: + periodic_symmetry_input = find_periodic_symmetry_with_unknowns( + input, color_only_in_input + ) + filled_grid = fill_grid( + input, + periodic_symmetry=periodic_symmetry_input, + unknown=color_only_in_input, + ) + return filled_grid + + Config.display_this_task = True + state = "find_symmetry_for_each_input" + return (state, solve_find_symmetry) + + +gridxforms: List[XformEntry[Object]] = [ + XformEntry(match_colored_objects, 3), + XformEntry(xform_identity, 1), + XformEntry(equal_modulo_rigid_transformation, 2), + XformEntry(primitive_to_xform(translate_down_1), 2), + XformEntry(CanvasGridMatch.canvas_grid_xform, 2), + XformEntry(InpaintingMatch.inpainting_xform, 2), +] + + +class ExpansionMatch: + + @staticmethod + def check_fractal_expansion_sizes(examples: List[Example[GridAndObjects]]): + """ + Check if every input is NxN and the output's size is N^2xN^2 + """ + for (input_grid, input_objects), (output_grid, output_objects) in examples: + if len(input_objects) == 0 or len(output_objects) == 0: + return False + for input_obj, output_obj in zip(input_objects, output_objects): + # Ensure input is NxN (i.e., square) + if input_obj.width != input_obj.height: + return False + # Ensure output is N^2xN^2 + if ( + output_obj.width != input_obj.width**2 + or output_obj.height != input_obj.height**2 + ): + return False + return True + + # TODDO: replace this with inferring a function from (grid, pixel coordinates) to output grid (of the same size) + @staticmethod + def fractal_expansion( + examples: List[Example[Object]], + task_name: str, + nesting_level: int, + ) -> Optional[Match[Object]]: + + def map_function_numpy_inplace( + output_grid: np.ndarray, + input_grid: np.ndarray, + original_image: np.ndarray, + background_color, + ) -> None: + width, height = input_grid.shape + for x in range(width): + for y in range(height): + color = input_grid[x, y] + if color != background_color: + output_grid[ + x * width : (x + 1) * width, y * height : (y + 1) * height + ] = original_image + + def apply_recursive_expansion_numpy_inplace( + original_image: np.ndarray, background_color + ) -> np.ndarray: + width, height = original_image.shape + output_grid = np.full((width * width, height * height), background_color) + map_function_numpy_inplace( + output_grid, original_image, original_image, background_color + ) + return output_grid + + if False: + display_multiple(examples, title=f"Fractal Expansion Examples") + + for input, output in examples: + input_image = input._data + output_image = output._data + expanded_image = apply_recursive_expansion_numpy_inplace(input_image, 0) + if False: + display( + Object(expanded_image), + Object(output_image), + title=f"Expanded vs Output", + ) + if not np.array_equal(expanded_image, output_image): + return None + + state = "fractal_expansion" + + def solve(input: Object) -> Object: + if isinstance(input, Object): + return Object(apply_recursive_expansion_numpy_inplace(input._data, 0)) + else: + assert False + + return (state, solve) + + @staticmethod + def stretch_height( + examples: List[Example[Object]], + task_name: str, + nesting_level: int, + ) -> Optional[Match[Object]]: + # TODO: implement the inference of the boolean function + for i, (input, output) in enumerate(examples): + if output.height != 2 or output.width != input.width: + return None + color = input.first_color + for x in range(input.width): + for y in range(input.height): + is_filled = x % 2 == 0 if y == 0 else x % 2 == 1 + if input.origin != (0, 0): + # TODO: fix this + is_filled = not is_filled + output_color = output[x, y] + predicted_color = color if is_filled else 0 + if output_color != predicted_color: + logger.info( + f"Example {i} failed: output_color {output_color} != predicted_color {predicted_color}" + ) + assert False + state = "stretch_height" + + def solve(input: Object) -> Object: + output = Object.empty((input.width, input.height * 2)) + color = input.first_color + for x in range(input.width): + for y in range(input.height * 2): + is_filled = x % 2 == 0 if y == 0 else x % 2 == 1 + if input.origin != (0, 0): + # TODO: fix this + is_filled = not is_filled + if is_filled: + output[x, y] = color + return output + + match = (state, solve) + return match + + +expansion_xforms: List[XformEntry[Object]] = [ + XformEntry(ExpansionMatch.fractal_expansion, 1), + XformEntry(ExpansionMatch.stretch_height, 1), +] + + +def out_objects_are_a_subset( + inputs: List[Object], outputs: List[Object] +) -> Optional[List[Tuple[int, int]]]: + """ + Determines if the output objects are a subset of the input objects based on their colors. + + Checks if each color in the output set is present in the input set. Returns a mapping + of input indices to output indices if the subset condition is met, or None if not satisfied + or if any output color is not present in the input. + """ + input_colors = [input_obj.first_color for input_obj in inputs] + output_colors = [output_obj.first_color for output_obj in outputs] + + input_to_output_indices = [] + + for ic in input_colors: + if ic in output_colors: + input_to_output_indices.append( + (input_colors.index(ic), output_colors.index(ic)) + ) + for oc in output_colors: + if oc not in input_colors and False: + display_multiple( + [ + (input_obj, output_obj) + for input_obj, output_obj in zip(inputs, outputs) + ], + title=f"Input vs Output", + ) + return None # Output color not in input + + return input_to_output_indices + + +class MapFunctionMatch: + @staticmethod + def stretch_height( + examples: List[Example[Object]], + task_name: str, + nesting_level: int, + ) -> Optional[Match[Object]]: + logger.info( + f"stretch_height examples:{len(examples)} task_name:{task_name} nesting_level:{nesting_level}" + ) + origin = None + for input, output in examples: + if origin is None: + origin = output.origin + if origin != output.origin: + logger.info( + f"Output origin: {output.origin} != Expected origin: {origin}" + ) + return None + if input.width != output.width: + logger.info( + f"Input width: {input.width} != Output width: {output.width}" + ) + return None + if input.height * 2 != output.height: + logger.info( + f"Input height * 2: {input.height * 2} != Output height: {output.height}" + ) + return None + logger.info( + f"stretch_height origin:{output.origin} width:{output.width} height:{output.height}" + ) + if False: + display(input, output, title=f"stretch_height") + # TODO: need to adjust the origin from the call to the expansion xform + for xform in expansion_xforms: + match = xform.xform(examples, task_name, nesting_level) + if match is not None: + return match + return None + + +map_xforms: List[XformEntry[Object]] = [XformEntry(MapFunctionMatch.stretch_height, 1)] + + +from typing import List, Tuple + + +class ObjectListMatch: + @staticmethod + def check_list_of_objects_subset( + examples: List[Example[GridAndObjects]], + ) -> Optional[List[Tuple[int, int]]]: + """ + Check if the output objects are a subset of the input objects based on their colors. + Returns a list of indices of the input objects that correspond to the output objects. + The same list must apply to all examples. + """ + input_to_output_indices_list = [] + for (_, input_objects), (_, output_objects) in examples: + if len(input_objects) < 2: + return None + input_to_output_indices = out_objects_are_a_subset( + input_objects, output_objects + ) + if input_to_output_indices is None: + return None + # store the indices + input_to_output_indices_list.append(input_to_output_indices) + # check if they are all the same + if len(set(tuple(indices) for indices in input_to_output_indices_list)) != 1: + return None + logger.info(f"input_to_output_indices_list: {input_to_output_indices_list}") + input_to_output_indices = input_to_output_indices_list[0] + if len(input_to_output_indices) == 0: + return None + return input_to_output_indices + + @staticmethod + def map_first_input_to_output_grid( + examples: List[Example[GridAndObjects]], + ) -> List[Example[Object]]: + input_output_objects_examples: List[Example[Object]] = [] + for (input_grid, input_objects), (output_grid, output_objects) in examples: + input_output_objects_examples.append((input_objects[0], output_grid)) + + return input_output_objects_examples + + @staticmethod + def match_list_of_objects( + examples: List[Example[GridAndObjects]], + task_name: str, + nesting_level: int, + ) -> Optional[Match[GridAndObjects]]: + logger.info( + f"{' ' * nesting_level}match_list_of_objects examples:{len(examples)} task_name:{task_name} nesting_level:{nesting_level}" + ) + + if ExpansionMatch.check_fractal_expansion_sizes(examples): + input_output_objects_examples = ( + ObjectListMatch.map_first_input_to_output_grid(examples) + ) + + # now pattern match recursively + match: Optional[Match[Object]] = find_xform_for_examples( + expansion_xforms, + input_output_objects_examples, + task_name, + nesting_level + 1, + ) + if match is not None: + state, solve = match + + def solve_grid_and_objects( + grid_and_objects: GridAndObjects, + ) -> GridAndObjects: + grid, objects = grid_and_objects + return (grid, [solve(obj) for obj in objects]) + + return state, solve_grid_and_objects + + # check if the input objects can be matched to the output objects + input_to_output_indices = ObjectListMatch.check_list_of_objects_subset(examples) + if input_to_output_indices is not None: + logger.info( + f"{' ' * nesting_level}Found input_to_output_indices: {input_to_output_indices}" + ) + + new_examples_train: List[List[Example[Object]]] = [ + [] for _ in input_to_output_indices + ] + for (_, e_inputs), (_, e_outputs) in examples: + for i, (input_index, output_index) in enumerate( + input_to_output_indices + ): + new_examples_train[i].append( + (e_inputs[input_index], e_outputs[output_index]) + ) + + for xform in map_xforms: + matches = [] # for each input/output index pair, the match + for i, (input_index, output_index) in enumerate( + input_to_output_indices + ): + match = xform.xform( + new_examples_train[i], + task_name, + nesting_level, + ) + if match is None: + logger.info( + f"Xform {xform.xform.__name__} index:{output_index} failed: no match" + ) + return None + else: + matches.append(match) + + logger.info(f"Xform {xform.xform.__name__} succeeded") + + new_state = "{" + for i, (s, _) in enumerate(matches): + new_state += f"{i}:{s}, " + new_state += "}" + + def solve_grid_and_objects( + grid_and_objects: GridAndObjects, + ) -> GridAndObjects: + input_grid, input_objects = grid_and_objects + outputs = [] + assert input_to_output_indices is not None + for i, (input_index, output_index) in enumerate( + input_to_output_indices + ): + state, solve = matches[i] + output = solve(input_objects[input_index]) + outputs.append(output) + return (input_grid, outputs) + + return new_state, solve_grid_and_objects + + logger.info(f"{' ' * nesting_level}TODO: more cases of match_list_of_objects") + + return None + + +list_xforms: List[XformEntry[GridAndObjects]] = [ + XformEntry(ObjectListMatch.match_list_of_objects, 4), +] + + +def find_xform_for_examples( + xforms: List[XformEntry[Object]], + examples: List[Example[Object]], + task_name: str, + nesting_level: int, + xform_name: List[str] = [], +) -> Optional[Match[Object]]: + logger.info( + f"\n{' ' * nesting_level}find_xform_for_examples examples:{len(examples)} task_name:{task_name} nesting_level:{nesting_level}" + ) + + for xform in xforms: + if Config.difficulty < xform.difficulty: + continue + func = xform.xform + logger.debug(f"{' ' * nesting_level}Checking xform {func.__name__}") + match = func(examples, task_name, nesting_level + 1) + if match is not None: + logger.info( + f"{' ' * nesting_level}Xform {xform.xform.__name__} state:{match[0]} is correct for examples" + ) + xform_name.append(xform.xform.__name__) + return match + else: + logger.info( + f"{' ' * nesting_level}Xform {func.__name__} is not applicable" + ) + + return None + + +def find_xform( + xforms: List[XformEntry[Object]], + examples: List[Example[Object]], + tests: List[Example[Object]], + task_name: str, + nesting_level: int, +) -> Optional[Match[Object]]: + logger.info( + f"\n{' ' * nesting_level}find_xform examples:{len(examples)} tests:{len(tests)} task_name:{task_name} nesting_level:{nesting_level}" + ) + + xform_name_list = ["no_xform"] + match = find_xform_for_examples( + xforms, examples, task_name, nesting_level, xform_name_list + ) + if match is None: + return None + xform_name = xform_name_list[-1] + + state, solve = match + + for i, test_example in enumerate(tests): + test_input = test_example[0] + test_output = test_example[1] + result_on_test = solve(test_input) + if result_on_test != test_output: + logger.info(f"Xform {xform_name} state:{state} failed for test input {i}") + if False: + width, height = test_output.size + for x in range(width): + for y in range(height): + if test_output[x, y] != result_on_test[x, y]: + logger.info( + f"Xform {xform_name} state:{state} failed for test input {i} at {x},{y}: {test_output[x, y]} != {result_on_test[x, y]}" + ) + display( + test_output, + result_on_test, + title=f"Ex{i} state:{state}", + ) + return None + + logger.info(f"Xform {xform_name} state:{state} succeeded for all tests") + return match + + +# ObjectMatch is a type alias representing a match between a list of detected input objects +# and the index of the object within that list that is identical to the output object. +# +# The first element of the tuple (List[Object]) contains all the detected input objects, +# while the second element (int) specifies the index of the object in this list that is +# identical to the output object in terms of size and data. +ObjectMatch = Tuple[List[Object], int] + + +def detect_common_features(matched_objects: List[ObjectMatch], initial_difficulty: int): + def detect_common_symmetry_features() -> Optional[DecisionRule]: + common_decision_rule = None + for input_objects, index in matched_objects: + embeddings = [detect_symmetry_features(obj) for obj in input_objects] + decision_rule = select_object_minimal(embeddings, index) + if decision_rule is not None: + logger.debug(f" Decision rule (Symmetry): {decision_rule}") + if common_decision_rule is None: + common_decision_rule = decision_rule + else: + common_decision_rule = common_decision_rule.intersection( + decision_rule + ) + if common_decision_rule is None: + break + else: + logger.debug(f" No decision rule found (Symmetry)") + common_decision_rule = None + break + return common_decision_rule + + def detect_common_color_features() -> Optional[DecisionRule]: + common_decision_rule = None + for input_objects, index in matched_objects: + embeddings = [ + detect_color_features(obj, input_objects) for obj in input_objects + ] + decision_rule = select_object_minimal(embeddings, index) + if decision_rule is not None: + logger.debug(f" Decision rule (Color): {decision_rule}") + if common_decision_rule is None: + common_decision_rule = decision_rule + else: + common_decision_rule = common_decision_rule.intersection( + decision_rule + ) + if common_decision_rule is None: + break + else: + logger.debug(f" No decision rule found (Color)") + common_decision_rule = None + break + return common_decision_rule + + def detect_common_shape_features(level: int) -> Optional[DecisionRule]: + common_decision_rule = None + for input_objects, index in matched_objects: + embeddings = [ + detect_shape_features(obj, input_objects, level) + for obj in input_objects + ] + decision_rule = select_object_minimal(embeddings, index) + if decision_rule is not None: + logger.debug(f" Decision rule (Shape): {decision_rule}") + if common_decision_rule is None: + common_decision_rule = decision_rule + else: + common_decision_rule = common_decision_rule.intersection( + decision_rule + ) + if common_decision_rule is None: + break + else: + logger.debug(f" No decision rule found (Shape)") + common_decision_rule = None + break + return common_decision_rule + + common_decision_rule = None + features_used = None + + # Try detecting common features in the order of shape, color, and symmetry + + if common_decision_rule is None and Config.difficulty >= initial_difficulty + 1: + common_decision_rule = detect_common_shape_features(initial_difficulty + 1) + features_used = "Shape" + + if common_decision_rule is None and Config.difficulty >= initial_difficulty + 2: + common_decision_rule = detect_common_color_features() + features_used = "Color" + + if common_decision_rule is None and Config.difficulty >= initial_difficulty + 3: + common_decision_rule = detect_common_symmetry_features() + features_used = "Symmetry" + assert num_difficulties_matching == 3 + + return common_decision_rule, features_used + + +def find_matched_objects( + examples: List[Example], task_type: str +) -> Optional[List[ObjectMatch]]: + """ + Identifies and returns a list of matched input objects that correspond to the output objects + in the given examples. For each example, it detects candidate objects in the input grid + and matches them with the output grid based on size and data. If all examples have a match, + the function returns the list of matched objects; otherwise, it returns None. + + Args: + examples: A list of examples, each containing an input and output grid. + task_type: A string indicating the type of task (e.g., 'train' or 'test'). + + Returns: + A list of ObjectMatch tuples if matches are found for all examples, otherwise None. + """ + + def candidate_objects_for_matching(input: Object, output: Object) -> List[Object]: + """ + Detects objects in the input grid that are candidates for matching the output grid. + """ + if output.has_frame(): + # If the output is a frame, detect objects in the input as frames + logger.debug(" Output is a frame") + num_colors_output = len(output.get_colors(allow_black=True)) + return find_rectangular_objects(input, allow_multicolor=num_colors_output > 1) + + def find_matching_input_object( + input_objects: List[Object], output: Object + ) -> Optional[int]: + for i, io in enumerate(input_objects): + if io.size == output.size and np.array_equal(io._data, output._data): + logger.debug(f" Input object matching output: {io}") + return i + return None + + def get_matched_objects(examples: List[Example]) -> Optional[List[ObjectMatch]]: + matched_objects: List[ObjectMatch] = [] + + for example in examples: + input = example[0] + output = example[1] + logger.info(f" {task_type} {input.size} -> {output.size}") + + input_objects = candidate_objects_for_matching(input, output) + matched_object_index = find_matching_input_object(input_objects, output) + + if matched_object_index is not None: + matched_objects.append((input_objects, matched_object_index)) + + return matched_objects if len(matched_objects) == len(examples) else None + + matched_objects = get_matched_objects(examples) + return matched_objects + + +num_difficulties_xform = max(xform.difficulty for xform in gridxforms) +num_difficulties_matching = 3 +num_difficulties_total = num_difficulties_xform + num_difficulties_matching + + +def process_tasks(tasks: Tasks, set: str): + num_correct = 0 + num_incorrect = 0 + for task_name, task in tasks.items(): + Config.display_this_task = False + if Config.task_name and task_name != Config.task_name: + continue + if ( + filter_simple_xforms(task, task_name) == False + and Config.only_simple_examples + and task_name not in Config.whitelisted_tasks + ): + continue + if Config.only_inpainting_puzzles and not InpaintingMatch.is_inpainting_puzzle(task.train): + continue + logger.info(f"\n***Task: {task_name} {set}***") + + examples = task.train + + tests = task.test + task_type = "train" + + if True: + current_difficulty = 0 + + if Config.find_xform: + correct_xform = find_xform(gridxforms, examples, tests, task_name, 0) + if correct_xform is not None: + num_correct += 1 + continue + + current_difficulty += num_difficulties_xform + + if Config.find_matched_objects: + # Check if the input objects can be matched to the output objects + logger.debug(f"Checking common features for {task_name} {set}") + matched_objects = find_matched_objects(examples, task_type) + if matched_objects: + # If the input objects can be matched to the output objects, try to detect common features + # to determine the correct object to pick + logger.debug( + f"XXX Matched {len(matched_objects)}/{len(examples)} {task_name} {set}" + ) + common_decision_rule, features_used = detect_common_features( + matched_objects, current_difficulty + ) + if common_decision_rule: + logger.info( + f"Common decision rule ({features_used}): {common_decision_rule}" + ) + num_correct += 1 + continue + else: + logger.warning( + f"Could not find common decision rule for {task_name} {set}" + ) + current_difficulty += num_difficulties_matching + + if Config.display_not_found: + Config.display_this_task = True + if Config.display_this_task: + grids = [(example[0], example[1]) for example in examples] + display_multiple(grids, title=f"{task_name} {set}") + + # If no valid dimensions could be determined, give up + logger.warning( + f"Could not find correct transformation for {task_name} {set} examples" + ) + num_incorrect += 1 + + return num_correct, num_incorrect + + +def compute_perc_correct(num_correct: int, num_incorrect: int) -> Optional[float]: + if num_correct + num_incorrect > 0: + return int(1000 * num_correct / (num_correct + num_incorrect)) / 10 + return None + + +def simple(): + num_correct_tr, num_incorrect_tr = process_tasks(training_data, "training_data") + num_correct_ev, num_incorrect_ev = process_tasks(evaluation_data, "evaluation_data") + perc_correct_tr = compute_perc_correct(num_correct_tr, num_incorrect_tr) + perc_correct_ev = compute_perc_correct(num_correct_ev, num_incorrect_ev) + + def log_evaluation_results(set: str, num_correct: int, num_incorrect: int): + perc_correct = compute_perc_correct(num_correct, num_incorrect) + if perc_correct is not None: + logger.error( + f"{set.capitalize()} data: " + f"Correct: {num_correct}, Incorrect: {num_incorrect}, Score: {perc_correct}%" + ) + + logger.error("\n***Summary***") + log_evaluation_results("training", num_correct_tr, num_incorrect_tr) + log_evaluation_results("evaluation", num_correct_ev, num_incorrect_ev) + + # Write summary of results to JSON file + with open("simple.json", "w") as f: + f.write( + f'{{"training_data":{perc_correct_tr},"evaluation_data":{perc_correct_ev}}}' + ) + + +def generate_inpainting_puzzle(): + np.random.seed(42) # For reproducibility + num_attempts = 0 + num_puzzles = 10 + num_found = 0 + while True: + # Define grid size and colors + grid_size = (6, 6) + output_colors = [1, 2] # Colors that appear in both input and output + inpainting_color = 0 # Extra color that appears only in the input + + # Create an output grid with random placement of output colors + output_grid_data = np.random.choice(output_colors, size=grid_size) + + # Create an input grid by adding the inpainting color at random positions + input_grid_data = output_grid_data.copy() + num_inpainting_pixels = 5 # Number of pixels to replace with inpainting color + inpainting_positions = np.random.choice( + grid_size[0] * grid_size[1], num_inpainting_pixels, replace=False + ) + for pos in inpainting_positions: + x = pos % grid_size[0] + y = pos // grid_size[1] + input_grid_data[x, y] = inpainting_color + + # Create Object instances + input_object = Object(input_grid_data) + output_object = Object(output_grid_data) + + # Check if the puzzle passes the inpainting conditions + color = InpaintingMatch.check_inpainting_conditions(input_object, output_object) + if color is None: + print("Puzzle does not pass inpainting conditions. Trying again.") + continue + + # Compute regularity score of the output + score = regularity_score(output_object) + print(f"Output grid regularity score: {score}") + + # Compute shared symmetries + examples = [(input_object, output_object)] + num_attempts += 1 + shared_symmetries = InpaintingMatch.compute_shared_symmetries(examples) + + if shared_symmetries is None: + print(f"Failed to compute shared symmetries after {num_attempts} attempts. Trying again.") + continue + + non_periodic_shared, periodic_shared, color = shared_symmetries + print(f"Shared symmetries computed successfully after {num_attempts} attempts.") + print(f"Non-periodic symmetry: {non_periodic_shared}") + print(f"Periodic symmetry: {periodic_shared}") + + # Try to solve the puzzle using the inpainting strategy + solved_output = InpaintingMatch.apply_shared(input_object, non_periodic_shared, periodic_shared, color) + + if solved_output == output_object: + num_found += 1 + print(f"Generated a solvable inpainting puzzle after {num_attempts} attempts.") + # Display the solved output (assuming display function is available) + if True: # Change to True if you want to display the puzzle + display(input_object, solved_output, title="Solvable Inpainting Puzzle Solution") + if num_found >= num_puzzles: + break + else: + print("Puzzle cannot be solved using the inpainting strategy. Trying again.") + + + +# Call the function to generate puzzles +if __name__ == "__main__": + generate_inpainting_puzzle() diff --git a/src/symmetry.py b/src/symmetry.py new file mode 100644 index 0000000..3ee54c7 --- /dev/null +++ b/src/symmetry.py @@ -0,0 +1,563 @@ +from dataclasses import dataclass +import numpy as np +from typing import Optional, Tuple +from objects import Object +from grid_types import Symmetry +from typing import TYPE_CHECKING +from math import gcd, lcm + +# To avoid circular imports +if TYPE_CHECKING: + from objects import Object as Object_t +else: + Object_t = None + + +@dataclass(frozen=True) +class PeriodicGridSymmetry: + px: Optional[int] = None # periodic horizontal + py: Optional[int] = None # periodic vertical + pd: Optional[int] = None # periodic diagonal + pa: Optional[int] = None # periodic anti-diagonal + + def intersection(self, other: "PeriodicGridSymmetry") -> "PeriodicGridSymmetry": + def intersect(a: Optional[int], b: Optional[int]) -> Optional[int]: + if a is None or b is None: + return None + if a == b: + return a + else: + return lcm(a, b) + + return PeriodicGridSymmetry( + intersect(self.px, other.px), + intersect(self.py, other.py), + intersect(self.pd, other.pd), + intersect(self.pa, other.pa), + ) + + +@dataclass(frozen=True) +class NonPeriodicGridSymmetry: + hx: bool = False # non-periodic horizontal + vy: bool = False # non-periodic vertical + dg: bool = False # non-periodic diagonal + ag: bool = False # non-periodic anti-diagonal + offset: Tuple[int, int] = (0, 0) # offset for symmetry checks + + def intersection( + self, other: "NonPeriodicGridSymmetry" + ) -> "NonPeriodicGridSymmetry": + # If offsets differ, symmetries involving translations should be invalidated (set to False) + if self.offset != other.offset: + return NonPeriodicGridSymmetry( + hx=False, + vy=False, + dg=False, + ag=False, + offset=(0, 0), # Reset the offset since they are different + ) + else: + # If offsets are the same, apply logical "and" to the symmetries + return NonPeriodicGridSymmetry( + hx=self.hx and other.hx, + vy=self.vy and other.vy, + dg=self.dg and other.dg, + ag=self.ag and other.ag, + offset=self.offset, # Offsets match, so we keep the offset + ) + + +def check_vertical_symmetry_with_unknowns(grid: Object, period: int, unknown: int): + """ + Check if rows repeat every 'period' rows, allowing for unknown cells. + """ + width, height = grid.size + for x in range(width): + for y in range(period, height): + if ( + grid[x, y] != unknown + and grid[x, y - period] != unknown + and grid[x, y] != grid[x, y - period] + ): + return False + return True + + +def check_horizontal_symmetry_with_unknowns(grid: Object, period: int, unknown: int): + """ + Check if columns repeat every 'period' columns, allowing for unknown cells. + """ + width, height = grid.size + for x in range(period, width): + for y in range(height): + if ( + grid[x, y] != unknown + and grid[x - period, y] != unknown + and grid[x, y] != grid[x - period, y] + ): + return False + return True + + +def check_diagonal_symmetry_with_unknowns(grid: Object, period: int, unknown: int): + """ + Check if the grid has diagonal symmetry with a given period, allowing for unknown cells. + Moving diagonally, we check that the same element is found every 'period' steps, without wrapping around. + """ + width, height = grid.size + + # Only iterate over the range where diagonal steps are valid + for x in range(width): + for y in range(height): + next_x = x + period + next_y = y + period + if next_x >= width or next_y >= height: + continue + + if ( + grid[x, y] != unknown + and grid[next_x, next_y] != unknown + and grid[x, y] != grid[next_x, next_y] + ): + return False + return True + + +def check_anti_diagonal_symmetry_with_unknowns(grid: Object, period: int, unknown: int): + """ + Check if the grid has anti-diagonal symmetry with a given period, allowing for unknown cells. + Moving anti-diagonally (bottom-left to top-right), we check that the same element is found every 'period' steps. + """ + width, height = grid.size + + # Only iterate over the range where anti-diagonal steps are valid + for x in range(width): + for y in range(height): + next_x = x + period + next_y = y - period + if next_x >= width or next_y < 0: + continue + + if ( + grid[x, y] != unknown + and grid[next_x, next_y] != unknown + and grid[x, y] != grid[next_x, next_y] + ): + return False + return True + + +def find_periodic_symmetry_with_unknowns( + grid: Object, unknown: int +) -> PeriodicGridSymmetry: + """ + Find the smallest periods px, py, pd, pa (if any) and non-periodic symmetries with unknowns. + """ + width, height = grid.size + + # Find smallest horizontal symmetry modulo px + px = None + for possible_px in range(1, width // 2 + 1): + if check_horizontal_symmetry_with_unknowns(grid, possible_px, unknown): + px = possible_px + break + + # Find smallest vertical symmetry modulo py + py = None + for possible_py in range(1, height // 2 + 1): + if check_vertical_symmetry_with_unknowns(grid, possible_py, unknown): + py = possible_py + break + + # Find smallest diagonal symmetry modulo pd + pd = None + # Ensure the grid is square for diagonal symmetry + if width == height: + for possible_pd in range(1, width // 2 + 1): + if check_diagonal_symmetry_with_unknowns(grid, possible_pd, unknown): + pd = possible_pd + break + + # Find smallest anti-diagonal symmetry modulo pa + pa = None + # Ensure the grid is square for anti-diagonal symmetry + if width == height: + for possible_pa in range(1, width // 2 + 1): + if check_anti_diagonal_symmetry_with_unknowns(grid, possible_pa, unknown): + pa = possible_pa + break + + return PeriodicGridSymmetry(px, py, pd, pa) + + +def find_non_periodic_symmetry(grid: Object, unknown: int) -> NonPeriodicGridSymmetry: + """ + Find the non-periodic symmetries of the grid, considering offsets. + """ + width, height = grid.size + max_distance = max(width, height) // 2 + + def check_symmetry_with_offset(symmetry_func): + offset = find_matching_subgrid_offset( + grid, symmetry_func(grid), max_distance, unknown + ) + return offset is not None, offset if offset else (0, 0) + + hx, hx_offset = check_symmetry_with_offset(lambda g: g.flip(Symmetry.HORIZONTAL)) + vy, vy_offset = check_symmetry_with_offset(lambda g: g.flip(Symmetry.VERTICAL)) + dg, dg_offset = check_symmetry_with_offset(lambda g: g.flip(Symmetry.DIAGONAL)) + ag, ag_offset = check_symmetry_with_offset(lambda g: g.flip(Symmetry.ANTI_DIAGONAL)) + + # combine the offsets + offset = (0, 0) + if hx: + # for horizontal symmetry, only the x-offset is relevant + offset = (hx_offset[0], offset[1]) + if vy: + # for vertical symmetry, only the y-offset is relevant + offset = (offset[0], vy_offset[1]) + # diagonal symmetry has offset dy-dx, so for square grids the offset is (0, 0) + # will assume that dx==dy and assume there's nothing to check + if ag and (hx or vy): + if ag_offset != offset: + # anti-diagonal symmetry is not invariant wrt translations + # so we set all symmetries to False as this is a contradiction + hx = False + vy = False + dg = False + ag = False + offset = (0, 0) + + return NonPeriodicGridSymmetry(hx, vy, dg, ag, offset) + + +def find_source_value( + filled_grid: Object, + x: int, + y: int, + periodic_symmetry: PeriodicGridSymmetry, + non_periodic_symmetry: NonPeriodicGridSymmetry, + unknown: int, +): + """ + Find a source value for the given destination coordinates based on symmetry. + """ + px, py, pd, pa = ( + periodic_symmetry.px, + periodic_symmetry.py, + periodic_symmetry.pd, + periodic_symmetry.pa, + ) + width, height = filled_grid.size + for x_src in range(x % px, width, px) if px is not None else [x]: + for y_src in range(y % py, height, py) if py is not None else [y]: + if filled_grid[x_src, y_src] != unknown: + return filled_grid[x_src, y_src] + + # Search based on diagonal (pd) symmetry if provided + if pd is not None and width == height: + size = (width // pd) * pd + + # Walk along the diagonal in both directions, by starting negative and going positive + for i in range(-size, size, pd): + x_src = x + i + y_src = y + i + + if ( + 0 <= x_src < size + and 0 <= y_src < size + and filled_grid[x_src, y_src] != unknown + ): + return filled_grid[x_src, y_src] + + # Search based on anti-diagonal symmetry (bottom-left to top-right) + if pa is not None and width == height: + size = (width // pa) * pa + + # Walk along the anti-diagonal in both directions, by starting negative and going positive + for i in range(-size, size, pa): + x_src = x + i + y_src = y - i + + if ( + 0 <= x_src < size + and 0 <= y_src < size + and filled_grid[x_src, y_src] != unknown + ): + return filled_grid[x_src, y_src] + + # Check non-periodic symmetries with offset + dx, dy = non_periodic_symmetry.offset + + x_dest_sym = x - dx + y_dest_sym = y - dy + + def fill_from_symmetry(x_src, y_src): + if ( + 0 <= x_src < width + and 0 <= y_src < height + and filled_grid[x_src, y_src] != unknown + ): + return True + else: + return False + + hx = non_periodic_symmetry.hx + vy = non_periodic_symmetry.vy + dg = non_periodic_symmetry.dg + ag = non_periodic_symmetry.ag + + if hx: + # (x,y) -> (x-dx, y-dy) -> ((w-dx)-1-x+dx, y-dy) -> + # -> ((w-dx)-1-x+dx+dx, y-dy+dy) = (w-1-x+dx, y) + x_src, y_src = width - 1 - x + dx, y + if fill_from_symmetry(x_src, y_src): + return filled_grid[x_src, y_src] + + if vy: + x_src, y_src = x, height - 1 - y + dy + if fill_from_symmetry(x_src, y_src): + return filled_grid[x_src, y_src] + + if dg: + # (x,y) -> (x-dx, y-dy) -> (y-dy, x-dx) -> (y-dy+dx, x-dx+dy) + x_src, y_src = y - dy + dx, x - dx + dy + if fill_from_symmetry(x_src, y_src): + return filled_grid[x_src, y_src] + + if ag: + # (x,y) -> (x-dx, y-dy) -> ((h-dy)-1-y+dy, (w-dx)-1-x+dx) -> + # -> (h-dy-1-y+dy+dx, w-dx-1-x+dx+dy) == (h-1-y+dx, w-1-x+dy) + x_src = height - 1 - y + dx + y_src = width - 1 - x + dy + if fill_from_symmetry(x_src, y_src): + return filled_grid[x_src, y_src] + + return unknown + + +def fill_grid( + grid: Object, + periodic_symmetry: PeriodicGridSymmetry = PeriodicGridSymmetry(), + non_periodic_symmetry: NonPeriodicGridSymmetry = NonPeriodicGridSymmetry(), + unknown: int = 0, +): + """ + Fills the unknown cells in a grid based on detected horizontal and vertical symmetries. + + This function fills each unknown cell in the grid by propagating values from known cells, + using the provided horizontal (px) and vertical (py) symmetry periods. It starts at each + destination cell and looks for a matching source cell at symmetrical positions, based on + the periods px and py. + + Args: + grid (Object): The grid containing known and unknown values to be filled. + symmetry (Symmetry): The symmetry object containing the periods px, py, pd, and pa. + unknown (int): The value representing unknown cells in the grid, which will be filled. + + Returns: + Object: A new grid with all unknown cells filled using the provided symmetry periods. + """ + width, height = grid.size + filled_grid = grid.copy() + + # Loop over all destination cells + for x_dest in range(width): + for y_dest in range(height): + if ( + filled_grid[x_dest, y_dest] == unknown + ): # If the destination cell is unknown + filled_grid[x_dest, y_dest] = find_source_value( + filled_grid, + x_dest, + y_dest, + periodic_symmetry, + non_periodic_symmetry, + unknown, + ) + + return filled_grid + + +def test_find_and_fill_symmetry(): + from objects import Object + + grid_xy = Object( + np.array( + [ + [0, 0, 0, 2, 1, 2], + [0, 7, 0, 0, 3, 7], + [1, 2, 0, 2, 1, 2], + [0, 0, 3, 0, 0, 0], + [1, 2, 0, 2, 1, 2], + [0, 0, 3, 0, 0, 0], + ] + ) + ) + + grid_y = Object( + np.array( + [ + [1, 2, 1, 9, 1, 2], + [3, 7, 3, 7, 3, 7], + [1, 0, 1, 9, 0, 2], + [3, 7, 3, 7, 3, 7], + [0, 0, 1, 9, 0, 0], + [3, 7, 3, 7, 0, 0], + ] + ) + ) + + grid_diagonal = Object( + np.array( + [ + [1, 2, 1, 2, 4, 2], + [3, 0, 3, 0, 3, 7], + [1, 2, 0, 2, 1, 2], + [3, 7, 3, 7, 3, 7], + [5, 9, 0, 2, 1, 2], + [8, 7, 3, 7, 3, 7], + ] + ) + ) + + def test_grid(grid: Object, unknown: int, title: str): + periodic_symmetry = find_periodic_symmetry_with_unknowns(grid, unknown) + non_periodic_symmetry = find_non_periodic_symmetry(grid, unknown) + print(f"{title}: {periodic_symmetry}, {non_periodic_symmetry}") + filled_grid = fill_grid(grid, periodic_symmetry, non_periodic_symmetry, unknown) + if False: + print(f"grid: {grid}") + print(f"filled_grid: {filled_grid}") + assert unknown not in filled_grid._data + return filled_grid + + test_grid(grid_xy, 0, "grid_xy") # horizontal and vertical symmetry + test_grid(grid_y, 0, "grid_y") # vertical symmetry + test_grid(grid_diagonal, 0, "grid_diagonal") # diagonal symmetry + + # Add a new test case for offset symmetry + grid_offset = Object( + np.array( + [ + [9, 1, 2, 3, 0, 1], + [9, 0, 0, 6, 5, 4], + [9, 7, 8, 9, 8, 7], + [9, 1, 2, 3, 2, 1], + [9, 4, 5, 6, 5, 4], + [9, 1, 2, 3, 2, 1], + ] + ) + ) + test_grid(grid_offset, 0, "grid_offset") # horizontal symmetry with offset + + +# Function to check if the visible parts of grid g2 match grid g1 at offset (ox, oy) +# with an "unknown" value that is equal to any other value in comparisons +def check_visible_subgrid_with_unknown( + g1: Object, g2: Object, ox: int, oy: int, unknown: int +) -> bool: + W1, H1 = g1.size # Width and height of g1 + W2, H2 = g2.size # Width and height of g2 + + # Iterate over g2's width (W2) and height (H2) + for x in range(W2): + for y in range(H2): + gx, gy = x + ox, y + oy + # Only check for visible parts (i.e., within the bounds of g1) + if 0 <= gx < W1 and 0 <= gy < H1: + val_g1 = g1[gx, gy] + val_g2 = g2[x, y] + # Treat 'unknown' as matching any value + if val_g1 != val_g2 and val_g1 != unknown and val_g2 != unknown: + return False + return True + + +# Iterator that yields all offsets with increasing Manhattan distance +def manhattan_offset_iterator(): + d = 0 + while True: + for i in range(-d, d + 1): + j = d - abs(i) + yield (i, j) + if j != 0: # Avoid adding the same point twice (i, 0) and (i, -0) + yield (i, -j) + d += 1 + + +# Brute-force search to find the matching subgrid offset with expanding Manhattan distances +def find_matching_subgrid_offset( + g1: Object, g2: Object, max_distance: int, unknown: int +) -> Optional[Tuple[int, int]]: + """ + Returns a tuple (ox, oy) or None, where: + + - (ox, oy) is the offset that satisfies the following properties: + 1. The Manhattan distance |ox| + |oy| is minimized and is <= max_distance. + 2. For all coordinates (x, y) in g2, if (x + ox, y + oy) is within the bounds of g1, + g1 and g2 are considered equal at those coordinates, *modulo unknown values*. + + - Definition of equality modulo unknown: + g1[a, b] ~= g2[c, d] if: + - g1[a, b] == g2[c, d], or + - g1[a, b] == unknown, or + - g2[c, d] == unknown. + + In other words, the unknown value is treated as matching any other value. + + - None is returned if no such offset exists within the given max_distance. + """ + + offset_iter = manhattan_offset_iterator() + + # Iterate through the offsets generated by the iterator + for _ in range( + (2 * max_distance + 1) ** 2 + ): # Check all offsets within the LxL limit + ox, oy = next(offset_iter) + # If the Manhattan distance exceeds the limit, stop the search + if abs(ox) + abs(oy) > max_distance: + break + if check_visible_subgrid_with_unknown(g1, g2, ox, oy, unknown): + return (ox, oy) # Return the offset if a match is found + + return None # Return None if no valid offset is found + + +def test_find_matching_subgrid_offset(): + # Example usage with unknown value + g1 = Object( + np.array( + [ + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 10], + [11, 12, 13, 14, 15], + [16, 17, 18, 19, 20], + [21, 22, 23, 24, 25], + ] + ) + ) + + # Modify g2 so the center has some values from g1 with 0 treated as unknown + g2 = Object( + np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 1, 2, 3, 4], + [0, 6, 7, 8, 9], + [0, 11, 12, 13, 14], + ] + ) + ) + + unknown_value = 0 # Treat 0 as "unknown" + result = find_matching_subgrid_offset(g1, g2, max_distance=3, unknown=unknown_value) + assert result == (-1, -2), f"Expected (-1, -2), but got {result}" + + +if __name__ == "__main__": + test_find_and_fill_symmetry() + test_find_matching_subgrid_offset() diff --git a/src/visual_cortex.py b/src/visual_cortex.py index 4116f7d..e315f9f 100644 --- a/src/visual_cortex.py +++ b/src/visual_cortex.py @@ -9,6 +9,7 @@ import numpy as np from typing import TYPE_CHECKING + # To avoid circular imports if TYPE_CHECKING: from objects import Object as Object_t @@ -52,7 +53,12 @@ def precompute_sums(grid: Object_t, color: int) -> Tuple[ndarray, ndarray]: def is_frame_dp( - row_sum: np.ndarray, col_sum: np.ndarray, top: int, left: int, bottom: int, right: int + row_sum: np.ndarray, + col_sum: np.ndarray, + top: int, + left: int, + bottom: int, + right: int, ) -> bool: """Check if the rectangle defined by (top, left) to (bottom, right) forms a frame using precomputed sums.""" if ( @@ -233,6 +239,7 @@ def is_frame_part_of_lattice(grid: Object_t, frame: Frame, foreground: int) -> b return False return True + def find_dividing_lines(grid: Object_t, color: int) -> Tuple[List[int], List[int]]: """Find the indices of vertical and horizontal lines that span the entire grid.""" @@ -277,6 +284,7 @@ def extract_subgrid_of_color( continue sub_grid_data = grid._data[prev_h:h, prev_v:v] from objects import Object + row.append(Object(np.array(sub_grid_data))) prev_v = v + 1 subgrid.append(row) @@ -357,43 +365,56 @@ def eval_with_lattice_check(): def test_lattices(): # Correct Lattice Grid from objects import Object - grid = Object(np.array([ - [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], - [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], - [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - ])) + + grid = Object( + np.array( + [ + [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + ] + ) + ) frame = (2, 2, 5, 8) is_lattice = is_frame_part_of_lattice(grid, frame, 1) assert is_lattice == True, f"Correct Lattice Grid: Frame {frame}" # Interrupted Lattice Grid - grid = Object(np.array([ - [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 9, 1, 0], # Break in the lattice pattern - [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], - [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - ])) + grid = Object( + np.array( + [ + [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 9, 1, 0], # Break in the lattice pattern + [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + ] + ) + ) frame = (2, 2, 5, 5) is_lattice = is_frame_part_of_lattice(grid, frame, 1) assert is_lattice == False, f"Interrupted Lattice Grid: Frame {frame}" # Break outside frames that fit in the grid does not affect lattice check - grid = Object(np.array([ - [0, 0, 2, 0, 0, 2, 0, 0, 2, 0], - [0, 0, 2, 0, 0, 2, 0, 0, 2, 0], - [2, 2, 2, 2, 2, 2, 2, 2, 2, 0], - [0, 0, 2, 0, 0, 2, 0, 0, 2, 0], - [0, 0, 2, 0, 0, 2, 0, 0, 2, 0], - [2, 2, 2, 2, 2, 2, 2, 2, 2, 9], # Break near edge - [0, 0, 2, 0, 0, 2, 0, 0, 2, 0], - ])) + grid = Object( + np.array( + [ + [0, 0, 2, 0, 0, 2, 0, 0, 2, 0], + [0, 0, 2, 0, 0, 2, 0, 0, 2, 0], + [2, 2, 2, 2, 2, 2, 2, 2, 2, 0], + [0, 0, 2, 0, 0, 2, 0, 0, 2, 0], + [0, 0, 2, 0, 0, 2, 0, 0, 2, 0], + [2, 2, 2, 2, 2, 2, 2, 2, 2, 9], # Break near edge + [0, 0, 2, 0, 0, 2, 0, 0, 2, 0], + ] + ) + ) frame = (2, 2, 5, 5) is_lattice = is_frame_part_of_lattice(grid, frame, 2) assert is_lattice == True, f"Break outside frames: Frame {frame}" @@ -402,6 +423,7 @@ def test_lattices(): def test_subgrid_extraction(): # Example grid with dividing lines from objects import Object + grid = Object( np.array( [ @@ -409,7 +431,7 @@ def test_subgrid_extraction(): [2, 2, 1, 3, 3, 1, 4, 4, 1, 5], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [6, 6, 1, 7, 7, 1, 8, 8, 1, 9], - [6, 6, 1, 7, 7, 1, 8, 8, 1, 9], + [6, 6, 1, 7, 7, 1, 8, 8, 1, 9], [6, 6, 1, 7, 7, 1, 8, 8, 1, 9], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 1, 3, 3, 1, 4, 4, 1, 5], @@ -426,8 +448,12 @@ def test_subgrid_extraction(): 3, 4, ), f"Test failed: Subgrid dimensions: {height}x{width}" - assert subgrid[0][0] == Object(np.array([[2, 2], [2, 2]])), "Test failed: Subgrid[0][0]" - assert subgrid[0][1] == Object(np.array([[3, 3], [3, 3]])), "Test failed: Subgrid[0][1]" + assert subgrid[0][0] == Object( + np.array([[2, 2], [2, 2]]) + ), "Test failed: Subgrid[0][0]" + assert subgrid[0][1] == Object( + np.array([[3, 3], [3, 3]]) + ), "Test failed: Subgrid[0][1]" assert subgrid[0][3] == Object(np.array([[5], [5]])), "Test failed: Subgrid[0][3]" assert subgrid[2][3] == Object(np.array([[5]])), "Test failed: Subgrid[2][3]" @@ -456,10 +482,13 @@ def extract_object_by_color(grid: Object_t, color: int) -> Object_t: data[data != color] = 0 from objects import Object + return Object(np.array(data), origin) -def find_colored_objects(grid: Object_t, background_color: Optional[int]) -> List[Object_t]: +def find_colored_objects( + grid: Object_t, background_color: Optional[int] +) -> List[Object_t]: """ Finds and returns a list of all distinct objects within the grid based on color. @@ -468,6 +497,7 @@ def find_colored_objects(grid: Object_t, background_color: Optional[int]) -> Lis Each object is represented as an instance of the `Object` class. """ from objects import Object + colors = grid.get_colors(allow_black=True) objects: List[Object] = [] for color in colors: @@ -583,6 +613,7 @@ def is_valid_rectangle(origin: Cell, height: int, width: int, color: int) -> boo for r in range(origin[0], origin[0] + height) ] from objects import Object + current_object = Object( np.array(object_grid_data), origin, @@ -592,16 +623,46 @@ def is_valid_rectangle(origin: Cell, height: int, width: int, color: int) -> boo return objects +def regularity_score(grid: Object_t) -> float: + """ + Score how regular a grid is by scoring every cell. + A cell is penalized if one of those in the 8 directions around it has the same color. + The score of the cell is the sum of those in the 8 directions. + The score of the grid is the average score of all cells. + """ + width, height = grid.width, grid.height + total_score = 0 + from grid_types import DIRECTIONS8 + + for x in range(width): + for y in range(height): + cell_score = 0 + cell_color = grid[x, y] + for dx, dy in DIRECTIONS8: + nx, ny = x + dx, y + dy + if 0 <= nx < width and 0 <= ny < height: + if grid[nx, ny] == cell_color: + cell_score += 1 + total_score += cell_score + + return total_score / (width * height * 8) + + def test_detect_rectangular_objects() -> None: from objects import Object - grid = Object(np.array([ - [0, 0, 0, 0, 0, 0], - [0, 1, 1, 1, 0, 0], - [0, 1, 0, 0, 1, 0], - [0, 0, 1, 0, 1, 0], - [0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0], - ])) + + grid = Object( + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 1, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0], + ] + ) + ) objects: List[Object_t] = find_rectangular_objects(grid, allow_multicolor=False) for obj in objects: @@ -612,14 +673,19 @@ def test_detect_rectangular_objects() -> None: def test_several_rectangular_objects_of_different_color(): from objects import Object - grid = Object(np.array([ - [0, 0, 0, 0, 0, 0], - [0, 1, 1, 1, 0, 0], - [0, 1, 0, 0, 2, 0], - [0, 0, 1, 0, 2, 2], - [0, 0, 0, 1, 2, 0], - [0, 0, 0, 0, 0, 0], - ])) + + grid = Object( + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 0, 0, 2, 0], + [0, 0, 1, 0, 2, 2], + [0, 0, 0, 1, 2, 0], + [0, 0, 0, 0, 0, 0], + ] + ) + ) objects = find_rectangular_objects(grid, allow_multicolor=False) for obj in objects: