Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to latest linters #29

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added .dict-speechbrain.txt
Empty file.
4 changes: 2 additions & 2 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[flake8]
ignore = E203, E266, E501, W503
ignore = E203, E266, E501, W503, DOC105, DOC106, DOC107, DOC203, DOC403, DOC404, DOC405, DOC501, DOC502
# line length is intentionally set to 80 here because black uses Bugbear
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
max-line-length = 80
max-complexity = 18
select = B,C,E,F,W,T4,B9
select = B,C,E,F,W,T4,B9,DOC
16 changes: 12 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,26 @@ repos:
args: [--maxkb=1024]

- repo: https://github.com/psf/black
rev: 19.10b0
rev: 24.3.0
hooks:
- id: black
types: [python]
additional_dependencies: ['click==8.0.4']
additional_dependencies: ['click==8.1.7']
- repo: https://github.com/PyCQA/flake8
rev: 3.7.9
rev: 7.0.0
hooks:
- id: flake8
types: [python]

- repo: https://github.com/adrienverge/yamllint
rev: v1.23.0
rev: v1.35.1
hooks:
- id: yamllint

- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
hooks:
- id: codespell
args: [--ignore-words=.dict-speechbrain.txt]
additional_dependencies:
- tomli
106 changes: 51 additions & 55 deletions benchmarks/CL_MASR/analyze_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def parse_train_log(train_log: "str") -> "Dict[str, ndarray]":

Arguments
---------
train_log:
train_log: str
The path to the train log.

Returns
-------
The metrics, i.e. a dict that maps names of
the metrics to their corresponding values.
The metrics, i.e. a dict that maps names of
the metrics to their corresponding values.

Examples
--------
Expand Down Expand Up @@ -107,16 +107,16 @@ def compute_wer_matrix(

Arguments
---------
wers:
wers: ndarray
The word error rate for each locale.
num_base_locales:
num_base_locales: int
The number of base locales.
num_new_locales:
num_new_locales: int
The number of new locales.

Returns
-------
The word error rate matrix.
The word error rate matrix.

Raises
------
Expand Down Expand Up @@ -152,12 +152,12 @@ def compute_awer(wer_matrix: "ndarray") -> "ndarray":

Arguments
---------
wer_matrix:
wer_matrix: ndarray
The word error rate matrix.

Returns
-------
The average word error rate.
The average word error rate.

References
----------
Expand Down Expand Up @@ -185,12 +185,12 @@ def compute_bwt(wer_matrix: "ndarray") -> "ndarray":

Arguments
---------
wer_matrix:
wer_matrix: ndarray
The word error rate matrix.

Returns
-------
The backward transfer.
The backward transfer.

References
----------
Expand Down Expand Up @@ -220,14 +220,14 @@ def compute_im(wer_matrix: "ndarray", refs: "ndarray") -> "ndarray":

Arguments
---------
wer_matrix:
wer_matrix: ndarray
The word error rate matrix.
refs:
refs: ndarray
The intransigence measure references (joint fine-tuning).

Returns
-------
The intransigence measure.
The intransigence measure.

References
----------
Expand Down Expand Up @@ -255,14 +255,14 @@ def compute_fwt(wer_matrix: "ndarray", refs: "ndarray") -> "ndarray":

Arguments
---------
wer_matrix:
wer_matrix: ndarray
The word error rate matrix.
refs:
refs: ndarray
The forward transfer references (single task fine-tuning).

Returns
-------
The forward transfer.
The forward transfer.

Examples
--------
Expand All @@ -289,31 +289,31 @@ def plot_wer(
usetex: "bool" = False,
hide_legend: "bool" = False,
style_file_or_name: "str" = "classic",
) -> "None":
):
"""Plot word error rates extracted from a
continual learning train log.

Arguments
---------
wers:
wers: ndarray
The word error rates (base + new locales).
output_image:
output_image: str
The path to the output image.
base_locales:
base_locales: Sequence[str]
The base locales.
new_locales:
new_locales: Sequence[str]
The new locales.
xlabel:
xlabel: str
The x-axis label.
figsize:
figsize: Tuple[float, float]
The figure size.
title:
title: str
The plot title.
usetex:
usetex: bool
True to render text with LaTeX, False otherwise.
hide_legend:
hide_legend: bool
True to hide the legend, False otherwise.
style_file_or_name:
style_file_or_name: str
The path to a Matplotlib style file or the name of one
of Matplotlib built-in styles
(see https://matplotlib.org/stable/gallery/style_sheets/style_sheets_reference.html).
Expand Down Expand Up @@ -435,7 +435,7 @@ def plot_wer(
margin={"t": 60, "b": 60},
)
fig.write_html(
f"{output_image.rsplit('.', 1)[0]}.html", include_plotlyjs=True,
f"{output_image.rsplit('.', 1)[0]}.html", include_plotlyjs=True
)
except ImportError:
logging.warning(
Expand All @@ -455,32 +455,32 @@ def plot_metric(
usetex: "bool" = False,
hide_legend: "bool" = False,
style_file_or_name: "str" = "classic",
) -> "None":
):
"""Plot a continual learning metric.

Arguments
---------
metric_csv_file:
metric_csv_file: str
The path to the continual learning metric CSV file.
output_image:
output_image: str
The path to the output image.
xlabel:
xlabel: str
The x-axis label.
ylabel:
ylabel: str
The y-axis label.
xticks:
xticks: List[str]
The x-ticks.
figsize:
figsize: Tuple[float, float]
The figure size.
title:
title: str
The plot title.
opacity:
opacity: float
The confidence interval opacity.
usetex:
usetex: bool
True to render text with LaTeX, False otherwise.
hide_legend:
hide_legend: bool
True to hide the legend, False otherwise.
style_file_or_name:
style_file_or_name: str
The path to a Matplotlib style file or the name of one
of Matplotlib built-in styles
(see https://matplotlib.org/stable/gallery/style_sheets/style_sheets_reference.html).
Expand Down Expand Up @@ -650,7 +650,7 @@ def hex_to_rgb(hex_color: "str") -> "Tuple":
margin={"t": 60, "b": 60},
)
fig.write_html(
f"{output_image.rsplit('.', 1)[0]}.html", include_plotlyjs=True,
f"{output_image.rsplit('.', 1)[0]}.html", include_plotlyjs=True
)
except ImportError:
logging.warning(
Expand Down Expand Up @@ -689,9 +689,7 @@ def hex_to_rgb(hex_color: "str") -> "Tuple":
# fmt: on
help="forward transfer references",
)
parser.add_argument(
"-f", "--format", default="png", help="image format",
)
parser.add_argument("-f", "--format", default="png", help="image format")
parser.add_argument(
"-s",
"--figsize",
Expand All @@ -700,17 +698,15 @@ def hex_to_rgb(hex_color: "str") -> "Tuple":
type=float,
help="figure size",
)
parser.add_argument("-t", "--title", default=None, help="title")
parser.add_argument(
"-t", "--title", default=None, help="title",
)
parser.add_argument(
"-o", "--opacity", default=0.15, help="confidence interval opacity",
"-o", "--opacity", default=0.15, help="confidence interval opacity"
)
parser.add_argument(
"--hide_legend", action="store_true", help="hide legend",
"--hide_legend", action="store_true", help="hide legend"
)
parser.add_argument(
"-u", "--usetex", action="store_true", help="render text with LaTeX",
"-u", "--usetex", action="store_true", help="render text with LaTeX"
)
parser.add_argument(
"--order",
Expand Down Expand Up @@ -831,7 +827,7 @@ def hex_to_rgb(hex_color: "str") -> "Tuple":
avg_mean = np.mean(avg)
# Assuming independence, sigma^2 = sum_1^n sigma_i^2 / n^2
avg_stddev = np.sqrt(
np.nansum(stddev ** 2) / (~np.isnan(stddev)).sum() ** 2
np.nansum(stddev**2) / (~np.isnan(stddev)).sum() ** 2
)
csv_writer.writerow(
[group_name]
Expand All @@ -851,9 +847,9 @@ def hex_to_rgb(hex_color: "str") -> "Tuple":
f"{name.lower().replace(' ', '_')}.{args.format}",
),
xlabel=None,
ylabel=f"{name} (\%)"
if args.usetex
else f"{name} (%)", # noqa: W605
ylabel=(
f"{name} (\\%)" if args.usetex else f"{name} (%)"
), # noqa: W605
xticks=["base"] + [f"L{i}" for i in range(1, 1 + len(new_locales))],
figsize=args.figsize,
title=args.title,
Expand Down
Loading
Loading