Skip to content

Commit f7816ba

Browse files
committed
Use sphinx_autodoc_typehints
1 parent 6c69e88 commit f7816ba

File tree

5 files changed

+150
-131
lines changed

5 files changed

+150
-131
lines changed

Diff for: adaptive/learner/learner1D.py

+29-17
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from copy import copy, deepcopy
77
from numbers import Integral as Int
88
from numbers import Real
9-
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
9+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Sequence, Tuple, Union
1010

1111
import cloudpickle
1212
import numpy as np
@@ -24,12 +24,22 @@
2424
partial_function_from_dataframe,
2525
)
2626

27+
if TYPE_CHECKING:
28+
import holoviews
29+
2730
try:
2831
from typing import TypeAlias
2932
except ImportError:
3033
# Remove this when we drop support for Python 3.9
3134
from typing_extensions import TypeAlias
3235

36+
try:
37+
from typing import Literal
38+
except ImportError:
39+
# Remove this when we drop support for Python 3.7
40+
from typing_extensions import Literal
41+
42+
3343
try:
3444
import pandas
3545

@@ -145,7 +155,7 @@ def resolution_loss_function(
145155
146156
Returns
147157
-------
148-
loss_function : callable
158+
loss_function
149159
150160
Examples
151161
--------
@@ -230,12 +240,12 @@ class Learner1D(BaseLearner):
230240
231241
Parameters
232242
----------
233-
function : callable
243+
function
234244
The function to learn. Must take a single real parameter and
235245
return a real number or 1D array.
236-
bounds : pair of reals
246+
bounds
237247
The bounds of the interval on which to learn 'function'.
238-
loss_per_interval: callable, optional
248+
loss_per_interval
239249
A function that returns the loss for a single interval of the domain.
240250
If not provided, then a default is used, which uses the scaled distance
241251
in the x-y plane as the loss. See the notes for more details.
@@ -356,15 +366,15 @@ def to_dataframe(
356366
357367
Parameters
358368
----------
359-
with_default_function_args : bool, optional
369+
with_default_function_args
360370
Include the ``learner.function``'s default arguments as a
361371
column, by default True
362-
function_prefix : str, optional
372+
function_prefix
363373
Prefix to the ``learner.function``'s default arguments' names,
364374
by default "function."
365-
x_name : str, optional
375+
x_name
366376
Name of the input value, by default "x"
367-
y_name : str, optional
377+
y_name
368378
Name of the output value, by default "y"
369379
370380
Returns
@@ -403,16 +413,16 @@ def load_dataframe(
403413
404414
Parameters
405415
----------
406-
df : pandas.DataFrame
416+
df
407417
The data to load.
408-
with_default_function_args : bool, optional
418+
with_default_function_args
409419
The ``with_default_function_args`` used in ``to_dataframe()``,
410420
by default True
411-
function_prefix : str, optional
421+
function_prefix
412422
The ``function_prefix`` used in ``to_dataframe``, by default "function."
413-
x_name : str, optional
423+
x_name
414424
The ``x_name`` used in ``to_dataframe``, by default "x"
415-
y_name : str, optional
425+
y_name
416426
The ``y_name`` used in ``to_dataframe``, by default "y"
417427
"""
418428
self.tell_many(df[x_name].values, df[y_name].values)
@@ -795,17 +805,19 @@ def _loss(
795805
loss = mapping[ival]
796806
return finite_loss(ival, loss, self._scale[0])
797807

798-
def plot(self, *, scatter_or_line: str = "scatter"):
808+
def plot(
809+
self, *, scatter_or_line: Literal["scatter", "line"] = "scatter"
810+
) -> holoviews.Overlay:
799811
"""Returns a plot of the evaluated data.
800812
801813
Parameters
802814
----------
803-
scatter_or_line : str, default: "scatter"
815+
scatter_or_line
804816
Plot as a scatter plot ("scatter") or a line plot ("line").
805817
806818
Returns
807819
-------
808-
plot : `holoviews.Overlay`
820+
plot
809821
Plot of the evaluated data.
810822
"""
811823
if scatter_or_line not in ("scatter", "line"):

Diff for: adaptive/learner/learner2D.py

+45-40
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import OrderedDict
66
from copy import copy
77
from math import sqrt
8-
from typing import Callable, Iterable
8+
from typing import TYPE_CHECKING, Callable, Iterable
99

1010
import cloudpickle
1111
import numpy as np
@@ -22,6 +22,9 @@
2222
partial_function_from_dataframe,
2323
)
2424

25+
if TYPE_CHECKING:
26+
import holoviews
27+
2528
try:
2629
import pandas
2730

@@ -40,11 +43,11 @@ def deviations(ip: LinearNDInterpolator) -> list[np.ndarray]:
4043
4144
Parameters
4245
----------
43-
ip : `scipy.interpolate.LinearNDInterpolator` instance
46+
ip
4447
4548
Returns
4649
-------
47-
deviations : list
50+
deviations
4851
The deviation per triangle.
4952
"""
5053
values = ip.values / (ip.values.ptp(axis=0).max() or 1)
@@ -79,11 +82,11 @@ def areas(ip: LinearNDInterpolator) -> np.ndarray:
7982
8083
Parameters
8184
----------
82-
ip : `scipy.interpolate.LinearNDInterpolator` instance
85+
ip
8386
8487
Returns
8588
-------
86-
areas : numpy.ndarray
89+
areas
8790
The area per triangle in ``ip.tri``.
8891
"""
8992
p = ip.tri.points[ip.tri.simplices]
@@ -99,11 +102,11 @@ def uniform_loss(ip: LinearNDInterpolator) -> np.ndarray:
99102
100103
Parameters
101104
----------
102-
ip : `scipy.interpolate.LinearNDInterpolator` instance
105+
ip
103106
104107
Returns
105108
-------
106-
losses : numpy.ndarray
109+
losses
107110
Loss per triangle in ``ip.tri``.
108111
109112
Examples
@@ -136,7 +139,7 @@ def resolution_loss_function(
136139
137140
Returns
138141
-------
139-
loss_function : callable
142+
loss_function
140143
141144
Examples
142145
--------
@@ -173,11 +176,11 @@ def minimize_triangle_surface_loss(ip: LinearNDInterpolator) -> np.ndarray:
173176
174177
Parameters
175178
----------
176-
ip : `scipy.interpolate.LinearNDInterpolator` instance
179+
ip
177180
178181
Returns
179182
-------
180-
losses : numpy.ndarray
183+
losses
181184
Loss per triangle in ``ip.tri``.
182185
183186
Examples
@@ -217,11 +220,11 @@ def default_loss(ip: LinearNDInterpolator) -> np.ndarray:
217220
218221
Parameters
219222
----------
220-
ip : `scipy.interpolate.LinearNDInterpolator` instance
223+
ip
221224
222225
Returns
223226
-------
224-
losses : numpy.ndarray
227+
losses
225228
Loss per triangle in ``ip.tri``.
226229
"""
227230
dev = np.sum(deviations(ip), axis=0)
@@ -241,15 +244,15 @@ def choose_point_in_triangle(triangle: np.ndarray, max_badness: int) -> np.ndarr
241244
242245
Parameters
243246
----------
244-
triangle : numpy.ndarray
247+
triangle
245248
The coordinates of a triangle with shape (3, 2).
246-
max_badness : int
249+
max_badness
247250
The badness at which the point is either chosen on a edge or
248251
in the middle.
249252
250253
Returns
251254
-------
252-
point : numpy.ndarray
255+
point
253256
The x and y coordinate of the suggested new point.
254257
"""
255258
a, b, c = triangle
@@ -267,17 +270,17 @@ def choose_point_in_triangle(triangle: np.ndarray, max_badness: int) -> np.ndarr
267270
return point
268271

269272

270-
def triangle_loss(ip):
273+
def triangle_loss(ip: LinearNDInterpolator) -> list[float]:
271274
r"""Computes the average of the volumes of the simplex combined with each
272275
neighbouring point.
273276
274277
Parameters
275278
----------
276-
ip : `scipy.interpolate.LinearNDInterpolator` instance
279+
ip
277280
278281
Returns
279282
-------
280-
triangle_loss : list
283+
triangle_loss
281284
The mean volume per triangle.
282285
283286
Notes
@@ -311,13 +314,13 @@ class Learner2D(BaseLearner):
311314
312315
Parameters
313316
----------
314-
function : callable
317+
function
315318
The function to learn. Must take a tuple of two real
316319
parameters and return a real number.
317-
bounds : list of 2-tuples
320+
bounds
318321
A list ``[(a1, b1), (a2, b2)]`` containing bounds,
319322
one per dimension.
320-
loss_per_triangle : callable, optional
323+
loss_per_triangle
321324
A function that returns the loss for every triangle.
322325
If not provided, then a default is used, which uses
323326
the deviation from a linear estimate, as well as
@@ -424,19 +427,19 @@ def to_dataframe(
424427
425428
Parameters
426429
----------
427-
with_default_function_args : bool, optional
430+
with_default_function_args
428431
Include the ``learner.function``'s default arguments as a
429432
column, by default True
430-
function_prefix : str, optional
433+
function_prefix
431434
Prefix to the ``learner.function``'s default arguments' names,
432435
by default "function."
433-
seed_name : str, optional
436+
seed_name
434437
Name of the seed parameter, by default "seed"
435-
x_name : str, optional
438+
x_name
436439
Name of the input x value, by default "x"
437-
y_name : str, optional
440+
y_name
438441
Name of the input y value, by default "y"
439-
z_name : str, optional
442+
z_name
440443
Name of the output value, by default "z"
441444
442445
Returns
@@ -475,18 +478,18 @@ def load_dataframe(
475478
476479
Parameters
477480
----------
478-
df : pandas.DataFrame
481+
df
479482
The data to load.
480-
with_default_function_args : bool, optional
483+
with_default_function_args
481484
The ``with_default_function_args`` used in ``to_dataframe()``,
482485
by default True
483-
function_prefix : str, optional
486+
function_prefix
484487
The ``function_prefix`` used in ``to_dataframe``, by default "function."
485-
x_name : str, optional
488+
x_name
486489
The ``x_name`` used in ``to_dataframe``, by default "x"
487-
y_name : str, optional
490+
y_name
488491
The ``y_name`` used in ``to_dataframe``, by default "y"
489-
z_name : str, optional
492+
z_name
490493
The ``z_name`` used in ``to_dataframe``, by default "z"
491494
"""
492495
data = df.set_index([x_name, y_name])[z_name].to_dict()
@@ -538,7 +541,7 @@ def interpolated_on_grid(
538541
539542
Parameters
540543
----------
541-
n : int, optional
544+
n
542545
Number of points in x and y. If None (default) this number is
543546
evaluated by looking at the size of the smallest triangle.
544547
@@ -611,14 +614,14 @@ def interpolator(self, *, scaled: bool = False) -> LinearNDInterpolator:
611614
612615
Parameters
613616
----------
614-
scaled : bool
617+
scaled
615618
Use True if all points are inside the
616619
unit-square [(-0.5, 0.5), (-0.5, 0.5)] or False if
617620
the data points are inside the ``learner.bounds``.
618621
619622
Returns
620623
-------
621-
interpolator : `scipy.interpolate.LinearNDInterpolator`
624+
interpolator
622625
623626
Examples
624627
--------
@@ -755,7 +758,9 @@ def remove_unfinished(self) -> None:
755758
if p not in self.data:
756759
self._stack[p] = np.inf
757760

758-
def plot(self, n=None, tri_alpha=0):
761+
def plot(
762+
self, n: int = None, tri_alpha: float = 0
763+
) -> holoviews.Overlay | holoviews.HoloMap:
759764
r"""Plot the Learner2D's current state.
760765
761766
This plot function interpolates the data on a regular grid.
@@ -764,16 +769,16 @@ def plot(self, n=None, tri_alpha=0):
764769
765770
Parameters
766771
----------
767-
n : int
772+
n
768773
Number of points in x and y. If None (default) this number is
769774
evaluated by looking at the size of the smallest triangle.
770-
tri_alpha : float
775+
tri_alpha
771776
The opacity ``(0 <= tri_alpha <= 1)`` of the triangles overlayed
772777
on top of the image. By default the triangulation is not visible.
773778
774779
Returns
775780
-------
776-
plot : `holoviews.core.Overlay` or `holoviews.core.HoloMap`
781+
plot
777782
A `holoviews.core.Overlay` of
778783
``holoviews.Image * holoviews.EdgePaths``. If the
779784
`learner.function` returns a vector output, a

0 commit comments

Comments
 (0)