Skip to content

Commit

Permalink
Fix decode.py to remove the correct axis. (k2-fsa#50)
Browse files Browse the repository at this point in the history
* Fix decode.py to remove the correct axis.

* Run GitHub actions manually.
  • Loading branch information
csukuangfj authored Sep 17, 2021
1 parent 9a6e048 commit cc77cb3
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/run-yesno-recipe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@ on:
branches:
- master
pull_request:
branches:
- master
types: [labeled]

jobs:
run-yesno-recipe:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
# os: [ubuntu-18.04, macos-10.15]
# TODO: enable macOS for CPU testing
os: [ubuntu-18.04]
python-version: [3.8]
torch: ["1.8.1"]
k2-version: ["1.8.dev20210917"]
fail-fast: false

steps:
Expand All @@ -54,10 +56,8 @@ jobs:
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip black flake8
python3 -m pip install -U pip
python3 -m pip install k2==1.7.dev20210914+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
python3 -m pip install torchaudio==0.7.2
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
# We are in ./icefall and there is a file: requirements.txt in it
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ on:
branches:
- master
pull_request:
branches:
- master
types: [labeled]

jobs:
test:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04, macos-10.15]
python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"]
k2-version: ["1.7.dev20210914"]
k2-version: ["1.8.dev20210917"]

fail-fast: false

Expand Down
10 changes: 5 additions & 5 deletions icefall/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def nbest_decoding(
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)

# Remove 0 (epsilon) and -1 from word_seq
word_seq = word_seq.remove_values_leq(0)
Expand Down Expand Up @@ -432,7 +432,7 @@ def rescore_with_n_best_list(
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)

# Remove epsilons and -1 from word_seq
word_seq = word_seq.remove_values_leq(0)
Expand Down Expand Up @@ -669,7 +669,7 @@ def nbest_oracle(
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)

word_seq = word_seq.remove_values_leq(0)
unique_word_seq, _, _ = word_seq.unique(
Expand Down Expand Up @@ -761,7 +761,7 @@ def rescore_with_attention_decoder(
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)

# Remove epsilons and -1 from word_seq
word_seq = word_seq.remove_values_leq(0)
Expand Down Expand Up @@ -815,7 +815,7 @@ def rescore_with_attention_decoder(
token_seq = k2.ragged.index(lattice.tokens, path)
else:
token_seq = lattice.tokens.index(path)
token_seq = token_seq.remove_axis(1)
token_seq = token_seq.remove_axis(token_seq.num_axes - 2)

# Remove epsilons and -1 from token_seq
token_seq = token_seq.remove_values_leq(0)
Expand Down

0 comments on commit cc77cb3

Please sign in to comment.