Skip to content

Commit

Permalink
Update requirements.
Browse files Browse the repository at this point in the history
  • Loading branch information
Update requirements bot committed Mar 25, 2024
1 parent 6243668 commit 0a46123
Show file tree
Hide file tree
Showing 5 changed files with 829 additions and 723 deletions.
18 changes: 6 additions & 12 deletions jopfra/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,13 @@ def size(self) -> int:
return int(np.prod(self._batch_shape))

@overload
def flatten(self, y: AnyNDArray) -> AnyNDArray:
...
def flatten(self, y: AnyNDArray) -> AnyNDArray: ...

@overload
def flatten(self, y: tc.Tensor) -> tc.Tensor:
...
def flatten(self, y: tc.Tensor) -> tc.Tensor: ...

@overload
def flatten(self, y: Evaluation) -> Evaluation:
...
def flatten(self, y: Evaluation) -> Evaluation: ...

@check_shapes(
"y: [batch_shape..., item_shape...]",
Expand All @@ -57,16 +54,13 @@ def flatten(
return np.reshape(y, (self.size,) + y.shape[self.dim :])

@overload
def unflatten(self, y: AnyNDArray) -> AnyNDArray:
...
def unflatten(self, y: AnyNDArray) -> AnyNDArray: ...

@overload
def unflatten(self, y: tc.Tensor) -> tc.Tensor:
...
def unflatten(self, y: tc.Tensor) -> tc.Tensor: ...

@overload
def unflatten(self, y: Evaluation) -> Evaluation:
...
def unflatten(self, y: Evaluation) -> Evaluation: ...

@check_shapes(
"y: [prod_batch_shape, item_shape...]",
Expand Down
11 changes: 5 additions & 6 deletions jopfra/minimisers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

class SingleMinimiser(ABC):
@abstractmethod
def single_minimise(self, problem: Problem, batch_shape: tuple[int, ...]) -> Evaluation:
...
def single_minimise(self, problem: Problem, batch_shape: tuple[int, ...]) -> Evaluation: ...


single_minimisers: dict[str, SingleMinimiser] = {}
Expand All @@ -21,8 +20,7 @@ def single_minimise(self, problem: Problem, batch_shape: tuple[int, ...]) -> Eva

class StoppingCriteria(ABC):
@abstractmethod
def stop(self) -> Stop:
...
def stop(self) -> Stop: ...


class IterMinimiserAdapter(SingleMinimiser):
Expand All @@ -43,8 +41,9 @@ def single_minimise(self, problem: Problem, batch_shape: tuple[int, ...]) -> Eva

class IterMinimiser(ABC):
@abstractmethod
def iter_minimise(self, problem: Problem, batch_shape: tuple[int, ...]) -> Iterator[Evaluation]:
...
def iter_minimise(
self, problem: Problem, batch_shape: tuple[int, ...]
) -> Iterator[Evaluation]: ...

def to_single(self, criteria: StoppingCriteria) -> IterMinimiserAdapter:
return IterMinimiserAdapter(self, criteria)
Expand Down
39 changes: 13 additions & 26 deletions jopfra/problems/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,37 +39,30 @@ def n_inputs(self) -> int:

class Problem(Protocol):
@property
def name(self) -> str:
...
def name(self) -> str: ...

@property
def n_inputs(self) -> int:
...
def n_inputs(self) -> int: ...

@property
def domain_lower(self) -> AnyNDArray:
...
def domain_lower(self) -> AnyNDArray: ...

@property
def domain_upper(self) -> AnyNDArray:
...
def domain_upper(self) -> AnyNDArray: ...

@property
def known_optima(self) -> Collection[AnyNDArray]:
...
def known_optima(self) -> Collection[AnyNDArray]: ...

@check_shapes(
"x: [batch..., n_inputs]",
"return: [batch...]",
)
def __call__(self, x: AnyNDArray) -> Evaluation:
...
def __call__(self, x: AnyNDArray) -> Evaluation: ...

@check_shapes(
"x: [n_inputs]",
)
def plot(self, dest: MiscDir, x: AnyNDArray) -> None:
...
def plot(self, dest: MiscDir, x: AnyNDArray) -> None: ...


problems: dict[str, Problem] = {}
Expand All @@ -83,12 +76,10 @@ def plot(self, dest: MiscDir, x: AnyNDArray) -> None:


class ProblemFunc(Protocol):
def __call__(self, x: AnyNDArray) -> tuple[AnyNDArray, AnyNDArray]:
...
def __call__(self, x: AnyNDArray) -> tuple[AnyNDArray, AnyNDArray]: ...

@property
def __name__(self) -> str:
...
def __name__(self) -> str: ...


check_plot_shapes = check_shapes(
Expand All @@ -97,8 +88,7 @@ def __name__(self) -> str:


class PlotFunc(Protocol):
def __call__(self, dest: MiscDir, x: AnyNDArray) -> None:
...
def __call__(self, dest: MiscDir, x: AnyNDArray) -> None: ...


@dataclass(order=True, frozen=True)
Expand Down Expand Up @@ -160,17 +150,14 @@ def _wrap(func: ProblemFunc) -> Problem:


class TorchProblemFunc(Protocol):
def __call__(self, x: tc.Tensor) -> tc.Tensor:
...
def __call__(self, x: tc.Tensor) -> tc.Tensor: ...

@property
def __name__(self) -> str:
...
def __name__(self) -> str: ...


class TorchPlotFunc(Protocol):
def __call__(self, dest: MiscDir, x: tc.Tensor) -> None:
...
def __call__(self, dest: MiscDir, x: tc.Tensor) -> None: ...


def torch_problem(
Expand Down
Loading

0 comments on commit 0a46123

Please sign in to comment.