Skip to content

Commit

Permalink
fix: Show Python traceback when constructing constraints fail (#1082)
Browse files Browse the repository at this point in the history
Fixes #969
  • Loading branch information
Christopher-Chianelli authored Sep 11, 2024
1 parent caa2bab commit 5f44f32
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
16 changes: 16 additions & 0 deletions python/python-core/src/main/python/_timefold_java_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,22 @@ def register_java_class(python_object: Solution_,
return python_object


def wrap_errors(func):
def wrapped_func(*args, **kwargs):
nonlocal func
try:
return func(*args, **kwargs)
except Exception as e:
import traceback
msg = ''.join(traceback.TracebackException.from_exception(e).format())
raise RuntimeError(msg)

wrapped_func.__doc__ = func.__doc__
wrapped_func.__qualname__ = func.__qualname__
wrapped_func.__name__ = func.__name__
return wrapped_func


unique_class_id = 0
"""A unique identifier; used to guarantee the generated class java name is unique"""

Expand Down
9 changes: 5 additions & 4 deletions python/python-core/src/main/python/score/_annotations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from ._constraint_factory import ConstraintFactory
from ._constraint_builder import Constraint
from .._timefold_java_interop import ensure_init, _generate_constraint_provider_class, register_java_class
from typing import TypeVar, Callable, TYPE_CHECKING

from ._constraint_builder import Constraint
from ._constraint_factory import ConstraintFactory
from .._timefold_java_interop import ensure_init, _generate_constraint_provider_class, register_java_class, wrap_errors

if TYPE_CHECKING:
from ..score import Score

Expand Down Expand Up @@ -46,7 +47,7 @@ def constraint_provider(constraint_provider_function: Callable[[ConstraintFactor
def constraint_provider_wrapper(function):
def wrapped_constraint_provider(constraint_factory):
from ..score import ConstraintFactory
out = function(ConstraintFactory(constraint_factory))
out = wrap_errors(function)(ConstraintFactory(constraint_factory))
return out
java_class = _generate_constraint_provider_class(function, wrapped_constraint_provider)
return register_java_class(wrapped_constraint_provider, java_class)
Expand Down
34 changes: 29 additions & 5 deletions python/python-core/tests/test_user_error.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import re
from dataclasses import dataclass, field
from typing import Annotated, List

import pytest
from timefold.solver import *
from timefold.solver.config import *
from timefold.solver.domain import *
from timefold.solver.heuristic import *
from timefold.solver.score import *

import pytest
import re
from typing import Annotated, List
from dataclasses import dataclass, field


@planning_entity
@dataclass
Expand Down Expand Up @@ -106,6 +106,30 @@ def not_proxied():
)._to_java_solver_config()


def test_constraint_construction_failed():
import inspect
line = inspect.getframeinfo(inspect.stack()[0][0]).lineno + 4

def bad_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(BadEntity)
.penalize(SimpleScore.ONE)
.as_constraint('Penalize each entity')
]

bad_constraints = constraint_provider(bad_constraints)
solver_config = SolverConfig(
solution_class=Solution,
entity_class_list=[Entity],
score_director_factory_config=ScoreDirectorFactoryConfig(
constraint_provider_function=bad_constraints
)
)

with pytest.raises(RuntimeError, match=re.escape(f'line {line}, in bad_constraints')):
SolverFactory.create(solver_config).build_solver()


def test_missing_enterprise():
with pytest.raises(RequiresEnterpriseError, match=re.escape('multithreaded solving')):
solver_config = SolverConfig(
Expand Down

0 comments on commit 5f44f32

Please sign in to comment.