diff --git a/.circleci/config.yml b/.circleci/config.yml index 644fd8b31b7..26b9f600e3c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -218,6 +218,9 @@ jobs: - restore_cache: keys: - data-cache-phantom-kit + - restore_cache: + keys: + - data-cache-ds004388 - run: name: Get data # This limit could be increased, but this is helpful for finding slow ones @@ -252,7 +255,7 @@ jobs: name: Check sphinx log for warnings (which are treated as errors) when: always command: | - ! grep "^.* WARNING: .*$" sphinx_log.txt + ! grep "^.*\(WARNING\|ERROR\): " sphinx_log.txt - run: name: Show profiling output when: always @@ -393,6 +396,10 @@ jobs: key: data-cache-phantom-kit paths: - ~/mne_data/MNE-phantom-KIT-data # (1 G) + - save_cache: + key: data-cache-ds004388 + paths: + - ~/mne_data/ds004388 # (1.8 G) linkcheck: diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 5214c911317..18543b854d0 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -13,10 +13,12 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + persist-credentials: false - uses: actions/setup-python@v5 with: python-version: '3.12' - run: pip install --upgrade towncrier pygithub gitpython numpy - run: python ./.github/actions/rename_towncrier/rename_towncrier.py - run: python ./tools/dev/ensure_headers.py - - uses: autofix-ci/action@ff86a557419858bb967097bfc916833f5647fa8c + - uses: autofix-ci/action@551dded8c6cc8a1054039c8bc0b8b48c51dfc6ef diff --git a/.github/workflows/automerge.yml b/.github/workflows/automerge.yml new file mode 100644 index 00000000000..68720eaaa34 --- /dev/null +++ b/.github/workflows/automerge.yml @@ -0,0 +1,17 @@ +name: Bot auto-merge +on: pull_request # yamllint disable-line rule:truthy + +jobs: + autobot: + permissions: + contents: write + pull-requests: write + runs-on: ubuntu-latest + # Names can be found with gh api /repos/mne-tools/mne-python/pulls/12998 -q .user.login for example + if: (github.event.pull_request.user.login == 'dependabot[bot]' || github.event.pull_request.user.login == 'pre-commit-ci[bot]' || github.event.pull_request.user.login == 'github-actions[bot]') && github.repository == 'mne-tools/mne-python' + steps: + - name: Enable auto-merge for bot PRs + run: gh pr merge --auto --squash "$PR_URL" + env: + PR_URL: ${{github.event.pull_request.html_url}} + GH_TOKEN: ${{secrets.GITHUB_TOKEN}} diff --git a/.github/workflows/check_changelog.yml b/.github/workflows/check_changelog.yml index cc85b591977..6995c399b34 100644 --- a/.github/workflows/check_changelog.yml +++ b/.github/workflows/check_changelog.yml @@ -5,6 +5,9 @@ on: # yamllint disable-line rule:truthy types: [opened, synchronize, labeled, unlabeled] branches: ["main"] +permissions: + contents: read + jobs: changelog_checker: name: Check towncrier entry in doc/changes/devel/ diff --git a/.github/workflows/circle_artifacts.yml b/.github/workflows/circle_artifacts.yml index fa32e1ce80c..301c6234eb5 100644 --- a/.github/workflows/circle_artifacts.yml +++ b/.github/workflows/circle_artifacts.yml @@ -1,4 +1,7 @@ on: [status] # yamllint disable-line rule:truthy +permissions: + contents: read + statuses: write jobs: circleci_artifacts_redirector_job: if: "${{ startsWith(github.event.context, 'ci/circleci: build_docs') }}" diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 7f348f80778..35a0d8fdc1a 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -39,6 +39,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 + with: + persist-credentials: false # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL @@ -48,11 +50,11 @@ jobs: # If you wish to specify custom queries, you can do so here or in a config file. # By default, queries listed here will override any specified in a config file. # Prefix the list here with "+" to use these queries and those in the config file. - + # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs # queries: security-extended,security-and-quality - + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild @@ -61,7 +63,7 @@ jobs: # ℹ️ Command-line programs to run using the OS shell. # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun - # If the Autobuild fails above, remove it and uncomment the following three lines. + # If the Autobuild fails above, remove it and uncomment the following three lines. # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. # - run: | diff --git a/.github/workflows/credit.yml b/.github/workflows/credit.yml index 96ab7544034..1e4f9ff0931 100644 --- a/.github/workflows/credit.yml +++ b/.github/workflows/credit.yml @@ -6,12 +6,11 @@ on: # yamllint disable-line rule:truthy - cron: '0 0 1 * *' # At 00:00 on day-of-month 1 workflow_dispatch: -permissions: - contents: write - pull-requests: write - jobs: update_credit: + permissions: + contents: write + pull-requests: write name: Update runs-on: ubuntu-latest env: @@ -19,6 +18,8 @@ jobs: GITHUB_TOKEN: ${{ github.token }} steps: - uses: actions/checkout@v4 + with: + persist-credentials: true - uses: actions/setup-python@v5 with: python-version: '3.12' @@ -37,8 +38,8 @@ jobs: git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com" git config --global user.name "github-actions[bot]" git checkout -b credit - git commit -am "MAINT: Update code credit [ci skip]" + git commit -am "MAINT: Update code credit" git push origin credit - PR_NUM=$(gh pr create --base main --head credit --title "MAINT: Update code credit" --body "Created by \"${{ github.workflow }}\" GitHub action." --label "no-changelog-entry-needed") + PR_NUM=$(gh pr create --base main --head credit --title "MAINT: Update code credit" --body "Created by credit [GitHub action](https://github.com/mne-tools/mne-python/actions/runs/${{ github.run_id }})." --label "no-changelog-entry-needed") echo "Opened https://github.com/mne-tools/mne-python/pull/${PR_NUM}" >> $GITHUB_STEP_SUMMARY if: steps.status.outputs.dirty == 'true' diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a004ff9ea21..90c83c8130a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -17,6 +17,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + persist-credentials: false - uses: actions/setup-python@v5 with: python-version: '3.10' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f0c9da989b3..298908cdc65 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,6 +18,8 @@ jobs: timeout-minutes: 3 steps: - uses: actions/checkout@v4 + with: + persist-credentials: false - uses: actions/setup-python@v5 with: python-version: '3.12' @@ -56,6 +58,9 @@ jobs: fail-fast: false matrix: include: + - os: ubuntu-latest + python: '3.13' + kind: pip - os: ubuntu-latest python: '3.12' kind: pip-pre @@ -81,44 +86,44 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 + persist-credentials: false - run: ./tools/github_actions_env_vars.sh # Xvfb/OpenGL - uses: pyvista/setup-headless-display-action@v3 with: qt: true pyvista: false + wm: false # Python (if pip) - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} if: startswith(matrix.kind, 'pip') # Python (if conda) - - name: Remove numba and dipy - run: | # TODO: Remove when numba 0.59 and dipy 1.8 land on conda-forge - sed -i '/numba/d' environment.yml - sed -i '/dipy/d' environment.yml - sed -i 's/- mne$/- mne-base/' environment.yml - if: matrix.os == 'ubuntu-latest' && startswith(matrix.kind, 'conda') && matrix.python == '3.12' + - name: Fixes for conda + run: | + # For some reason on Linux we get crashes + if [[ "$RUNNER_OS" == "Linux" ]]; then + sed -i "/numba/d" environment.yml + elif [[ "$RUNNER_OS" == "macOS" ]]; then + sed -i "" "s/ - PySide6 .*/ - PySide6 <6.8/g" environment.yml + fi + if: matrix.kind == 'conda' || matrix.kind == 'mamba' - uses: mamba-org/setup-micromamba@v2 with: environment-file: ${{ env.CONDA_ENV }} environment-name: mne create-args: >- python=${{ env.PYTHON_VERSION }} - mamba - nomkl if: ${{ !startswith(matrix.kind, 'pip') }} - # Make sure we have the right Python - - run: python -c "import platform; assert platform.machine() == 'arm64', platform.machine()" - if: matrix.os == 'macos-14' - - run: ./tools/github_actions_dependencies.sh + - run: bash ./tools/github_actions_dependencies.sh # Minimal commands on Linux (macOS stalls) - - run: ./tools/get_minimal_commands.sh - if: ${{ startswith(matrix.os, 'ubuntu') }} - - run: ./tools/github_actions_infos.sh + - run: bash ./tools/get_minimal_commands.sh + if: startswith(matrix.os, 'ubuntu') && matrix.kind != 'minimal' && matrix.kind != 'old' + - run: bash ./tools/github_actions_infos.sh # Check Qt - - run: ./tools/check_qt_import.sh $MNE_QT_BACKEND - if: ${{ env.MNE_QT_BACKEND != '' }} + - run: bash ./tools/check_qt_import.sh $MNE_QT_BACKEND + if: env.MNE_QT_BACKEND != '' - name: Run tests with no testing data run: MNE_SKIP_TESTING_DATASET_TESTS=true pytest -m "not (ultraslowtest or pgtest)" --tb=short --cov=mne --cov-report xml -vv -rfE mne/ if: matrix.kind == 'minimal' @@ -127,8 +132,8 @@ jobs: with: key: ${{ env.TESTING_VERSION }} path: ~/mne_data - - run: ./tools/github_actions_download.sh - - run: ./tools/github_actions_test.sh + - run: bash ./tools/github_actions_download.sh + - run: bash ./tools/github_actions_test.sh # for some reason on macOS we need to run "bash X" in order for a failed test run to show up - uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.mailmap b/.mailmap index 327f5a9a648..133eb2be306 100644 --- a/.mailmap +++ b/.mailmap @@ -2,6 +2,7 @@ Adam Li Adam Li Adam Li Adam Li Alan Leggitt leggitta Alessandro Tonin Lychfindel <58313635+Lychfindel@users.noreply.github.com> +Alex Lepauvre Alex lepauvre Alex Rockhill Alex Alex Rockhill Alex Alex Rockhill Alex Rockhill @@ -356,4 +357,4 @@ Yousra Bekhti Yousra BEKHTI Yousra Bekhti yousrabk Zhi Zhang <850734033@qq.com> ZHANG Zhi <850734033@qq.com> Zhi Zhang <850734033@qq.com> ZHANG Zhi -Ziyi ZENG +Ziyi ZENG ZIYI ZENG diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5546126b2d6..34b3ce9b130 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.0 + rev: v0.9.2 hooks: - id: ruff name: ruff lint mne @@ -70,16 +70,21 @@ repos: name: Copy dependency changes from pyproject.toml to environment.yml language: python entry: ./tools/hooks/update_environment_file.py - files: pyproject.toml + files: '^(pyproject.toml|tools/hooks/update_environment_file.py)$' - repo: local hooks: - id: dependency-sync name: Copy core dependencies from pyproject.toml to README.rst language: python entry: ./tools/hooks/sync_dependencies.py - files: pyproject.toml - additional_dependencies: ["mne"] + files: '^(pyproject.toml|tools/hooks/sync_dependencies.py)$' + additional_dependencies: ["mne==1.9.0"] + # zizmor + - repo: https://github.com/woodruffw/zizmor-pre-commit + rev: v1.2.2 + hooks: + - id: zizmor # these should *not* be run on CIs: ci: diff --git a/CITATION.cff b/CITATION.cff index a55d21e00c0..2ba5fd6b4c6 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -1,9 +1,9 @@ cff-version: 1.2.0 title: "MNE-Python" message: "If you use this software, please cite both the software itself, and the paper listed in the preferred-citation field." -version: 1.8.0 -date-released: "2024-08-18" -commit: 9a760d76971e845b67b619804c1156cc04c9c948 +version: 1.9.0 +date-released: "2024-12-18" +commit: 14938b9657b255a38aa96482a4aaf410e8865859 doi: 10.5281/zenodo.592483 keywords: - MEG @@ -40,15 +40,15 @@ authors: - family-names: Luessi given-names: Martin - family-names: King - given-names: Jean-Remi + given-names: Jean-Rémi - family-names: Höchenberger given-names: Richard - family-names: Goj given-names: Roman - - family-names: Favelier - given-names: Guillaume - family-names: Brunner given-names: Clemens + - family-names: Favelier + given-names: Guillaume - family-names: van Vliet given-names: Marijn - family-names: Wronkiewicz @@ -57,22 +57,22 @@ authors: given-names: Alex - family-names: Holdgraf given-names: Chris - - family-names: Massich - given-names: Joan - - family-names: Bekhti - given-names: Yousra - family-names: Scheltienne given-names: Mathieu + - family-names: Massich + given-names: Joan - family-names: Appelhoff given-names: Stefan + - family-names: Bekhti + given-names: Yousra - family-names: Leggitt given-names: Alan - family-names: Dykstra given-names: Andrew - - family-names: Luke - given-names: Rob - family-names: Trachel given-names: Romain + - family-names: Luke + given-names: Robert - family-names: De Santis given-names: Lorenzo - family-names: Panda @@ -111,12 +111,12 @@ authors: given-names: Jair - family-names: Woodman given-names: Marmaduke + - family-names: Huberty + given-names: Scott - family-names: Lee given-names: Ingoo - family-names: Schulz given-names: Martin - - family-names: Huberty - given-names: Scott - family-names: Foti given-names: Nick - family-names: Nangini @@ -156,6 +156,8 @@ authors: given-names: Ana - family-names: Buran given-names: Brad + - family-names: Woessner + given-names: Jacob - family-names: Massias given-names: Mathurin - family-names: Hämäläinen @@ -168,8 +170,6 @@ authors: given-names: Christopher - family-names: Raimundo given-names: Félix - - family-names: Woessner - given-names: Jacob - family-names: Kaneda given-names: Michiru - family-names: Alday @@ -192,6 +192,8 @@ authors: given-names: Mads - family-names: Gahlot given-names: Tanay + - family-names: Binns + given-names: Thomas S - family-names: Nunes given-names: Adonay - family-names: Gütlin @@ -221,6 +223,10 @@ authors: given-names: Natalie - family-names: Roujansky given-names: Paul + - family-names: Luke + given-names: Rob + - family-names: Ruuskanen + given-names: Santeri - family-names: Kern given-names: Simon - family-names: Rantala @@ -275,10 +281,10 @@ authors: given-names: Nathalie - family-names: Ward given-names: Nick - - family-names: Ruuskanen - given-names: Santeri - family-names: Herbst given-names: Sophie + - family-names: Férat + given-names: Victor - family-names: Radanovic given-names: Ana - family-names: Quinn @@ -299,6 +305,8 @@ authors: given-names: Evgenii - family-names: Mamashli given-names: Fahimeh + - family-names: Belonosov + given-names: Gennadiy - family-names: O'Neill given-names: George - family-names: Marinato @@ -327,6 +335,8 @@ authors: given-names: Nicolas - family-names: Kapralov given-names: Nikolai + - family-names: Chu + given-names: Qian - family-names: Falach given-names: Rotem - family-names: Deslauriers-Gauthier @@ -337,14 +347,10 @@ authors: given-names: Steve - family-names: Bierer given-names: Steven - - family-names: Binns - given-names: Thomas S - family-names: Binns given-names: Thomas Samuel - family-names: Stenner given-names: Tristan - - family-names: Férat - given-names: Victor - family-names: Peterson given-names: Victoria - family-names: Baratz @@ -369,8 +375,8 @@ authors: given-names: Dominique - family-names: Mikulan given-names: Ezequiel - - family-names: Belonosov - given-names: Gennadiy + - family-names: Hofer + given-names: Florian - family-names: Schiratti given-names: Jean-Baptiste - family-names: Evans @@ -416,12 +422,14 @@ authors: given-names: Peter J - family-names: Ablin given-names: Pierre - - family-names: Chu - given-names: Qian + - family-names: Das + given-names: Proloy - family-names: Bertrand given-names: Quentin - family-names: Shoorangiz given-names: Reza + - family-names: Scholz + given-names: Richard - family-names: Hübner given-names: Rodrigo - family-names: Sommariva @@ -460,6 +468,8 @@ authors: given-names: Adina - family-names: Ciok given-names: Alex + - family-names: Lepauvre + given-names: Alex - family-names: Kiefer given-names: Alexander - family-names: Gilbert @@ -534,12 +544,12 @@ authors: given-names: Etienne - family-names: Goldstein given-names: Evgeny + - family-names: Mamashli + given-names: Fahimeh - family-names: Negahbani given-names: Farzin - family-names: Zamberlan given-names: Federico - - family-names: Hofer - given-names: Florian - family-names: Pop given-names: Florin - family-names: Weber @@ -573,6 +583,8 @@ authors: given-names: Ivan - family-names: de Jong given-names: Ivo + - family-names: Phelan + given-names: Jacob - family-names: Kaczmarzyk given-names: Jakub - family-names: Zerfowski @@ -603,6 +615,8 @@ authors: given-names: Laetitia - family-names: Andersen given-names: Lau Møller + - family-names: Almeida + given-names: Leonardo Rochael - family-names: Barbosa given-names: Leonardo S - family-names: Alfine @@ -671,8 +685,6 @@ authors: given-names: Padma - family-names: Silva given-names: Pedro - - family-names: Das - given-names: Proloy - family-names: Li given-names: Quanliang - family-names: Barthélemy @@ -689,8 +701,6 @@ authors: given-names: Reza - family-names: Koehler given-names: Richard - - family-names: Scholz - given-names: Richard - family-names: Stargardsky given-names: Riessarius - family-names: Oostenveld @@ -727,6 +737,8 @@ authors: given-names: Simeon - family-names: Wong given-names: Simeon + - family-names: Hofmann + given-names: Simon M - family-names: Poil given-names: Simon-Shlomo - family-names: Foslien @@ -769,6 +781,8 @@ authors: given-names: Yiping - family-names: Zhang given-names: Zhi + - family-names: ZENG + given-names: Ziyi - name: btkcodedev - name: buildqa - name: luzpaz diff --git a/Makefile b/Makefile index 89adc810eec..80c79edace3 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,7 @@ test-doc: sample_data testing_data $(PYTESTS) --doctest-modules --doctest-ignore-import-errors --doctest-glob='*.rst' ./doc/ --ignore=./doc/auto_examples --ignore=./doc/auto_tutorials --ignore=./doc/_build --ignore=./doc/conf.py --ignore=doc/sphinxext --fulltrace pre-commit: - @pre-commit run -a + @pre-commit run -a --show-diff-on-failure # Aliases for stuff we used to support or users might think of ruff: pre-commit diff --git a/README.rst b/README.rst index 50e0daaa52c..cfa7488324f 100644 --- a/README.rst +++ b/README.rst @@ -72,7 +72,7 @@ The minimum required dependencies to run MNE-Python are: .. ↓↓↓ BEGIN CORE DEPS LIST. DO NOT EDIT! HANDLED BY PRE-COMMIT HOOK ↓↓↓ -- `Python `__ ≥ 3.9 +- `Python `__ ≥ 3.10 - `NumPy `__ ≥ 1.23 - `SciPy `__ ≥ 1.9 - `Matplotlib `__ ≥ 3.6 diff --git a/SECURITY.md b/SECURITY.md index 82d4c9e45de..a8e59476a67 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -10,9 +10,9 @@ without a proper 6-month deprecation cycle. | Version | Supported | | ------- | ------------------------ | -| 1.8.x | :heavy_check_mark: (dev) | -| 1.7.x | :heavy_check_mark: | -| < 1.7 | :x: | +| 1.9.x | :heavy_check_mark: (dev) | +| 1.8.x | :heavy_check_mark: | +| < 1.8 | :x: | ## Reporting a Vulnerability diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 3ca4177174f..7149edac50b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -88,6 +88,7 @@ stages: variables: DISPLAY: ':99' OPENBLAS_NUM_THREADS: '1' + MNE_TEST_ALLOW_SKIP: '^.*(PySide6 causes segfaults).*$' steps: - bash: | set -e @@ -111,7 +112,7 @@ stages: - bash: | set -e python -m pip install --progress-bar off --upgrade pip - python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn pytest-error-for-skips python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1" + python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1,!=6.8.1.1" pandas neo pymatreader antio defusedxml python -m pip uninstall -yq mne python -m pip install --progress-bar off --upgrade -e .[test] displayName: 'Install dependencies with pip' @@ -132,7 +133,7 @@ stages: displayName: 'Cache testing data' - script: python -c "import mne; mne.datasets.testing.data_path(verbose=True)" displayName: 'Get test data' - - script: pytest --error-for-skips -m "ultraslowtest or pgtest" --tb=short --cov=mne --cov-report=xml --cov-report=html -vv mne + - script: pytest -m "ultraslowtest or pgtest" --tb=short --cov=mne --cov-report=xml -vv mne displayName: 'slow and mne-qt-browser tests' # Coverage - bash: bash <(curl -s https://codecov.io/bash) @@ -144,11 +145,9 @@ stages: testRunTitle: 'Publish test results for $(Agent.JobName)' failTaskOnFailedTests: true condition: succeededOrFailed() - - task: PublishCodeCoverageResults@1 + - task: PublishCodeCoverageResults@2 inputs: - codeCoverageTool: Cobertura summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' - reportDirectory: '$(System.DefaultWorkingDirectory)/**/htmlcov' - job: Qt pool: @@ -156,7 +155,8 @@ stages: variables: DISPLAY: ':99' OPENBLAS_NUM_THREADS: '1' - TEST_OPTIONS: "--tb=short --cov=mne --cov-report=xml --cov-report=html --cov-append -vv mne/viz/_brain mne/viz/backends mne/viz/tests/test_evoked.py mne/gui mne/report" + TEST_OPTIONS: "--tb=short --cov=mne --cov-report=xml --cov-append -vv mne/viz/_brain mne/viz/backends mne/viz/tests/test_evoked.py mne/gui mne/report" + MNE_TEST_ALLOW_SKIP: '^.*(PySide6 causes segfaults).*$' steps: - bash: ./tools/setup_xvfb.sh displayName: 'Install Ubuntu dependencies' @@ -192,6 +192,7 @@ stages: set -eo pipefail python -m pip install PyQt6 LD_DEBUG=libs python -c "from PyQt6.QtWidgets import QApplication, QWidget; app = QApplication([]); import matplotlib; matplotlib.use('QtAgg'); import matplotlib.pyplot as plt; plt.figure()" + displayName: 'Check Qt import' - bash: | set -eo pipefail mne sys_info -pd @@ -226,11 +227,9 @@ stages: testRunTitle: 'Publish test results for $(Agent.JobName)' failTaskOnFailedTests: true condition: succeededOrFailed() - - task: PublishCodeCoverageResults@1 + - task: PublishCodeCoverageResults@2 inputs: - codeCoverageTool: Cobertura summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' - reportDirectory: '$(System.DefaultWorkingDirectory)/**/htmlcov' - job: Windows pool: @@ -244,7 +243,7 @@ stages: PYTHONIOENCODING: 'utf-8' AZURE_CI_WINDOWS: 'true' PYTHON_ARCH: 'x64' - timeoutInMinutes: 75 + timeoutInMinutes: 80 strategy: maxParallel: 4 matrix: @@ -285,7 +284,7 @@ stages: displayName: 'Cache testing data' - script: python -c "import mne; mne.datasets.testing.data_path(verbose=True)" displayName: 'Get test data' - - script: pytest -m "not (slowtest or pgtest)" --tb=short --cov=mne --cov-report=xml --cov-report=html -vv mne + - script: pytest -m "not (slowtest or pgtest)" --tb=short --cov=mne --cov-report=xml -vv mne displayName: 'Run tests' - bash: bash <(curl -s https://codecov.io/bash) displayName: 'Codecov' @@ -296,8 +295,6 @@ stages: testRunTitle: 'Publish test results for $(Agent.JobName) $(TEST_MODE) $(PYTHON_VERSION)' failTaskOnFailedTests: true condition: succeededOrFailed() - - task: PublishCodeCoverageResults@1 + - task: PublishCodeCoverageResults@2 inputs: - codeCoverageTool: Cobertura summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' - reportDirectory: '$(System.DefaultWorkingDirectory)/**/htmlcov' diff --git a/codemeta.json b/codemeta.json index 0351ca98a1c..d5ec88c23e9 100644 --- a/codemeta.json +++ b/codemeta.json @@ -5,11 +5,11 @@ "codeRepository": "git+https://github.com/mne-tools/mne-python.git", "dateCreated": "2010-12-26", "datePublished": "2014-08-04", - "dateModified": "2024-08-18", - "downloadUrl": "https://github.com/mne-tools/mne-python/archive/v1.8.0.zip", + "dateModified": "2024-12-18", + "downloadUrl": "https://github.com/mne-tools/mne-python/archive/v1.9.0.zip", "issueTracker": "https://github.com/mne-tools/mne-python/issues", "name": "MNE-Python", - "version": "1.8.0", + "version": "1.9.0", "description": "MNE-Python is an open-source Python package for exploring, visualizing, and analyzing human neurophysiological data. It provides methods for data input/output, preprocessing, visualization, source estimation, time-frequency analysis, connectivity analysis, machine learning, and statistics.", "applicationCategory": "Neuroscience", "developmentStatus": "active", @@ -37,16 +37,16 @@ "macOS" ], "softwareRequirements": [ - "python>=3.9", - "numpy>=1.23,<3", - "scipy>=1.9", - "matplotlib>=3.6", - "tqdm", - "pooch>=1.5", + "python>= 3.10", "decorator", - "packaging", "jinja2", - "lazy_loader>=0.3" + "lazy_loader >= 0.3", + "matplotlib >= 3.6", + "numpy >= 1.23,<3", + "packaging", + "pooch >= 1.5", + "scipy >= 1.9", + "tqdm" ], "author": [ { @@ -112,7 +112,7 @@ { "@type":"Person", "email":"jeanremi.king+github@gmail.com", - "givenName":"Jean-Remi", + "givenName":"Jean-Rémi", "familyName": "King" }, { @@ -127,18 +127,18 @@ "givenName":"Roman", "familyName": "Goj" }, - { - "@type":"Person", - "email":"guillaume.favelier@gmail.com", - "givenName":"Guillaume", - "familyName": "Favelier" - }, { "@type":"Person", "email":"clemens.brunner@gmail.com", "givenName":"Clemens", "familyName": "Brunner" }, + { + "@type":"Person", + "email":"guillaume.favelier@gmail.com", + "givenName":"Guillaume", + "familyName": "Favelier" + }, { "@type":"Person", "email":"w.m.vanvliet@gmail.com", @@ -163,30 +163,30 @@ "givenName":"Chris", "familyName": "Holdgraf" }, - { - "@type":"Person", - "email":"mailsik@gmail.com", - "givenName":"Joan", - "familyName": "Massich" - }, - { - "@type":"Person", - "email":"yousra.bekhti@gmail.com", - "givenName":"Yousra", - "familyName": "Bekhti" - }, { "@type":"Person", "email":"mathieu.scheltienne@gmail.com", "givenName":"Mathieu", "familyName": "Scheltienne" }, + { + "@type":"Person", + "email":"mailsik@gmail.com", + "givenName":"Joan", + "familyName": "Massich" + }, { "@type":"Person", "email":"stefan.appelhoff@mailbox.org", "givenName":"Stefan", "familyName": "Appelhoff" }, + { + "@type":"Person", + "email":"yousra.bekhti@gmail.com", + "givenName":"Yousra", + "familyName": "Bekhti" + }, { "@type":"Person", "email":"leggitta3@gmail.com", @@ -199,18 +199,18 @@ "givenName":"Andrew", "familyName": "Dykstra" }, - { - "@type":"Person", - "email":"code@robertluke.net", - "givenName":"Rob", - "familyName": "Luke" - }, { "@type":"Person", "email":"romain.trachel@inria.fr", "givenName":"Romain", "familyName": "Trachel" }, + { + "@type":"Person", + "email":"code@robertluke.net", + "givenName":"Robert", + "familyName": "Luke" + }, { "@type":"Person", "email":"desantis.lnz@gmail.com", @@ -325,6 +325,12 @@ "givenName":"Marmaduke", "familyName": "Woodman" }, + { + "@type":"Person", + "email":"", + "givenName":"Scott", + "familyName": "Huberty" + }, { "@type":"Person", "email":"dlsrnsladlek@naver.com", @@ -337,12 +343,6 @@ "givenName":"Martin", "familyName": "Schulz" }, - { - "@type":"Person", - "email":"", - "givenName":"Scott", - "familyName": "Huberty" - }, { "@type":"Person", "email":"nfoti01@gmail.com", @@ -463,6 +463,12 @@ "givenName":"Brad", "familyName": "Buran" }, + { + "@type":"Person", + "email":"Woessner.jacob@gmail.com", + "givenName":"Jacob", + "familyName": "Woessner" + }, { "@type":"Person", "email":"mathurin.massias@gmail.com", @@ -499,12 +505,6 @@ "givenName":"Félix", "familyName": "Raimundo" }, - { - "@type":"Person", - "email":"Woessner.jacob@gmail.com", - "givenName":"Jacob", - "familyName": "Woessner" - }, { "@type":"Person", "email":"rcmdnk@gmail.com", @@ -571,6 +571,12 @@ "givenName":"Tanay", "familyName": "Gahlot" }, + { + "@type":"Person", + "email":"t.s.binns@outlook.com", + "givenName":"Thomas S", + "familyName": "Binns" + }, { "@type":"Person", "email":"adonay.s.nunes@gmail.com", @@ -661,6 +667,18 @@ "givenName":"Paul", "familyName": "Roujansky" }, + { + "@type":"Person", + "email":"code@robertluke.net", + "givenName":"Rob", + "familyName": "Luke" + }, + { + "@type":"Person", + "email":"santeri.ruuskanen@aalto.fi", + "givenName":"Santeri", + "familyName": "Ruuskanen" + }, { "@type":"Person", "email":"simon.kern@online.de", @@ -823,18 +841,18 @@ "givenName":"Nick", "familyName": "Ward" }, - { - "@type":"Person", - "email":"santeri.ruuskanen@aalto.fi", - "givenName":"Santeri", - "familyName": "Ruuskanen" - }, { "@type":"Person", "email":"ksherbst@gmail.com", "givenName":"Sophie", "familyName": "Herbst" }, + { + "@type":"Person", + "email":"victor.ferat@live.Fr", + "givenName":"Victor", + "familyName": "Férat" + }, { "@type":"Person", "email":"", @@ -895,6 +913,12 @@ "givenName":"Fahimeh", "familyName": "Mamashli" }, + { + "@type":"Person", + "email":"", + "givenName":"Gennadiy", + "familyName": "Belonosov" + }, { "@type":"Person", "email":"g.o'neill@ucl.ac.uk", @@ -979,6 +1003,12 @@ "givenName":"Nikolai", "familyName": "Kapralov" }, + { + "@type":"Person", + "email":"", + "givenName":"Qian", + "familyName": "Chu" + }, { "@type":"Person", "email":"falachrotem@gmail.com", @@ -1009,12 +1039,6 @@ "givenName":"Steven", "familyName": "Bierer" }, - { - "@type":"Person", - "email":"t.s.binns@outlook.com", - "givenName":"Thomas S", - "familyName": "Binns" - }, { "@type":"Person", "email":"t.s.binns@outlook.com", @@ -1027,12 +1051,6 @@ "givenName":"Tristan", "familyName": "Stenner" }, - { - "@type":"Person", - "email":"victor.ferat@live.Fr", - "givenName":"Victor", - "familyName": "Férat" - }, { "@type":"Person", "email":"victoriapeterson09@gmail.com", @@ -1107,9 +1125,9 @@ }, { "@type":"Person", - "email":"", - "givenName":"Gennadiy", - "familyName": "Belonosov" + "email":"hofaflo@gmail.com", + "givenName":"Florian", + "familyName": "Hofer" }, { "@type":"Person", @@ -1251,9 +1269,9 @@ }, { "@type":"Person", - "email":"", - "givenName":"Qian", - "familyName": "Chu" + "email":"proloy@umd.edu", + "givenName":"Proloy", + "familyName": "Das" }, { "@type":"Person", @@ -1267,6 +1285,12 @@ "givenName":"Reza", "familyName": "Shoorangiz" }, + { + "@type":"Person", + "email":"", + "givenName":"Richard", + "familyName": "Scholz" + }, { "@type":"Person", "email":"rhubner@gmail.com", @@ -1387,6 +1411,12 @@ "givenName":"Alex", "familyName": "Ciok" }, + { + "@type":"Person", + "email":"alex.lepauvre@ae.mpg.de", + "givenName":"Alex", + "familyName": "Lepauvre" + }, { "@type":"Person", "email":"", @@ -1609,6 +1639,12 @@ "givenName":"Evgeny", "familyName": "Goldstein" }, + { + "@type":"Person", + "email":"fmamashli@gmail.com", + "givenName":"Fahimeh", + "familyName": "Mamashli" + }, { "@type":"Person", "email":"farzin.negahbani@gmail.com", @@ -1621,12 +1657,6 @@ "givenName":"Federico", "familyName": "Zamberlan" }, - { - "@type":"Person", - "email":"hofaflo@gmail.com", - "givenName":"Florian", - "familyName": "Hofer" - }, { "@type":"Person", "email":"florinpop@me.com", @@ -1729,6 +1759,12 @@ "givenName":"Ivo", "familyName": "de Jong" }, + { + "@type":"Person", + "email":"jacob.phelan.jp@gmail.com", + "givenName":"Jacob", + "familyName": "Phelan" + }, { "@type":"Person", "email":"", @@ -1819,6 +1855,12 @@ "givenName":"Lau Møller", "familyName": "Andersen" }, + { + "@type":"Person", + "email":"leorochael@gmail.com", + "givenName":"Leonardo Rochael", + "familyName": "Almeida" + }, { "@type":"Person", "email":"lsbarbosa@gmail.com", @@ -2023,12 +2065,6 @@ "givenName":"Pedro", "familyName": "Silva" }, - { - "@type":"Person", - "email":"proloy@umd.edu", - "givenName":"Proloy", - "familyName": "Das" - }, { "@type":"Person", "email":"glia@dtu.dk", @@ -2077,12 +2113,6 @@ "givenName":"Richard", "familyName": "Koehler" }, - { - "@type":"Person", - "email":"", - "givenName":"Richard", - "familyName": "Scholz" - }, { "@type":"Person", "email":"rie.acad@gmail.com", @@ -2191,6 +2221,12 @@ "givenName":"Simeon", "familyName": "Wong" }, + { + "@type":"Person", + "email":"", + "givenName":"Simon M", + "familyName": "Hofmann" + }, { "@type":"Person", "email":"", @@ -2317,6 +2353,12 @@ "givenName":"Zhi", "familyName": "Zhang" }, + { + "@type":"Person", + "email":"ziyizeng@link.cuhk.edu.cn", + "givenName":"Ziyi", + "familyName": "ZENG" + }, { "@type":"Person", "email":"btk.codedev@gmail.com", diff --git a/doc/_static/versions.json b/doc/_static/versions.json index ba4f9fc5d99..478677634c0 100644 --- a/doc/_static/versions.json +++ b/doc/_static/versions.json @@ -1,14 +1,19 @@ [ { - "name": "1.9 (devel)", + "name": "1.10 (devel)", "version": "dev", "url": "https://mne.tools/dev/" }, { - "name": "1.8 (stable)", + "name": "1.9 (stable)", "version": "stable", "url": "https://mne.tools/stable/" }, + { + "name": "1.8", + "version": "1.8", + "url": "https://mne.tools/1.8/" + }, { "name": "1.7", "version": "1.7", diff --git a/doc/api/datasets.rst b/doc/api/datasets.rst index 2b2c92c8654..87730fbd717 100644 --- a/doc/api/datasets.rst +++ b/doc/api/datasets.rst @@ -18,6 +18,7 @@ Datasets brainstorm.bst_auditory.data_path brainstorm.bst_resting.data_path brainstorm.bst_raw.data_path + default_path eegbci.load_data eegbci.standardize fetch_aparc_sub_parcellation diff --git a/doc/api/preprocessing.rst b/doc/api/preprocessing.rst index 1e0e9e56079..9fe3f995cc4 100644 --- a/doc/api/preprocessing.rst +++ b/doc/api/preprocessing.rst @@ -48,6 +48,7 @@ Projections: read_dig_localite make_standard_montage read_custom_montage + transform_to_head compute_dev_head_t read_layout find_layout @@ -115,6 +116,7 @@ Projections: read_ica_eeglab read_fine_calibration write_fine_calibration + apply_pca_obs :py:mod:`mne.preprocessing.nirs`: diff --git a/doc/api/time_frequency.rst b/doc/api/time_frequency.rst index 8923920bdba..a9ab2c34268 100644 --- a/doc/api/time_frequency.rst +++ b/doc/api/time_frequency.rst @@ -31,6 +31,8 @@ Functions that operate on mne-python objects: .. autosummary:: :toctree: ../generated/ + combine_spectrum + combine_tfr csd_tfr csd_fourier csd_multitaper diff --git a/doc/changes/devel/12071.newfeature.rst b/doc/changes/devel/12071.newfeature.rst new file mode 100644 index 00000000000..4e7995e3beb --- /dev/null +++ b/doc/changes/devel/12071.newfeature.rst @@ -0,0 +1 @@ +Add new ``select`` parameter to :func:`mne.viz.plot_evoked_topo` and :meth:`mne.Evoked.plot_topo` to toggle lasso selection of sensors, by `Marijn van Vliet`_. diff --git a/doc/changes/devel/12366.newfeature.rst b/doc/changes/devel/12366.newfeature.rst deleted file mode 100644 index 979c7141504..00000000000 --- a/doc/changes/devel/12366.newfeature.rst +++ /dev/null @@ -1 +0,0 @@ -Add support for `dict` type argument ``ref_channels`` to :func:`mne.set_eeg_reference`, to allow flexible re-referencing (e.g. ``raw.set_eeg_reference(ref_channels={'A1': ['A2', 'A3']})`` will set the new A1 data to be ``A1 - mean(A2, A3)``), by `Alex Lepauvre`_ and `Qian Chu`_ and `Daniel McCloy`_. \ No newline at end of file diff --git a/doc/changes/devel/12656.bugfix.rst b/doc/changes/devel/12656.bugfix.rst new file mode 100644 index 00000000000..3f32dbd23e5 --- /dev/null +++ b/doc/changes/devel/12656.bugfix.rst @@ -0,0 +1 @@ +Fix bug where :func:`mne.export.export_raw` does not correct for recording start time (:attr:`raw.first_time `) when exporting Raw instances to EDF or EEGLAB formats, by `Qian Chu`_. \ No newline at end of file diff --git a/doc/changes/devel/12787.other.rst b/doc/changes/devel/12787.other.rst deleted file mode 100644 index 1f53fdea066..00000000000 --- a/doc/changes/devel/12787.other.rst +++ /dev/null @@ -1 +0,0 @@ -Use custom code in :func:`mne.sys_info` to get the amount of physical memory and a more informative CPU name instead of using the ``psutil`` package, by `Clemens Brunner`_. \ No newline at end of file diff --git a/doc/changes/devel/12792.newfeature.rst b/doc/changes/devel/12792.newfeature.rst deleted file mode 100644 index 81ef79c8a11..00000000000 --- a/doc/changes/devel/12792.newfeature.rst +++ /dev/null @@ -1 +0,0 @@ -Add reader for ANT Neuro files in the ``*.cnt`` format with :func:`~mne.io.read_raw_ant`, by `Mathieu Scheltienne`_, `Eric Larson`_ and `Proloy Das`_. diff --git a/doc/changes/devel/12798.dependency.rst b/doc/changes/devel/12798.dependency.rst deleted file mode 100644 index ef05dab1e8d..00000000000 --- a/doc/changes/devel/12798.dependency.rst +++ /dev/null @@ -1 +0,0 @@ -- Minimum supported dependencies were updated in accordance with SPEC0_, most notably Python 3.10+ is now required. diff --git a/doc/changes/devel/12801.newfeature.rst b/doc/changes/devel/12801.newfeature.rst deleted file mode 100644 index 5f81e025c52..00000000000 --- a/doc/changes/devel/12801.newfeature.rst +++ /dev/null @@ -1 +0,0 @@ -- Add support for a :class:`mne.transforms.Transform` in the argument ``trans`` of the coregistration GUI called with :func:`mne.gui.coregistration`, by `Mathieu Scheltienne`_. diff --git a/doc/changes/devel/12803.bugfix.rst b/doc/changes/devel/12803.bugfix.rst deleted file mode 100644 index c10bddd517b..00000000000 --- a/doc/changes/devel/12803.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Fix handling of MRI file-path in :class:`mne.SourceSpaces` and safeguard saving of :class:`pathlib.Path` with ``h5io`` by casting to :class:`str`, by `Mathieu Scheltienne`_. diff --git a/doc/changes/devel/12804.bugfix.rst b/doc/changes/devel/12804.bugfix.rst deleted file mode 100644 index 87a988a4525..00000000000 --- a/doc/changes/devel/12804.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Cast ``fwd["info"]`` to :class:`~mne.Info` and ``fwd["src"]`` to :class:`~mne.SourceSpaces` when loading a forward solution from an HDF5 file, by `Mathieu Scheltienne`_. diff --git a/doc/changes/devel/12805.newfeature.rst b/doc/changes/devel/12805.newfeature.rst deleted file mode 100644 index 2c77d55d3ba..00000000000 --- a/doc/changes/devel/12805.newfeature.rst +++ /dev/null @@ -1 +0,0 @@ -Added support for ``sensor_scales`` to :meth:`mne.viz.Brain.add_sensors` and :func:`mne.viz.plot_alignment`, by :newcontrib:`Alex Lepauvre`. \ No newline at end of file diff --git a/doc/changes/devel/12811.newfeature.rst b/doc/changes/devel/12811.newfeature.rst deleted file mode 100644 index def54b1a68b..00000000000 --- a/doc/changes/devel/12811.newfeature.rst +++ /dev/null @@ -1 +0,0 @@ -:meth:`~mne.io.Raw` and :meth:`~mne.Epochs.save` now return the path to the saved file(s), by `Victor Ferat`_. diff --git a/doc/changes/devel/12827.other.rst b/doc/changes/devel/12827.other.rst deleted file mode 100644 index 3ccbaa0bff6..00000000000 --- a/doc/changes/devel/12827.other.rst +++ /dev/null @@ -1 +0,0 @@ -Improve documentation clarity of ``fit_transform`` methods for :class:`mne.decoding.SSD`, :class:`mne.decoding.CSP`, and :class:`mne.decoding.SPoC` classes, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/12828.bugfix.rst b/doc/changes/devel/12828.bugfix.rst new file mode 100644 index 00000000000..707385ac698 --- /dev/null +++ b/doc/changes/devel/12828.bugfix.rst @@ -0,0 +1 @@ +Fixed behavior of :func:`mne.viz.plot_source_estimates` where the ``title`` was not displayed properly, by :newcontrib:`Shristi Baral`. diff --git a/doc/changes/devel/12829.apichange.rst b/doc/changes/devel/12829.apichange.rst deleted file mode 100644 index d0bd4c12a46..00000000000 --- a/doc/changes/devel/12829.apichange.rst +++ /dev/null @@ -1 +0,0 @@ -Deprecate ``average`` parameter in ``plot_filters`` and ``plot_patterns`` methods of the :class:`mne.decoding.CSP` and :class:`mne.decoding.SPoC` classes, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/12830.newfeature.rst b/doc/changes/devel/12830.newfeature.rst deleted file mode 100644 index 4d51229392d..00000000000 --- a/doc/changes/devel/12830.newfeature.rst +++ /dev/null @@ -1 +0,0 @@ -:func:`mne.channels.read_custom_montage` may now read a newer version of the ``.elc`` ASA Electrode file format, by `Stefan Appelhoff`_. diff --git a/doc/changes/devel/12834.dependency.rst b/doc/changes/devel/12834.dependency.rst deleted file mode 100644 index ca19423df87..00000000000 --- a/doc/changes/devel/12834.dependency.rst +++ /dev/null @@ -1,2 +0,0 @@ -Importing from ``mne.decoding`` now explicitly requires ``scikit-learn`` to be installed, -by `Eric Larson`_. diff --git a/doc/changes/devel/12842.bugfix.rst b/doc/changes/devel/12842.bugfix.rst deleted file mode 100644 index 75f83683b8f..00000000000 --- a/doc/changes/devel/12842.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Fix bug where :meth:`mne.Epochs.compute_tfr` could not be used with the multitaper method and complex or phase outputs, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/12843.bugfix.rst b/doc/changes/devel/12843.bugfix.rst deleted file mode 100644 index 6f3be428b3a..00000000000 --- a/doc/changes/devel/12843.bugfix.rst +++ /dev/null @@ -1,3 +0,0 @@ -Fixed a bug where split FIF files that were read and then appended to other -:class:`mne.io.Raw` instances had their ``BAD boundary`` annotations incorrectly offset -in samples by the number of split files, by `Eric Larson`_. diff --git a/doc/changes/devel/12843.other.rst b/doc/changes/devel/12843.other.rst deleted file mode 100644 index 5271d6124de..00000000000 --- a/doc/changes/devel/12843.other.rst +++ /dev/null @@ -1 +0,0 @@ -Improve handling of filenames in ``raw.filenames`` by using :class:`~pathlib.Path` instead of :class:`str`, by `Mathieu Scheltienne`_. diff --git a/doc/changes/devel/12844.other.rst b/doc/changes/devel/12844.other.rst deleted file mode 100644 index ce959d8132a..00000000000 --- a/doc/changes/devel/12844.other.rst +++ /dev/null @@ -1 +0,0 @@ -Improve automatic figure scaling of :func:`mne.viz.plot_events`, and event_id and count overview legend when a high amount of unique events is supplied, by `Stefan Appelhoff`_. diff --git a/doc/changes/devel/12846.bugfix.rst b/doc/changes/devel/12846.bugfix.rst deleted file mode 100644 index ce18e8f5201..00000000000 --- a/doc/changes/devel/12846.bugfix.rst +++ /dev/null @@ -1,2 +0,0 @@ -Enforce SI units for Eyetracking data (eyegaze data should be radians of visual angle, not pixels. Pupil size data should be meters). -Updated tutorials so demonstrate how to convert data to SI units before analyses, by `Scott Huberty`_. \ No newline at end of file diff --git a/doc/changes/devel/12853.bugfix.rst b/doc/changes/devel/12853.bugfix.rst deleted file mode 100644 index 18c8afbb8ea..00000000000 --- a/doc/changes/devel/12853.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Prevent the ``colorbar`` parameter being ignored in topomap plots such as :meth:`mne.time_frequency.Spectrum.plot_topomap`, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/12862.other.rst b/doc/changes/devel/12862.other.rst deleted file mode 100644 index 393beeb8a8c..00000000000 --- a/doc/changes/devel/12862.other.rst +++ /dev/null @@ -1 +0,0 @@ -:meth:`mne.preprocessing.ICA.find_bads_muscle` can now be run when passing an ``inst`` without sensor positions. However, it will just use the first of three criteria (slope) to find muscle-related ICA components, by `Stefan Appelhoff`_. diff --git a/doc/changes/devel/12871.newfeature.rst b/doc/changes/devel/12871.newfeature.rst deleted file mode 100644 index 7c6f9e6c9df..00000000000 --- a/doc/changes/devel/12871.newfeature.rst +++ /dev/null @@ -1,2 +0,0 @@ -Added the ``title`` argument to :func:`mne.viz.create_3d_figure`, and -``color`` and ``position`` arguments to :func:`mne.viz.set_3d_title`, by `Eric Larson`_. diff --git a/doc/changes/devel/12875.bugfix.rst b/doc/changes/devel/12875.bugfix.rst deleted file mode 100644 index c4fa57e9100..00000000000 --- a/doc/changes/devel/12875.bugfix.rst +++ /dev/null @@ -1,2 +0,0 @@ -Fix bug where invalid data types (e.g., ``np.ndarray``s) could be used in some -:class:`mne.io.Info` fields like ``info["subject_info"]["weight"]``, by `Eric Larson`_. \ No newline at end of file diff --git a/doc/changes/devel/12877.bugfix.rst b/doc/changes/devel/12877.bugfix.rst deleted file mode 100644 index 2d9ecf2c489..00000000000 --- a/doc/changes/devel/12877.bugfix.rst +++ /dev/null @@ -1,4 +0,0 @@ -When creating a :class:`~mne.time_frequency.SpectrumArray`, the array shape check now -compares against the total of both 'good' and 'bad' channels in the provided -:class:`~mne.Info` (previously only good channels were checked), by -`Mathieu Scheltienne`_. diff --git a/doc/changes/devel/12884.bugfix.rst b/doc/changes/devel/12884.bugfix.rst deleted file mode 100644 index 6c5beda7241..00000000000 --- a/doc/changes/devel/12884.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Fix bug where :ref:`mne coreg` would always show MEG channels even if the "MEG Sensors" checkbox was disabled, by `Eric Larson`_. diff --git a/doc/changes/devel/12896.other.rst b/doc/changes/devel/12896.other.rst deleted file mode 100644 index 7ad9ff17a63..00000000000 --- a/doc/changes/devel/12896.other.rst +++ /dev/null @@ -1 +0,0 @@ -Update governance model, by `Daniel McCloy`_. diff --git a/doc/changes/devel/12901.bugfix.rst b/doc/changes/devel/12901.bugfix.rst deleted file mode 100644 index d68f70f7141..00000000000 --- a/doc/changes/devel/12901.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -:class:`mne.Report` HDF5 files are now written in ``mode='a'`` (append) to allow users to store other data in the HDF5 files, by `Eric Larson`_. diff --git a/doc/changes/devel/12901.newfeature.rst b/doc/changes/devel/12901.newfeature.rst deleted file mode 100644 index 8d0137fce78..00000000000 --- a/doc/changes/devel/12901.newfeature.rst +++ /dev/null @@ -1,8 +0,0 @@ -Improved reporting and plotting options: - -- :meth:`mne.Report.add_projs` can now plot with :func:`mne.viz.plot_projs_joint` rather than :func:`mne.viz.plot_projs_topomap` -- :class:`mne.Report` now has attributes ``img_max_width`` and ``img_max_res`` that can be used to control image scaling. -- :class:`mne.Report` now has an attribute ``collapse`` that allows collapsing sections and/or subsections by default. -- :func:`mne.viz.plot_head_positions` now has a ``totals=True`` option to show the total distance and angle of the head. - -Changes by `Eric Larson`_. diff --git a/doc/changes/devel/12909.bugfix.rst b/doc/changes/devel/12909.bugfix.rst deleted file mode 100644 index 9e2f5672323..00000000000 --- a/doc/changes/devel/12909.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Fix bug in :func:`mne.io.read_raw_gdf` when NumPy >= 2 is used, by `Clemens Brunner`_. \ No newline at end of file diff --git a/doc/changes/devel/12910.newfeature.rst b/doc/changes/devel/12910.newfeature.rst new file mode 100644 index 00000000000..95605c11017 --- /dev/null +++ b/doc/changes/devel/12910.newfeature.rst @@ -0,0 +1 @@ +Added the option to return taper weights from :func:`mne.time_frequency.tfr_array_multitaper`, and taper weights are now stored in the :class:`mne.time_frequency.BaseTFR` objects, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/12911.bugfix.rst b/doc/changes/devel/12911.bugfix.rst deleted file mode 100644 index c04a23d645d..00000000000 --- a/doc/changes/devel/12911.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Allow exporting edf where a channel contains only constant values, by `Florian Hofer`_. diff --git a/doc/changes/devel/12912.newfeature.rst b/doc/changes/devel/12912.newfeature.rst deleted file mode 100644 index 2a7343ebd2c..00000000000 --- a/doc/changes/devel/12912.newfeature.rst +++ /dev/null @@ -1 +0,0 @@ -Added the ``psd_args`` argument to :func:`mne.viz.plot_ica_sources` and :meth:`mne.preprocessing.ICA.plot_sources`, by `Richard Scholz`_. \ No newline at end of file diff --git a/doc/changes/devel/12918.apichange.rst b/doc/changes/devel/12918.apichange.rst deleted file mode 100644 index 958662b1b6f..00000000000 --- a/doc/changes/devel/12918.apichange.rst +++ /dev/null @@ -1 +0,0 @@ -Deprecate ``subject`` parameter in favor of ``subjects`` in :func:`mne.datasets.eegbci.load_data`, by `Stefan Appelhoff`_. diff --git a/doc/changes/devel/12924.bugfix.rst b/doc/changes/devel/12924.bugfix.rst deleted file mode 100644 index 57afa60fbd8..00000000000 --- a/doc/changes/devel/12924.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Fix typos in the Spatio-Spectral Decomposition example, by :newcontrib:`Simon M. Hofmann`. \ No newline at end of file diff --git a/doc/changes/devel/12931.bugfix.rst b/doc/changes/devel/12931.bugfix.rst deleted file mode 100644 index 7c41cd03a7d..00000000000 --- a/doc/changes/devel/12931.bugfix.rst +++ /dev/null @@ -1,3 +0,0 @@ -Fix a bug in :func:`mne.epochs.make_metadata`, where missing values in the columns -generated for ``keep_first`` and ``keep_last`` events were represented by empty strings, -while it should have been ``NA`` values, by `Richard Höchenberger`_. diff --git a/doc/changes/devel/12931.other.rst b/doc/changes/devel/12931.other.rst deleted file mode 100644 index bf0a83534ad..00000000000 --- a/doc/changes/devel/12931.other.rst +++ /dev/null @@ -1 +0,0 @@ -Improve the :ref:`tut-autogenerate-metadata`, by `Clemens Brunner`_ and `Richard Höchenberger`_. diff --git a/doc/changes/devel/12936.bugfix.rst b/doc/changes/devel/12936.bugfix.rst deleted file mode 100644 index 8cb1967d4c4..00000000000 --- a/doc/changes/devel/12936.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Fix decimal places of :class:`float` ``mne.Evoked.nave`` in :meth:`mne.Evoked.plot` and :meth:`mne.Evoked.plot_image`, by `Gennadiy Belonosov`_. diff --git a/doc/changes/devel/12955.bugfix.rst b/doc/changes/devel/12955.bugfix.rst deleted file mode 100644 index 924944da9dd..00000000000 --- a/doc/changes/devel/12955.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Fix duration calculation for the textual (``__repr__``) and html (``_repr_html_``, used by e.g. Jupyter) display of :class:`mne.io.Raw` instances. For example a duration of 1h is now displayed as ``00:01:00`` rather than ``00:59:60``. By :newcontrib:`Leonardo Rochael Almeida`. diff --git a/doc/changes/devel/12955.newfeature.rst b/doc/changes/devel/12955.newfeature.rst deleted file mode 100644 index 8ab68c9a138..00000000000 --- a/doc/changes/devel/12955.newfeature.rst +++ /dev/null @@ -1 +0,0 @@ -Add convenience :attr:`mne.io.Raw.duration` property to centralize duration calculation for the textual (``__repr__``) and html (``_repr_html_``, used by e.g. Jupyter) display of :class:`mne.io.Raw` instances, by :newcontrib:`Leonardo Rochael Almeida`. diff --git a/doc/changes/devel/12960.other.rst b/doc/changes/devel/12960.other.rst deleted file mode 100644 index 5d136d8e1d8..00000000000 --- a/doc/changes/devel/12960.other.rst +++ /dev/null @@ -1 +0,0 @@ -Mention some gotchas that arise from the fact that by default, we pool across dipole orientations when performing source estimation, by `Marijn van Vliet`_ diff --git a/doc/changes/devel/12962.bugfix.rst b/doc/changes/devel/12962.bugfix.rst deleted file mode 100644 index cf70d8458ba..00000000000 --- a/doc/changes/devel/12962.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Fix displayed units in representations of classes such as :class:`mne.io.Raw` to correctly use KiB, MiB, GiB, and so on, by `Clemens Brunner`_. \ No newline at end of file diff --git a/doc/changes/devel/12966.newfeature.rst b/doc/changes/devel/12966.newfeature.rst deleted file mode 100644 index dff334d9b0a..00000000000 --- a/doc/changes/devel/12966.newfeature.rst +++ /dev/null @@ -1 +0,0 @@ -Add ability to use :func:`mne.preprocessing.compute_fine_calibration` with non-Neuromag-style systems, as well as options to control the bad-angle and error tolerances, by `Eric Larson`_. diff --git a/doc/changes/devel/12968.bugfix.rst b/doc/changes/devel/12968.bugfix.rst deleted file mode 100644 index a512cc34ad6..00000000000 --- a/doc/changes/devel/12968.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Gracefully handle invalid patient info when reading EDF files by `Scott Huberty`_. \ No newline at end of file diff --git a/doc/changes/devel/12971.newfeature.rst b/doc/changes/devel/12971.newfeature.rst deleted file mode 100644 index a822dd24ab5..00000000000 --- a/doc/changes/devel/12971.newfeature.rst +++ /dev/null @@ -1 +0,0 @@ -Add support for ``uint16_codec`` argument in :func:`mne.io.read_raw_eeglab` when ``pymatreader`` (which already supported this argument previously) is not installed, by `Clemens Brunner`_. \ No newline at end of file diff --git a/doc/changes/devel/12978.other.rst b/doc/changes/devel/12978.other.rst deleted file mode 100644 index 33e73c2d8f9..00000000000 --- a/doc/changes/devel/12978.other.rst +++ /dev/null @@ -1 +0,0 @@ -Fix a mistake in :ref:`tut-artifact-regression` where the wrong regression coefficients were applied, by :newcontrib:`Jacob Phelan`. diff --git a/doc/changes/devel/12986.bugfix.rst b/doc/changes/devel/12986.bugfix.rst deleted file mode 100644 index 5bacb548fdd..00000000000 --- a/doc/changes/devel/12986.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Fix IndexError when loading CNT file does not have annotations, by :newcontrib:`Ziyi ZENG`. \ No newline at end of file diff --git a/doc/changes/devel/13003.newfeature.rst b/doc/changes/devel/13003.newfeature.rst deleted file mode 100644 index 141265406a8..00000000000 --- a/doc/changes/devel/13003.newfeature.rst +++ /dev/null @@ -1 +0,0 @@ -Added support for saving and loading channel names from FIF in :meth:`mne.channels.DigMontage.save` and :meth:`mne.channels.read_dig_fif` and added the convention that they should be saved as ``-dig.fif``, by `Eric Larson`_. diff --git a/doc/changes/devel/13019.newfeature.rst b/doc/changes/devel/13019.newfeature.rst new file mode 100644 index 00000000000..6fe8ee492bf --- /dev/null +++ b/doc/changes/devel/13019.newfeature.rst @@ -0,0 +1 @@ +Add ``fig.mne`` container for :class:`Colorbar ` in :func:`plot_connectivity_circle ` to allow users to access it directly, by `Santeri Ruuskanen`_. \ No newline at end of file diff --git a/doc/changes/devel/13028.bugfix.rst b/doc/changes/devel/13028.bugfix.rst new file mode 100644 index 00000000000..13e34189eaf --- /dev/null +++ b/doc/changes/devel/13028.bugfix.rst @@ -0,0 +1 @@ +Fix epoch indexing in :class:`mne.time_frequency.EpochsTFRArray` when initialising the class with the default ``drop_log`` parameter, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/13037.newfeature.rst b/doc/changes/devel/13037.newfeature.rst new file mode 100644 index 00000000000..3b28e2294ab --- /dev/null +++ b/doc/changes/devel/13037.newfeature.rst @@ -0,0 +1 @@ +Add PCA-OBS preprocessing for the removal of heart-artefacts from EEG or ESG datasets via :func:`mne.preprocessing.apply_pca_obs`, by :newcontrib:`Emma Bailey` and :newcontrib:`Steinn Hauser Magnusson`. diff --git a/doc/changes/devel/13042.bugfix.rst b/doc/changes/devel/13042.bugfix.rst new file mode 100644 index 00000000000..15208d5b89f --- /dev/null +++ b/doc/changes/devel/13042.bugfix.rst @@ -0,0 +1 @@ +Fix loading and saving of :class:`~mne.time_frequency.EpochsSpectrum` objects that contain slash-separators in their condition names, by `Daniel McCloy`_. diff --git a/doc/changes/devel/13048.bugfix.rst b/doc/changes/devel/13048.bugfix.rst new file mode 100644 index 00000000000..8f0fe46f3c7 --- /dev/null +++ b/doc/changes/devel/13048.bugfix.rst @@ -0,0 +1 @@ +Fix input boxes for the max value not showing when plotting fieldlines with :func:`~mne.viz.plot_evoked_field` when ``show_density=False``, by `Marijn van Vliet`_. diff --git a/doc/changes/devel/13054.newfeature.rst b/doc/changes/devel/13054.newfeature.rst new file mode 100644 index 00000000000..3c89290e7fe --- /dev/null +++ b/doc/changes/devel/13054.newfeature.rst @@ -0,0 +1 @@ +Added :func:`mne.time_frequency.combine_tfr` to allow combining TFRs across tapers, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/13056.bugfix.rst b/doc/changes/devel/13056.bugfix.rst new file mode 100644 index 00000000000..2a7919de289 --- /dev/null +++ b/doc/changes/devel/13056.bugfix.rst @@ -0,0 +1 @@ +Fix bug with saving of anonymized data when helium info is present in measurement info, by `Eric Larson`_. diff --git a/doc/changes/devel/13058.newfeature.rst b/doc/changes/devel/13058.newfeature.rst new file mode 100644 index 00000000000..bbd01fa4552 --- /dev/null +++ b/doc/changes/devel/13058.newfeature.rst @@ -0,0 +1 @@ +Add the function :func:`mne.time_frequency.combine_spectrum` for combining data across :class:`mne.time_frequency.Spectrum` objects, and allow :func:`mne.grand_average` to operate on :class:`mne.time_frequency.Spectrum` objects, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/13062.bugfix.rst b/doc/changes/devel/13062.bugfix.rst new file mode 100644 index 00000000000..9e01fc4c835 --- /dev/null +++ b/doc/changes/devel/13062.bugfix.rst @@ -0,0 +1 @@ +Fix computation of time intervals in :func:`mne.preprocessing.compute_fine_calibration` by `Eric Larson`_. diff --git a/doc/changes/devel/13063.bugfix.rst b/doc/changes/devel/13063.bugfix.rst new file mode 100644 index 00000000000..76eba2032a1 --- /dev/null +++ b/doc/changes/devel/13063.bugfix.rst @@ -0,0 +1 @@ +Fix bug in the colorbars created by :func:`mne.viz.plot_evoked_topomap` by `Santeri Ruuskanen`_. \ No newline at end of file diff --git a/doc/changes/devel/13065.bugfix.rst b/doc/changes/devel/13065.bugfix.rst new file mode 100644 index 00000000000..bbaa07ae127 --- /dev/null +++ b/doc/changes/devel/13065.bugfix.rst @@ -0,0 +1,7 @@ +Improved sklearn class compatibility and compliance, which resulted in some parameters of classes having an underscore appended to their name during ``fit``, such as: + +- :class:`mne.decoding.FilterEstimator` parameter ``picks`` passed to the initializer is set as ``est.picks_`` +- :class:`mne.decoding.UnsupervisedSpatialFilter` parameter ``estimator`` passed to the initializer is set as ``est.estimator_`` + +Unused ``verbose`` class parameters (that had no effect) were removed from :class:`~mne.decoding.PSDEstimator`, :class:`~mne.decoding.TemporalFilter`, and :class:`~mne.decoding.FilterEstimator` as well. +Changes by `Eric Larson`_. diff --git a/doc/changes/devel/13067.bugfix.rst b/doc/changes/devel/13067.bugfix.rst new file mode 100644 index 00000000000..237df7623d5 --- /dev/null +++ b/doc/changes/devel/13067.bugfix.rst @@ -0,0 +1 @@ +Fix bug where taper weights were not correctly applied when computing multitaper power with :meth:`mne.Epochs.compute_tfr` and :func:`mne.time_frequency.tfr_array_multitaper`, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/devel/13069.bugfix.rst b/doc/changes/devel/13069.bugfix.rst new file mode 100644 index 00000000000..7c23221c8df --- /dev/null +++ b/doc/changes/devel/13069.bugfix.rst @@ -0,0 +1 @@ +Fix bug cause by unnecessary assertion when loading mixed frequency EDFs without preloading :func:`mne.io.read_raw_edf` by `Simon Kern`_. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 3dfc742b3b3..eb444c5e594 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -73,6 +73,7 @@ .. _Eberhard Eich: https://github.com/ebeich .. _Eduard Ort: https://github.com/eort .. _Emily Stephen: https://github.com/emilyps14 +.. _Emma Bailey: https://www.cbs.mpg.de/employees/bailey .. _Enrico Varano: https://github.com/enricovara/ .. _Enzo Altamiranda: https://www.linkedin.com/in/enzoalt .. _Eric Larson: https://larsoner.com @@ -273,6 +274,7 @@ .. _Senwen Deng: https://snwn.de .. _Seyed Yahya Shirazi: https://neuromechanist.github.io .. _Sheraz Khan: https://github.com/SherazKhan +.. _Shristi Baral: https://github.com/shristibaral .. _Silvia Cotroneo: https://github.com/sfc-neuro .. _Simeon Wong: https://github.com/dtxe .. _Simon Kern: https://skjerns.de @@ -283,6 +285,7 @@ .. _Stanislas Chambon: https://github.com/Slasnista .. _Stefan Appelhoff: https://stefanappelhoff.com .. _Stefan Repplinger: https://github.com/stfnrpplngr +.. _Steinn Hauser Magnusson: https://github.com/steinnhauser .. _Steven Bethard: https://github.com/bethard .. _Steven Bierer: https://github.com/neurolaunch .. _Steven Gutstein: https://github.com/smgutstein diff --git a/doc/changes/v1.7.rst b/doc/changes/v1.7.rst index dfd3129a18d..6b118612541 100644 --- a/doc/changes/v1.7.rst +++ b/doc/changes/v1.7.rst @@ -75,12 +75,12 @@ Bugfixes - Fix validation of ``ch_type`` in :func:`mne.preprocessing.annotate_muscle_zscore`, by `Mathieu Scheltienne`_. (`#12444 `__) - Fix errant redundant use of ``BIDSPath.split`` when writing split raw and epochs data, by `Eric Larson`_. (`#12451 `__) - Disable config parser interpolation when reading BrainVision files, which allows using the percent sign as a regular character in channel units, by `Clemens Brunner`_. (`#12456 `__) -- - Fix the default color of :meth:`mne.viz.Brain.add_text` to properly contrast with the figure background color, by `Marijn van Vliet`_. (`#12470 `__) -- - Changed default ECoG and sEEG electrode sizes in brain plots to better reflect real world sizes, by `Liberty Hamilton`_ (`#12474 `__) +- Fix the default color of :meth:`mne.viz.Brain.add_text` to properly contrast with the figure background color, by `Marijn van Vliet`_. (`#12470 `__) +- Changed default ECoG and sEEG electrode sizes in brain plots to better reflect real world sizes, by `Liberty Hamilton`_ (`#12474 `__) - Fixed bugs with handling of rank in :class:`mne.decoding.CSP`, by `Eric Larson`_. (`#12476 `__) -- - Fix reading segmented recordings with :func:`mne.io.read_raw_eyelink` by `Dominik Welke`_. (`#12481 `__) +- Fix reading segmented recordings with :func:`mne.io.read_raw_eyelink` by `Dominik Welke`_. (`#12481 `__) - Improve compatibility with other Qt-based GUIs by handling theme icons better, by `Eric Larson`_. (`#12483 `__) -- - Fix problem caused by onsets with NaN values using :func:`mne.io.read_raw_eeglab` by `Jacob Woessner`_ (`#12484 `__) +- Fix problem caused by onsets with NaN values using :func:`mne.io.read_raw_eeglab` by `Jacob Woessner`_ (`#12484 `__) - Fix cleaning of channel names for non vectorview or CTF dataset including whitespaces or dash in their channel names, by `Mathieu Scheltienne`_. (`#12489 `__) - Fix bug with :meth:`mne.preprocessing.ICA.plot_sources` for ``evoked`` data where the legend contained too many entries, by `Eric Larson`_. (`#12498 `__) diff --git a/doc/changes/v1.9.rst b/doc/changes/v1.9.rst new file mode 100644 index 00000000000..0c6f7c1fddc --- /dev/null +++ b/doc/changes/v1.9.rst @@ -0,0 +1,128 @@ +.. _changes_1_9_0: + +Version 1.9.0 (2024-12-18) +========================== + +Dependencies +------------ + +- Minimum supported dependencies were updated in accordance with SPEC0_, most notably Python 3.10+ is now required. (`#12798 `__) +- Importing from ``mne.decoding`` now explicitly requires ``scikit-learn`` to be installed, + by `Eric Larson`_. (`#12834 `__) +- Compatibility improved for Python 3.13, by `Eric Larson`_. (`#13021 `__) + + +Bugfixes +-------- + +- Fix typos in the Spatio-Spectral Decomposition example, by :newcontrib:`Simon M. Hofmann`. (`#12924 `__) +- Fix duration calculation for the textual (``__repr__``) and html (``_repr_html_``, used by e.g. Jupyter) display of :class:`mne.io.Raw` instances. For example a duration of 1h is now displayed as ``00:01:00`` rather than ``00:59:60``. By :newcontrib:`Leonardo Rochael Almeida`. (`#12955 `__) +- Fix IndexError when loading CNT file does not have annotations, by :newcontrib:`Ziyi ZENG`. (`#12986 `__) +- Fix handling of MRI file-path in :class:`mne.SourceSpaces` and safeguard saving of :class:`pathlib.Path` with ``h5io`` by casting to :class:`str`, by `Mathieu Scheltienne`_. (`#12803 `__) +- Cast ``fwd["info"]`` to :class:`~mne.Info` and ``fwd["src"]`` to :class:`~mne.SourceSpaces` when loading a forward solution from an HDF5 file, by `Mathieu Scheltienne`_. (`#12804 `__) +- Fix bug where :meth:`mne.Epochs.compute_tfr` could not be used with the multitaper method and complex or phase outputs, by `Thomas Binns`_. (`#12842 `__) +- Fixed a bug where split FIF files that were read and then appended to other + :class:`mne.io.Raw` instances had their ``BAD boundary`` annotations incorrectly offset + in samples by the number of split files, by `Eric Larson`_. (`#12843 `__) +- Enforce SI units for Eyetracking data (eyegaze data should be radians of visual angle, not pixels. Pupil size data should be meters). + Updated tutorials so demonstrate how to convert data to SI units before analyses, by `Scott Huberty`_. (`#12846 `__) +- Prevent the ``colorbar`` parameter being ignored in topomap plots such as :meth:`mne.time_frequency.Spectrum.plot_topomap`, by `Thomas Binns`_. (`#12853 `__) +- Fix bug where invalid data types (e.g., ``np.ndarray``s) could be used in some + :class:`mne.io.Info` fields like ``info["subject_info"]["weight"]``, by `Eric Larson`_. (`#12875 `__) +- When creating a :class:`~mne.time_frequency.SpectrumArray`, the array shape check now + compares against the total of both 'good' and 'bad' channels in the provided + :class:`~mne.Info` (previously only good channels were checked), by + `Mathieu Scheltienne`_. (`#12877 `__) +- Fix bug where :ref:`mne coreg` would always show MEG channels even if the "MEG Sensors" checkbox was disabled, by `Eric Larson`_. (`#12884 `__) +- :class:`mne.Report` HDF5 files are now written in ``mode='a'`` (append) to allow users to store other data in the HDF5 files, by `Eric Larson`_. (`#12901 `__) +- Fix bug in :func:`mne.io.read_raw_gdf` when NumPy >= 2 is used, by `Clemens Brunner`_. (`#12909 `__) +- Allow exporting edf where a channel contains only constant values, by `Florian Hofer`_. (`#12911 `__) +- Fix a bug in :func:`mne.epochs.make_metadata`, where missing values in the columns + generated for ``keep_first`` and ``keep_last`` events were represented by empty strings, + while it should have been ``NA`` values, by `Richard Höchenberger`_. (`#12931 `__) +- Fix decimal places of :class:`float` ``mne.Evoked.nave`` in :meth:`mne.Evoked.plot` and :meth:`mne.Evoked.plot_image`, by `Gennadiy Belonosov`_. (`#12936 `__) +- Fix displayed units in representations of classes such as :class:`mne.io.Raw` to correctly use KiB, MiB, GiB, and so on, by `Clemens Brunner`_. (`#12962 `__) +- Gracefully handle invalid patient info when reading EDF files by `Scott Huberty`_. (`#12968 `__) +- Correct :func:`mne.io.read_raw_cnt` to read responses and fix exceptions by `Jacob Woessner`_. (`#13007 `__) +- Fix errant detection of software-rendered vs hardware-rendered MESA GL contexts in 3D rendering on Linux, by `Eric Larson`_. (`#13012 `__) +- Fix plot scaling for :meth:`Spectrum.plot(dB=True, amplitude=True) `, by `Daniel McCloy`_. (`#13036 `__) + + +API changes by deprecation +-------------------------- + +- Deprecate ``average`` parameter in ``plot_filters`` and ``plot_patterns`` methods of the :class:`mne.decoding.CSP` and :class:`mne.decoding.SPoC` classes, by `Thomas Binns`_. (`#12829 `__) +- Deprecate ``subject`` parameter in favor of ``subjects`` in :func:`mne.datasets.eegbci.load_data`, by `Stefan Appelhoff`_. (`#12918 `__) + + +New features +------------ + +- Added support for ``sensor_scales`` to :meth:`mne.viz.Brain.add_sensors` and :func:`mne.viz.plot_alignment`, by :newcontrib:`Alex Lepauvre`. (`#12805 `__) +- Add convenience :attr:`mne.io.Raw.duration` property to centralize duration calculation for the textual (``__repr__``) and html (``_repr_html_``, used by e.g. Jupyter) display of :class:`mne.io.Raw` instances, by :newcontrib:`Leonardo Rochael Almeida`. (`#12955 `__) +- Add option to :func:`mne.preprocessing.fix_stim_artifact` to use baseline average to flatten TMS pulse artifact by `Fahimeh Mamashli`_ and `Padma Sundaram`_ and `Mohammad Daneshzand`_. (`#6915 `__) +- Add support for `dict` type argument ``ref_channels`` to :func:`mne.set_eeg_reference`, to allow flexible re-referencing (e.g. ``raw.set_eeg_reference(ref_channels={'A1': ['A2', 'A3']})`` will set the new A1 data to be ``A1 - mean(A2, A3)``), by `Alex Lepauvre`_ and `Qian Chu`_ and `Daniel McCloy`_. (`#12366 `__) +- Add reader for ANT Neuro files in the ``*.cnt`` format with :func:`~mne.io.read_raw_ant`, by `Mathieu Scheltienne`_, `Eric Larson`_ and `Proloy Das`_. (`#12792 `__) +- Add support for a :class:`mne.transforms.Transform` in the argument ``trans`` of the coregistration GUI called with :func:`mne.gui.coregistration`, by `Mathieu Scheltienne`_. (`#12801 `__) +- :meth:`~mne.io.Raw` and :meth:`~mne.Epochs.save` now return the path to the saved file(s), by `Victor Ferat`_. (`#12811 `__) +- :func:`mne.channels.read_custom_montage` may now read a newer version of the ``.elc`` ASA Electrode file format, by `Stefan Appelhoff`_. (`#12830 `__) +- Added the ``title`` argument to :func:`mne.viz.create_3d_figure`, and + ``color`` and ``position`` arguments to :func:`mne.viz.set_3d_title`, by `Eric Larson`_. (`#12871 `__) +- Improved reporting and plotting options: + + - :meth:`mne.Report.add_projs` can now plot with :func:`mne.viz.plot_projs_joint` rather than :func:`mne.viz.plot_projs_topomap` + - :class:`mne.Report` now has attributes ``img_max_width`` and ``img_max_res`` that can be used to control image scaling. + - :class:`mne.Report` now has an attribute ``collapse`` that allows collapsing sections and/or subsections by default. + - :func:`mne.viz.plot_head_positions` now has a ``totals=True`` option to show the total distance and angle of the head. + + Changes by `Eric Larson`_. (`#12901 `__) +- Added the ``psd_args`` argument to :func:`mne.viz.plot_ica_sources` and :meth:`mne.preprocessing.ICA.plot_sources`, by `Richard Scholz`_. (`#12912 `__) +- Add ability to use :func:`mne.preprocessing.compute_fine_calibration` with non-Neuromag-style systems, as well as options to control the bad-angle and error tolerances, by `Eric Larson`_. (`#12966 `__) +- Add support for ``uint16_codec`` argument in :func:`mne.io.read_raw_eeglab` when ``pymatreader`` (which already supported this argument previously) is not installed, by `Clemens Brunner`_. (`#12971 `__) +- Added support for saving and loading channel names from FIF in :meth:`mne.channels.DigMontage.save` and :meth:`mne.channels.read_dig_fif` and added the convention that they should be saved as ``-dig.fif``, by `Eric Larson`_. (`#13003 `__) +- Add new :meth:`Raw.rescale ` method to rescale the data in place, by `Clemens Brunner`_. (`#13018 `__) + + +Other changes +------------- + +- Fix a mistake in :ref:`tut-artifact-regression` where the wrong regression coefficients were applied, by :newcontrib:`Jacob Phelan`. (`#12978 `__) +- Use custom code in :func:`mne.sys_info` to get the amount of physical memory and a more informative CPU name instead of using the ``psutil`` package, by `Clemens Brunner`_. (`#12787 `__) +- Improve documentation clarity of ``fit_transform`` methods for :class:`mne.decoding.SSD`, :class:`mne.decoding.CSP`, and :class:`mne.decoding.SPoC` classes, by `Thomas Binns`_. (`#12827 `__) +- Improve handling of filenames in ``raw.filenames`` by using :class:`~pathlib.Path` instead of :class:`str`, by `Mathieu Scheltienne`_. (`#12843 `__) +- Improve automatic figure scaling of :func:`mne.viz.plot_events`, and event_id and count overview legend when a high amount of unique events is supplied, by `Stefan Appelhoff`_. (`#12844 `__) +- :meth:`mne.preprocessing.ICA.find_bads_muscle` can now be run when passing an ``inst`` without sensor positions. However, it will just use the first of three criteria (slope) to find muscle-related ICA components, by `Stefan Appelhoff`_. (`#12862 `__) +- Update governance model, by `Daniel McCloy`_. (`#12896 `__) +- Improve the :ref:`tut-autogenerate-metadata`, by `Clemens Brunner`_ and `Richard Höchenberger`_. (`#12931 `__) +- Mention some gotchas that arise from the fact that by default, we pool across dipole orientations when performing source estimation, by `Marijn van Vliet`_ (`#12960 `__) +- Repository CI security is now audited using `zizmor `__, by `Eric Larson`_. (`#13011 `__) + +Authors +------- + +* Alex Lepauvre+ +* Britta Westner +* Clemens Brunner +* Daniel McCloy +* Eric Larson +* Fahimeh Mamashli +* Florian Hofer +* Gennadiy Belonosov +* Jacob Phelan +* Jacob Woessner +* Leonardo Rochael Almeida+ +* Mainak Jas +* Marijn van Vliet +* Mathieu Scheltienne +* Proloy Das +* Qian Chu +* Richard Höchenberger +* Richard Scholz +* Santeri Ruuskanen +* Scott Huberty +* Simon M. Hofmann+ +* Stefan Appelhoff +* Thomas Grainger +* Thomas S. Binns +* Victor Férat +* Ziyi ZENG+ diff --git a/doc/conf.py b/doc/conf.py index 7dd6ec90d4f..f1b771571d6 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -355,6 +355,7 @@ "n_frequencies", "n_tests", "n_samples", + "n_peaks", "n_permutations", "nchan", "n_points", @@ -666,6 +667,10 @@ def fix_sklearn_inherited_docstrings(app, what, name, obj, options, lines): r"https://scholar.google.com/scholar\?cites=12188330066413208874&as_ylo=2014", r"https://scholar.google.com/scholar\?cites=1521584321377182930&as_ylo=2013", "https://www.research.chop.edu/imaging", + "http://prdownloads.sourceforge.net/optipng/optipng-0.7.8-win64.zip?download", + "https://sourceforge.net/projects/aespa/files/", + "https://sourceforge.net/projects/ezwinports/files/", + "https://www.mathworks.com/products/compiler/matlab-runtime.html", # 500 server error "https://openwetware.org/wiki/Beauchamp:FreeSurfer", # 503 Server error @@ -688,6 +693,7 @@ def fix_sklearn_inherited_docstrings(app, what, name, obj, options, lines): # SSL problems sometimes "http://ilabs.washington.edu", "https://psychophysiology.cpmc.columbia.edu", + "https://erc.easme-web.eu", ] linkcheck_anchors = False # saves a bit of time linkcheck_timeout = 15 # some can be quite slow @@ -1284,7 +1290,7 @@ def fix_sklearn_inherited_docstrings(app, what, name, obj, options, lines): rst_prolog += f""" .. |{icon}| raw:: html - + """ rst_prolog += """ diff --git a/doc/development/contributing.rst b/doc/development/contributing.rst index 1cca8a8608d..011fd3c11f4 100644 --- a/doc/development/contributing.rst +++ b/doc/development/contributing.rst @@ -860,7 +860,7 @@ data with a meaningful middle (zero-point) and ``Reds`` otherwise. This applies to both visualization functions and tutorials/examples. -.. _run_tests: +.. _run-tests: Running the test suite ~~~~~~~~~~~~~~~~~~~~~~ @@ -1114,6 +1114,6 @@ it can serve as a useful example of what to expect from the PR review process. .. optipng .. _optipng: http://optipng.sourceforge.net/ -.. _optipng for Windows: http://prdownloads.sourceforge.net/optipng/optipng-0.7.7-win32.zip?download +.. _optipng for Windows: http://prdownloads.sourceforge.net/optipng/optipng-0.7.8-win64.zip?download .. include:: ../links.inc diff --git a/doc/development/whats_new.rst b/doc/development/whats_new.rst index 75ece13b5e0..94b92c0f019 100644 --- a/doc/development/whats_new.rst +++ b/doc/development/whats_new.rst @@ -9,6 +9,7 @@ Changes for each version of MNE-Python are listed below. :maxdepth: 1 ../changes/devel.rst + ../changes/v1.9.rst ../changes/v1.8.rst ../changes/v1.7.rst ../changes/v1.6.rst diff --git a/doc/documentation/cited.rst b/doc/documentation/cited.rst index 565698d9c67..343e28d00ee 100644 --- a/doc/documentation/cited.rst +++ b/doc/documentation/cited.rst @@ -3,7 +3,7 @@ Papers citing MNE-Python ======================== -Estimates provided by Google Scholar as of 18 August 2024: +Estimates provided by Google Scholar as of 16 December 2024: -- `MNE (1810) `_ -- `MNE-Python (2860) `_ +- `MNE (1,900) `_ +- `MNE-Python (3,250) `_ diff --git a/doc/install/advanced.rst b/doc/install/advanced.rst index 0fe5a5e324a..61c7bc07aa3 100644 --- a/doc/install/advanced.rst +++ b/doc/install/advanced.rst @@ -281,6 +281,20 @@ of VTK and/or QT are incompatible. This series of commands should fix it: If you installed VTK using ``pip`` rather than ``conda``, substitute the first line for ``pip uninstall -y vtk``. +3D plotting trouble on Linux +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you are having trouble with 3D plotting on Linux, one possibility is that you +are using Wayland for graphics. To check, you can do: + +.. code-block:: console + + $ echo $XDG_SESSION_TYPE + wayland + +If so, you will need to tell Qt to use X11 instead of Wayland. You can do this +by setting ``export QT_QPA_PLATFORM=xcb`` in your terminal session. To make it +permanent for your logins, you can set it for example in ``~/.profile``. .. LINKS diff --git a/doc/install/installers.rst b/doc/install/installers.rst index c29d09ba132..533c0207963 100644 --- a/doc/install/installers.rst +++ b/doc/install/installers.rst @@ -17,7 +17,7 @@ Platform-specific installers :class-content: text-center :name: install-linux - .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.8.0/MNE-Python-1.8.0_0-Linux.sh + .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.9.0/MNE-Python-1.9.0_0-Linux.sh :ref-type: ref :color: primary :shadow: @@ -31,14 +31,14 @@ Platform-specific installers .. code-block:: console - $ sh ./MNE-Python-1.8.0_0-Linux.sh + $ sh ./MNE-Python-1.9.0_0-Linux.sh .. tab-item:: macOS (Intel) :class-content: text-center :name: install-macos-intel - .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.8.0/MNE-Python-1.8.0_0-macOS_Intel.pkg + .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.9.0/MNE-Python-1.9.0_0-macOS_Intel.pkg :ref-type: ref :color: primary :shadow: @@ -54,7 +54,7 @@ Platform-specific installers :class-content: text-center :name: install-macos-apple - .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.8.0/MNE-Python-1.8.0_0-macOS_M1.pkg + .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.9.0/MNE-Python-1.9.0_0-macOS_M1.pkg :ref-type: ref :color: primary :shadow: @@ -70,7 +70,7 @@ Platform-specific installers :class-content: text-center :name: install-windows - .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.8.0/MNE-Python-1.8.0_0-Windows.exe + .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.9.0/MNE-Python-1.9.0_0-Windows.exe :ref-type: ref :color: primary :shadow: @@ -86,7 +86,7 @@ Platform-specific installers .. We have to use a button-link here because button-ref doesn't properly nested parse the inline code - .. button-link:: ./ides.html + .. button-link:: ides.html :ref-type: ref :color: success :shadow: @@ -156,7 +156,7 @@ To remove the MNE-Python distribution provided by our installers above: .. code-block:: bash $ which python - /home/username/mne-python/1.8.0_0/bin/python + /home/username/mne-python/1.9.0_0/bin/python $ rm -Rf /home/$USER/mne-python $ rm /home/$USER/.local/share/applications/mne-python-*.desktop @@ -170,7 +170,7 @@ To remove the MNE-Python distribution provided by our installers above: .. code-block:: bash $ which python - /Users/username/Applications/MNE-Python/1.8.0_0/.mne-python/bin/python + /Users/username/Applications/MNE-Python/1.9.0_0/.mne-python/bin/python $ rm -Rf /Users/$USER/Applications/MNE-Python # if user-specific $ rm -Rf /Applications/MNE-Python # if system-wide diff --git a/doc/references.bib b/doc/references.bib index a129d2f46a2..e2578ed18f2 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -1335,6 +1335,16 @@ @inproceedings{NdiayeEtAl2016 year = {2016} } +@article{NiazyEtAl2005, + author = {Niazy, R. K. and Beckmann, C.F. and Iannetti, G.D. and Brady, J. M. and Smith, S. M.}, + title = {Removal of FMRI environment artifacts from EEG data using optimal basis sets}, + journal = {NeuroImage}, + year = {2005}, + volume = {28}, + pages = {720-737}, + doi = {10.1016/j.neuroimage.2005.06.067.} +} + @article{NicholsHolmes2002, author = {Nichols, Thomas E. and Holmes, Andrew P.}, doi = {10.1002/hbm.1058}, diff --git a/doc/sphinxext/credit_tools.py b/doc/sphinxext/credit_tools.py index 708dcf00ce8..e22bd0b5530 100644 --- a/doc/sphinxext/credit_tools.py +++ b/doc/sphinxext/credit_tools.py @@ -169,7 +169,7 @@ def generate_credit_rst(app=None, *, verbose=False): if author["e"] is not None: if author["e"] not in name_map: unknown_emails.add( - f'{author["e"].ljust(29)} ' + f"{author['e'].ljust(29)} " "https://github.com/mne-tools/mne-python/pull/" f"{commit}/files" ) @@ -178,9 +178,9 @@ def generate_credit_rst(app=None, *, verbose=False): else: name = author["n"] if name in manual_renames: - assert _good_name( - manual_renames[name] - ), f"Bad manual rename: {name}" + assert _good_name(manual_renames[name]), ( + f"Bad manual rename: {name}" + ) name = manual_renames[name] if " " in name: first, last = name.rsplit(" ", maxsplit=1) diff --git a/doc/sphinxext/mne_doc_utils.py b/doc/sphinxext/mne_doc_utils.py index 4811cde5f44..e626838f251 100644 --- a/doc/sphinxext/mne_doc_utils.py +++ b/doc/sphinxext/mne_doc_utils.py @@ -8,6 +8,7 @@ import os import time import warnings +from pathlib import Path import numpy as np import pyvista @@ -16,6 +17,7 @@ import mne from mne.utils import ( _assert_no_instances, + _get_extra_data_path, sizeof_fmt, ) from mne.viz import Brain @@ -95,6 +97,8 @@ def reset_warnings(gallery_conf, fname): r"numpy\.core is deprecated and has been renamed to numpy\._core", # matplotlib "__array_wrap__ must accept context and return_scalar.*", + # nibabel + "__array__ implementation doesn't accept.*", ): warnings.filterwarnings( # deal with other modules having bad imports "ignore", message=f".*{key}.*", category=DeprecationWarning @@ -225,6 +229,7 @@ def reset_modules(gallery_conf, fname, when): mne.viz.ui_events._event_channels ) + orig_when = when when = f"mne/conf.py:Resetter.__call__:{when}:{fname}" # Support stuff like # MNE_SKIP_INSTANCE_ASSERTIONS="Brain,Plotter,BackgroundPlotter,vtkPolyData,_Renderer" make html-memory # noqa: E501 @@ -262,6 +267,25 @@ def reset_modules(gallery_conf, fname, when): mem = sizeof_fmt(process.memory_info().rss) print(f"{prefix}{time.time() - t0:6.1f} s : {mem}".ljust(22)) + if fname == "50_configure_mne.py": + # This messes with the config, so let's do so in a temp dir + if orig_when == "before": + fake_home = Path(_get_extra_data_path()) / "temp" + fake_home.mkdir(exist_ok=True, parents=True) + os.environ["_MNE_FAKE_HOME_DIR"] = str(fake_home) + else: + assert orig_when == "after" + to_del = Path(os.environ["_MNE_FAKE_HOME_DIR"]) + try: + (to_del / "mne-python.json").unlink() + except Exception: + pass + try: + to_del.rmdir() + except Exception: + pass + del os.environ["_MNE_FAKE_HOME_DIR"] + report_scraper = mne.report._ReportScraper() mne_qt_browser_scraper = mne.viz._scraper._MNEQtBrowserScraper() diff --git a/doc/sphinxext/prs/12896.json b/doc/sphinxext/prs/12896.json new file mode 100644 index 00000000000..203d5d05c49 --- /dev/null +++ b/doc/sphinxext/prs/12896.json @@ -0,0 +1,27 @@ +{ + "merge_commit_sha": "a1a05ae11234929f0608d5a6b4fc30206af89031", + "authors": [ + { + "n": "Daniel McCloy", + "e": null + }, + { + "n": "Britta Westner", + "e": "britta.wstnr@gmail.com" + } + ], + "changes": { + "doc/changes/devel/12896.other.rst": { + "a": 1, + "d": 0 + }, + "doc/development/governance.rst": { + "a": 251, + "d": 187 + }, + "doc/overview/people.rst": { + "a": 18, + "d": 11 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/12995.json b/doc/sphinxext/prs/12995.json new file mode 100644 index 00000000000..48d967dc7eb --- /dev/null +++ b/doc/sphinxext/prs/12995.json @@ -0,0 +1,15 @@ +{ + "merge_commit_sha": "d987c2793148eb855e04e3ae9d4871e486d01e46", + "authors": [ + { + "n": "Stefan Appelhoff", + "e": "stefan.appelhoff@mailbox.org" + } + ], + "changes": { + "mne/_fiff/meas_info.py": { + "a": 2, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/12996.json b/doc/sphinxext/prs/12996.json new file mode 100644 index 00000000000..b931ebcc541 --- /dev/null +++ b/doc/sphinxext/prs/12996.json @@ -0,0 +1,39 @@ +{ + "merge_commit_sha": "ec77e7c36ce4a2d7897122513395b0e2418ea151", + "authors": [ + { + "n": "Stefan Appelhoff", + "e": "stefan.appelhoff@mailbox.org" + } + ], + "changes": { + "mne/io/ant/ant.py": { + "a": 1, + "d": 1 + }, + "mne/io/kit/kit.py": { + "a": 1, + "d": 1 + }, + "mne/io/snirf/_snirf.py": { + "a": 1, + "d": 1 + }, + "mne/io/tests/test_raw.py": { + "a": 4, + "d": 4 + }, + "mne/preprocessing/ica.py": { + "a": 1, + "d": 1 + }, + "mne/preprocessing/tests/test_ica.py": { + "a": 2, + "d": 2 + }, + "mne/viz/raw.py": { + "a": 1, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/12997.json b/doc/sphinxext/prs/12997.json new file mode 100644 index 00000000000..7df9ddfa26b --- /dev/null +++ b/doc/sphinxext/prs/12997.json @@ -0,0 +1,15 @@ +{ + "merge_commit_sha": "0e09163e80013a426d419bda0c560ff3d48209bb", + "authors": [ + { + "n": "Stefan Appelhoff", + "e": "stefan.appelhoff@mailbox.org" + } + ], + "changes": { + "tutorials/forward/35_eeg_no_mri.py": { + "a": 9, + "d": 13 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/12998.json b/doc/sphinxext/prs/12998.json new file mode 100644 index 00000000000..289620794a2 --- /dev/null +++ b/doc/sphinxext/prs/12998.json @@ -0,0 +1,131 @@ +{ + "merge_commit_sha": "096243fe43c936587190a8e9e8e86b155446800e", + "authors": [ + { + "n": "github-actions[bot]", + "e": "41898282+github-actions[bot]@users.noreply.github.com" + } + ], + "changes": { + "doc/sphinxext/prs/12931.json": { + "a": 35, + "d": 0 + }, + "doc/sphinxext/prs/12935.json": { + "a": 135, + "d": 0 + }, + "doc/sphinxext/prs/12936.json": { + "a": 19, + "d": 0 + }, + "doc/sphinxext/prs/12937.json": { + "a": 19, + "d": 0 + }, + "doc/sphinxext/prs/12938.json": { + "a": 23, + "d": 0 + }, + "doc/sphinxext/prs/12941.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/12942.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/12947.json": { + "a": 35, + "d": 0 + }, + "doc/sphinxext/prs/12948.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/12951.json": { + "a": 59, + "d": 0 + }, + "doc/sphinxext/prs/12955.json": { + "a": 43, + "d": 0 + }, + "doc/sphinxext/prs/12957.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/12958.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/12960.json": { + "a": 19, + "d": 0 + }, + "doc/sphinxext/prs/12962.json": { + "a": 35, + "d": 0 + }, + "doc/sphinxext/prs/12966.json": { + "a": 35, + "d": 0 + }, + "doc/sphinxext/prs/12967.json": { + "a": 43, + "d": 0 + }, + "doc/sphinxext/prs/12968.json": { + "a": 19, + "d": 0 + }, + "doc/sphinxext/prs/12970.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/12971.json": { + "a": 19, + "d": 0 + }, + "doc/sphinxext/prs/12972.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/12973.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/12975.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/12976.json": { + "a": 23, + "d": 0 + }, + "doc/sphinxext/prs/12978.json": { + "a": 23, + "d": 0 + }, + "doc/sphinxext/prs/12983.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/12984.json": { + "a": 19, + "d": 0 + }, + "doc/sphinxext/prs/12986.json": { + "a": 27, + "d": 0 + }, + "doc/sphinxext/prs/12988.json": { + "a": 231, + "d": 0 + }, + "doc/sphinxext/prs/12991.json": { + "a": 27, + "d": 0 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/12999.json b/doc/sphinxext/prs/12999.json new file mode 100644 index 00000000000..4c71e401c33 --- /dev/null +++ b/doc/sphinxext/prs/12999.json @@ -0,0 +1,27 @@ +{ + "merge_commit_sha": "ed2fd8da1e16e449cdfe779491542fecad6ecbcb", + "authors": [ + { + "n": "dependabot[bot]", + "e": "49699333+dependabot[bot]@users.noreply.github.com" + }, + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + ".github/workflows/autofix.yml": { + "a": 1, + "d": 1 + }, + ".mailmap": { + "a": 3, + "d": 0 + }, + "doc/sphinxext/credit_tools.py": { + "a": 9, + "d": 5 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13000.json b/doc/sphinxext/prs/13000.json new file mode 100644 index 00000000000..d813ea98560 --- /dev/null +++ b/doc/sphinxext/prs/13000.json @@ -0,0 +1,27 @@ +{ + "merge_commit_sha": "53792b12a2d60229ada3c946987a184f3915c535", + "authors": [ + { + "n": "Daniel McCloy", + "e": null + }, + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + "azure-pipelines.yml": { + "a": 1, + "d": 1 + }, + "doc/sphinxext/mne_doc_utils.py": { + "a": 0, + "d": 2 + }, + "pyproject.toml": { + "a": 1, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13001.json b/doc/sphinxext/prs/13001.json new file mode 100644 index 00000000000..032c9a5714f --- /dev/null +++ b/doc/sphinxext/prs/13001.json @@ -0,0 +1,19 @@ +{ + "merge_commit_sha": "b19ac58598f55975ca7f6458fa7bfa7d21e00ae0", + "authors": [ + { + "n": "pre-commit-ci[bot]", + "e": "66853113+pre-commit-ci[bot]@users.noreply.github.com" + }, + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + ".pre-commit-config.yaml": { + "a": 1, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13003.json b/doc/sphinxext/prs/13003.json new file mode 100644 index 00000000000..cda225124fa --- /dev/null +++ b/doc/sphinxext/prs/13003.json @@ -0,0 +1,47 @@ +{ + "merge_commit_sha": "7071a0e24e121b28c851565f2e64a0128941e83a", + "authors": [ + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + }, + { + "n": "autofix-ci[bot]", + "e": "114827586+autofix-ci[bot]@users.noreply.github.com" + }, + { + "n": "Stefan Appelhoff", + "e": "stefan.appelhoff@mailbox.org" + } + ], + "changes": { + "doc/changes/devel/13003.newfeature.rst": { + "a": 1, + "d": 0 + }, + "mne/_fiff/_digitization.py": { + "a": 24, + "d": 5 + }, + "mne/_fiff/write.py": { + "a": 5, + "d": 1 + }, + "mne/channels/montage.py": { + "a": 36, + "d": 12 + }, + "mne/channels/tests/test_montage.py": { + "a": 33, + "d": 14 + }, + "mne/viz/tests/test_montage.py": { + "a": 1, + "d": 1 + }, + "tutorials/forward/35_eeg_no_mri.py": { + "a": 1, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13007.json b/doc/sphinxext/prs/13007.json new file mode 100644 index 00000000000..7553cfa048b --- /dev/null +++ b/doc/sphinxext/prs/13007.json @@ -0,0 +1,31 @@ +{ + "merge_commit_sha": "d47c22fc2ba21d31b46ae7816eba054d5e13add9", + "authors": [ + { + "n": "Jacob Woessner", + "e": "Woessner.jacob@gmail.com" + }, + { + "n": "pre-commit-ci[bot]", + "e": "66853113+pre-commit-ci[bot]@users.noreply.github.com" + }, + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + "doc/changes/devel/13007.bugfix.rst": { + "a": 1, + "d": 0 + }, + "mne/io/cnt/cnt.py": { + "a": 19, + "d": 3 + }, + "mne/io/cnt/tests/test_cnt.py": { + "a": 2, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13009.json b/doc/sphinxext/prs/13009.json new file mode 100644 index 00000000000..ecf5894f489 --- /dev/null +++ b/doc/sphinxext/prs/13009.json @@ -0,0 +1,23 @@ +{ + "merge_commit_sha": "391fd88dca91bca78aeacc996fa284712c3ea33b", + "authors": [ + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + "doc/sphinxext/mne_doc_utils.py": { + "a": 22, + "d": 0 + }, + "examples/preprocessing/otp.py": { + "a": 1, + "d": 1 + }, + "mne/datasets/utils.py": { + "a": 5, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13010.json b/doc/sphinxext/prs/13010.json new file mode 100644 index 00000000000..4f4b88a5406 --- /dev/null +++ b/doc/sphinxext/prs/13010.json @@ -0,0 +1,19 @@ +{ + "merge_commit_sha": "4967ecd3f44968408046948b738977aa41521ea2", + "authors": [ + { + "n": "Stefan Appelhoff", + "e": "stefan.appelhoff@mailbox.org" + } + ], + "changes": { + "doc/api/preprocessing.rst": { + "a": 1, + "d": 0 + }, + "mne/channels/__init__.pyi": { + "a": 2, + "d": 0 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13011.json b/doc/sphinxext/prs/13011.json new file mode 100644 index 00000000000..ab057c0e83c --- /dev/null +++ b/doc/sphinxext/prs/13011.json @@ -0,0 +1,47 @@ +{ + "merge_commit_sha": "b329515933915fd077495ea41de876119ac04c97", + "authors": [ + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + }, + { + "n": "autofix-ci[bot]", + "e": "114827586+autofix-ci[bot]@users.noreply.github.com" + }, + { + "n": "Thomas Grainger", + "e": "tagrain@gmail.com" + } + ], + "changes": { + ".github/workflows/autofix.yml": { + "a": 2, + "d": 0 + }, + ".github/workflows/codeql-analysis.yml": { + "a": 5, + "d": 3 + }, + ".github/workflows/credit.yml": { + "a": 6, + "d": 5 + }, + ".github/workflows/release.yml": { + "a": 2, + "d": 0 + }, + ".github/workflows/tests.yml": { + "a": 3, + "d": 0 + }, + ".pre-commit-config.yaml": { + "a": 5, + "d": 0 + }, + "doc/changes/devel/13011.other.rst": { + "a": 1, + "d": 0 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13012.json b/doc/sphinxext/prs/13012.json new file mode 100644 index 00000000000..066940ab6f5 --- /dev/null +++ b/doc/sphinxext/prs/13012.json @@ -0,0 +1,35 @@ +{ + "merge_commit_sha": "9b7b5596ff7c089939bca179b98f1ce0094cb668", + "authors": [ + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + }, + { + "n": "autofix-ci[bot]", + "e": "114827586+autofix-ci[bot]@users.noreply.github.com" + } + ], + "changes": { + "doc/changes/devel/13012.bugfix.rst": { + "a": 1, + "d": 0 + }, + "doc/install/advanced.rst": { + "a": 14, + "d": 0 + }, + "mne/viz/backends/_pyvista.py": { + "a": 8, + "d": 5 + }, + "mne/viz/backends/_qt.py": { + "a": 0, + "d": 1 + }, + "mne/viz/backends/tests/test_renderer.py": { + "a": 12, + "d": 7 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13015.json b/doc/sphinxext/prs/13015.json new file mode 100644 index 00000000000..5b137162759 --- /dev/null +++ b/doc/sphinxext/prs/13015.json @@ -0,0 +1,15 @@ +{ + "merge_commit_sha": "31436fecd881a0e6fb29b83b4f36764ae81dabc7", + "authors": [ + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + ".github/workflows/automerge.yml": { + "a": 17, + "d": 0 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13017.json b/doc/sphinxext/prs/13017.json new file mode 100644 index 00000000000..49b2ec0202b --- /dev/null +++ b/doc/sphinxext/prs/13017.json @@ -0,0 +1,15 @@ +{ + "merge_commit_sha": "b3eb56cf1e8993940daa1df68d5220f3110ecdbb", + "authors": [ + { + "n": "pre-commit-ci[bot]", + "e": "66853113+pre-commit-ci[bot]@users.noreply.github.com" + } + ], + "changes": { + ".pre-commit-config.yaml": { + "a": 1, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13018.json b/doc/sphinxext/prs/13018.json new file mode 100644 index 00000000000..ef9951d9902 --- /dev/null +++ b/doc/sphinxext/prs/13018.json @@ -0,0 +1,39 @@ +{ + "merge_commit_sha": "521d667f9802655a71166823fd890fbd00bae5a8", + "authors": [ + { + "n": "Clemens Brunner", + "e": null + }, + { + "n": "Daniel McCloy", + "e": "dan@mccloy.info" + }, + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + "doc/changes/devel/13018.newfeature.rst": { + "a": 1, + "d": 0 + }, + "doc/sphinxext/related_software.py": { + "a": 4, + "d": 0 + }, + "mne/conftest.py": { + "a": 2, + "d": 0 + }, + "mne/io/base.py": { + "a": 65, + "d": 0 + }, + "mne/io/tests/test_raw.py": { + "a": 17, + "d": 0 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13019.json b/doc/sphinxext/prs/13019.json new file mode 100644 index 00000000000..0d6677670c2 --- /dev/null +++ b/doc/sphinxext/prs/13019.json @@ -0,0 +1,27 @@ +{ + "merge_commit_sha": "4f1f4bbbc1a9d0f828e28de9be4e69c05f86d9f5", + "authors": [ + { + "n": "Santeri Ruuskanen", + "e": null + }, + { + "n": "pre-commit-ci[bot]", + "e": "66853113+pre-commit-ci[bot]@users.noreply.github.com" + } + ], + "changes": { + "doc/changes/devel/13019.newfeature.rst": { + "a": 1, + "d": 0 + }, + "mne/viz/circle.py": { + "a": 2, + "d": 0 + }, + "mne/viz/tests/test_circle.py": { + "a": 3, + "d": 0 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13020.json b/doc/sphinxext/prs/13020.json new file mode 100644 index 00000000000..5513e4c6153 --- /dev/null +++ b/doc/sphinxext/prs/13020.json @@ -0,0 +1,15 @@ +{ + "merge_commit_sha": "730358c1e2e17baf0491c04e0c8269382a29c613", + "authors": [ + { + "n": "Santeri Ruuskanen", + "e": null + } + ], + "changes": { + "doc/development/contributing.rst": { + "a": 1, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13021.json b/doc/sphinxext/prs/13021.json new file mode 100644 index 00000000000..1d9a78d45cc --- /dev/null +++ b/doc/sphinxext/prs/13021.json @@ -0,0 +1,59 @@ +{ + "merge_commit_sha": "bd4a160215be67e2de1df7e0a86e27425b074807", + "authors": [ + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + ".github/workflows/tests.yml": { + "a": 3, + "d": 0 + }, + "azure-pipelines.yml": { + "a": 2, + "d": 1 + }, + "doc/changes/devel/13021.dependency.rst": { + "a": 1, + "d": 0 + }, + "mne/evoked.py": { + "a": 2, + "d": 1 + }, + "mne/tests/test_docstring_parameters.py": { + "a": 1, + "d": 0 + }, + "mne/utils/docs.py": { + "a": 33, + "d": 22 + }, + "mne/utils/tests/test_config.py": { + "a": 1, + "d": 1 + }, + "mne/utils/tests/test_docs.py": { + "a": 16, + "d": 15 + }, + "tools/azure_dependencies.sh": { + "a": 1, + "d": 1 + }, + "tools/circleci_dependencies.sh": { + "a": 1, + "d": 1 + }, + "tools/github_actions_dependencies.sh": { + "a": 5, + "d": 0 + }, + "tools/github_actions_env_vars.sh": { + "a": 7, + "d": 3 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13028.json b/doc/sphinxext/prs/13028.json new file mode 100644 index 00000000000..7a162141000 --- /dev/null +++ b/doc/sphinxext/prs/13028.json @@ -0,0 +1,23 @@ +{ + "merge_commit_sha": "d8149546e7010ee7d2b5a95335aeed0e9b202aaa", + "authors": [ + { + "n": "Thomas S. Binns", + "e": "t.s.binns@outlook.com" + } + ], + "changes": { + "doc/changes/devel/13028.bugfix.rst": { + "a": 1, + "d": 0 + }, + "mne/time_frequency/tests/test_tfr.py": { + "a": 16, + "d": 12 + }, + "mne/time_frequency/tfr.py": { + "a": 7, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13029.json b/doc/sphinxext/prs/13029.json new file mode 100644 index 00000000000..33a84c3f2c6 --- /dev/null +++ b/doc/sphinxext/prs/13029.json @@ -0,0 +1,91 @@ +{ + "merge_commit_sha": "f3a7fde522d69bc5bbc15844718812c7ab6480f4", + "authors": [ + { + "n": "github-actions[bot]", + "e": "41898282+github-actions[bot]@users.noreply.github.com" + } + ], + "changes": { + "doc/sphinxext/prs/12896.json": { + "a": 27, + "d": 0 + }, + "doc/sphinxext/prs/12995.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/12996.json": { + "a": 39, + "d": 0 + }, + "doc/sphinxext/prs/12997.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/12998.json": { + "a": 131, + "d": 0 + }, + "doc/sphinxext/prs/12999.json": { + "a": 27, + "d": 0 + }, + "doc/sphinxext/prs/13000.json": { + "a": 27, + "d": 0 + }, + "doc/sphinxext/prs/13001.json": { + "a": 19, + "d": 0 + }, + "doc/sphinxext/prs/13003.json": { + "a": 47, + "d": 0 + }, + "doc/sphinxext/prs/13007.json": { + "a": 31, + "d": 0 + }, + "doc/sphinxext/prs/13009.json": { + "a": 23, + "d": 0 + }, + "doc/sphinxext/prs/13010.json": { + "a": 19, + "d": 0 + }, + "doc/sphinxext/prs/13011.json": { + "a": 47, + "d": 0 + }, + "doc/sphinxext/prs/13012.json": { + "a": 35, + "d": 0 + }, + "doc/sphinxext/prs/13015.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/13017.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/13018.json": { + "a": 39, + "d": 0 + }, + "doc/sphinxext/prs/13019.json": { + "a": 27, + "d": 0 + }, + "doc/sphinxext/prs/13020.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/13021.json": { + "a": 59, + "d": 0 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13031.json b/doc/sphinxext/prs/13031.json new file mode 100644 index 00000000000..428346e1b90 --- /dev/null +++ b/doc/sphinxext/prs/13031.json @@ -0,0 +1,35 @@ +{ + "merge_commit_sha": "5b06ca4bb8f9138bf4af85ea3171d95df07462c5", + "authors": [ + { + "n": "github-actions[bot]", + "e": "41898282+github-actions[bot]@users.noreply.github.com" + }, + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + ".github/workflows/credit.yml": { + "a": 2, + "d": 2 + }, + "doc/documentation/cited.rst": { + "a": 3, + "d": 3 + }, + "doc/sphinxext/prs/13029.json": { + "a": 91, + "d": 0 + }, + "doc/sphinxext/prs/6915.json": { + "a": 43, + "d": 0 + }, + "tools/dev/update_credit_json.py": { + "a": 3, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13032.json b/doc/sphinxext/prs/13032.json new file mode 100644 index 00000000000..041aa7dfe1b --- /dev/null +++ b/doc/sphinxext/prs/13032.json @@ -0,0 +1,15 @@ +{ + "merge_commit_sha": "b38385ef90d0fd8214d54b15c5fd91333c3bc032", + "authors": [ + { + "n": "pre-commit-ci[bot]", + "e": "66853113+pre-commit-ci[bot]@users.noreply.github.com" + } + ], + "changes": { + ".pre-commit-config.yaml": { + "a": 2, + "d": 2 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13035.json b/doc/sphinxext/prs/13035.json new file mode 100644 index 00000000000..e9cf38b09c9 --- /dev/null +++ b/doc/sphinxext/prs/13035.json @@ -0,0 +1,51 @@ +{ + "merge_commit_sha": "dcd26258c5fd83fd3974d73ba2b1e2773c33bc3d", + "authors": [ + { + "n": "Mathieu Scheltienne", + "e": "mathieu.scheltienne@gmail.com" + }, + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + "environment.yml": { + "a": 1, + "d": 1 + }, + "mne/io/ant/ant.py": { + "a": 14, + "d": 12 + }, + "mne/io/ant/tests/test_ant.py": { + "a": 1, + "d": 1 + }, + "mne/preprocessing/tests/test_fine_cal.py": { + "a": 1, + "d": 1 + }, + "mne/utils/check.py": { + "a": 23, + "d": 24 + }, + "mne/utils/tests/test_check.py": { + "a": 7, + "d": 0 + }, + "pyproject.toml": { + "a": 1, + "d": 1 + }, + "tools/circleci_dependencies.sh": { + "a": 1, + "d": 6 + }, + "tutorials/intro/70_report.py": { + "a": 4, + "d": 5 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13036.json b/doc/sphinxext/prs/13036.json new file mode 100644 index 00000000000..6ce98904756 --- /dev/null +++ b/doc/sphinxext/prs/13036.json @@ -0,0 +1,27 @@ +{ + "merge_commit_sha": "41dbdd55eaff77314440ebc8700e0e58b1183113", + "authors": [ + { + "n": "Daniel McCloy", + "e": null + }, + { + "n": "autofix-ci[bot]", + "e": "114827586+autofix-ci[bot]@users.noreply.github.com" + } + ], + "changes": { + "doc/changes/devel/13036.bugfix.rst": { + "a": 1, + "d": 0 + }, + "mne/time_frequency/tests/test_spectrum.py": { + "a": 26, + "d": 0 + }, + "mne/viz/utils.py": { + "a": 3, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13038.json b/doc/sphinxext/prs/13038.json new file mode 100644 index 00000000000..c988969d1dd --- /dev/null +++ b/doc/sphinxext/prs/13038.json @@ -0,0 +1,63 @@ +{ + "merge_commit_sha": "14938b9657b255a38aa96482a4aaf410e8865859", + "authors": [ + { + "n": "github-actions[bot]", + "e": "41898282+github-actions[bot]@users.noreply.github.com" + }, + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + }, + { + "n": "pre-commit-ci[bot]", + "e": "66853113+pre-commit-ci[bot]@users.noreply.github.com" + } + ], + "changes": { + "azure-pipelines.yml": { + "a": 1, + "d": 2 + }, + "doc/sphinxext/prs/13031.json": { + "a": 35, + "d": 0 + }, + "doc/sphinxext/prs/13032.json": { + "a": 15, + "d": 0 + }, + "doc/sphinxext/prs/13035.json": { + "a": 51, + "d": 0 + }, + "doc/sphinxext/prs/13036.json": { + "a": 27, + "d": 0 + }, + "environment.yml": { + "a": 1, + "d": 1 + }, + "mne/conftest.py": { + "a": 3, + "d": 1 + }, + "tools/azure_dependencies.sh": { + "a": 1, + "d": 1 + }, + "tools/circleci_dependencies.sh": { + "a": 0, + "d": 1 + }, + "tools/hooks/update_environment_file.py": { + "a": 0, + "d": 3 + }, + "tutorials/time-freq/50_ssvep.py": { + "a": 1, + "d": 1 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13040.json b/doc/sphinxext/prs/13040.json new file mode 100644 index 00000000000..97a70ab4a32 --- /dev/null +++ b/doc/sphinxext/prs/13040.json @@ -0,0 +1,171 @@ +{ + "merge_commit_sha": "637c231f40d8e6e022ab3ae04fa30911cbe0f78f", + "authors": [ + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + ".pre-commit-config.yaml": { + "a": 1, + "d": 1 + }, + "README.rst": { + "a": 1, + "d": 1 + }, + "environment.yml": { + "a": 4, + "d": 4 + }, + "mne/channels/montage.py": { + "a": 1, + "d": 3 + }, + "mne/commands/mne_flash_bem.py": { + "a": 0, + "d": 20 + }, + "mne/commands/tests/test_commands.py": { + "a": 1, + "d": 2 + }, + "mne/conftest.py": { + "a": 5, + "d": 14 + }, + "mne/datasets/eegbci/eegbci.py": { + "a": 3, + "d": 23 + }, + "mne/datasets/eegbci/tests/test_eegbci.py": { + "a": 0, + "d": 12 + }, + "mne/decoding/csp.py": { + "a": 0, + "d": 13 + }, + "mne/decoding/ems.py": { + "a": 1, + "d": 1 + }, + "mne/decoding/tests/test_base.py": { + "a": 1, + "d": 1 + }, + "mne/fixes.py": { + "a": 2, + "d": 6 + }, + "mne/io/base.py": { + "a": 1, + "d": 18 + }, + "mne/io/egi/tests/test_egi.py": { + "a": 0, + "d": 3 + }, + "mne/io/fiff/tests/test_raw_fiff.py": { + "a": 0, + "d": 8 + }, + "mne/io/neuralynx/neuralynx.py": { + "a": 1, + "d": 1 + }, + "mne/report/report.py": { + "a": 1, + "d": 0 + }, + "mne/source_estimate.py": { + "a": 0, + "d": 16 + }, + "mne/tests/test_source_estimate.py": { + "a": 1, + "d": 2 + }, + "mne/time_frequency/tfr.py": { + "a": 0, + "d": 50 + }, + "mne/utils/linalg.py": { + "a": 1, + "d": 1 + }, + "mne/viz/_3d.py": { + "a": 3, + "d": 5 + }, + "mne/viz/backends/_abstract.py": { + "a": 1, + "d": 1 + }, + "mne/viz/backends/_notebook.py": { + "a": 2, + "d": 2 + }, + "mne/viz/backends/_pyvista.py": { + "a": 1, + "d": 3 + }, + "mne/viz/backends/_qt.py": { + "a": 3, + "d": 3 + }, + "mne/viz/backends/tests/_utils.py": { + "a": 0, + "d": 44 + }, + "mne/viz/backends/tests/test_abstract.py": { + "a": 0, + "d": 2 + }, + "mne/viz/backends/tests/test_renderer.py": { + "a": 3, + "d": 2 + }, + "mne/viz/evoked.py": { + "a": 1, + "d": 1 + }, + "mne/viz/montage.py": { + "a": 4, + "d": 20 + }, + "mne/viz/tests/test_raw.py": { + "a": 1, + "d": 1 + }, + "mne/viz/topomap.py": { + "a": 3, + "d": 5 + }, + "pyproject.toml": { + "a": 4, + "d": 4 + }, + "tools/dev/Makefile": { + "a": 3, + "d": 0 + }, + "tools/environment_old.yml": { + "a": 10, + "d": 9 + }, + "tools/hooks/sync_dependencies.py": { + "a": 11, + "d": 1 + }, + "tools/hooks/update_environment_file.py": { + "a": 37, + "d": 7 + }, + "tools/vulture_allowlist.py": { + "a": 0, + "d": 2 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13041.json b/doc/sphinxext/prs/13041.json new file mode 100644 index 00000000000..26790d0e352 --- /dev/null +++ b/doc/sphinxext/prs/13041.json @@ -0,0 +1,43 @@ +{ + "merge_commit_sha": "90d9c91fe5d46e0390cb22c943fd6e0dbb578838", + "authors": [ + { + "n": "Daniel McCloy", + "e": null + }, + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + "Makefile": { + "a": 1, + "d": 1 + }, + "mne/html_templates/repr/_acquisition.html.jinja": { + "a": 1, + "d": 1 + }, + "mne/html_templates/repr/_frequencies.html.jinja": { + "a": 62, + "d": 0 + }, + "mne/html_templates/repr/spectrum.html.jinja": { + "a": 9, + "d": 48 + }, + "mne/time_frequency/spectrum.py": { + "a": 3, + "d": 1 + }, + "tools/hooks/sync_dependencies.py": { + "a": 1, + "d": 2 + }, + "tutorials/time-freq/10_spectrum_class.py": { + "a": 4, + "d": 5 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13042.json b/doc/sphinxext/prs/13042.json new file mode 100644 index 00000000000..5f2d05f7013 --- /dev/null +++ b/doc/sphinxext/prs/13042.json @@ -0,0 +1,27 @@ +{ + "merge_commit_sha": "ee2d0caa98b23fe40bcddea2a666b6803f64c3dd", + "authors": [ + { + "n": "Daniel McCloy", + "e": null + }, + { + "n": "autofix-ci[bot]", + "e": "114827586+autofix-ci[bot]@users.noreply.github.com" + } + ], + "changes": { + "doc/changes/devel/13042.bugfix.rst": { + "a": 1, + "d": 0 + }, + "mne/time_frequency/spectrum.py": { + "a": 2, + "d": 2 + }, + "mne/time_frequency/tests/test_spectrum.py": { + "a": 8, + "d": 2 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/13043.json b/doc/sphinxext/prs/13043.json new file mode 100644 index 00000000000..a67a958fa80 --- /dev/null +++ b/doc/sphinxext/prs/13043.json @@ -0,0 +1,15 @@ +{ + "merge_commit_sha": "a51e32a9341bd1b44183356f6116af0821546fda", + "authors": [ + { + "n": "pre-commit-ci[bot]", + "e": "66853113+pre-commit-ci[bot]@users.noreply.github.com" + } + ], + "changes": { + ".pre-commit-config.yaml": { + "a": 2, + "d": 2 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/prs/6915.json b/doc/sphinxext/prs/6915.json new file mode 100644 index 00000000000..24abd2f2a8c --- /dev/null +++ b/doc/sphinxext/prs/6915.json @@ -0,0 +1,43 @@ +{ + "merge_commit_sha": "610ec2a50d7a0e17d1ae8a229793a51370dec81d", + "authors": [ + { + "n": "Fahimeh Mamashli", + "e": "fahimeh.mamashli@pfizer.com" + }, + { + "n": "Mainak Jas", + "e": "jasmainak@users.noreply.github.com" + }, + { + "n": "Eric Larson", + "e": "larson.eric.d@gmail.com" + } + ], + "changes": { + ".github/workflows/tests.yml": { + "a": 2, + "d": 2 + }, + "doc/changes/devel/6915.newfeature.rst": { + "a": 1, + "d": 0 + }, + "examples/datasets/brainstorm_data.py": { + "a": 2, + "d": 3 + }, + "mne/preprocessing/stim.py": { + "a": 56, + "d": 9 + }, + "mne/preprocessing/tests/test_stim.py": { + "a": 35, + "d": 0 + }, + "tools/install_pre_requirements.sh": { + "a": 8, + "d": 2 + } + } +} \ No newline at end of file diff --git a/doc/sphinxext/related_software.py b/doc/sphinxext/related_software.py index b5b74b0f90b..2548725390a 100644 --- a/doc/sphinxext/related_software.py +++ b/doc/sphinxext/related_software.py @@ -81,6 +81,10 @@ "Summary": "A graphical user interface for MNE", }, # TODO: these do not set a valid homepage or documentation page on PyPI + "eeg_positions": { + "Home-page": "https://eeg-positions.readthedocs.io", + "Summary": "Compute and plot standard EEG electrode positions.", + }, "mne-features": { "Home-page": "https://mne.tools/mne-features", "Summary": "MNE-Features software for extracting features from multivariate time series", # noqa: E501 @@ -159,9 +163,9 @@ def _get_packages() -> dict[str, str]: assert not dups, f"Duplicates in MANUAL_PACKAGES and PYPI_PACKAGES: {sorted(dups)}" # And the installer and PyPI-only should be disjoint: dups = set(PYPI_PACKAGES) & set(packages) - assert ( - not dups - ), f"Duplicates in PYPI_PACKAGES and installer packages: {sorted(dups)}" + assert not dups, ( + f"Duplicates in PYPI_PACKAGES and installer packages: {sorted(dups)}" + ) for name in PYPI_PACKAGES | set(MANUAL_PACKAGES): if name not in packages: packages.append(name) @@ -169,6 +173,7 @@ def _get_packages() -> dict[str, str]: packages = sorted(packages, key=lambda x: x.lower()) packages = [RENAMES.get(package, package) for package in packages] out = dict() + reasons = [] for package in status_iterator( packages, f"Adding {len(packages)} related software packages: " ): @@ -179,12 +184,17 @@ def _get_packages() -> dict[str, str]: else: md = importlib.metadata.metadata(package) except importlib.metadata.PackageNotFoundError: - pass # raise a complete error later + reasons.append(f"{package}: not found, needs to be installed") + continue # raise a complete error later else: # Every project should really have this + do_continue = False for key in ("Summary",): if key not in md: - raise ExtensionError(f"Missing {repr(key)} for {package}") + reasons.extend(f"{package}: missing {repr(key)}") + do_continue = True + if do_continue: + continue # It is annoying to find the home page url = None if "Home-page" in md: @@ -200,15 +210,17 @@ def _get_packages() -> dict[str, str]: if url is not None: break else: - raise RuntimeError( - f"Could not find Home-page for {package} in:\n" - f"{sorted(set(md))}\nwith Summary:\n{md['Summary']}" + reasons.append( + f"{package}: could not find Home-page in {sorted(md)}" ) + continue out[package]["url"] = url out[package]["description"] = md["Summary"].replace("\n", "") - bad = [package for package in packages if not out[package]] - if bad and REQUIRE_METADATA: - raise ExtensionError(f"Could not find metadata for:\n{' '.join(bad)}") + reason_str = "\n".join(reasons) + if reason_str and REQUIRE_METADATA: + raise ExtensionError( + f"Could not find suitable metadata for related software:\n{reason_str}" + ) return out diff --git a/doc/sphinxext/unit_role.py b/doc/sphinxext/unit_role.py index b52665e8321..bf31ddf76c4 100644 --- a/doc/sphinxext/unit_role.py +++ b/doc/sphinxext/unit_role.py @@ -10,8 +10,7 @@ def unit_role(name, rawtext, text, lineno, inliner, options={}, content=[]): # def pass_error_to_sphinx(rawtext, text, lineno, inliner): msg = inliner.reporter.error( - "The :unit: role requires a space-separated number and unit; " - f"got {text}", + f"The :unit: role requires a space-separated number and unit; got {text}", line=lineno, ) prb = inliner.problematic(rawtext, rawtext, msg) diff --git a/environment.yml b/environment.yml index 18a8ec931fa..78c773e56bf 100644 --- a/environment.yml +++ b/environment.yml @@ -4,7 +4,7 @@ channels: - conda-forge dependencies: - python >=3.10 - - antio >=0.4.0 + - antio >=0.5.0 - darkdetect - decorator - defusedxml @@ -23,16 +23,18 @@ dependencies: - joblib - jupyter - lazy_loader >=0.3 - - matplotlib >=3.6 + - mamba + - matplotlib >=3.7 - mffpy >=0.5.7 - mne-qt-browser - nibabel - nilearn + - nomkl - numba - - numpy >=1.23,<3 - - openmeeg =2.5.12=*_1 + - numpy >=1.25,<3 + - openmeeg >=2.5.5 - packaging - - pandas + - pandas >=2.0 - pillow - pip - pooch >=1.5 @@ -47,7 +49,7 @@ dependencies: - qdarkstyle !=3.2.2 - qtpy - scikit-learn - - scipy >=1.9 + - scipy >=1.11 - sip - snirf - statsmodels @@ -57,5 +59,5 @@ dependencies: - trame - trame-vtk - trame-vuetify - - vtk >=9.2 + - vtk =9.3.1=qt_* - xlrd diff --git a/examples/datasets/brainstorm_data.py b/examples/datasets/brainstorm_data.py index 6331c9f1b29..ab5499fea71 100644 --- a/examples/datasets/brainstorm_data.py +++ b/examples/datasets/brainstorm_data.py @@ -6,9 +6,8 @@ ===================================== Here we compute the evoked from raw for the Brainstorm -tutorial dataset. For comparison, see :footcite:`TadelEtAl2011` and: - - https://neuroimage.usc.edu/brainstorm/Tutorials/MedianNerveCtf +tutorial dataset. For comparison, see :footcite:`TadelEtAl2011` and +https://neuroimage.usc.edu/brainstorm/Tutorials/MedianNerveCtf. """ # Authors: Mainak Jas diff --git a/examples/decoding/linear_model_patterns.py b/examples/decoding/linear_model_patterns.py index c1390cbb0d3..7373c0a18b3 100644 --- a/examples/decoding/linear_model_patterns.py +++ b/examples/decoding/linear_model_patterns.py @@ -79,7 +79,7 @@ # Extract and plot spatial filters and spatial patterns for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)): - # We fitted the linear model onto Z-scored data. To make the filters + # We fit the linear model on Z-scored data. To make the filters # interpretable, we must reverse this normalization step coef = scaler.inverse_transform([coef])[0] diff --git a/examples/inverse/vector_mne_solution.py b/examples/inverse/vector_mne_solution.py index ca953cd2f24..f6ae788c145 100644 --- a/examples/inverse/vector_mne_solution.py +++ b/examples/inverse/vector_mne_solution.py @@ -79,7 +79,7 @@ # inverse was computed with loose=0.2 print( "Absolute cosine similarity between source normals and directions: " - f'{np.abs(np.sum(directions * inv["source_nn"][2::3], axis=-1)).mean()}' + f"{np.abs(np.sum(directions * inv['source_nn'][2::3], axis=-1)).mean()}" ) brain_max = stc_max.plot( initial_time=peak_time, diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py new file mode 100755 index 00000000000..a6c6bb3c2ba --- /dev/null +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -0,0 +1,196 @@ +""" +.. _ex-pcaobs: + +===================================================================================== +Principal Component Analysis - Optimal Basis Sets (PCA-OBS) removing cardiac artefact +===================================================================================== + +This script shows an example of how to use an adaptation of PCA-OBS +:footcite:`NiazyEtAl2005`. PCA-OBS was originally designed to remove +the ballistocardiographic artefact in simultaneous EEG-fMRI. Here, it +has been adapted to remove the delay between the detected R-peak and the +ballistocardiographic artefact such that the algorithm can be applied to +remove the cardiac artefact in EEG (electroencephalography) and ESG +(electrospinography) data. We will illustrate how it works by applying the +algorithm to ESG data, where the effect of removal is most pronounced. + +See: https://www.biorxiv.org/content/10.1101/2024.09.05.611423v1 +for more details on the dataset and application for ESG data. + +""" + +# Authors: Emma Bailey , +# Steinn Hauser Magnusson +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import glob + +import numpy as np + +# %% +# Download sample subject data from OpenNeuro if you haven't already. +# This will download simultaneous EEG and ESG data from a single run of a +# single participant after median nerve stimulation of the left wrist. +import openneuro +from matplotlib import pyplot as plt + +import mne +from mne import Epochs, events_from_annotations +from mne.io import read_raw_eeglab +from mne.preprocessing import find_ecg_events, fix_stim_artifact + +# add the path where you want the OpenNeuro data downloaded. Each run is ~2GB of data +ds = "ds004388" +target_dir = mne.datasets.default_path() / ds +run_name = "sub-001/eeg/*median_run-03_eeg*.set" +if not glob.glob(str(target_dir / run_name)): + target_dir.mkdir(exist_ok=True) + openneuro.download(dataset=ds, target_dir=target_dir, include=run_name[:-4]) +block_files = glob.glob(str(target_dir / run_name)) +assert len(block_files) == 1 + +# %% +# Define the esg channels (arranged in two patches over the neck and lower back). + +esg_chans = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "L4", + "S6", + "S23", +] + +# Interpolation window in seconds for ESG data to remove stimulation artefact +tstart_esg = -7e-3 +tmax_esg = 7e-3 + +# Define timing of heartbeat epochs in seconds relative to R-peaks +iv_baseline = [-400e-3, -300e-3] +iv_epoch = [-400e-3, 600e-3] + +# %% +# Next, we perform minimal preprocessing including removing the +# stimulation artefact, downsampling and filtering. + +raw = read_raw_eeglab(block_files[0], verbose="error") +raw.set_channel_types(dict(ECG="ecg")) +# Isolate the ESG channels (include the ECG channel for R-peak detection) +raw.pick(esg_chans + ["ECG"]) +# Trim duration and downsample (from 10kHz) to improve example speed +raw.crop(0, 60).load_data().resample(2000) + +# Find trigger timings to remove the stimulation artefact +events, event_dict = events_from_annotations(raw) +trigger_name = "Median - Stimulation" + +fix_stim_artifact( + raw, + events=events, + event_id=event_dict[trigger_name], + tmin=tstart_esg, + tmax=tmax_esg, + mode="linear", + stim_channel=None, +) + +# %% +# Find ECG events and add to the raw structure as event annotations. + +ecg_events, ch_ecg, average_pulse = find_ecg_events(raw, ch_name="ECG") +ecg_event_samples = np.asarray( + [[ecg_event[0] for ecg_event in ecg_events]] +) # Samples only + +qrs_event_time = [ + x / raw.info["sfreq"] for x in ecg_event_samples.reshape(-1) +] # Divide by sampling rate to make times +duration = np.repeat(0.0, len(ecg_event_samples)) +description = ["qrs"] * len(ecg_event_samples) + +raw.annotations.append( + qrs_event_time, duration, description, ch_names=[esg_chans] * len(qrs_event_time) +) + +# %% +# Create evoked response around the detected R-peaks +# before and after cardiac artefact correction. + +events, event_ids = events_from_annotations(raw) +event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} +epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_before = epochs.average() + +# Apply function - modifies the data in place. Optionally high-pass filter +# the data before applying PCA-OBS to remove low frequency drifts +raw = mne.preprocessing.apply_pca_obs( + raw, picks=esg_chans, n_jobs=5, qrs_times=raw.times[ecg_event_samples.reshape(-1)] +) + +epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_after = epochs.average() + +# %% +# Compare evoked responses to assess completeness of artefact removal. + +fig, axes = plt.subplots(1, 1, layout="constrained") +data_before = evoked_before.get_data(units=dict(eeg="uV")).T +data_after = evoked_after.get_data(units=dict(eeg="uV")).T +hs = list() +hs.append(axes.plot(epochs.times, data_before, color="k")[0]) +hs.append(axes.plot(epochs.times, data_after, color="green", label="after")[0]) +axes.set(ylim=[-500, 1000], ylabel="Amplitude (µV)", xlabel="Time (s)") +axes.set(title="ECG artefact removal using PCA-OBS") +axes.legend(hs, ["before", "after"]) +plt.show() + +# %% +# References +# ---------- +# .. footbibliography:: diff --git a/examples/preprocessing/otp.py b/examples/preprocessing/otp.py index df3a6c74ffe..4f2d7619ab8 100644 --- a/examples/preprocessing/otp.py +++ b/examples/preprocessing/otp.py @@ -72,7 +72,7 @@ def compute_bias(raw): idx = epochs.time_as_index(0.036)[0] data = epochs.get_data(copy=False)[:, :, idx].T evoked = mne.EvokedArray(data, epochs.info, tmin=0.0) - dip = fit_dipole(evoked, cov, sphere, n_jobs=None, verbose=False)[0] + dip = fit_dipole(evoked, cov, sphere, verbose=False)[0] actual_pos = mne.dipole.get_phantom_dipoles()[0][dipole_number - 1] misses = 1000 * np.linalg.norm(dip.pos - actual_pos, axis=-1) return misses diff --git a/examples/visualization/evoked_topomap.py b/examples/visualization/evoked_topomap.py index 83d1916c6f9..53b7a60dbba 100644 --- a/examples/visualization/evoked_topomap.py +++ b/examples/visualization/evoked_topomap.py @@ -5,8 +5,8 @@ Plotting topographic maps of evoked data ======================================== -Load evoked data and plot topomaps for selected time points using multiple -additional options. +Load evoked data and plot topomaps for selected time points using +multiple additional options. """ # Authors: Christian Brodbeck # Tal Linzen diff --git a/examples/visualization/evoked_whitening.py b/examples/visualization/evoked_whitening.py index ed05ae3ba11..4bcb4bc8c04 100644 --- a/examples/visualization/evoked_whitening.py +++ b/examples/visualization/evoked_whitening.py @@ -85,7 +85,7 @@ print("Covariance estimates sorted from best to worst") for c in noise_covs: - print(f'{c["method"]} : {c["loglik"]}') + print(f"{c['method']} : {c['loglik']}") # %% # Show the evoked data: diff --git a/ignore_words.txt b/ignore_words.txt index 150a32058e2..12e1a14ae0e 100644 --- a/ignore_words.txt +++ b/ignore_words.txt @@ -41,3 +41,5 @@ connec sme tim whitelists +gotcha +uner diff --git a/mne/_fiff/_digitization.py b/mne/_fiff/_digitization.py index e55fd5d2dae..eb8b6bc396a 100644 --- a/mne/_fiff/_digitization.py +++ b/mne/_fiff/_digitization.py @@ -328,8 +328,7 @@ def _get_data_as_dict_from_dig(dig, exclude_ref_channel=True): dig_coord_frames = set([FIFF.FIFFV_COORD_HEAD]) if len(dig_coord_frames) != 1: raise RuntimeError( - "Only single coordinate frame in dig is supported, " - f"got {dig_coord_frames}" + f"Only single coordinate frame in dig is supported, got {dig_coord_frames}" ) dig_ch_pos_location = np.array(dig_ch_pos_location) dig_ch_pos_location.shape = (-1, 3) # empty will be (0, 3) diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 629d9a4b0ce..51612824a6a 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -455,7 +455,7 @@ def _check_set(ch, projs, ch_type): for proj in projs: if ch["ch_name"] in proj["data"]["col_names"]: raise RuntimeError( - f'Cannot change channel type for channel {ch["ch_name"]} in ' + f"Cannot change channel type for channel {ch['ch_name']} in " f'projector "{proj["desc"]}"' ) ch["kind"] = new_kind @@ -1867,7 +1867,7 @@ def _check_consistency(self, prepend_error=""): ): raise RuntimeError( f'{prepend_error}info["meas_date"] must be a datetime object in UTC' - f' or None, got {repr(self["meas_date"])!r}' + f" or None, got {repr(self['meas_date'])!r}" ) chs = [ch["ch_name"] for ch in self["chs"]] @@ -2493,6 +2493,8 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): hi["meas_date"] = _ensure_meas_date_none_or_dt( tuple(int(t) for t in tag.data), ) + if "meas_date" not in hi: + hi["meas_date"] = None info["helium_info"] = hi del hi @@ -2879,7 +2881,8 @@ def write_meas_info(fid, info, data_type=None, reset_range=True): write_float(fid, FIFF.FIFF_HELIUM_LEVEL, hi["helium_level"]) if hi.get("orig_file_guid") is not None: write_string(fid, FIFF.FIFF_ORIG_FILE_GUID, hi["orig_file_guid"]) - write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"])) + if hi["meas_date"] is not None: + write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"])) end_block(fid, FIFF.FIFFB_HELIUM) del hi @@ -2916,8 +2919,10 @@ def write_meas_info(fid, info, data_type=None, reset_range=True): _write_proc_history(fid, info) -@fill_doc -def write_info(fname, info, data_type=None, reset_range=True): +@verbose +def write_info( + fname, info, *, data_type=None, reset_range=True, overwrite=False, verbose=None +): """Write measurement info in fif file. Parameters @@ -2931,8 +2936,10 @@ def write_info(fname, info, data_type=None, reset_range=True): raw data. reset_range : bool If True, info['chs'][k]['range'] will be set to unity. + %(overwrite)s + %(verbose)s """ - with start_and_end_file(fname) as fid: + with start_and_end_file(fname, overwrite=overwrite) as fid: start_block(fid, FIFF.FIFFB_MEAS) write_meas_info(fid, info, data_type, reset_range) end_block(fid, FIFF.FIFFB_MEAS) @@ -3673,8 +3680,7 @@ def _write_ch_infos(fid, chs, reset_range, ch_names_mapping): # only write new-style channel information if necessary if len(ch_names_mapping): logger.info( - " Writing channel names to FIF truncated to 15 characters " - "with remapping" + " Writing channel names to FIF truncated to 15 characters with remapping" ) for ch in chs: start_block(fid, FIFF.FIFFB_CH_INFO) diff --git a/mne/_fiff/proj.py b/mne/_fiff/proj.py index 0376826138a..d6ec108e34d 100644 --- a/mne/_fiff/proj.py +++ b/mne/_fiff/proj.py @@ -76,7 +76,7 @@ def __repr__(self): # noqa: D105 s += f", active : {self['active']}" s += f", n_channels : {len(self['data']['col_names'])}" if self["explained_var"] is not None: - s += f', exp. var : {self["explained_var"] * 100:0.2f}%' + s += f", exp. var : {self['explained_var'] * 100:0.2f}%" return f"" # speed up info copy by taking advantage of mutability @@ -324,8 +324,7 @@ def apply_proj(self, verbose=None): if all(p["active"] for p in self.info["projs"]): logger.info( - "Projections have already been applied. " - "Setting proj attribute to True." + "Projections have already been applied. Setting proj attribute to True." ) return self @@ -663,9 +662,9 @@ def _read_proj(fid, node, *, ch_names_mapping=None, verbose=None): for proj in projs: misc = "active" if proj["active"] else " idle" logger.info( - f' {proj["desc"]} ' - f'({proj["data"]["nrow"]} x ' - f'{len(proj["data"]["col_names"])}) {misc}' + f" {proj['desc']} " + f"({proj['data']['nrow']} x " + f"{len(proj['data']['col_names'])}) {misc}" ) return projs @@ -795,8 +794,7 @@ def _make_projector(projs, ch_names, bads=(), include_active=True, inplace=False if not p["active"] or include_active: if len(p["data"]["col_names"]) != len(np.unique(p["data"]["col_names"])): raise ValueError( - f"Channel name list in projection item {k}" - " contains duplicate items" + f"Channel name list in projection item {k} contains duplicate items" ) # Get the two selection vectors to pick correct elements from @@ -832,7 +830,7 @@ def _make_projector(projs, ch_names, bads=(), include_active=True, inplace=False ) ): warn( - f'Projection vector {repr(p["desc"])} has been ' + f"Projection vector {repr(p['desc'])} has been " f"reduced to {100 * psize:0.2f}% of its " "original magnitude by subselecting " f"{len(vecsel)}/{orig_n} of the original " diff --git a/mne/_fiff/reference.py b/mne/_fiff/reference.py index e70bf5e36c1..b4c050c096d 100644 --- a/mne/_fiff/reference.py +++ b/mne/_fiff/reference.py @@ -102,7 +102,7 @@ def _check_before_dict_reference(inst, ref_dict): raise TypeError( f"{elem_name.capitalize()}s in the ref_channels dict must be strings. " f"Your dict has {elem_name}s of type " - f'{", ".join(map(lambda x: x.__name__, bad_elem))}.' + f"{', '.join(map(lambda x: x.__name__, bad_elem))}." ) # Check that keys are valid channels and values are lists-of-valid-channels @@ -113,8 +113,8 @@ def _check_before_dict_reference(inst, ref_dict): for elem_name, elem in dict(key=keys, value=values).items(): if bad_elem := elem - ch_set: raise ValueError( - f'ref_channels dict contains invalid {elem_name}(s) ' - f'({", ".join(bad_elem)}) ' + f"ref_channels dict contains invalid {elem_name}(s) " + f"({', '.join(bad_elem)}) " "that are not names of channels in the instance." ) # Check that values are not bad channels diff --git a/mne/_fiff/tag.py b/mne/_fiff/tag.py index abc7d32036b..3fd36454d58 100644 --- a/mne/_fiff/tag.py +++ b/mne/_fiff/tag.py @@ -70,8 +70,7 @@ def _frombuffer_rows(fid, tag_size, dtype=None, shape=None, rlims=None): have_shape = tag_size // item_size if want_shape != have_shape: raise ValueError( - f"Wrong shape specified, requested {want_shape} but got " - f"{have_shape}" + f"Wrong shape specified, requested {want_shape} but got {have_shape}" ) if not len(rlims) == 2: raise ValueError("rlims must have two elements") diff --git a/mne/_fiff/tests/test_meas_info.py b/mne/_fiff/tests/test_meas_info.py index 3e3c150573f..a38ecaade50 100644 --- a/mne/_fiff/tests/test_meas_info.py +++ b/mne/_fiff/tests/test_meas_info.py @@ -306,7 +306,9 @@ def test_read_write_info(tmp_path): gantry_angle = info["gantry_angle"] meas_id = info["meas_id"] - write_info(temp_file, info) + with pytest.raises(FileExistsError, match="Destination file exists"): + write_info(temp_file, info) + write_info(temp_file, info, overwrite=True) info = read_info(temp_file) assert info["proc_history"][0]["creator"] == creator assert info["hpi_meas"][0]["creator"] == creator @@ -348,7 +350,7 @@ def test_read_write_info(tmp_path): info["meas_date"] = datetime(1800, 1, 1, 0, 0, 0, tzinfo=timezone.utc) fname = tmp_path / "test.fif" with pytest.raises(RuntimeError, match="must be between "): - write_info(fname, info) + write_info(fname, info, overwrite=True) @testing.requires_testing_data @@ -377,7 +379,7 @@ def test_io_coord_frame(tmp_path): for ch_type in ("eeg", "seeg", "ecog", "dbs", "hbo", "hbr"): info = create_info(ch_names=["Test Ch"], sfreq=1000.0, ch_types=[ch_type]) info["chs"][0]["loc"][:3] = [0.05, 0.01, -0.03] - write_info(fname, info) + write_info(fname, info, overwrite=True) info2 = read_info(fname) assert info2["chs"][0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD @@ -585,7 +587,7 @@ def test_check_consistency(): info2["subject_info"] = {"height": "bad"} -def _test_anonymize_info(base_info): +def _test_anonymize_info(base_info, tmp_path): """Test that sensitive information can be anonymized.""" pytest.raises(TypeError, anonymize_info, "foo") assert isinstance(base_info, Info) @@ -692,14 +694,25 @@ def _adjust_back(e_i, dt): # exp 4 tests is a supplied daysback delta_t_3 = timedelta(days=223 + 364 * 500) + def _check_equiv(got, want, err_msg): + __tracebackhide__ = True + fname_temp = tmp_path / "test.fif" + assert_object_equal(got, want, err_msg=err_msg) + write_info(fname_temp, got, reset_range=False, overwrite=True) + got = read_info(fname_temp) + # this gets changed on write but that's expected + with got._unlock(): + got["file_id"] = want["file_id"] + assert_object_equal(got, want, err_msg=f"{err_msg} (on I/O round trip)") + new_info = anonymize_info(base_info.copy()) - assert_object_equal(new_info, exp_info, err_msg="anon mismatch") + _check_equiv(new_info, exp_info, err_msg="anon mismatch") new_info = anonymize_info(base_info.copy(), keep_his=True) - assert_object_equal(new_info, exp_info_2, err_msg="anon keep_his mismatch") + _check_equiv(new_info, exp_info_2, err_msg="anon keep_his mismatch") new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) - assert_object_equal(new_info, exp_info_3, err_msg="anon daysback mismatch") + _check_equiv(new_info, exp_info_3, err_msg="anon daysback mismatch") with pytest.raises(RuntimeError, match="anonymize_info generated"): anonymize_info(base_info.copy(), daysback=delta_t_3.days) @@ -726,7 +739,7 @@ def _adjust_back(e_i, dt): new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) else: new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) - assert_object_equal( + _check_equiv( new_info, exp_info_3, err_msg="meas_date=None daysback mismatch", @@ -734,7 +747,7 @@ def _adjust_back(e_i, dt): with _record_warnings(): # meas_date is None new_info = anonymize_info(base_info.copy()) - assert_object_equal(new_info, exp_info_3, err_msg="meas_date=None mismatch") + _check_equiv(new_info, exp_info_3, err_msg="meas_date=None mismatch") @pytest.mark.parametrize( @@ -777,8 +790,8 @@ def _complete_info(info): height=2.0, ) info["helium_info"] = dict( - he_level_raw=12.34, - helium_level=45.67, + he_level_raw=np.float32(12.34), + helium_level=np.float32(45.67), meas_date=datetime(2024, 11, 14, 14, 8, 2, tzinfo=timezone.utc), orig_file_guid="e", ) @@ -796,14 +809,13 @@ def _complete_info(info): machid=np.ones(2, int), secs=d[0], usecs=d[1], - date=d, ), experimenter="j", max_info=dict( - max_st=[], - sss_ctc=[], - sss_cal=[], - sss_info=dict(head_pos=None, in_order=8), + max_st=dict(), + sss_ctc=dict(), + sss_cal=dict(), + sss_info=dict(in_order=8), ), date=d, ), @@ -830,8 +842,8 @@ def test_anonymize(tmp_path): # test mne.anonymize_info() events = read_events(event_name) epochs = Epochs(raw, events[:1], 2, 0.0, 0.1, baseline=None) - _test_anonymize_info(raw.info) - _test_anonymize_info(epochs.info) + _test_anonymize_info(raw.info, tmp_path) + _test_anonymize_info(epochs.info, tmp_path) # test instance methods & I/O roundtrip for inst, keep_his in zip((raw, epochs), (True, False)): @@ -1106,7 +1118,7 @@ def test_channel_name_limit(tmp_path, monkeypatch, fname): meas_info, "_read_extended_ch_info", _read_extended_ch_info ) short_proj_names = [ - f"{name[:13 - bool(len(ref_names))]}-{ni}" + f"{name[: 13 - bool(len(ref_names))]}-{ni}" for ni, name in enumerate(long_proj_names) ] assert raw_read.info["projs"][0]["data"]["col_names"] == short_proj_names diff --git a/mne/_fiff/tests/test_pick.py b/mne/_fiff/tests/test_pick.py index 90830e1d5e5..5d1b24247ab 100644 --- a/mne/_fiff/tests/test_pick.py +++ b/mne/_fiff/tests/test_pick.py @@ -136,7 +136,7 @@ def _channel_type_old(info, idx): else: return t - raise ValueError(f'Unknown channel type for {ch["ch_name"]}') + raise ValueError(f"Unknown channel type for {ch['ch_name']}") def _assert_channel_types(info): diff --git a/mne/_fiff/write.py b/mne/_fiff/write.py index 1fc32f0163e..8486ca13121 100644 --- a/mne/_fiff/write.py +++ b/mne/_fiff/write.py @@ -13,7 +13,7 @@ import numpy as np from scipy.sparse import csc_array, csr_array -from ..utils import _file_like, _validate_type, logger +from ..utils import _check_fname, _file_like, _validate_type, logger from ..utils.numerics import _date_to_julian from .constants import FIFF @@ -277,7 +277,7 @@ def end_block(fid, kind): write_int(fid, FIFF.FIFF_BLOCK_END, kind) -def start_file(fname, id_=None): +def start_file(fname, id_=None, *, overwrite=True): """Open a fif file for writing and writes the compulsory header tags. Parameters @@ -294,6 +294,7 @@ def start_file(fname, id_=None): fid = fname fid.seek(0) else: + fname = _check_fname(fname, overwrite=overwrite) fname = str(fname) if op.splitext(fname)[1].lower() == ".gz": logger.debug("Writing using gzip") @@ -311,9 +312,9 @@ def start_file(fname, id_=None): @contextmanager -def start_and_end_file(fname, id_=None): +def start_and_end_file(fname, id_=None, *, overwrite=True): """Start and (if successfully written) close the file.""" - with start_file(fname, id_=id_) as fid: + with start_file(fname, id_=id_, overwrite=overwrite) as fid: yield fid end_file(fid) # we only hit this line if the yield does not err diff --git a/mne/beamformer/_compute_beamformer.py b/mne/beamformer/_compute_beamformer.py index bb947cdd757..16bedc2c317 100644 --- a/mne/beamformer/_compute_beamformer.py +++ b/mne/beamformer/_compute_beamformer.py @@ -507,13 +507,13 @@ def __repr__(self): # noqa: D105 n_channels, ) if self["pick_ori"] is not None: - out += f', {self["pick_ori"]} ori' + out += f", {self['pick_ori']} ori" if self["weight_norm"] is not None: - out += f', {self["weight_norm"]} norm' + out += f", {self['weight_norm']} norm" if self.get("inversion") is not None: - out += f', {self["inversion"]} inversion' + out += f", {self['inversion']} inversion" if "rank" in self: - out += f', rank {self["rank"]}' + out += f", rank {self['rank']}" out += ">" return out @@ -531,7 +531,7 @@ def save(self, fname, overwrite=False, verbose=None): """ _, write_hdf5 = _import_h5io_funcs() - ending = f'-{self["kind"].lower()}.h5' + ending = f"-{self['kind'].lower()}.h5" check_fname(fname, self["kind"], (ending,)) csd_orig = None try: diff --git a/mne/beamformer/tests/test_lcmv.py b/mne/beamformer/tests/test_lcmv.py index 957dbaf5284..9ae5473e190 100644 --- a/mne/beamformer/tests/test_lcmv.py +++ b/mne/beamformer/tests/test_lcmv.py @@ -380,7 +380,7 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): rank = 17 if proj else 20 assert "LCMV" in repr(filters) assert "unknown subject" not in repr(filters) - assert f'{fwd["nsource"]} vert' in repr(filters) + assert f"{fwd['nsource']} vert" in repr(filters) assert "20 ch" in repr(filters) assert f"rank {rank}" in repr(filters) diff --git a/mne/bem.py b/mne/bem.py index d361272fd49..22aa02d2a0d 100644 --- a/mne/bem.py +++ b/mne/bem.py @@ -91,7 +91,7 @@ class ConductorModel(dict): def __repr__(self): # noqa: D105 if self["is_sphere"]: - center = ", ".join(f"{x * 1000.:.1f}" for x in self["r0"]) + center = ", ".join(f"{x * 1000.0:.1f}" for x in self["r0"]) rad = self.radius if rad is None: # no radius / MEG only extra = f"Sphere (no layers): r0=[{center}] mm" @@ -538,7 +538,7 @@ def _assert_complete_surface(surf, incomplete="raise"): prop = tot_angle / (2 * np.pi) if np.abs(prop - 1.0) > 1e-5: msg = ( - f'Surface {_bem_surf_name[surf["id"]]} is not complete (sum of ' + f"Surface {_bem_surf_name[surf['id']]} is not complete (sum of " f"solid angles yielded {prop}, should be 1.)" ) _on_missing(incomplete, msg, name="incomplete", error_klass=RuntimeError) @@ -571,7 +571,7 @@ def _check_surface_size(surf): sizes = surf["rr"].max(axis=0) - surf["rr"].min(axis=0) if (sizes < 0.05).any(): raise RuntimeError( - f'Dimensions of the surface {_bem_surf_name[surf["id"]]} seem too ' + f"Dimensions of the surface {_bem_surf_name[surf['id']]} seem too " f"small ({1000 * sizes.min():9.5f}). Maybe the unit of measure" " is meters instead of mm" ) @@ -599,8 +599,7 @@ def _surfaces_to_bem( # surfs can be strings (filenames) or surface dicts if len(surfs) not in (1, 3) or not (len(surfs) == len(ids) == len(sigmas)): raise ValueError( - "surfs, ids, and sigmas must all have the same " - "number of elements (1 or 3)" + "surfs, ids, and sigmas must all have the same number of elements (1 or 3)" ) for si, surf in enumerate(surfs): if isinstance(surf, str | Path | os.PathLike): @@ -1260,8 +1259,7 @@ def make_watershed_bem( if op.isdir(ws_dir): if not overwrite: raise RuntimeError( - f"{ws_dir} already exists. Use the --overwrite option" - " to recreate it." + f"{ws_dir} already exists. Use the --overwrite option to recreate it." ) else: shutil.rmtree(ws_dir) @@ -2460,7 +2458,7 @@ def check_seghead(surf_path=subj_path / "surf"): logger.info(f"{ii}. Creating {level} tessellation...") logger.info( f"{ii}.1 Decimating the dense tessellation " - f'({len(surf["tris"])} -> {n_tri} triangles)...' + f"({len(surf['tris'])} -> {n_tri} triangles)..." ) points, tris = decimate_surface( points=surf["rr"], triangles=surf["tris"], n_triangles=n_tri diff --git a/mne/channels/__init__.pyi b/mne/channels/__init__.pyi index 0f3417ec59d..05f273a713d 100644 --- a/mne/channels/__init__.pyi +++ b/mne/channels/__init__.pyi @@ -32,6 +32,7 @@ __all__ = [ "read_polhemus_fastscan", "read_vectorview_selection", "rename_channels", + "transform_to_head", "unify_bad_channels", ] from .channels import ( @@ -73,4 +74,5 @@ from .montage import ( read_dig_localite, read_dig_polhemus_isotrak, read_polhemus_fastscan, + transform_to_head, ) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 8fbff33c13e..bf9e58f2819 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -661,17 +661,21 @@ def _pick_projs(self): return self def add_channels(self, add_list, force_update_info=False): - """Append new channels to the instance. + """Append new channels from other MNE objects to the instance. Parameters ---------- add_list : list - A list of objects to append to self. Must contain all the same - type as the current object. + A list of MNE objects to append to the current instance. + The channels contained in the other instances are appended to the + channels of the current instance. Therefore, all other instances + must be of the same type as the current object. + See notes on how to add data coming from an array. force_update_info : bool If True, force the info for objects to be appended to match the - values in ``self``. This should generally only be used when adding - stim channels for which important metadata won't be overwritten. + values of the current instance. This should generally only be + used when adding stim channels for which important metadata won't + be overwritten. .. versionadded:: 0.12 @@ -688,6 +692,12 @@ def add_channels(self, add_list, force_update_info=False): ----- If ``self`` is a Raw instance that has been preloaded into a :obj:`numpy.memmap` instance, the memmap will be resized. + + This function expects an MNE object to be appended (e.g. :class:`~mne.io.Raw`, + :class:`~mne.Epochs`, :class:`~mne.Evoked`). If you simply want to add a + channel based on values of an np.ndarray, you need to create a + :class:`~mne.io.RawArray`. + See `_ """ # avoid circular imports from ..epochs import BaseEpochs @@ -1372,7 +1382,7 @@ def read_ch_adjacency(fname, picks=None): raise ValueError( f"No built-in channel adjacency matrix found with name: " f"{ch_adj_name}. Valid names are: " - f'{", ".join(get_builtin_ch_adjacencies())}' + f"{', '.join(get_builtin_ch_adjacencies())}" ) ch_adj = [a for a in _BUILTIN_CHANNEL_ADJACENCIES if a.name == ch_adj_name][0] diff --git a/mne/channels/montage.py b/mne/channels/montage.py index a6ded682de9..15cef38dec7 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -361,8 +361,7 @@ def __repr__(self): def plot( self, *, - scale=None, - scale_factor=None, + scale=1.0, show_names=True, kind="topomap", show=True, @@ -373,7 +372,6 @@ def plot( return plot_montage( self, scale=scale, - scale_factor=scale_factor, show_names=show_names, kind=kind, show=show, @@ -407,7 +405,7 @@ def save(self, fname, *, overwrite=False, verbose=None): Parameters ---------- fname : path-like - The filename to use. Should end in .fif or .fif.gz. + The filename to use. Should end in ``-dig.fif`` or ``-dig.fif.gz``. %(overwrite)s %(verbose)s @@ -1289,7 +1287,7 @@ def _backcompat_value(pos, ref_pos): f"Not setting position{_pl(extra)} of {len(extra)} {types} " f"channel{_pl(extra)} found in montage:\n{names}\n" "Consider setting the channel types to be of " - f'{docdict["montage_types"]} ' + f"{docdict['montage_types']} " "using inst.set_channel_types before calling inst.set_montage, " "or omit these channels when creating your montage." ) diff --git a/mne/channels/tests/test_channels.py b/mne/channels/tests/test_channels.py index f51b551a1c8..bb886c51a96 100644 --- a/mne/channels/tests/test_channels.py +++ b/mne/channels/tests/test_channels.py @@ -404,8 +404,7 @@ def test_adjacency_matches_ft(tmp_path): if hash_mne.hexdigest() != hash_ft.hexdigest(): raise ValueError( - f"Hash mismatch between built-in and FieldTrip neighbors " - f"for {fname}" + f"Hash mismatch between built-in and FieldTrip neighbors for {fname}" ) diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index 8add1398409..d9306b5e1bd 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -420,12 +420,7 @@ def test_documented(): ), pytest.param( partial(read_dig_hpts, unit="m"), - ( - "eeg Fp1 -95.0 -3. -3.\n" - "eeg AF7 -1 -1 -3\n" - "eeg A3 -2 -2 2\n" - "eeg A 0 0 0" - ), + ("eeg Fp1 -95.0 -3. -3.\neeg AF7 -1 -1 -3\neeg A3 -2 -2 2\neeg A 0 0 0"), make_dig_montage( ch_pos={ "A": [0.0, 0.0, 0.0], diff --git a/mne/commands/mne_flash_bem.py b/mne/commands/mne_flash_bem.py index b6c7a1b795d..63bcb79d9d8 100644 --- a/mne/commands/mne_flash_bem.py +++ b/mne/commands/mne_flash_bem.py @@ -95,17 +95,6 @@ def run(): "been registered with the T1.mgz file." ), ) - parser.add_option( - "-n", - "--noconvert", - dest="noconvert", - action="store_true", - default=False, - help=( - "[DEPRECATED] Assume that the Flash MRI images " - "have already been converted to mgz files" - ), - ) parser.add_option( "-u", "--unwarp", @@ -139,15 +128,6 @@ def run(): help="Use copies instead of symlinks for surfaces", action="store_true", ) - parser.add_option( - "-p", - "--flash-path", - dest="flash_path", - default=None, - help="[DEPRECATED] The directory containing flash5.mgz " - "files (defaults to " - "$SUBJECTS_DIR/$SUBJECT/mri/flash/parameter_maps", - ) options, _ = parser.parse_args() diff --git a/mne/commands/mne_make_scalp_surfaces.py b/mne/commands/mne_make_scalp_surfaces.py index 5b7d020b98d..894ede7fa1a 100644 --- a/mne/commands/mne_make_scalp_surfaces.py +++ b/mne/commands/mne_make_scalp_surfaces.py @@ -49,8 +49,7 @@ def run(): "--force", dest="force", action="store_true", - help="Force creation of the surface even if it has " - "some topological defects.", + help="Force creation of the surface even if it has some topological defects.", ) parser.add_option( "-t", diff --git a/mne/commands/mne_setup_source_space.py b/mne/commands/mne_setup_source_space.py index e536a59f90b..273e833b31c 100644 --- a/mne/commands/mne_setup_source_space.py +++ b/mne/commands/mne_setup_source_space.py @@ -62,8 +62,7 @@ def run(): parser.add_option( "--ico", dest="ico", - help="use the recursively subdivided icosahedron " - "to create the source space.", + help="use the recursively subdivided icosahedron to create the source space.", default=None, type="int", ) diff --git a/mne/commands/tests/test_commands.py b/mne/commands/tests/test_commands.py index ae885d3c21d..30a625e9578 100644 --- a/mne/commands/tests/test_commands.py +++ b/mne/commands/tests/test_commands.py @@ -334,7 +334,7 @@ def test_flash_bem(tmp_path): # First test without flash30 with ArgvSetter( - ("-d", tempdir, "-s", "sample", "-n", "-r", "-3"), + ("-d", tempdir, "-s", "sample", "-r", "-3"), disable_stdout=False, disable_stderr=False, ): @@ -361,7 +361,6 @@ def test_flash_bem(tmp_path): tempdir, "-s", "sample", - "-n", "-3", str(mridata_path / "flash" / "mef30.mgz"), "-5", diff --git a/mne/conftest.py b/mne/conftest.py index 8795ef1e282..8a4586067b3 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -6,6 +6,7 @@ import inspect import os import os.path as op +import re import shutil import sys import warnings @@ -34,6 +35,7 @@ _pl, _record_warnings, _TempDir, + check_version, numerics, ) @@ -78,7 +80,7 @@ collect_ignore = ["export/_brainvision.py", "export/_eeglab.py", "export/_edf.py"] -def pytest_configure(config): +def pytest_configure(config: pytest.Config): """Configure pytest options.""" # Markers for marker in ( @@ -119,7 +121,7 @@ def pytest_configure(config): # we should remove them from here. # - This list should also be considered alongside reset_warnings in # doc/conf.py. - if os.getenv("MNE_IGNORE_WARNINGS_IN_TESTS", "") != "true": + if os.getenv("MNE_IGNORE_WARNINGS_IN_TESTS", "") not in ("true", "1"): first_kind = "error" else: first_kind = "always" @@ -178,6 +180,15 @@ def pytest_configure(config): ignore:__array__ implementation doesn't accept a copy.*:DeprecationWarning # quantities via neo ignore:The 'copy' argument in Quantity is deprecated.*: + # debugpy uses deprecated matplotlib API + ignore:The (non_)?interactive_bk attribute was deprecated.*: + # SWIG (via OpenMEEG) + ignore:.*builtin type swigvarlink has no.*:DeprecationWarning + # eeglabio + ignore:numpy\.core\.records is deprecated.*:DeprecationWarning + ignore:Starting field name with a underscore.*: + # joblib + ignore:process .* is multi-threaded, use of fork/exec.*:DeprecationWarning """ # noqa: E501 for warning_line in warning_lines.split("\n"): warning_line = warning_line.strip() @@ -632,23 +643,20 @@ def _use_backend(backend_name, interactive): def _check_skip_backend(name): from mne.viz.backends._utils import _notebook_vtk_works - from mne.viz.backends.tests._utils import ( - has_imageio_ffmpeg, - has_pyvista, - has_pyvistaqt, - ) - if not has_pyvista(): - pytest.skip("Test skipped, requires pyvista.") - if not has_imageio_ffmpeg(): - pytest.skip("Test skipped, requires imageio-ffmpeg") + pytest.importorskip("pyvista") + pytest.importorskip("imageio_ffmpeg") if name == "pyvistaqt": + pytest.importorskip("pyvistaqt") if not _check_qt_version(): pytest.skip("Test skipped, requires Qt.") - if not has_pyvistaqt(): - pytest.skip("Test skipped, requires pyvistaqt") else: assert name == "notebook", name + pytest.importorskip("jupyter") + pytest.importorskip("ipympl") + pytest.importorskip("trame") + pytest.importorskip("trame_vtk") + pytest.importorskip("trame_vuetify") if not _notebook_vtk_works(): pytest.skip("Test skipped, requires working notebook vtk") @@ -656,10 +664,8 @@ def _check_skip_backend(name): @pytest.fixture(scope="session") def pixel_ratio(): """Get the pixel ratio.""" - from mne.viz.backends.tests._utils import has_pyvista - # _check_qt_version will init an app for us, so no need for us to do it - if not has_pyvista() or not _check_qt_version(): + if not check_version("pyvista", "0.32") or not _check_qt_version(): return 1.0 from qtpy.QtCore import Qt from qtpy.QtWidgets import QMainWindow @@ -1179,10 +1185,55 @@ def qt_windows_closed(request): @pytest.hookimpl(tryfirst=True, hookwrapper=True) def pytest_runtest_makereport(item, call): - """Stash the status of each item.""" + """Stash the status of each item and turn unexpected skips into errors.""" outcome = yield - rep = outcome.get_result() + rep: pytest.TestReport = outcome.get_result() item.stash.setdefault(_phase_report_key, {})[rep.when] = rep + _modify_report_skips(rep) + return rep + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_make_collect_report(collector: pytest.Collector): + """Turn unexpected skips during collection (e.g., module-level) into errors.""" + outcome = yield + rep: pytest.CollectReport = outcome.get_result() + _modify_report_skips(rep) + return rep + + +# Default means "allow all skips". Can use something like "$." to mean +# "never match", i.e., "treat all skips as errors" +_valid_skips_re = re.compile(os.getenv("MNE_TEST_ALLOW_SKIP", ".*")) + + +# To turn unexpected skips into errors, we need to look both at the collection phase +# (for decorated tests) and the call phase (for things like `importorskip` +# within the test body). code adapted from pytest-error-for-skips +def _modify_report_skips(report: pytest.TestReport | pytest.CollectReport): + if not report.skipped: + return + if isinstance(report.longrepr, tuple): + file, lineno, reason = report.longrepr + else: + file, lineno, reason = "", 1, str(report.longrepr) + if _valid_skips_re.match(reason): + return + assert isinstance(report, pytest.TestReport | pytest.CollectReport), type(report) + if file.endswith("doctest.py"): # _python/doctest.py + return + # xfail tests aren't true "skips" but show up as skipped in reports + if getattr(report, "keywords", {}).get("xfail", False): + return + # the above only catches marks, so we need to actually parse the report to catch + # an xfail based on the traceback + if " pytest.xfail( " in reason: + return + if reason.startswith("Skipped: "): + reason = reason[9:] + report.longrepr = f"{file}:{lineno}: UNEXPECTED SKIP: {reason}" + # Make it show up as an error in the report + report.outcome = "error" if isinstance(report, pytest.TestReport) else "failed" @pytest.fixture(scope="function") diff --git a/mne/coreg.py b/mne/coreg.py index f28c6142c96..c7549ee028a 100644 --- a/mne/coreg.py +++ b/mne/coreg.py @@ -876,8 +876,7 @@ def _scale_params(subject_to, subject_from, scale, subjects_dir): subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) if (subject_from is None) != (scale is None): raise TypeError( - "Need to provide either both subject_from and scale " - "parameters, or neither." + "Need to provide either both subject_from and scale parameters, or neither." ) if subject_from is None: @@ -1402,8 +1401,7 @@ def _read_surface(filename, *, on_defects): complete_surface_info(bem, copy=False) except Exception: raise ValueError( - f"Error loading surface from {filename} (see " - "Terminal for details)." + f"Error loading surface from {filename} (see Terminal for details)." ) return bem @@ -2145,8 +2143,7 @@ def omit_head_shape_points(self, distance): mask = self._orig_hsp_point_distance <= distance n_excluded = np.sum(~mask) logger.info( - "Coregistration: Excluding %i head shape points with " - "distance >= %.3f m.", + "Coregistration: Excluding %i head shape points with distance >= %.3f m.", n_excluded, distance, ) diff --git a/mne/cov.py b/mne/cov.py index 8b86119c1d1..694c836d0cd 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -1226,7 +1226,7 @@ def _compute_rank_raw_array( from .io import RawArray return _compute_rank( - RawArray(data, info, copy=None, verbose=_verbose_safe_false()), + RawArray(data, info, copy="auto", verbose=_verbose_safe_false()), rank, scalings, info, @@ -1293,7 +1293,7 @@ def _compute_covariance_auto( data_ = data.copy() name = method_.__name__ if callable(method_) else method_ logger.info( - f'Estimating {cov_kind + (" " if cov_kind else "")}' + f"Estimating {cov_kind + (' ' if cov_kind else '')}" f"covariance using {name.upper()}" ) mp = method_params[method_] @@ -1405,7 +1405,7 @@ def _compute_covariance_auto( # project back cov = np.dot(eigvec.T, np.dot(cov, eigvec)) # undo bias - cov *= data.shape[0] / (data.shape[0] - 1) + cov *= data.shape[0] / max(data.shape[0] - 1, 1) # undo scaling _undo_scaling_cov(cov, picks_list, scalings) method_ = method[ei] @@ -1712,7 +1712,7 @@ def _get_ch_whitener(A, pca, ch_type, rank): logger.info( f" Setting small {ch_type} eigenvalues to zero " - f'({"using" if pca else "without"} PCA)' + f"({'using' if pca else 'without'} PCA)" ) if pca: # No PCA case. # This line will reduce the actual number of variables in data @@ -2400,7 +2400,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): data = tag.data diag = True logger.info( - " %d x %d diagonal covariance (kind = " "%d) found.", + " %d x %d diagonal covariance (kind = %d) found.", dim, dim, cov_kind, @@ -2416,7 +2416,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): data.flat[:: dim + 1] /= 2.0 diag = False logger.info( - " %d x %d full covariance (kind = %d) " "found.", + " %d x %d full covariance (kind = %d) found.", dim, dim, cov_kind, @@ -2425,7 +2425,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): diag = False data = tag.data logger.info( - " %d x %d sparse covariance (kind = %d)" " found.", + " %d x %d sparse covariance (kind = %d) found.", dim, dim, cov_kind, diff --git a/mne/datasets/__init__.pyi b/mne/datasets/__init__.pyi index 44cee84fe7f..2f69a1027e5 100644 --- a/mne/datasets/__init__.pyi +++ b/mne/datasets/__init__.pyi @@ -6,6 +6,7 @@ __all__ = [ "epilepsy_ecog", "erp_core", "eyelink", + "default_path", "fetch_aparc_sub_parcellation", "fetch_dataset", "fetch_fsaverage", @@ -70,6 +71,7 @@ from ._infant import fetch_infant_template from ._phantom.base import fetch_phantom from .utils import ( _download_all_example_data, + default_path, fetch_aparc_sub_parcellation, fetch_hcp_mmp_parcellation, has_dataset, diff --git a/mne/datasets/_fetch.py b/mne/datasets/_fetch.py index 1e38606f908..8f44459ad97 100644 --- a/mne/datasets/_fetch.py +++ b/mne/datasets/_fetch.py @@ -143,8 +143,7 @@ def fetch_dataset( if auth is not None: if len(auth) != 2: raise RuntimeError( - "auth should be a 2-tuple consisting " - "of a username and password/token." + "auth should be a 2-tuple consisting of a username and password/token." ) # processor to uncompress files diff --git a/mne/datasets/config.py b/mne/datasets/config.py index ccd4babacd9..75eff184cd1 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -92,8 +92,8 @@ phantom_kit="0.2", ucl_opm_auditory="0.2", ) -TESTING_VERSIONED = f'mne-testing-data-{RELEASES["testing"]}' -MISC_VERSIONED = f'mne-misc-data-{RELEASES["misc"]}' +TESTING_VERSIONED = f"mne-testing-data-{RELEASES['testing']}" +MISC_VERSIONED = f"mne-misc-data-{RELEASES['misc']}" # To update any other dataset besides `testing` or `misc`, upload the new # version of the data archive itself (e.g., to https://osf.io or wherever) and @@ -118,7 +118,7 @@ hash="md5:d94fe9f3abe949a507eaeb865fb84a3f", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" - f'tar.gz/{RELEASES["testing"]}' + f"tar.gz/{RELEASES['testing']}" ), # In case we ever have to resort to osf.io again... # archive_name='mne-testing-data.tar.gz', @@ -131,8 +131,7 @@ archive_name=f"{MISC_VERSIONED}.tar.gz", # 'mne-misc-data', hash="md5:e343d3a00cb49f8a2f719d14f4758afe", url=( - "https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/" - f'{RELEASES["misc"]}' + f"https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/{RELEASES['misc']}" ), folder_name="MNE-misc-data", config_key="MNE_DATASETS_MISC_PATH", diff --git a/mne/datasets/eegbci/eegbci.py b/mne/datasets/eegbci/eegbci.py index 91d78f57a03..7b6b3e8bdc8 100644 --- a/mne/datasets/eegbci/eegbci.py +++ b/mne/datasets/eegbci/eegbci.py @@ -9,7 +9,7 @@ from os import path as op from pathlib import Path -from ...utils import _url_to_local_path, logger, verbose, warn +from ...utils import _url_to_local_path, logger, verbose from ..utils import _do_path_update, _downloader_params, _get_path, _log_time_size EEGMI_URL = "https://physionet.org/files/eegmmidb/1.0.0/" @@ -94,10 +94,9 @@ def data_path(url, path=None, force_update=False, update_path=None, *, verbose=N @verbose def load_data( - subjects=None, - runs=None, + subjects, + runs, *, - subject=None, path=None, force_update=False, update_path=None, @@ -117,9 +116,6 @@ def load_data( The subjects to use. Can be in the range of 1-109 (inclusive). runs : int | list of int The runs to use (see Notes for details). - subject : int - This parameter is deprecated and will be removed in mne version 1.9. - Please use ``subjects`` instead. path : None | path-like Location of where to look for the EEGBCI data. If ``None``, the environment variable or config parameter ``MNE_DATASETS_EEGBCI_PATH`` is used. If neither @@ -170,22 +166,6 @@ def load_data( """ import pooch - # XXX: Remove this with mne 1.9 ↓↓↓ - # Also remove the subject parameter at that point. - # Also remove the `None` default for subjects and runs params at that point. - if subject is not None: - subjects = subject - warn( - "The ``subject`` parameter is deprecated and will be removed in version " - "1.9. Use the ``subjects`` parameter (note the `s`) to suppress this " - "warning.", - FutureWarning, - ) - del subject - if subjects is None or runs is None: - raise ValueError("You must pass the parameters ``subjects`` and ``runs``.") - # ↑↑↑ - t0 = time.time() if not hasattr(subjects, "__iter__"): diff --git a/mne/datasets/eegbci/tests/test_eegbci.py b/mne/datasets/eegbci/tests/test_eegbci.py index 40ef5ee030f..e9f63fee288 100644 --- a/mne/datasets/eegbci/tests/test_eegbci.py +++ b/mne/datasets/eegbci/tests/test_eegbci.py @@ -2,7 +2,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -import pytest from mne.datasets import eegbci @@ -14,14 +13,3 @@ def test_eegbci_download(tmp_path, fake_retrieve): fnames = eegbci.load_data(subj, runs=[3], path=tmp_path, update_path=False) assert len(fnames) == 1, subj assert fake_retrieve.call_count == 4 - - # XXX: remove in version 1.9 - with pytest.warns(FutureWarning, match="The ``subject``"): - fnames = eegbci.load_data( - subject=subjects, runs=[3], path=tmp_path, update_path=False - ) - assert len(fnames) == 4 - - # XXX: remove in version 1.9 - with pytest.raises(ValueError, match="You must pass the parameters"): - fnames = eegbci.load_data(path=tmp_path, update_path=False) diff --git a/mne/datasets/sleep_physionet/_utils.py b/mne/datasets/sleep_physionet/_utils.py index b97d0611591..7fbcca3a2d7 100644 --- a/mne/datasets/sleep_physionet/_utils.py +++ b/mne/datasets/sleep_physionet/_utils.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import inspect import os import os.path as op @@ -114,12 +115,16 @@ def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS): data = data.set_index(("Subject - age - sex", "Nr")) data.index.name = "subject" data.columns.names = [None, None] + kwargs = dict() + # TODO VERSION can be removed once we require Pandas 2.1 + if "future_stack" in inspect.getfullargspec(pd.DataFrame.stack).args: + kwargs["future_stack"] = True data = ( data.set_index( [("Subject - age - sex", "Age"), ("Subject - age - sex", "M1/F2")], append=True, ) - .stack(level=0) + .stack(level=0, **kwargs) .reset_index() ) diff --git a/mne/datasets/sleep_physionet/age.py b/mne/datasets/sleep_physionet/age.py index c14282ed202..b5ea1764946 100644 --- a/mne/datasets/sleep_physionet/age.py +++ b/mne/datasets/sleep_physionet/age.py @@ -122,10 +122,7 @@ def fetch_data( ) _on_missing(on_missing, msg) if 13 in subjects and 2 in recording: - msg = ( - "Requested recording 2 for subject 13, but it is not available " - "in corpus." - ) + msg = "Requested recording 2 for subject 13, but it is not available in corpus." _on_missing(on_missing, msg) fnames = [] diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index b7f651f0804..93aabc0841a 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import glob import importlib import inspect import logging @@ -92,6 +93,22 @@ def _dataset_version(path, name): return version +@verbose +def default_path(*, verbose=None): + """Get the default MNE_DATA path. + + Parameters + ---------- + %(verbose)s + + Returns + ------- + data_path : instance of Path + Path to the default MNE_DATA directory. + """ + return _get_path(None, None, None) + + def _get_path(path, key, name): """Get a dataset path.""" # 1. Input @@ -113,7 +130,8 @@ def _get_path(path, key, name): return path # 4. ~/mne_data (but use a fake home during testing so we don't # unnecessarily create ~/mne_data) - logger.info(f"Using default location ~/mne_data for {name}...") + extra = f" for {name}" if name else "" + logger.info(f"Using default location ~/mne_data{extra}...") path = Path(os.getenv("_MNE_FAKE_HOME_DIR", "~")).expanduser() / "mne_data" if not path.is_dir(): logger.info(f"Creating {path}") @@ -319,6 +337,8 @@ def _download_all_example_data(verbose=True): # # verbose=True by default so we get nice status messages. # Consider adding datasets from here to CircleCI for PR-auto-build + import openneuro + paths = dict() for kind in ( "sample testing misc spm_face somato hf_sef multimodal " @@ -356,9 +376,13 @@ def _download_all_example_data(verbose=True): # If the user has SUBJECTS_DIR, respect it, if not, set it to the EEG one # (probably on CircleCI, or otherwise advanced user) - fetch_fsaverage(None) + fetch_fsaverage(subjects_dir=None) logger.info("[done fsaverage]") + # Now also update the sample dataset path, if not already SUBJECTS_DIR + # (some tutorials make use of these files) + fetch_fsaverage(subjects_dir=paths["sample"] / "subjects") + fetch_infant_template("6mo") logger.info("[done infant_template]") @@ -371,6 +395,14 @@ def _download_all_example_data(verbose=True): limo.load_data(subject=1, update_path=True) logger.info("[done limo]") + # for ESG + ds = "ds004388" + target_dir = default_path() / ds + run_name = "sub-001/eeg/*median_run-03_eeg*.set" + if not glob.glob(str(target_dir / run_name)): + target_dir.mkdir(exist_ok=True) + openneuro.download(dataset=ds, target_dir=target_dir, include=run_name[:-4]) + @verbose def fetch_aparc_sub_parcellation(subjects_dir=None, verbose=None): diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 85ed102b514..f73cd976fe3 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -19,7 +19,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.metrics import check_scoring from sklearn.model_selection import KFold, StratifiedKFold, check_cv -from sklearn.utils import check_array, indexable +from sklearn.utils import check_array, check_X_y, indexable from ..parallel import parallel_func from ..utils import _pl, logger, verbose, warn @@ -76,16 +76,20 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator): ) def __init__(self, model=None): + # TODO: We need to set this to get our tag checking to work properly if model is None: model = LogisticRegression(solver="liblinear") - self.model = model def __sklearn_tags__(self): """Get sklearn tags.""" from sklearn.utils import get_tags # added in 1.6 - return get_tags(self.model) + # fit method below does not allow sparse data via check_data, we could + # eventually make it smarter if we had to + tags = get_tags(self.model) + tags.input_tags.sparse = False + return tags def __getattr__(self, attr): """Wrap to model for some attributes.""" @@ -118,7 +122,11 @@ def fit(self, X, y, **fit_params): self : instance of LinearModel Returns the modified instance. """ - X = check_array(X, input_name="X") + if y is not None: + X = check_array(X) + else: + X, y = check_X_y(X, y) + self.n_features_in_ = X.shape[1] if y is not None: y = check_array(y, dtype=None, ensure_2d=False, input_name="y") if y.ndim > 2: @@ -129,6 +137,7 @@ def fit(self, X, y, **fit_params): # fit the Model self.model.fit(X, y, **fit_params) + self.model_ = self.model # for better sklearn compat # Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 1261ca82055..ea38fd58ca3 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -6,7 +6,8 @@ import numpy as np from scipy.linalg import eigh -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted from .._fiff.meas_info import create_info from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh @@ -18,12 +19,12 @@ _verbose_safe_false, fill_doc, pinv, - warn, ) +from .transformer import MNETransformerMixin @fill_doc -class CSP(TransformerMixin, BaseEstimator): +class CSP(MNETransformerMixin, BaseEstimator): """M/EEG signal decomposition using the Common Spatial Patterns (CSP). This class can be used as a supervised decomposition to estimate spatial @@ -113,49 +114,44 @@ def __init__( component_order="mutual_info", ): # Init default CSP - if not isinstance(n_components, int): - raise ValueError("n_components must be an integer.") self.n_components = n_components self.rank = rank self.reg = reg - - # Init default cov_est - if not (cov_est == "concat" or cov_est == "epoch"): - raise ValueError("unknown covariance estimation method") self.cov_est = cov_est - - # Init default transform_into - self.transform_into = _check_option( - "transform_into", transform_into, ["average_power", "csp_space"] - ) - - # Init default log - if transform_into == "average_power": - if log is not None and not isinstance(log, bool): - raise ValueError( - 'log must be a boolean if transform_into == "average_power".' - ) - else: - if log is not None: - raise ValueError('log must be a None if transform_into == "csp_space".') + self.transform_into = transform_into self.log = log - - _validate_type(norm_trace, bool, "norm_trace") self.norm_trace = norm_trace self.cov_method_params = cov_method_params - self.component_order = _check_option( - "component_order", component_order, ("mutual_info", "alternate") + self.component_order = component_order + + def _validate_params(self, *, y): + _validate_type(self.n_components, int, "n_components") + if hasattr(self, "cov_est"): + _validate_type(self.cov_est, str, "cov_est") + _check_option("cov_est", self.cov_est, ("concat", "epoch")) + if hasattr(self, "norm_trace"): + _validate_type(self.norm_trace, bool, "norm_trace") + _check_option( + "transform_into", self.transform_into, ["average_power", "csp_space"] ) - - def _check_Xy(self, X, y=None): - """Check input data.""" - if not isinstance(X, np.ndarray): - raise ValueError(f"X should be of type ndarray (got {type(X)}).") - if y is not None: - if len(X) != len(y) or len(y) < 1: - raise ValueError("X and y must have the same length.") - if X.ndim < 3: - raise ValueError("X must have at least 3 dimensions.") + if self.transform_into == "average_power": + _validate_type( + self.log, + (bool, None), + "log", + extra="when transform_into is 'average_power'", + ) + else: + _validate_type( + self.log, None, "log", extra="when transform_into is 'csp_space'" + ) + _check_option( + "component_order", self.component_order, ("mutual_info", "alternate") + ) + self.classes_ = np.unique(y) + n_classes = len(self.classes_) + if n_classes < 2: + raise ValueError(f"n_classes must be >= 2, but got {n_classes} class") def fit(self, X, y): """Estimate the CSP decomposition on epochs. @@ -172,12 +168,9 @@ def fit(self, X, y): self : instance of CSP Returns the modified instance. """ - self._check_Xy(X, y) - - self._classes = np.unique(y) - n_classes = len(self._classes) - if n_classes < 2: - raise ValueError("n_classes must be >= 2.") + X, y = self._check_data(X, y=y, fit=True, return_y=True) + self._validate_params(y=y) + n_classes = len(self.classes_) if n_classes > 2 and self.component_order == "alternate": raise ValueError( "component_order='alternate' requires two classes, but data contains " @@ -226,13 +219,8 @@ def transform(self, X): If self.transform_into == 'csp_space' then returns the data in CSP space and shape is (n_epochs, n_components, n_times). """ - if not isinstance(X, np.ndarray): - raise ValueError(f"X should be of type ndarray (got {type(X)}).") - if self.filters_ is None: - raise RuntimeError( - "No filters available. Please first fit CSP decomposition." - ) - + check_is_fitted(self, "filters_") + X = self._check_data(X) pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) @@ -304,7 +292,6 @@ def plot_patterns( info, components=None, *, - average=None, ch_type=None, scalings=None, sensors=True, @@ -342,7 +329,6 @@ def plot_patterns( :func:`mne.create_info`. components : float | array of float | None The patterns to plot. If ``None``, all components will be shown. - %(average_plot_evoked_topomap)s %(ch_type_topomap)s scalings : dict | float | None The scalings of the channel types to be applied for plotting. @@ -391,9 +377,6 @@ def plot_patterns( if components is None: components = np.arange(self.n_components) - if average is not None: - warn("`average` is deprecated and will be removed in 1.10.", FutureWarning) - # set sampling frequency to have 1 component per time point info = cp.deepcopy(info) with info._unlock(): @@ -403,7 +386,6 @@ def plot_patterns( # the call plot_topomap fig = patterns.plot_topomap( times=components, - average=average, ch_type=ch_type, scalings=scalings, sensors=sensors, @@ -438,7 +420,6 @@ def plot_filters( info, components=None, *, - average=None, ch_type=None, scalings=None, sensors=True, @@ -476,7 +457,6 @@ def plot_filters( :func:`mne.create_info`. components : float | array of float | None The patterns to plot. If ``None``, all components will be shown. - %(average_plot_evoked_topomap)s %(ch_type_topomap)s scalings : dict | float | None The scalings of the channel types to be applied for plotting. @@ -525,9 +505,6 @@ def plot_filters( if components is None: components = np.arange(self.n_components) - if average is not None: - warn("`average` is deprecated and will be removed in 1.10.", FutureWarning) - # set sampling frequency to have 1 component per time point info = cp.deepcopy(info) with info._unlock(): @@ -537,7 +514,6 @@ def plot_filters( # the call plot_topomap fig = filters.plot_topomap( times=components, - average=average, ch_type=ch_type, scalings=scalings, sensors=sensors, @@ -590,7 +566,7 @@ def _compute_covariance_matrices(self, X, y): covs = [] sample_weights = [] - for ci, this_class in enumerate(self._classes): + for ci, this_class in enumerate(self.classes_): cov, weight = cov_estimator( X[y == this_class], cov_kind=f"class={this_class}", @@ -702,7 +678,7 @@ def _normalize_eigenvectors(self, eigen_vectors, covs, sample_weights): def _order_components( self, covs, sample_weights, eigen_vectors, eigen_values, component_order ): - n_classes = len(self._classes) + n_classes = len(self.classes_) if component_order == "mutual_info" and n_classes > 2: mutual_info = self._compute_mutual_info(covs, sample_weights, eigen_vectors) ix = np.argsort(mutual_info)[::-1] @@ -902,10 +878,8 @@ def fit(self, X, y): self : instance of SPoC Returns the modified instance. """ - self._check_Xy(X, y) - - if len(np.unique(y)) < 2: - raise ValueError("y must have at least two distinct values.") + X, y = self._check_data(X, y=y, fit=True, return_y=True) + self._validate_params(y=y) # The following code is directly copied from pyRiemann diff --git a/mne/decoding/ems.py b/mne/decoding/ems.py index 911b25e6692..5c7557798ef 100644 --- a/mne/decoding/ems.py +++ b/mne/decoding/ems.py @@ -5,15 +5,16 @@ from collections import Counter import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from .._fiff.pick import _picks_to_idx, pick_info, pick_types from ..parallel import parallel_func from ..utils import logger, verbose from .base import _set_cv +from .transformer import MNETransformerMixin -class EMS(TransformerMixin, BaseEstimator): +class EMS(MNETransformerMixin, BaseEstimator): """Transformer to compute event-matched spatial filters. This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire @@ -37,6 +38,16 @@ class EMS(TransformerMixin, BaseEstimator): .. footbibliography:: """ + def __sklearn_tags__(self): + """Return sklearn tags.""" + from sklearn.utils import ClassifierTags + + tags = super().__sklearn_tags__() + if tags.classifier_tags is None: + tags.classifier_tags = ClassifierTags() + tags.classifier_tags.multi_class = False + return tags + def __repr__(self): # noqa: D105 if hasattr(self, "filters_"): return ( @@ -64,11 +75,12 @@ def fit(self, X, y): self : instance of EMS Returns self. """ - classes = np.unique(y) - if len(classes) != 2: + X, y = self._check_data(X, y=y, fit=True, return_y=True) + classes, y = np.unique(y, return_inverse=True) + if len(classes) > 2: raise ValueError("EMS only works for binary classification.") self.classes_ = classes - filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0) + filters = X[y == 0].mean(0) - X[y == 1].mean(0) filters /= np.linalg.norm(filters, axis=0)[None, :] self.filters_ = filters return self @@ -86,13 +98,14 @@ def transform(self, X): X : array, shape (n_epochs, n_times) The input data transformed by the spatial filters. """ + X = self._check_data(X) Xt = np.sum(X * self.filters_, axis=1) return Xt @verbose def compute_ems( - epochs, conditions=None, picks=None, n_jobs=None, cv=None, verbose=None + epochs, conditions=None, picks=None, n_jobs=None, cv=None, *, verbose=None ): """Compute event-matched spatial filter on epochs. @@ -188,7 +201,7 @@ def compute_ems( data[:, this_picks] /= np.std(data[:, this_picks]) # Setup cross-validation. Need to use _set_cv to deal with sklearn - # deprecation of cv objects. + # changes in cv object handling. y = epochs.events[:, 2] _, cv_splits = _set_cv(cv, "classifier", X=y, y=y) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index e3059a3e959..8bd96781185 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -5,18 +5,25 @@ import logging import numpy as np -from sklearn.base import BaseEstimator, MetaEstimatorMixin, TransformerMixin, clone +from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone from sklearn.metrics import check_scoring from sklearn.preprocessing import LabelEncoder -from sklearn.utils import check_array +from sklearn.utils.validation import check_is_fitted from ..parallel import parallel_func -from ..utils import ProgressBar, _parse_verbose, array_split_idx, fill_doc, verbose +from ..utils import ( + ProgressBar, + _parse_verbose, + _verbose_safe_false, + array_split_idx, + fill_doc, +) from .base import _check_estimator +from .transformer import MNETransformerMixin @fill_doc -class SlidingEstimator(MetaEstimatorMixin, TransformerMixin, BaseEstimator): +class SlidingEstimator(MetaEstimatorMixin, MNETransformerMixin, BaseEstimator): """Search Light. Fit, predict and score a series of models to each subset of the dataset @@ -38,7 +45,6 @@ class SlidingEstimator(MetaEstimatorMixin, TransformerMixin, BaseEstimator): List of fitted scikit-learn estimators (one per task). """ - @verbose def __init__( self, base_estimator, @@ -49,7 +55,6 @@ def __init__( allow_2d=False, verbose=None, ): - _check_estimator(base_estimator) self.base_estimator = base_estimator self.n_jobs = n_jobs self.scoring = scoring @@ -102,9 +107,13 @@ def fit(self, X, y, **fit_params): self : object Return self. """ - X = self._check_Xy(X, y) + _check_estimator(self.base_estimator) + X, _ = self._check_Xy(X, y, fit=True) parallel, p_func, n_jobs = parallel_func( - _sl_fit, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _sl_fit, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) self.estimators_ = list() self.fit_params_ = fit_params @@ -153,14 +162,19 @@ def fit_transform(self, X, y, **fit_params): def _transform(self, X, method): """Aux. function to make parallel predictions/transformation.""" - X = self._check_Xy(X) + X, is_nd = self._check_Xy(X) + orig_method = method + check_is_fitted(self) method = _check_method(self.base_estimator, method) if X.shape[-1] != len(self.estimators_): raise ValueError("The number of estimators does not match X.shape[-1]") # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _sl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _sl_transform, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) X_splits = np.array_split(X, n_jobs, axis=-1) @@ -174,6 +188,10 @@ def _transform(self, X, method): ) y_pred = np.concatenate(y_pred, axis=1) + if orig_method == "transform": + y_pred = y_pred.astype(X.dtype) + if orig_method == "predict_proba" and not is_nd: + y_pred = y_pred[:, 0, :] return y_pred def transform(self, X): @@ -196,7 +214,7 @@ def transform(self, X): Xt : array, shape (n_samples, n_estimators) The transformed values generated by each estimator. """ # noqa: E501 - return self._transform(X, "transform").astype(X.dtype) + return self._transform(X, "transform") def predict(self, X): """Predict each data slice/task with a series of independent estimators. @@ -265,15 +283,12 @@ def decision_function(self, X): """ # noqa: E501 return self._transform(X, "decision_function") - def _check_Xy(self, X, y=None): + def _check_Xy(self, X, y=None, fit=False): """Aux. function to check input data.""" # Once we require sklearn 1.1+ we should do something like: - X = check_array(X, ensure_2d=False, allow_nd=True, input_name="X") - if y is not None: - y = check_array(y, dtype=None, ensure_2d=False, input_name="y") - if len(X) != len(y) or len(y) < 1: - raise ValueError("X and y must have the same length.") - if X.ndim < 3: + X = self._check_data(X, y=y, atleast_3d=False, fit=fit) + is_nd = X.ndim >= 3 + if not is_nd: err = None if not self.allow_2d: err = 3 @@ -282,7 +297,7 @@ def _check_Xy(self, X, y=None): if err: raise ValueError(f"X must have at least {err} dimensions.") X = X[..., np.newaxis] - return X + return X, is_nd def score(self, X, y): """Score each estimator on each task. @@ -307,7 +322,7 @@ def score(self, X, y): score : array, shape (n_samples, n_estimators) Score for each estimator/task. """ # noqa: E501 - X = self._check_Xy(X, y) + X, _ = self._check_Xy(X, y) if X.shape[-1] != len(self.estimators_): raise ValueError("The number of estimators does not match X.shape[-1]") @@ -317,7 +332,10 @@ def score(self, X, y): # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _sl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _sl_score, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) X_splits = np.array_split(X, n_jobs, axis=-1) est_splits = np.array_split(self.estimators_, n_jobs) @@ -483,11 +501,16 @@ def __repr__(self): # noqa: D105 def _transform(self, X, method): """Aux. function to make parallel predictions/transformation.""" - X = self._check_Xy(X) + X, is_nd = self._check_Xy(X) + check_is_fitted(self) + orig_method = method method = _check_method(self.base_estimator, method) parallel, p_func, n_jobs = parallel_func( - _gl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _gl_transform, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) context = _create_progressbar_context(self, X, "Transforming") @@ -500,6 +523,10 @@ def _transform(self, X, method): ) y_pred = np.concatenate(y_pred, axis=2) + if orig_method == "transform": + y_pred = y_pred.astype(X.dtype) + if orig_method == "predict_proba" and not is_nd: + y_pred = y_pred[:, 0, 0, :] return y_pred def transform(self, X): @@ -518,6 +545,7 @@ def transform(self, X): Xt : array, shape (n_samples, n_estimators, n_slices) The transformed values generated by each estimator. """ + check_is_fitted(self) return self._transform(X, "transform") def predict(self, X): @@ -603,11 +631,14 @@ def score(self, X, y): score : array, shape (n_samples, n_estimators, n_slices) Score for each estimator / data slice couple. """ # noqa: E501 - X = self._check_Xy(X, y) + X, _ = self._check_Xy(X, y) # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _gl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _gl_score, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) scoring = check_scoring(self.base_estimator, self.scoring) y = _fix_auc(scoring, y) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 8bc0036d315..111ded9f274 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -4,8 +4,10 @@ import numpy as np from scipy.linalg import eigh -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted +from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx from ..cov import Covariance, _regularized_covariance from ..defaults import _handle_default @@ -13,17 +15,17 @@ from ..rank import compute_rank from ..time_frequency import psd_array_welch from ..utils import ( - _check_option, _time_mask, _validate_type, _verbose_safe_false, fill_doc, logger, ) +from .transformer import MNETransformerMixin @fill_doc -class SSD(TransformerMixin, BaseEstimator): +class SSD(MNETransformerMixin, BaseEstimator): """ Signal decomposition using the Spatio-Spectral Decomposition (SSD). @@ -64,7 +66,7 @@ class SSD(TransformerMixin, BaseEstimator): If sort_by_spectral_ratio is set to True, then the SSD sources will be sorted according to their spectral ratio which is calculated based on :func:`mne.time_frequency.psd_array_welch`. The n_fft parameter sets the - length of FFT used. + length of FFT used. The default (None) will use 1 second of data. See :func:`mne.time_frequency.psd_array_welch` for more information. cov_method_params : dict | None (default None) As in :class:`mne.decoding.SPoC` @@ -104,7 +106,25 @@ def __init__( rank=None, ): """Initialize instance.""" - dicts = {"signal": filt_params_signal, "noise": filt_params_noise} + self.info = info + self.filt_params_signal = filt_params_signal + self.filt_params_noise = filt_params_noise + self.reg = reg + self.n_components = n_components + self.picks = picks + self.sort_by_spectral_ratio = sort_by_spectral_ratio + self.return_filtered = return_filtered + self.n_fft = n_fft + self.cov_method_params = cov_method_params + self.rank = rank + + def _validate_params(self, X): + if isinstance(self.info, float): # special case, mostly for testing + self.sfreq_ = self.info + else: + _validate_type(self.info, Info, "info") + self.sfreq_ = self.info["sfreq"] + dicts = {"signal": self.filt_params_signal, "noise": self.filt_params_noise} for param, dd in [("l", 0), ("h", 0), ("l", 1), ("h", 1)]: key = ("signal", "noise")[dd] if param + "_freq" not in dicts[key]: @@ -116,48 +136,47 @@ def __init__( _validate_type(val, ("numeric",), f"{key} {param}_freq") # check freq bands if ( - filt_params_noise["l_freq"] > filt_params_signal["l_freq"] - or filt_params_signal["h_freq"] > filt_params_noise["h_freq"] + self.filt_params_noise["l_freq"] > self.filt_params_signal["l_freq"] + or self.filt_params_signal["h_freq"] > self.filt_params_noise["h_freq"] ): raise ValueError( "Wrongly specified frequency bands!\n" "The signal band-pass must be within the noise " "band-pass!" ) - self.picks = picks - del picks - self.info = info - self.freqs_signal = (filt_params_signal["l_freq"], filt_params_signal["h_freq"]) - self.freqs_noise = (filt_params_noise["l_freq"], filt_params_noise["h_freq"]) - self.filt_params_signal = filt_params_signal - self.filt_params_noise = filt_params_noise - # check if boolean - if not isinstance(sort_by_spectral_ratio, (bool)): - raise ValueError("sort_by_spectral_ratio must be boolean") - self.sort_by_spectral_ratio = sort_by_spectral_ratio - if n_fft is None: - self.n_fft = int(self.info["sfreq"]) - else: - self.n_fft = int(n_fft) - # check if boolean - if not isinstance(return_filtered, (bool)): - raise ValueError("return_filtered must be boolean") - self.return_filtered = return_filtered - self.reg = reg - self.n_components = n_components - self.rank = rank - self.cov_method_params = cov_method_params + self.freqs_signal_ = ( + self.filt_params_signal["l_freq"], + self.filt_params_signal["h_freq"], + ) + self.freqs_noise_ = ( + self.filt_params_noise["l_freq"], + self.filt_params_noise["h_freq"], + ) + _validate_type(self.sort_by_spectral_ratio, (bool,), "sort_by_spectral_ratio") + _validate_type(self.n_fft, ("numeric", None), "n_fft") + self.n_fft_ = min( + int(self.n_fft if self.n_fft is not None else self.sfreq_), + X.shape[-1], + ) + _validate_type(self.return_filtered, (bool,), "return_filtered") + if isinstance(self.info, Info): + ch_types = self.info.get_channel_types(picks=self.picks, unique=True) + if len(ch_types) > 1: + raise ValueError( + "At this point SSD only supports fitting " + f"single channel types. Your info has {len(ch_types)} types." + ) - def _check_X(self, X): + def _check_X(self, X, *, y=None, fit=False): """Check input data.""" - _validate_type(X, np.ndarray, "X") - _check_option("X.ndim", X.ndim, (2, 3)) + X = self._check_data(X, y=y, fit=fit, atleast_3d=False) n_chan = X.shape[-2] - if n_chan != self.info["nchan"]: + if isinstance(self.info, Info) and n_chan != self.info["nchan"]: raise ValueError( "Info must match the input data." f"Found {n_chan} channels but expected {self.info['nchan']}." ) + return X def fit(self, X, y=None): """Estimate the SSD decomposition on raw or epoched data. @@ -176,18 +195,17 @@ def fit(self, X, y=None): self : instance of SSD Returns the modified instance. """ - ch_types = self.info.get_channel_types(picks=self.picks, unique=True) - if len(ch_types) > 1: - raise ValueError( - "At this point SSD only supports fitting " - f"single channel types. Your info has {len(ch_types)} types." - ) - self.picks_ = _picks_to_idx(self.info, self.picks, none="data", exclude="bads") - self._check_X(X) + X = self._check_X(X, y=y, fit=True) + self._validate_params(X) + if isinstance(self.info, Info): + info = self.info + else: + info = create_info(X.shape[-2], self.sfreq_, ch_types="eeg") + self.picks_ = _picks_to_idx(info, self.picks, none="data", exclude="bads") X_aux = X[..., self.picks_, :] - X_signal = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal) - X_noise = filter_data(X_aux, self.info["sfreq"], **self.filt_params_noise) + X_signal = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) + X_noise = filter_data(X_aux, self.sfreq_, **self.filt_params_noise) X_noise -= X_signal if X.ndim == 3: X_signal = np.hstack(X_signal) @@ -199,19 +217,19 @@ def fit(self, X, y=None): reg=self.reg, method_params=self.cov_method_params, rank="full", - info=self.info, + info=info, ) cov_noise = _regularized_covariance( X_noise, reg=self.reg, method_params=self.cov_method_params, rank="full", - info=self.info, + info=info, ) # project cov to rank subspace cov_signal, cov_noise, rank_proj = _dimensionality_reduction( - cov_signal, cov_noise, self.info, self.rank + cov_signal, cov_noise, info, self.rank ) eigvals_, eigvects_ = eigh(cov_signal, cov_noise) @@ -226,10 +244,10 @@ def fit(self, X, y=None): # than the initial ordering. This ordering should be also learned when # fitting. X_ssd = self.filters_.T @ X[..., self.picks_, :] - sorter_spec = Ellipsis + sorter_spec = slice(None) if self.sort_by_spectral_ratio: _, sorter_spec = self.get_spectral_ratio(ssd_sources=X_ssd) - self.sorter_spec = sorter_spec + self.sorter_spec_ = sorter_spec logger.info("Done.") return self @@ -248,17 +266,13 @@ def transform(self, X): X_ssd : array, shape ([n_epochs, ]n_components, n_times) The processed data. """ - self._check_X(X) - if self.filters_ is None: - raise RuntimeError("No filters available. Please first call fit") + check_is_fitted(self, "filters_") + X = self._check_X(X) if self.return_filtered: X_aux = X[..., self.picks_, :] - X = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal) + X = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) X_ssd = self.filters_.T @ X[..., self.picks_, :] - if X.ndim == 2: - X_ssd = X_ssd[self.sorter_spec][: self.n_components] - else: - X_ssd = X_ssd[:, self.sorter_spec, :][:, : self.n_components, :] + X_ssd = X_ssd[..., self.sorter_spec_, :][..., : self.n_components, :] return X_ssd def fit_transform(self, X, y=None, **fit_params): @@ -308,11 +322,9 @@ def get_spectral_ratio(self, ssd_sources): ---------- .. footbibliography:: """ - psd, freqs = psd_array_welch( - ssd_sources, sfreq=self.info["sfreq"], n_fft=self.n_fft - ) - sig_idx = _time_mask(freqs, *self.freqs_signal) - noise_idx = _time_mask(freqs, *self.freqs_noise) + psd, freqs = psd_array_welch(ssd_sources, sfreq=self.sfreq_, n_fft=self.n_fft_) + sig_idx = _time_mask(freqs, *self.freqs_signal_) + noise_idx = _time_mask(freqs, *self.freqs_noise_) if psd.ndim == 3: mean_sig = psd[:, :, sig_idx].mean(axis=2).mean(axis=0) mean_noise = psd[:, :, noise_idx].mean(axis=2).mean(axis=0) @@ -352,7 +364,7 @@ def apply(self, X): The processed data. """ X_ssd = self.transform(X) - pick_patterns = self.patterns_[self.sorter_spec][: self.n_components].T + pick_patterns = self.patterns_[self.sorter_spec_][: self.n_components].T X = pick_patterns @ X_ssd return X diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index 4ec6ed4d281..504e309d53c 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -86,6 +86,8 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3): X = Y.dot(A.T) X += np.random.randn(n_samples, n_features) # add noise X += np.random.rand(n_features) # Put an offset + if n_targets == 1: + Y = Y[:, 0] return X, Y, A @@ -95,7 +97,7 @@ def test_get_coef(): """Test getting linear coefficients (filters/patterns) from estimators.""" lm_classification = LinearModel() assert hasattr(lm_classification, "__sklearn_tags__") - print(lm_classification.__sklearn_tags__) + print(lm_classification.__sklearn_tags__()) assert is_classifier(lm_classification.model) assert is_classifier(lm_classification) assert not is_regressor(lm_classification.model) @@ -273,7 +275,12 @@ def test_get_coef_multiclass(n_features, n_targets): """Test get_coef on multiclass problems.""" # Check patterns with more than 1 regressor X, Y, A = _make_data(n_samples=30000, n_features=n_features, n_targets=n_targets) - lm = LinearModel(LinearRegression()).fit(X, Y) + lm = LinearModel(LinearRegression()) + assert not hasattr(lm, "model_") + lm.fit(X, Y) + # TODO: modifying non-underscored `model` is a sklearn no-no, maybe should be a + # metaestimator? + assert lm.model is lm.model_ assert_array_equal(lm.filters_.shape, lm.patterns_.shape) if n_targets == 1: want_shape = (n_features,) @@ -320,7 +327,7 @@ def test_get_coef_multiclass(n_features, n_targets): ], ) # TODO: Need to fix this properly in LinearModel -@pytest.mark.filterwarnings("ignore:'multi_class' was deprecated in.*:FutureWarning") +@pytest.mark.filterwarnings("ignore:'multi_class' was depr.*:FutureWarning") @pytest.mark.filterwarnings("ignore:lbfgs failed to converge.*:") def test_get_coef_multiclass_full(n_classes, n_channels, n_times): """Test a full example with pattern extraction.""" @@ -473,9 +480,8 @@ def test_cross_val_multiscore(): def test_sklearn_compliance(estimator, check): """Test LinearModel compliance with sklearn.""" ignores = ( - "check_n_features_in", # maybe we should add this someday? - "check_estimator_sparse_data", # we densify "check_estimators_overwrite_params", # self.model changes! + "check_dont_overwrite_parameters", "check_parameters_default_constructible", ) if any(ignore in str(check) for ignore in ignores): diff --git a/mne/decoding/tests/test_csp.py b/mne/decoding/tests/test_csp.py index 7a1a83feeaf..e754b6952f9 100644 --- a/mne/decoding/tests/test_csp.py +++ b/mne/decoding/tests/test_csp.py @@ -19,6 +19,7 @@ from sklearn.model_selection import StratifiedKFold, cross_val_score from sklearn.pipeline import Pipeline, make_pipeline from sklearn.svm import SVC +from sklearn.utils.estimator_checks import parametrize_with_checks from mne import Epochs, compute_proj_raw, io, pick_types, read_events from mne.decoding import CSP, LinearModel, Scaler, SPoC, get_coef @@ -139,18 +140,22 @@ def test_csp(): y = epochs.events[:, -1] # Init - pytest.raises(ValueError, CSP, n_components="foo", norm_trace=False) + csp = CSP(n_components="foo") + with pytest.raises(TypeError, match="must be an instance"): + csp.fit(epochs_data, y) for reg in ["foo", -0.1, 1.1]: csp = CSP(reg=reg, norm_trace=False) pytest.raises(ValueError, csp.fit, epochs_data, epochs.events[:, -1]) for reg in ["oas", "ledoit_wolf", 0, 0.5, 1.0]: CSP(reg=reg, norm_trace=False) - for cov_est in ["foo", None]: - pytest.raises(ValueError, CSP, cov_est=cov_est, norm_trace=False) + csp = CSP(cov_est="foo", norm_trace=False) + with pytest.raises(ValueError, match="Invalid value"): + csp.fit(epochs_data, y) + csp = CSP(norm_trace="foo") with pytest.raises(TypeError, match="instance of bool"): - CSP(norm_trace="foo") + csp.fit(epochs_data, y) for cov_est in ["concat", "epoch"]: - CSP(cov_est=cov_est, norm_trace=False) + CSP(cov_est=cov_est, norm_trace=False).fit(epochs_data, y) n_components = 3 # Fit @@ -171,8 +176,8 @@ def test_csp(): # Test data exception pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) - pytest.raises(ValueError, csp.fit, epochs, y) - pytest.raises(ValueError, csp.transform, epochs) + pytest.raises(ValueError, csp.fit, "foo", y) + pytest.raises(ValueError, csp.transform, "foo") # Test plots epochs.pick(picks="mag") @@ -200,7 +205,7 @@ def test_csp(): for cov_est in ["concat", "epoch"]: csp = CSP(n_components=n_components, cov_est=cov_est, norm_trace=False) csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) - assert_equal(len(csp._classes), 4) + assert_equal(len(csp.classes_), 4) assert_array_equal(csp.filters_.shape, [n_channels, n_channels]) assert_array_equal(csp.patterns_.shape, [n_channels, n_channels]) @@ -220,15 +225,17 @@ def test_csp(): # Different normalization return different transform assert np.sum((X_trans["True"] - X_trans["False"]) ** 2) > 1.0 # Check wrong inputs - pytest.raises(ValueError, CSP, transform_into="average_power", log="foo") + csp = CSP(transform_into="average_power", log="foo") + with pytest.raises(TypeError, match="must be an instance of bool"): + csp.fit(epochs_data, epochs.events[:, 2]) # Test csp space transform csp = CSP(transform_into="csp_space", norm_trace=False) assert csp.transform_into == "csp_space" for log in ("foo", True, False): - pytest.raises( - ValueError, CSP, transform_into="csp_space", log=log, norm_trace=False - ) + csp = CSP(transform_into="csp_space", log=log, norm_trace=False) + with pytest.raises(TypeError, match="must be an instance"): + csp.fit(epochs_data, epochs.events[:, 2]) n_components = 2 csp = CSP(n_components=n_components, transform_into="csp_space", norm_trace=False) Xt = csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) @@ -343,8 +350,8 @@ def test_regularized_csp(ch_type, rank, reg): # test init exception pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) - pytest.raises(ValueError, csp.fit, epochs, y) - pytest.raises(ValueError, csp.transform, epochs) + pytest.raises(ValueError, csp.fit, "foo", y) + pytest.raises(ValueError, csp.transform, "foo") csp.n_components = n_components sources = csp.transform(epochs_data) @@ -465,7 +472,9 @@ def test_csp_component_ordering(): """Test that CSP component ordering works as expected.""" x, y = deterministic_toy_data(["class_a", "class_b"]) - pytest.raises(ValueError, CSP, component_order="invalid") + csp = CSP(component_order="invalid") + with pytest.raises(ValueError, match="Invalid value"): + csp.fit(x, y) # component_order='alternate' only works with two classes csp = CSP(component_order="alternate") @@ -480,3 +489,10 @@ def test_csp_component_ordering(): # p_alt arranges them to [0.8, 0.06, 0.5, 0.1] # p_mut arranges them to [0.06, 0.1, 0.8, 0.5] assert_array_almost_equal(p_alt, p_mut[[2, 0, 3, 1]]) + + +@pytest.mark.filterwarnings("ignore:.*Only one sample available.*") +@parametrize_with_checks([CSP(), SPoC()]) +def test_sklearn_compliance(estimator, check): + """Test compliance with sklearn.""" + check(estimator) diff --git a/mne/decoding/tests/test_ems.py b/mne/decoding/tests/test_ems.py index 10774c0681a..dc54303a541 100644 --- a/mne/decoding/tests/test_ems.py +++ b/mne/decoding/tests/test_ems.py @@ -11,6 +11,7 @@ pytest.importorskip("sklearn") from sklearn.model_selection import StratifiedKFold +from sklearn.utils.estimator_checks import parametrize_with_checks from mne import Epochs, io, pick_types, read_events from mne.decoding import EMS, compute_ems @@ -91,3 +92,9 @@ def test_ems(): assert_equal(ems.__repr__(), "") assert_array_almost_equal(filters, np.mean(coefs, axis=0)) assert_array_almost_equal(surrogates, np.vstack(Xt)) + + +@parametrize_with_checks([EMS()]) +def test_sklearn_compliance(estimator, check): + """Test compliance with sklearn.""" + check(estimator) diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 7cb3a66dd81..e7abfd9209e 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -41,7 +41,7 @@ def make_data(): return X, y -def test_search_light(): +def test_search_light_basic(): """Test SlidingEstimator.""" # https://github.com/scikit-learn/scikit-learn/issues/27711 if platform.system() == "Windows" and check_version("numpy", "2.0.0.dev0"): @@ -52,7 +52,9 @@ def test_search_light(): X, y = make_data() n_epochs, _, n_time = X.shape # init - pytest.raises(ValueError, SlidingEstimator, "foo") + sl = SlidingEstimator("foo") + with pytest.raises(ValueError, match="must be"): + sl.fit(X, y) sl = SlidingEstimator(Ridge()) assert not is_classifier(sl) sl = SlidingEstimator(LogisticRegression(solver="liblinear")) @@ -69,7 +71,8 @@ def test_search_light(): # transforms pytest.raises(ValueError, sl.predict, X[:, :, :2]) y_trans = sl.transform(X) - assert X.dtype == y_trans.dtype == np.dtype(float) + assert X.dtype == float + assert y_trans.dtype == float y_pred = sl.predict(X) assert y_pred.dtype == np.dtype(int) assert_array_equal(y_pred.shape, [n_epochs, n_time]) @@ -344,22 +347,19 @@ def predict_proba(self, X): @pytest.mark.slowtest -@parametrize_with_checks([SlidingEstimator(LogisticRegression(), allow_2d=True)]) +@parametrize_with_checks( + [ + SlidingEstimator(LogisticRegression(), allow_2d=True), + GeneralizingEstimator(LogisticRegression(), allow_2d=True), + ] +) def test_sklearn_compliance(estimator, check): """Test LinearModel compliance with sklearn.""" ignores = ( - "check_estimator_sparse_data", # we densify - "check_classifiers_one_label_sample_weights", # don't handle singleton - "check_classifiers_classes", # dim mismatch + # TODO: we don't handle singleton right (probably) + "check_classifiers_one_label_sample_weights", + "check_classifiers_classes", "check_classifiers_train", - "check_decision_proba_consistency", - "check_parameters_default_constructible", - # Should probably fix these? - "check_estimators_unfitted", - "check_transformer_data_not_an_array", - "check_n_features_in", - "check_fit2d_predict1d", - "check_do_not_raise_errors_in_init_or_set_params", ) if any(ignore in str(check) for ignore in ignores): return diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 198feeb6532..b6cdfc472c3 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -11,6 +11,7 @@ pytest.importorskip("sklearn") from sklearn.pipeline import Pipeline +from sklearn.utils.estimator_checks import parametrize_with_checks from mne import create_info, io from mne.decoding import CSP @@ -101,8 +102,9 @@ def test_ssd(): l_trans_bandwidth=1, h_trans_bandwidth=1, ) + ssd = SSD(info, filt_params_signal, filt_params_noise) with pytest.raises(TypeError, match="must be an instance "): - ssd = SSD(info, filt_params_signal, filt_params_noise) + ssd.fit(X) # Wrongly specified noise band freq = 2 @@ -115,14 +117,16 @@ def test_ssd(): l_trans_bandwidth=1, h_trans_bandwidth=1, ) + ssd = SSD(info, filt_params_signal, filt_params_noise) with pytest.raises(ValueError, match="Wrongly specified "): - ssd = SSD(info, filt_params_signal, filt_params_noise) + ssd.fit(X) # filt param no dict filt_params_signal = freqs_sig filt_params_noise = freqs_noise + ssd = SSD(info, filt_params_signal, filt_params_noise) with pytest.raises(ValueError, match="must be defined"): - ssd = SSD(info, filt_params_signal, filt_params_noise) + ssd.fit(X) # Data type filt_params_signal = dict( @@ -140,15 +144,18 @@ def test_ssd(): ssd = SSD(info, filt_params_signal, filt_params_noise) raw = io.RawArray(X, info) - pytest.raises(TypeError, ssd.fit, raw) + with pytest.raises(ValueError): + ssd.fit(raw) # check non-boolean return_filtered - with pytest.raises(ValueError, match="return_filtered"): - ssd = SSD(info, filt_params_signal, filt_params_noise, return_filtered=0) + ssd = SSD(info, filt_params_signal, filt_params_noise, return_filtered=0) + with pytest.raises(TypeError, match="return_filtered"): + ssd.fit(X) # check non-boolean sort_by_spectral_ratio - with pytest.raises(ValueError, match="sort_by_spectral_ratio"): - ssd = SSD(info, filt_params_signal, filt_params_noise, sort_by_spectral_ratio=0) + ssd = SSD(info, filt_params_signal, filt_params_noise, sort_by_spectral_ratio=0) + with pytest.raises(TypeError, match="sort_by_spectral_ratio"): + ssd.fit(X) # More than 1 channel type ch_types = np.reshape([["mag"] * 10, ["eeg"] * 10], n_channels) @@ -161,7 +168,8 @@ def test_ssd(): # Number of channels info_3 = create_info(ch_names=n_channels + 1, sfreq=sf, ch_types="eeg") ssd = SSD(info_3, filt_params_signal, filt_params_noise) - pytest.raises(ValueError, ssd.fit, X) + with pytest.raises(ValueError, match="channels but expected"): + ssd.fit(X) # Fit n_components = 10 @@ -381,7 +389,7 @@ def test_sorting(): ssd.fit(Xtr) # check sorters - sorter_in = ssd.sorter_spec + sorter_in = ssd.sorter_spec_ ssd = SSD( info, filt_params_signal, @@ -474,5 +482,31 @@ def test_non_full_rank_data(): ssd = SSD(info, filt_params_signal, filt_params_noise) if sys.platform == "darwin": - pytest.skip("Unknown linalg bug (Accelerate?)") + pytest.xfail("Unknown linalg bug (Accelerate?)") ssd.fit(X) + + +@pytest.mark.filterwarnings("ignore:.*invalid value encountered in divide.*") +@pytest.mark.filterwarnings("ignore:.*is longer than.*") +@parametrize_with_checks( + [ + SSD( + 100.0, + dict(l_freq=0.0, h_freq=30.0), + dict(l_freq=0.0, h_freq=40.0), + ) + ] +) +def test_sklearn_compliance(estimator, check): + """Test LinearModel compliance with sklearn.""" + ignores = ( + "check_methods_sample_order_invariance", + # Shape stuff + "check_fit_idempotent", + "check_methods_subset_invariance", + "check_transformer_general", + "check_transformer_data_not_an_array", + ) + if any(ignore in str(check) for ignore in ignores): + return + check(estimator) diff --git a/mne/decoding/tests/test_time_frequency.py b/mne/decoding/tests/test_time_frequency.py index 37e7d7d8dc2..638cebda21e 100644 --- a/mne/decoding/tests/test_time_frequency.py +++ b/mne/decoding/tests/test_time_frequency.py @@ -10,18 +10,23 @@ pytest.importorskip("sklearn") from sklearn.base import clone +from sklearn.utils.estimator_checks import parametrize_with_checks from mne.decoding.time_frequency import TimeFrequency -def test_timefrequency(): +def test_timefrequency_basic(): """Test TimeFrequency.""" # Init n_freqs = 3 freqs = [20, 21, 22] tf = TimeFrequency(freqs, sfreq=100) + n_epochs, n_chans, n_times = 10, 2, 100 + X = np.random.rand(n_epochs, n_chans, n_times) for output in ["avg_power", "foo", None]: - pytest.raises(ValueError, TimeFrequency, freqs, output=output) + tf = TimeFrequency(freqs, output=output) + with pytest.raises(ValueError, match="Invalid value"): + tf.fit(X) tf = clone(tf) # Clone estimator @@ -30,9 +35,9 @@ def test_timefrequency(): clone(tf) # Fit - n_epochs, n_chans, n_times = 10, 2, 100 - X = np.random.rand(n_epochs, n_chans, n_times) + assert not hasattr(tf, "fitted_") tf.fit(X, None) + assert tf.fitted_ # Transform tf = TimeFrequency(freqs, sfreq=100) @@ -41,9 +46,15 @@ def test_timefrequency(): Xt = tf.transform(X) assert_array_equal(Xt.shape, [n_epochs, n_chans, n_freqs, n_times]) # 2-D X - Xt = tf.transform(X[:, 0, :]) + Xt = tf.fit_transform(X[:, 0, :]) assert_array_equal(Xt.shape, [n_epochs, n_freqs, n_times]) # 3-D with decim tf = TimeFrequency(freqs, sfreq=100, decim=2) - Xt = tf.transform(X) + Xt = tf.fit_transform(X) assert_array_equal(Xt.shape, [n_epochs, n_chans, n_freqs, n_times // 2]) + + +@parametrize_with_checks([TimeFrequency([300, 400], 1000.0, n_cycles=0.25)]) +def test_sklearn_compliance(estimator, check): + """Test LinearModel compliance with sklearn.""" + check(estimator) diff --git a/mne/decoding/tests/test_transformer.py b/mne/decoding/tests/test_transformer.py index 8dcc3ad74c7..a8afe209d96 100644 --- a/mne/decoding/tests/test_transformer.py +++ b/mne/decoding/tests/test_transformer.py @@ -17,10 +17,14 @@ from sklearn.decomposition import PCA from sklearn.kernel_ridge import KernelRidge +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.utils.estimator_checks import parametrize_with_checks -from mne import Epochs, io, pick_types, read_events +from mne import Epochs, EpochsArray, create_info, io, pick_types, read_events from mne.decoding import ( FilterEstimator, + LinearModel, PSDEstimator, Scaler, TemporalFilter, @@ -36,6 +40,7 @@ data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_dir / "test_raw.fif" event_name = data_dir / "test-eve.fif" +info = create_info(2, 1000.0, "eeg") @pytest.mark.parametrize( @@ -101,9 +106,11 @@ def test_scaler(info, method): assert_array_almost_equal(epochs_data, Xi) # Test init exception - pytest.raises(ValueError, Scaler, None, None) - pytest.raises(TypeError, scaler.fit, epochs, y) - pytest.raises(TypeError, scaler.transform, epochs) + x = Scaler(None, None) + with pytest.raises(ValueError): + x.fit(epochs_data, y) + pytest.raises(ValueError, scaler.fit, "foo", y) + pytest.raises(ValueError, scaler.transform, "foo") epochs_bad = Epochs( raw, events, @@ -164,8 +171,8 @@ def test_filterestimator(): X = filt.fit_transform(epochs_data, y) # Test init exception - pytest.raises(ValueError, filt.fit, epochs, y) - pytest.raises(ValueError, filt.transform, epochs) + pytest.raises(ValueError, filt.fit, "foo", y) + pytest.raises(ValueError, filt.transform, "foo") def test_psdestimator(): @@ -182,14 +189,18 @@ def test_psdestimator(): epochs_data = epochs.get_data(copy=False) psd = PSDEstimator(2 * np.pi, 0, np.inf) y = epochs.events[:, -1] + assert not hasattr(psd, "fitted_") X = psd.fit_transform(epochs_data, y) + assert psd.fitted_ assert X.shape[0] == epochs_data.shape[0] assert_array_equal(psd.fit(epochs_data, y).transform(epochs_data), X) # Test init exception - pytest.raises(ValueError, psd.fit, epochs, y) - pytest.raises(ValueError, psd.transform, epochs) + with pytest.raises(ValueError): + psd.fit("foo", y) + with pytest.raises(ValueError): + psd.transform("foo") def test_vectorizer(): @@ -210,9 +221,16 @@ def test_vectorizer(): assert_equal(vect.fit_transform(data[1:]).shape, (149, 108)) # check if raised errors are working correctly - vect.fit(np.random.rand(105, 12, 3)) - pytest.raises(ValueError, vect.transform, np.random.rand(105, 12, 3, 1)) - pytest.raises(ValueError, vect.inverse_transform, np.random.rand(102, 12, 12)) + X = np.random.default_rng(0).standard_normal((105, 12, 3)) + y = np.arange(X.shape[0]) % 2 + pytest.raises(ValueError, vect.transform, X[..., np.newaxis]) + pytest.raises(ValueError, vect.inverse_transform, X[:, :-1]) + + # And that pipelines work properly + X_arr = EpochsArray(X, create_info(12, 1000.0, "eeg")) + vect.fit(X_arr) + clf = make_pipeline(Vectorizer(), StandardScaler(), LinearModel()) + clf.fit(X_arr, y) def test_unsupervised_spatial_filter(): @@ -235,11 +253,13 @@ def test_unsupervised_spatial_filter(): verbose=False, ) - # Test estimator - pytest.raises(ValueError, UnsupervisedSpatialFilter, KernelRidge(2)) + # Test estimator (must be a transformer) + X = epochs.get_data(copy=False) + usf = UnsupervisedSpatialFilter(KernelRidge(2)) + with pytest.raises(ValueError, match="transform"): + usf.fit(X) # Test fit - X = epochs.get_data(copy=False) n_components = 4 usf = UnsupervisedSpatialFilter(PCA(n_components)) usf.fit(X) @@ -255,7 +275,9 @@ def test_unsupervised_spatial_filter(): # Test with average param usf = UnsupervisedSpatialFilter(PCA(4), average=True) usf.fit_transform(X) - pytest.raises(ValueError, UnsupervisedSpatialFilter, PCA(4), 2) + usf = UnsupervisedSpatialFilter(PCA(4), 2) + with pytest.raises(TypeError, match="average must be"): + usf.fit(X) def test_temporal_filter(): @@ -281,8 +303,8 @@ def test_temporal_filter(): assert X.shape == Xt.shape # Test fit and transform numpy type check - with pytest.raises(ValueError, match="Data to be filtered must be"): - filt.transform([1, 2]) + with pytest.raises(ValueError): + filt.transform("foo") # Test with 2 dimensional data array X = np.random.rand(101, 500) @@ -298,4 +320,36 @@ def test_bad_triage(): filt = TemporalFilter(l_freq=8, h_freq=60, sfreq=160.0) # Used to fail with "ValueError: Effective band-stop frequency (135.0) is # too high (maximum based on Nyquist is 80.0)" + assert not hasattr(filt, "fitted_") filt.fit_transform(np.zeros((1, 1, 481))) + assert filt.fitted_ + + +@pytest.mark.filterwarnings("ignore:.*filter_length.*") +@parametrize_with_checks( + [ + FilterEstimator(info, l_freq=1, h_freq=10), + PSDEstimator(), + Scaler(scalings="mean"), + # Not easy to test Scaler(info) b/c number of channels must match + TemporalFilter(), + UnsupervisedSpatialFilter(PCA()), + Vectorizer(), + ] +) +def test_sklearn_compliance(estimator, check): + """Test LinearModel compliance with sklearn.""" + ignores = [] + if estimator.__class__.__name__ == "FilterEstimator": + ignores += [ + "check_estimators_overwrite_params", # we modify self.info + "check_methods_sample_order_invariance", + ] + if estimator.__class__.__name__.startswith(("PSD", "Temporal")): + ignores += [ + "check_transformers_unfitted", # allow unfitted transform + "check_methods_sample_order_invariance", + ] + if any(ignore in str(check) for ignore in ignores): + return + check(estimator) diff --git a/mne/decoding/time_frequency.py b/mne/decoding/time_frequency.py index de6ec52155b..29232aaeb9f 100644 --- a/mne/decoding/time_frequency.py +++ b/mne/decoding/time_frequency.py @@ -3,14 +3,16 @@ # Copyright the MNE-Python contributors. import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted from ..time_frequency.tfr import _compute_tfr -from ..utils import _check_option, fill_doc, verbose +from ..utils import _check_option, fill_doc +from .transformer import MNETransformerMixin @fill_doc -class TimeFrequency(TransformerMixin, BaseEstimator): +class TimeFrequency(MNETransformerMixin, BaseEstimator): """Time frequency transformer. Time-frequency transform of times series along the last axis. @@ -59,7 +61,6 @@ class TimeFrequency(TransformerMixin, BaseEstimator): mne.time_frequency.tfr_multitaper """ - @verbose def __init__( self, freqs, @@ -74,9 +75,6 @@ def __init__( verbose=None, ): """Init TimeFrequency transformer.""" - # Check non-average output - output = _check_option("output", output, ["complex", "power", "phase"]) - self.freqs = freqs self.sfreq = sfreq self.method = method @@ -89,6 +87,16 @@ def __init__( self.n_jobs = n_jobs self.verbose = verbose + def __sklearn_tags__(self): + """Return sklearn tags.""" + out = super().__sklearn_tags__() + from sklearn.utils import TransformerTags + + if out.transformer_tags is None: + out.transformer_tags = TransformerTags() + out.transformer_tags.preserves_dtype = [] # real->complex + return out + def fit_transform(self, X, y=None): """Time-frequency transform of times series along the last axis. @@ -123,6 +131,10 @@ def fit(self, X, y=None): # noqa: D401 self : object Return self. """ + # Check non-average output + _check_option("output", self.output, ["complex", "power", "phase"]) + self._check_data(X, y=y, fit=True) + self.fitted_ = True return self def transform(self, X): @@ -130,16 +142,18 @@ def transform(self, X): Parameters ---------- - X : array, shape (n_samples, n_channels, n_times) + X : array, shape (n_samples, [n_channels, ]n_times) The training data samples. The channel dimension can be zero- or 1-dimensional. Returns ------- - Xt : array, shape (n_samples, n_channels, n_freqs, n_times) + Xt : array, shape (n_samples, [n_channels, ]n_freqs, n_times) The time-frequency transform of the data, where n_channels can be zero- or 1-dimensional. """ + X = self._check_data(X, atleast_3d=False) + check_is_fitted(self, "fitted_") # Ensure 3-dimensional X shape = X.shape[1:-1] if not shape: diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index 8eb2dcc5510..6d0c83f42ab 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -3,19 +3,72 @@ # Copyright the MNE-Python contributors. import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator, TransformerMixin, check_array, clone +from sklearn.preprocessing import RobustScaler, StandardScaler +from sklearn.utils import check_X_y +from sklearn.utils.validation import check_is_fitted, validate_data from .._fiff.pick import ( _pick_data_channels, _picks_by_type, _picks_to_idx, pick_info, - pick_types, ) from ..cov import _check_scalings_user +from ..epochs import BaseEpochs from ..filter import filter_data from ..time_frequency import psd_array_multitaper -from ..utils import _check_option, _validate_type, fill_doc, verbose +from ..utils import _check_option, _validate_type, fill_doc + + +class MNETransformerMixin(TransformerMixin): + """TransformerMixin plus some helpers.""" + + def _check_data( + self, + epochs_data, + *, + y=None, + atleast_3d=True, + fit=False, + return_y=False, + multi_output=False, + check_n_features=True, + ): + # Sklearn calls asarray under the hood which works, but elsewhere they check for + # __len__ then look at the size of obj[0]... which is an epoch of shape (1, ...) + # rather than what they expect (shape (...)). So we explicitly get the NumPy + # array to make everyone happy. + if isinstance(epochs_data, BaseEpochs): + epochs_data = epochs_data.get_data(copy=False) + kwargs = dict(dtype=np.float64, allow_nd=True, order="C", force_writeable=True) + if hasattr(self, "n_features_in_") and check_n_features: + if y is None: + epochs_data = validate_data( + self, + epochs_data, + **kwargs, + reset=fit, + ) + else: + epochs_data, y = validate_data( + self, + epochs_data, + y, + **kwargs, + reset=fit, + ) + elif y is None: + epochs_data = check_array(epochs_data, **kwargs) + else: + epochs_data, y = check_X_y( + X=epochs_data, y=y, multi_output=multi_output, **kwargs + ) + if fit: + self.n_features_in_ = epochs_data.shape[1] + if atleast_3d: + epochs_data = np.atleast_3d(epochs_data) + return (epochs_data, y) if return_y else epochs_data class _ConstantScaler: @@ -55,8 +108,9 @@ def fit_transform(self, X, y=None): def _sklearn_reshape_apply(func, return_result, X, *args, **kwargs): """Reshape epochs and apply function.""" - if not isinstance(X, np.ndarray): - raise ValueError(f"data should be an np.ndarray, got {type(X)}.") + _validate_type(X, np.ndarray, "X") + if X.size == 0: + return X.copy() if return_result else None orig_shape = X.shape X = np.reshape(X.transpose(0, 2, 1), (-1, orig_shape[1])) X = func(X, *args, **kwargs) @@ -67,7 +121,7 @@ def _sklearn_reshape_apply(func, return_result, X, *args, **kwargs): @fill_doc -class Scaler(TransformerMixin, BaseEstimator): +class Scaler(MNETransformerMixin, BaseEstimator): """Standardize channel data. This class scales data for each channel. It differs from scikit-learn @@ -109,31 +163,6 @@ def __init__(self, info=None, scalings=None, with_mean=True, with_std=True): self.with_std = with_std self.scalings = scalings - if not (scalings is None or isinstance(scalings, dict | str)): - raise ValueError( - f"scalings type should be dict, str, or None, got {type(scalings)}" - ) - if isinstance(scalings, str): - _check_option("scalings", scalings, ["mean", "median"]) - if scalings is None or isinstance(scalings, dict): - if info is None: - raise ValueError( - f'Need to specify "info" if scalings is {type(scalings)}' - ) - self._scaler = _ConstantScaler(info, scalings, self.with_std) - elif scalings == "mean": - from sklearn.preprocessing import StandardScaler - - self._scaler = StandardScaler( - with_mean=self.with_mean, with_std=self.with_std - ) - else: # scalings == 'median': - from sklearn.preprocessing import RobustScaler - - self._scaler = RobustScaler( - with_centering=self.with_mean, with_scaling=self.with_std - ) - def fit(self, epochs_data, y=None): """Standardize data across channels. @@ -149,11 +178,30 @@ def fit(self, epochs_data, y=None): self : instance of Scaler The modified instance. """ - _validate_type(epochs_data, np.ndarray, "epochs_data") - if epochs_data.ndim == 2: - epochs_data = epochs_data[..., np.newaxis] + epochs_data = self._check_data(epochs_data, y=y, fit=True, multi_output=True) assert epochs_data.ndim == 3, epochs_data.shape - _sklearn_reshape_apply(self._scaler.fit, False, epochs_data, y=y) + + _validate_type(self.scalings, (dict, str, type(None)), "scalings") + if isinstance(self.scalings, str): + _check_option( + "scalings", self.scalings, ["mean", "median"], extra="when str" + ) + if self.scalings is None or isinstance(self.scalings, dict): + if self.info is None: + raise ValueError( + f'Need to specify "info" if scalings is {type(self.scalings)}' + ) + self.scaler_ = _ConstantScaler(self.info, self.scalings, self.with_std) + elif self.scalings == "mean": + self.scaler_ = StandardScaler( + with_mean=self.with_mean, with_std=self.with_std + ) + else: # scalings == 'median': + self.scaler_ = RobustScaler( + with_centering=self.with_mean, with_scaling=self.with_std + ) + + _sklearn_reshape_apply(self.scaler_.fit, False, epochs_data, y=y) return self def transform(self, epochs_data): @@ -174,13 +222,14 @@ def transform(self, epochs_data): This function makes a copy of the data before the operations and the memory usage may be large with big data. """ - _validate_type(epochs_data, np.ndarray, "epochs_data") + check_is_fitted(self, "scaler_") + epochs_data = self._check_data(epochs_data, atleast_3d=False) if epochs_data.ndim == 2: # can happen with SlidingEstimator if self.info is not None: assert len(self.info["ch_names"]) == epochs_data.shape[1] epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape - return _sklearn_reshape_apply(self._scaler.transform, True, epochs_data) + return _sklearn_reshape_apply(self.scaler_.transform, True, epochs_data) def fit_transform(self, epochs_data, y=None): """Fit to data, then transform it. @@ -226,19 +275,20 @@ def inverse_transform(self, epochs_data): This function makes a copy of the data before the operations and the memory usage may be large with big data. """ + epochs_data = self._check_data(epochs_data, atleast_3d=False) squeeze = False # Can happen with CSP if epochs_data.ndim == 2: squeeze = True epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape - out = _sklearn_reshape_apply(self._scaler.inverse_transform, True, epochs_data) + out = _sklearn_reshape_apply(self.scaler_.inverse_transform, True, epochs_data) if squeeze: out = out[..., 0] return out -class Vectorizer(TransformerMixin): +class Vectorizer(MNETransformerMixin, BaseEstimator): """Transform n-dimensional array into 2D array of n_samples by n_features. This class reshapes an n-dimensional array into an n_samples * n_features @@ -275,7 +325,7 @@ def fit(self, X, y=None): self : instance of Vectorizer Return the modified instance. """ - X = np.asarray(X) + X = self._check_data(X, y=y, atleast_3d=False, fit=True, check_n_features=False) self.features_shape_ = X.shape[1:] return self @@ -295,7 +345,7 @@ def transform(self, X): X : array, shape (n_samples, n_features) The transformed data. """ - X = np.asarray(X) + X = self._check_data(X, atleast_3d=False) if X.shape[1:] != self.features_shape_: raise ValueError("Shape of X used in fit and transform must be same") return X.reshape(len(X), -1) @@ -334,7 +384,7 @@ def inverse_transform(self, X): The data transformed into shape as used in fit. The first dimension is of length n_samples. """ - X = np.asarray(X) + X = self._check_data(X, atleast_3d=False, check_n_features=False) if X.ndim not in (2, 3): raise ValueError( f"X should be of 2 or 3 dimensions but has shape {X.shape}" @@ -343,7 +393,7 @@ def inverse_transform(self, X): @fill_doc -class PSDEstimator(TransformerMixin): +class PSDEstimator(MNETransformerMixin, BaseEstimator): """Compute power spectral density (PSD) using a multi-taper method. Parameters @@ -365,7 +415,6 @@ class PSDEstimator(TransformerMixin): n_jobs : int Number of parallel jobs to use (only used if adaptive=True). %(normalization)s - %(verbose)s See Also -------- @@ -375,7 +424,6 @@ class PSDEstimator(TransformerMixin): mne.Evoked.compute_psd """ - @verbose def __init__( self, sfreq=2 * np.pi, @@ -386,8 +434,6 @@ def __init__( low_bias=True, n_jobs=None, normalization="length", - *, - verbose=None, ): self.sfreq = sfreq self.fmin = fmin @@ -398,7 +444,7 @@ def __init__( self.n_jobs = n_jobs self.normalization = normalization - def fit(self, epochs_data, y): + def fit(self, epochs_data, y=None): """Compute power spectral density (PSD) using a multi-taper method. Parameters @@ -413,11 +459,8 @@ def fit(self, epochs_data, y): self : instance of PSDEstimator The modified instance. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) - + self._check_data(epochs_data, y=y, fit=True) + self.fitted_ = True # sklearn compliance return self def transform(self, epochs_data): @@ -433,10 +476,7 @@ def transform(self, epochs_data): psd : array, shape (n_signals, n_freqs) or (n_freqs,) The computed PSD. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) + epochs_data = self._check_data(epochs_data) psd, _ = psd_array_multitaper( epochs_data, sfreq=self.sfreq, @@ -452,7 +492,7 @@ def transform(self, epochs_data): @fill_doc -class FilterEstimator(TransformerMixin): +class FilterEstimator(MNETransformerMixin, BaseEstimator): """Estimator to filter RtEpochs. Applies a zero-phase low-pass, high-pass, band-pass, or band-stop @@ -488,7 +528,6 @@ class FilterEstimator(TransformerMixin): See mne.filter.construct_iir_filter for details. If iir_params is None and method="iir", 4th order Butterworth will be used. %(fir_design)s - %(verbose)s See Also -------- @@ -514,13 +553,11 @@ def __init__( method="fir", iir_params=None, fir_design="firwin", - *, - verbose=None, ): self.info = info self.l_freq = l_freq self.h_freq = h_freq - self.picks = _picks_to_idx(info, picks) + self.picks = picks self.filter_length = filter_length self.l_trans_bandwidth = l_trans_bandwidth self.h_trans_bandwidth = h_trans_bandwidth @@ -544,24 +581,11 @@ def fit(self, epochs_data, y): self : instance of FilterEstimator The modified instance. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) - - if self.picks is None: - self.picks = pick_types( - self.info, meg=True, eeg=True, ref_meg=False, exclude=[] - ) + self.picks_ = _picks_to_idx(self.info, self.picks) + self._check_data(epochs_data, y=y, fit=True) if self.l_freq == 0: self.l_freq = None - if self.h_freq is not None and self.h_freq > (self.info["sfreq"] / 2.0): - self.h_freq = None - if self.l_freq is not None and not isinstance(self.l_freq, float): - self.l_freq = float(self.l_freq) - if self.h_freq is not None and not isinstance(self.h_freq, float): - self.h_freq = float(self.h_freq) if self.info["lowpass"] is None or ( self.h_freq is not None @@ -594,17 +618,12 @@ def transform(self, epochs_data): X : array, shape (n_epochs, n_channels, n_times) The data after filtering. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) - epochs_data = np.atleast_3d(epochs_data) return filter_data( - epochs_data, + self._check_data(epochs_data), self.info["sfreq"], self.l_freq, self.h_freq, - self.picks, + self.picks_, self.filter_length, self.l_trans_bandwidth, self.h_trans_bandwidth, @@ -617,7 +636,7 @@ def transform(self, epochs_data): ) -class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): +class UnsupervisedSpatialFilter(MNETransformerMixin, BaseEstimator): """Use unsupervised spatial filtering across time and samples. Parameters @@ -630,19 +649,6 @@ class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): """ def __init__(self, estimator, average=False): - # XXX: Use _check_estimator #3381 - for attr in ("fit", "transform", "fit_transform"): - if not hasattr(estimator, attr): - raise ValueError( - "estimator must be a scikit-learn " - f"transformer, missing {attr} method" - ) - - if not isinstance(average, bool): - raise ValueError( - f"average parameter must be of bool type, got {type(bool)} instead" - ) - self.estimator = estimator self.average = average @@ -661,13 +667,25 @@ def fit(self, X, y=None): self : instance of UnsupervisedSpatialFilter Return the modified instance. """ + # sklearn.utils.estimator_checks.check_estimator(self.estimator) is probably + # too strict for us, given that we don't fully adhere yet, so just check attrs + for attr in ("fit", "transform", "fit_transform"): + if not hasattr(self.estimator, attr): + raise ValueError( + "estimator must be a scikit-learn " + f"transformer, missing {attr} method" + ) + _validate_type(self.average, bool, "average") + X = self._check_data(X, y=y, fit=True) if self.average: X = np.mean(X, axis=0).T else: n_epochs, n_channels, n_times = X.shape # trial as time samples X = np.transpose(X, (1, 0, 2)).reshape((n_channels, n_epochs * n_times)).T - self.estimator.fit(X) + + self.estimator_ = clone(self.estimator) + self.estimator_.fit(X) return self def fit_transform(self, X, y=None): @@ -700,6 +718,8 @@ def transform(self, X): X : array, shape (n_epochs, n_channels, n_times) The transformed data. """ + check_is_fitted(self.estimator_) + X = self._check_data(X) return self._apply_method(X, "transform") def inverse_transform(self, X): @@ -735,7 +755,7 @@ def _apply_method(self, X, method): X = np.transpose(X, [1, 0, 2]) X = np.reshape(X, [n_channels, n_epochs * n_times]).T # apply method - method = getattr(self.estimator, method) + method = getattr(self.estimator_, method) X = method(X) # put it back to n_epochs, n_dimensions X = np.reshape(X.T, [-1, n_epochs, n_times]).transpose([1, 0, 2]) @@ -743,7 +763,7 @@ def _apply_method(self, X, method): @fill_doc -class TemporalFilter(TransformerMixin): +class TemporalFilter(MNETransformerMixin, BaseEstimator): """Estimator to filter data array along the last dimension. Applies a zero-phase low-pass, high-pass, band-pass, or band-stop @@ -817,7 +837,6 @@ class TemporalFilter(TransformerMixin): attenuation using fewer samples than "firwin2". .. versionadded:: 0.15 - %(verbose)s See Also -------- @@ -826,7 +845,6 @@ class TemporalFilter(TransformerMixin): mne.filter.filter_data """ - @verbose def __init__( self, l_freq=None, @@ -840,8 +858,6 @@ def __init__( iir_params=None, fir_window="hamming", fir_design="firwin", - *, - verbose=None, ): self.l_freq = l_freq self.h_freq = h_freq @@ -855,17 +871,12 @@ def __init__( self.fir_window = fir_window self.fir_design = fir_design - if not isinstance(self.n_jobs, int) and self.n_jobs == "cuda": - raise ValueError( - f'n_jobs must be int or "cuda", got {type(self.n_jobs)} instead.' - ) - def fit(self, X, y=None): """Do nothing (for scikit-learn compatibility purposes). Parameters ---------- - X : array, shape (n_epochs, n_channels, n_times) or or shape (n_channels, n_times) + X : array, shape ([n_epochs, ]n_channels, n_times) The data to be filtered over the last dimension. The channels dimension can be zero when passing a 2D array. y : None @@ -875,7 +886,9 @@ def fit(self, X, y=None): ------- self : instance of TemporalFilter The modified instance. - """ # noqa: E501 + """ + self.fitted_ = True # sklearn compliance + self._check_data(X, y=y, atleast_3d=False, fit=True) return self def transform(self, X): @@ -883,7 +896,7 @@ def transform(self, X): Parameters ---------- - X : array, shape (n_epochs, n_channels, n_times) or shape (n_channels, n_times) + X : array, shape ([n_epochs, ]n_channels, n_times) The data to be filtered over the last dimension. The channels dimension can be zero when passing a 2D array. @@ -892,6 +905,7 @@ def transform(self, X): X : array The data after filtering. """ # noqa: E501 + X = self._check_data(X, atleast_3d=False) X = np.atleast_2d(X) if X.ndim > 3: diff --git a/mne/epochs.py b/mne/epochs.py index 761a15199a9..96f247875d9 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1353,6 +1353,7 @@ def plot_topo_image( fig_facecolor="k", fig_background=None, font_color="w", + select=False, show=True, ): return plot_topo_image_epochs( @@ -1371,6 +1372,7 @@ def plot_topo_image( fig_facecolor=fig_facecolor, fig_background=fig_background, font_color=font_color, + select=select, show=show, ) @@ -1671,8 +1673,7 @@ def _get_data( # we start out with an empty array, allocate only if necessary data = np.empty((0, len(self.info["ch_names"]), len(self.times))) msg = ( - f"for {n_events} events and {len(self._raw_times)} " - "original time points" + f"for {n_events} events and {len(self._raw_times)} original time points" ) if self._decim > 1: msg += " (prior to decimation)" @@ -2301,8 +2302,7 @@ def save( logger.info(f"Splitting into {n_parts} parts") if n_parts > 100: # This must be an error raise ValueError( - f"Split size {split_size} would result in writing " - f"{n_parts} files" + f"Split size {split_size} would result in writing {n_parts} files" ) if len(self.drop_log) > 100000: @@ -3143,7 +3143,7 @@ def _ensure_list(x): raise ValueError( f"The event names in keep_first and keep_last must " f"be mutually exclusive. Specified in both: " - f'{", ".join(sorted(keep_first_and_last))}' + f"{', '.join(sorted(keep_first_and_last))}" ) del keep_first_and_last @@ -3163,7 +3163,7 @@ def _diff_input_strings_vs_event_id(input_strings, input_name, event_id): if event_name_diff: raise ValueError( f"Present in {input_name}, but missing from event_id: " - f'{", ".join(event_name_diff)}' + f"{', '.join(event_name_diff)}" ) _diff_input_strings_vs_event_id( @@ -3556,8 +3556,7 @@ def __init__( if not isinstance(raw, BaseRaw): raise ValueError( - "The first argument to `Epochs` must be an " - "instance of mne.io.BaseRaw" + "The first argument to `Epochs` must be an instance of mne.io.BaseRaw" ) info = deepcopy(raw.info) annotations = raw.annotations.copy() @@ -4441,8 +4440,7 @@ def _get_epoch_from_raw(self, idx, verbose=None): else: # read the correct subset of the data raise RuntimeError( - "Correct epoch could not be found, please " - "contact mne-python developers" + "Correct epoch could not be found, please contact mne-python developers" ) # the following is equivalent to this, but faster: # diff --git a/mne/event.py b/mne/event.py index 723615ea56a..a19270db1e6 100644 --- a/mne/event.py +++ b/mne/event.py @@ -1649,7 +1649,7 @@ def match_event_names(event_names, keys, *, on_missing="raise"): _on_missing( on_missing=on_missing, msg=f'Event name "{key}" could not be found. The following events ' - f'are present in the data: {", ".join(event_names)}', + f"are present in the data: {', '.join(event_names)}", error_klass=KeyError, ) diff --git a/mne/evoked.py b/mne/evoked.py index a985fc30ad7..7bd2355e4ee 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -613,9 +613,11 @@ def plot_topo( background_color="w", noise_cov=None, exclude="bads", + select=False, show=True, ): - """ + """. + Notes ----- .. versionadded:: 0.10.0 @@ -638,6 +640,7 @@ def plot_topo( background_color=background_color, noise_cov=noise_cov, exclude=exclude, + select=select, show=show, ) @@ -961,7 +964,7 @@ def __neg__(self): if out.comment is not None and " + " in out.comment: out.comment = f"({out.comment})" # multiple conditions in evoked - out.comment = f'- {out.comment or "unknown"}' + out.comment = f"- {out.comment or 'unknown'}" return out def get_peak( @@ -1052,8 +1055,7 @@ def get_peak( raise ValueError('Channel type must be "grad" for merge_grads') elif mode == "neg": raise ValueError( - "Negative mode (mode=neg) does not make " - "sense with merge_grads=True" + "Negative mode (mode=neg) does not make sense with merge_grads=True" ) meg = eeg = misc = seeg = dbs = ecog = fnirs = False @@ -1649,12 +1651,12 @@ def combine_evoked(all_evoked, weights): if e.comment is not None and " + " in e.comment: # multiple conditions this_comment = f"({e.comment})" else: - this_comment = f'{e.comment or "unknown"}' + this_comment = f"{e.comment or 'unknown'}" # assemble everything if idx == 0: comment += f"{sign}{weight}{multiplier}{this_comment}" else: - comment += f' {sign or "+"} {weight}{multiplier}{this_comment}' + comment += f" {sign or '+'} {weight}{multiplier}{this_comment}" # special-case: combine_evoked([e1, -e2], [1, -1]) evoked.comment = comment.replace(" - - ", " + ") return evoked @@ -1871,8 +1873,7 @@ def _read_evoked(fname, condition=None, kind="average", allow_maxshield=False): if len(chs) != nchan: raise ValueError( - "Number of channels and number of " - "channel definitions are different" + "Number of channels and number of channel definitions are different" ) ch_names_mapping = _read_extended_ch_info(chs, my_evoked, fid) diff --git a/mne/export/_brainvision.py b/mne/export/_brainvision.py index ba64ba010ce..6503c540f41 100644 --- a/mne/export/_brainvision.py +++ b/mne/export/_brainvision.py @@ -107,6 +107,13 @@ def _export_mne_raw(*, raw, fname, events=None, overwrite=False): def _mne_annots2pybv_events(raw): """Convert mne Annotations to pybv events.""" + # check that raw.annotations.orig_time is the same as raw.info["meas_date"] + # so that onsets are relative to the first sample + # (after further correction for first_time) + if raw.annotations and raw.info["meas_date"] != raw.annotations.orig_time: + raise ValueError( + "Annotations must have the same orig_time as raw.info['meas_date']" + ) events = [] for annot in raw.annotations: # handle onset and duration: seconds to sample, relative to diff --git a/mne/export/_edf.py b/mne/export/_edf.py index ef870692014..e50b05f7056 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -7,6 +7,7 @@ import numpy as np +from ..annotations import _sync_onset from ..utils import _check_edfio_installed, warn _check_edfio_installed() @@ -204,7 +205,9 @@ def _export_raw(fname, raw, physical_range, add_ch_type): for desc, onset, duration, ch_names in zip( raw.annotations.description, - raw.annotations.onset, + # subtract raw.first_time because EDF marks events starting from the first + # available data point and ignores raw.first_time + _sync_onset(raw, raw.annotations.onset, inverse=False), raw.annotations.duration, raw.annotations.ch_names, ): diff --git a/mne/export/_eeglab.py b/mne/export/_eeglab.py index 3c8f896164a..459207f0616 100644 --- a/mne/export/_eeglab.py +++ b/mne/export/_eeglab.py @@ -4,6 +4,7 @@ import numpy as np +from ..annotations import _sync_onset from ..utils import _check_eeglabio_installed _check_eeglabio_installed() @@ -24,11 +25,16 @@ def _export_raw(fname, raw): ch_names = [ch for ch in raw.ch_names if ch not in drop_chs] cart_coords = _get_als_coords_from_chs(raw.info["chs"], drop_chs) - annotations = [ - raw.annotations.description, - raw.annotations.onset, - raw.annotations.duration, - ] + if raw.annotations: + annotations = [ + raw.annotations.description, + # subtract raw.first_time because EEGLAB marks events starting from + # the first available data point and ignores raw.first_time + _sync_onset(raw, raw.annotations.onset, inverse=False), + raw.annotations.duration, + ] + else: + annotations = None eeglabio.raw.export_set( fname, data=raw.get_data(picks=ch_names), diff --git a/mne/export/_egimff.py b/mne/export/_egimff.py index 3792ea4a6a5..185afb5f558 100644 --- a/mne/export/_egimff.py +++ b/mne/export/_egimff.py @@ -53,7 +53,7 @@ def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, verbose= info = evoked[0].info if np.round(info["sfreq"]) != info["sfreq"]: raise ValueError( - f'Sampling frequency must be a whole number. sfreq: {info["sfreq"]}' + f"Sampling frequency must be a whole number. sfreq: {info['sfreq']}" ) sampling_rate = int(info["sfreq"]) diff --git a/mne/export/_export.py b/mne/export/_export.py index 490bf986895..4b93fda917e 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -25,6 +25,14 @@ def export_raw( %(export_warning)s + .. warning:: + When exporting ``Raw`` with annotations, ``raw.info["meas_date"]`` must be the + same as ``raw.annotations.orig_time``. This guarantees that the annotations are + in the same reference frame as the samples. When + :attr:`Raw.first_time ` is not zero (e.g., after + cropping), the onsets are automatically corrected so that onsets are always + relative to the first sample. + Parameters ---------- %(fname_export_params)s @@ -216,7 +224,6 @@ def _infer_check_export_fmt(fmt, fname, supported_formats): supported_str = ", ".join(supported) raise ValueError( - f"Format '{fmt}' is not supported. " - f"Supported formats are {supported_str}." + f"Format '{fmt}' is not supported. Supported formats are {supported_str}." ) return fmt diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 706a83476e4..6f712923c7d 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -122,6 +122,49 @@ def test_export_raw_eeglab(tmp_path): raw.export(temp_fname, overwrite=True) +@pytest.mark.parametrize("tmin", (0, 1, 5, 10)) +def test_export_raw_eeglab_annotations(tmp_path, tmin): + """Test annotations in the exported EEGLAB file. + + All annotations should be preserved and onset corrected. + """ + pytest.importorskip("eeglabio") + raw = read_raw_fif(fname_raw, preload=True) + raw.apply_proj() + annotations = Annotations( + onset=[0.01, 0.05, 0.90, 1.05], + duration=[0, 1, 0, 0], + description=["test1", "test2", "test3", "test4"], + ch_names=[["MEG 0113"], ["MEG 0113", "MEG 0132"], [], ["MEG 0143"]], + ) + raw.set_annotations(annotations) + raw.crop(tmin) + + # export + temp_fname = tmp_path / "test.set" + raw.export(temp_fname) + + # read in the file + with pytest.warns(RuntimeWarning, match="is above the 99th percentile"): + raw_read = read_raw_eeglab(temp_fname, preload=True, montage_units="m") + assert raw_read.first_time == 0 # exportation resets first_time + valid_annot = ( + raw.annotations.onset >= tmin + ) # only annotations in the cropped range gets exported + + # compare annotations before and after export + assert_array_almost_equal( + raw.annotations.onset[valid_annot] - raw.first_time, + raw_read.annotations.onset, + ) + assert_array_equal( + raw.annotations.duration[valid_annot], raw_read.annotations.duration + ) + assert_array_equal( + raw.annotations.description[valid_annot], raw_read.annotations.description + ) + + def _create_raw_for_edf_tests(stim_channel_index=None): rng = np.random.RandomState(12345) ch_types = [ @@ -145,7 +188,7 @@ def _create_raw_for_edf_tests(stim_channel_index=None): edfio_mark = pytest.mark.skipif( - not _check_edfio_installed(strict=False), reason="unsafe use of private module" + not _check_edfio_installed(strict=False), reason="requires edfio" ) @@ -154,6 +197,7 @@ def test_double_export_edf(tmp_path): """Test exporting an EDF file multiple times.""" raw = _create_raw_for_edf_tests(stim_channel_index=2) raw.info.set_meas_date("2023-09-04 14:53:09.000") + raw.set_annotations(Annotations(onset=[1], duration=[0], description=["test"])) # include subject info and measurement date raw.info["subject_info"] = dict( @@ -235,7 +279,7 @@ def test_edf_padding(tmp_path, pad_width): RuntimeWarning, match=( "EDF format requires equal-length data blocks.*" - f"{pad_width/1000:.3g} seconds of edge values were appended.*" + f"{pad_width / 1000:.3g} seconds of edge values were appended.*" ), ): raw.export(temp_fname) @@ -258,8 +302,12 @@ def test_edf_padding(tmp_path, pad_width): @edfio_mark() -def test_export_edf_annotations(tmp_path): - """Test that exporting EDF preserves annotations.""" +@pytest.mark.parametrize("tmin", (0, 0.005, 0.03, 1)) +def test_export_edf_annotations(tmp_path, tmin): + """Test annotations in the exported EDF file. + + All annotations should be preserved and onset corrected. + """ raw = _create_raw_for_edf_tests() annotations = Annotations( onset=[0.01, 0.05, 0.90, 1.05], @@ -268,17 +316,44 @@ def test_export_edf_annotations(tmp_path): ch_names=[["0"], ["0", "1"], [], ["1"]], ) raw.set_annotations(annotations) + raw.crop(tmin) + assert raw.first_time == tmin + + if raw.n_times % raw.info["sfreq"] == 0: + expectation = nullcontext() + else: + expectation = pytest.warns( + RuntimeWarning, match="EDF format requires equal-length data blocks" + ) # export temp_fname = tmp_path / "test.edf" - raw.export(temp_fname) + with expectation: + raw.export(temp_fname) # read in the file raw_read = read_raw_edf(temp_fname, preload=True) - assert_array_equal(raw.annotations.onset, raw_read.annotations.onset) - assert_array_equal(raw.annotations.duration, raw_read.annotations.duration) - assert_array_equal(raw.annotations.description, raw_read.annotations.description) - assert_array_equal(raw.annotations.ch_names, raw_read.annotations.ch_names) + assert raw_read.first_time == 0 # exportation resets first_time + bad_annot = raw_read.annotations.description == "BAD_ACQ_SKIP" + if bad_annot.any(): + raw_read.annotations.delete(bad_annot) + valid_annot = ( + raw.annotations.onset >= tmin + ) # only annotations in the cropped range gets exported + + # compare annotations before and after export + assert_array_almost_equal( + raw.annotations.onset[valid_annot] - raw.first_time, raw_read.annotations.onset + ) + assert_array_equal( + raw.annotations.duration[valid_annot], raw_read.annotations.duration + ) + assert_array_equal( + raw.annotations.description[valid_annot], raw_read.annotations.description + ) + assert_array_equal( + raw.annotations.ch_names[valid_annot], raw_read.annotations.ch_names + ) @edfio_mark() @@ -476,7 +551,7 @@ def test_export_epochs_eeglab(tmp_path, preload): with ctx(): epochs.export(temp_fname) epochs.drop_channels([ch for ch in ["epoc", "STI 014"] if ch in epochs.ch_names]) - epochs_read = read_epochs_eeglab(temp_fname) + epochs_read = read_epochs_eeglab(temp_fname, verbose="error") # head radius assert epochs.ch_names == epochs_read.ch_names cart_coords = np.array([d["loc"][:3] for d in epochs.info["chs"]]) # just xyz cart_coords_read = np.array([d["loc"][:3] for d in epochs_read.info["chs"]]) @@ -580,7 +655,7 @@ def test_export_to_mff_incompatible_sfreq(): """Test non-whole number sampling frequency throws ValueError.""" pytest.importorskip("mffpy", "0.5.7") evoked = read_evokeds(fname_evoked) - with pytest.raises(ValueError, match=f'sfreq: {evoked[0].info["sfreq"]}'): + with pytest.raises(ValueError, match=f"sfreq: {evoked[0].info['sfreq']}"): export_evokeds("output.mff", evoked) diff --git a/mne/filter.py b/mne/filter.py index 025f778d07f..a7d7c883e2f 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -411,8 +411,7 @@ def _prep_for_filtering(x, copy, picks=None): picks = np.tile(picks, n_epochs) + offset elif len(orig_shape) > 3: raise ValueError( - "picks argument is not supported for data with more" - " than three dimensions" + "picks argument is not supported for data with more than three dimensions" ) assert all(0 <= pick < x.shape[0] for pick in picks) # guaranteed by above @@ -434,7 +433,7 @@ def _firwin_design(N, freq, gain, window, sfreq): for this_freq, this_gain in zip(freq[::-1][1:], gain[::-1][1:]): assert this_gain in (0, 1) if this_gain != prev_gain: - # Get the correct N to satistify the requested transition bandwidth + # Get the correct N to satisfy the requested transition bandwidth transition = (prev_freq - this_freq) / 2.0 this_N = int(round(_length_factors[window] / transition)) this_N += 1 - this_N % 2 # make it odd @@ -2873,7 +2872,7 @@ def design_mne_c_filter( h_width = (int(((n_freqs - 1) * h_trans_bandwidth) / (0.5 * sfreq)) + 1) // 2 h_start = int(((n_freqs - 1) * h_freq) / (0.5 * sfreq)) logger.info( - "filter : %7.3f ... %6.1f Hz bins : %d ... %d of %d " "hpw : %d lpw : %d", + "filter : %7.3f ... %6.1f Hz bins : %d ... %d of %d hpw : %d lpw : %d", l_freq, h_freq, l_start, diff --git a/mne/fixes.py b/mne/fixes.py index 2b37de982c5..070d4125d18 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -22,12 +22,10 @@ from math import log import numpy as np +from packaging.version import parse ############################################################################### -# distutils - -# distutils has been deprecated since Python 3.10 and was removed -# from the standard library with the release of Python 3.12. +# distutils LooseVersion removed in Python 3.12 def _compare_version(version_a, operator, version_b): @@ -48,8 +46,6 @@ def _compare_version(version_a, operator, version_b): bool The result of the version comparison. """ - from packaging.version import parse - mapping = {"<": "lt", "<=": "le", "==": "eq", "!=": "ne", ">=": "ge", ">": "gt"} with warnings.catch_warnings(record=True): warnings.simplefilter("ignore") @@ -724,3 +720,16 @@ def minimum_phase(h, method="homomorphic", n_fft=None, *, half=True): n_out = (n_half + len(h) % 2) if half else len(h) return h_minimum[:n_out] + + +# SciPy 1.15 deprecates sph_harm for sph_harm_y and using it will trigger a +# DeprecationWarning. This is a backport of the new function for older SciPy versions. +def sph_harm_y(n, m, theta, phi, *, diff_n=0): + """Wrap scipy.special.sph_harm for sph_harm_y.""" + # Can be removed once we no longer support scipy < 1.15.0 + from scipy import special + + if "sph_harm_y" in special.__dict__: + return special.sph_harm_y(n, m, theta, phi, diff_n=diff_n) + else: + return special.sph_harm(m, n, phi, theta) diff --git a/mne/forward/_field_interpolation.py b/mne/forward/_field_interpolation.py index b505b5e45df..e98a147b560 100644 --- a/mne/forward/_field_interpolation.py +++ b/mne/forward/_field_interpolation.py @@ -96,7 +96,7 @@ def _pinv_trunc(x, miss): varexp /= varexp[-1] n = np.where(varexp >= (1.0 - miss))[0][0] + 1 logger.info( - " Truncating at %d/%d components to omit less than %g " "(%0.2g)", + " Truncating at %d/%d components to omit less than %g (%0.2g)", n, len(s), miss, @@ -111,8 +111,7 @@ def _pinv_tikhonov(x, reg): # _reg_pinv requires square Hermitian, which we have here inv, _, n = _reg_pinv(x, reg=reg, rank=None) logger.info( - f" Truncating at {n}/{len(x)} components and regularizing " - f"with α={reg:0.1e}" + f" Truncating at {n}/{len(x)} components and regularizing with α={reg:0.1e}" ) return inv, n diff --git a/mne/forward/_make_forward.py b/mne/forward/_make_forward.py index 64aadf69fec..6c77f47e312 100644 --- a/mne/forward/_make_forward.py +++ b/mne/forward/_make_forward.py @@ -160,8 +160,7 @@ def _create_meg_coil(coilset, ch, acc, do_es): break else: raise RuntimeError( - "Desired coil definition not found " - f"(type = {ch['coil_type']} acc = {acc})" + f"Desired coil definition not found (type = {ch['coil_type']} acc = {acc})" ) # Apply a coordinate transformation if so desired @@ -295,8 +294,8 @@ def _setup_bem(bem, bem_extra, neeg, mri_head_t, allow_none=False, verbose=None) else: if bem["surfs"][0]["coord_frame"] != FIFF.FIFFV_COORD_MRI: raise RuntimeError( - f'BEM is in {_coord_frame_name(bem["surfs"][0]["coord_frame"])} ' - 'coordinates, should be in MRI' + f"BEM is in {_coord_frame_name(bem['surfs'][0]['coord_frame'])} " + "coordinates, should be in MRI" ) if neeg > 0 and len(bem["surfs"]) == 1: raise RuntimeError( @@ -335,7 +334,7 @@ def _prep_meg_channels( del picks # Get channel info and names for MEG channels - logger.info(f'Read {len(info_meg["chs"])} MEG channels from info') + logger.info(f"Read {len(info_meg['chs'])} MEG channels from info") # Get MEG compensation channels compensator = post_picks = None @@ -352,7 +351,7 @@ def _prep_meg_channels( 'channels. Consider using "ignore_ref=True" in ' "calculation" ) - logger.info(f'{len(info["comps"])} compensation data sets in info') + logger.info(f"{len(info['comps'])} compensation data sets in info") # Compose a compensation data set if necessary # adapted from mne_make_ctf_comp() from mne_ctf_comp.c logger.info("Setting up compensation data...") diff --git a/mne/forward/forward.py b/mne/forward/forward.py index e3e5c08d2f8..f1c2c2d11d7 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -512,7 +512,7 @@ def _merge_fwds(fwds, *, verbose=None): a[k]["row_names"] = a[k]["row_names"] + b[k]["row_names"] a["nchan"] = a["nchan"] + b["nchan"] if len(fwds) > 1: - logger.info(f' Forward solutions combined: {", ".join(combined)}') + logger.info(f" Forward solutions combined: {', '.join(combined)}") return fwd @@ -677,8 +677,7 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=True, verbos # Make sure forward solution is in either the MRI or HEAD coordinate frame if fwd["coord_frame"] not in (FIFF.FIFFV_COORD_MRI, FIFF.FIFFV_COORD_HEAD): raise ValueError( - "Only forward solutions computed in MRI or head " - "coordinates are acceptable" + "Only forward solutions computed in MRI or head coordinates are acceptable" ) # Transform each source space to the HEAD or MRI coordinate frame, @@ -1205,8 +1204,7 @@ def _triage_loose(src, loose, fixed="auto"): if fixed is True: if not all(v == 0.0 for v in loose.values()): raise ValueError( - 'When using fixed=True, loose must be 0. or "auto", ' - f"got {orig_loose}" + f'When using fixed=True, loose must be 0. or "auto", got {orig_loose}' ) elif fixed is False: if any(v == 0.0 for v in loose.values()): @@ -1666,8 +1664,7 @@ def apply_forward( for ch_name in fwd["sol"]["row_names"]: if ch_name not in info["ch_names"]: raise ValueError( - f"Channel {ch_name} of forward operator not present in " - "evoked_template." + f"Channel {ch_name} of forward operator not present in evoked_template." ) # project the source estimate to the sensor space diff --git a/mne/forward/tests/test_make_forward.py b/mne/forward/tests/test_make_forward.py index 37ec6e041b5..a357c5779c9 100644 --- a/mne/forward/tests/test_make_forward.py +++ b/mne/forward/tests/test_make_forward.py @@ -482,7 +482,7 @@ def test_make_forward_solution_openmeeg(n_layers): eeg_atol=100, meg_corr_tol=0.98, eeg_corr_tol=0.98, - meg_rdm_tol=0.1, + meg_rdm_tol=0.11, eeg_rdm_tol=0.2, ) diff --git a/mne/gui/_coreg.py b/mne/gui/_coreg.py index 98e3fbfc0b3..b365a2eed5a 100644 --- a/mne/gui/_coreg.py +++ b/mne/gui/_coreg.py @@ -1611,8 +1611,7 @@ def _configure_dock(self): func=self._set_subjects_dir, is_directory=True, icon=True, - tooltip="Load the path to the directory containing the " - "FreeSurfer subjects", + tooltip="Load the path to the directory containing the FreeSurfer subjects", layout=subjects_dir_layout, ) self._renderer._layout_add_widget( @@ -1741,8 +1740,7 @@ def _configure_dock(self): self._widgets["omit"] = self._renderer._dock_add_button( name="Omit", callback=self._omit_hsp, - tooltip="Exclude the head shape points that are far away from " - "the MRI head", + tooltip="Exclude the head shape points that are far away from the MRI head", layout=omit_hsp_layout_2, ) self._widgets["reset_omit"] = self._renderer._dock_add_button( diff --git a/mne/html_templates/_templates.py b/mne/html_templates/_templates.py index 9427f2d6a25..1f68303a51e 100644 --- a/mne/html_templates/_templates.py +++ b/mne/html_templates/_templates.py @@ -66,7 +66,7 @@ def _format_time_range(inst) -> str: def _format_projs(info) -> list[str]: """Format projectors.""" - projs = [f'{p["desc"]} ({"on" if p["active"] else "off"})' for p in info["projs"]] + projs = [f"{p['desc']} ({'on' if p['active'] else 'off'})" for p in info["projs"]] return projs diff --git a/mne/html_templates/repr/_acquisition.html.jinja b/mne/html_templates/repr/_acquisition.html.jinja index 0016740cdf8..e1ee4f69dd3 100644 --- a/mne/html_templates/repr/_acquisition.html.jinja +++ b/mne/html_templates/repr/_acquisition.html.jinja @@ -81,7 +81,7 @@ {{ "%0.2f" | format(info["sfreq"]) }} Hz {% endif %} -{% if inst is defined and inst.times is defined %} +{% if inst is defined and inst | has_attr("times") and inst.times is defined %} Time points diff --git a/mne/html_templates/repr/_frequencies.html.jinja b/mne/html_templates/repr/_frequencies.html.jinja new file mode 100644 index 00000000000..b55b8ddf883 --- /dev/null +++ b/mne/html_templates/repr/_frequencies.html.jinja @@ -0,0 +1,62 @@ +{% set section = "Frequencies" %} +{% set section_class_name = section | lower | append_uuid %} + +{# Collapse content during documentation build. #} +{% if collapsed %} +{% set collapsed_row_class = "mne-repr-collapsed" %} +{% else %} +{% set collapsed_row_class = "" %} +{% endif %} + +{%include 'static/_section_header_row.html.jinja' %} + + + + Data type + {{ inst._data_type }} + + + + Computed from + {{ computed_from }} + + + + Estimation method + {{ inst.method }} + + +{% if "taper" in inst._dims %} + + + Number of tapers + {{ inst._mt_weights.size }} + +{% endif %} +{% if inst.freqs is defined %} + + + Frequency range + {{ '%.2f'|format(inst.freqs[0]) }} – {{ '%.2f'|format(inst.freqs[-1]) }} Hz + + + + Number of frequency bins + {{ inst.freqs|length }} + +{%- for unit in units %} + + + {%- if loop.index == 1 %} + Units + {%- endif %} + {{ unit }} + +{%- endfor %} +{% endif %} diff --git a/mne/html_templates/repr/spectrum.html.jinja b/mne/html_templates/repr/spectrum.html.jinja index 11a1d7a886f..40cc2222005 100644 --- a/mne/html_templates/repr/spectrum.html.jinja +++ b/mne/html_templates/repr/spectrum.html.jinja @@ -1,50 +1,11 @@ +{%include '_js_and_css.html.jinja' %} + +{% set info = inst.info %} + - - - - - {%- for unit in units %} - - {%- if loop.index == 1 %} - - {%- endif %} - - - {%- endfor %} - - - - - {%- if inst_type == "Epochs" %} - - - - - {% endif -%} - - - - - - - - - {% if "taper" in spectrum._dims %} - - - - - {% endif %} - - - - - - - - - - - - + {%include '_general.html.jinja' %} + {%include '_acquisition.html.jinja' %} + {%include '_channels.html.jinja' %} + {%include '_frequencies.html.jinja' %} + {%include '_filters.html.jinja' %}
Data type{{ spectrum._data_type }}
Units{{ unit }}
Data source{{ inst_type }}
Number of epochs{{ spectrum.shape[0] }}
Dims{{ spectrum._dims | join(", ") }}
Estimation method{{ spectrum.method }}
Number of tapers{{ spectrum._mt_weights.size }}
Number of channels{{ spectrum.ch_names|length }}
Number of frequency bins{{ spectrum.freqs|length }}
Frequency range{{ '%.2f'|format(spectrum.freqs[0]) }} – {{ '%.2f'|format(spectrum.freqs[-1]) }} Hz
diff --git a/mne/io/ant/ant.py b/mne/io/ant/ant.py index e46aabe8a16..854406267f4 100644 --- a/mne/io/ant/ant.py +++ b/mne/io/ant/ant.py @@ -4,7 +4,6 @@ from __future__ import annotations -import importlib import re from collections import defaultdict from typing import TYPE_CHECKING @@ -16,8 +15,8 @@ from ...annotations import Annotations from ...utils import ( _check_fname, + _soft_import, _validate_type, - check_version, copy_doc, fill_doc, logger, @@ -80,6 +79,8 @@ class RawANT(BaseRaw): Note that the impedance annotation will likely have a duration of ``0``. If the measurement marks a discontinuity, the duration should be modified to cover the discontinuity in its entirety. + encoding : str + Encoding to use for :class:`str` in the CNT file. Defaults to ``'latin-1'``. %(preload)s %(verbose)s """ @@ -93,16 +94,12 @@ def __init__( bipolars: list[str] | tuple[str, ...] | None, impedance_annotation: str, *, + encoding: str = "latin-1", preload: bool | NDArray, verbose=None, ) -> None: logger.info("Reading ANT file %s", fname) - if importlib.util.find_spec("antio") is None: - raise ImportError( - "Missing optional dependency 'antio'. Use pip or conda to install " - "'antio'." - ) - check_version("antio", "0.3.0") + _soft_import("antio", "reading ANT files", min_version="0.5.0") from antio import read_cnt from antio.parser import ( @@ -122,8 +119,7 @@ def __init__( raise ValueError("The impedance annotation cannot be an empty string.") cnt = read_cnt(fname) # parse channels, sampling frequency, and create info - ch_info = read_info(cnt) # load in 2 lines for compat with antio 0.2 and 0.3 - ch_names, ch_units, ch_refs = ch_info[0], ch_info[1], ch_info[2] + ch_names, ch_units, ch_refs, _, _ = read_info(cnt, encoding=encoding) ch_types = _parse_ch_types(ch_names, eog, misc, ch_refs) if bipolars is not None: # handle bipolar channels bipolars_idx = _handle_bipolar_channels(ch_names, ch_refs, bipolars) @@ -139,9 +135,9 @@ def __init__( ch_names, sfreq=cnt.get_sample_frequency(), ch_types=ch_types ) info.set_meas_date(read_meas_date(cnt)) - make, model, serial, site = read_device_info(cnt) + make, model, serial, site = read_device_info(cnt, encoding=encoding) info["device_info"] = dict(type=make, model=model, serial=serial, site=site) - his_id, name, sex, birthday = read_subject_info(cnt) + his_id, name, sex, birthday = read_subject_info(cnt, encoding=encoding) info["subject_info"] = dict( his_id=his_id, first_name=name, @@ -315,6 +311,7 @@ def read_raw_ant( bipolars=None, impedance_annotation="impedance", *, + encoding: str = "latin-1", preload=False, verbose=None, ) -> RawANT: @@ -324,6 +321,10 @@ def read_raw_ant( raw : instance of RawANT A Raw object containing ANT data. See :class:`mne.io.Raw` for documentation of attributes and methods. + + Notes + ----- + .. versionadded:: 1.9 """ return RawANT( fname, @@ -331,6 +332,7 @@ def read_raw_ant( misc=misc, bipolars=bipolars, impedance_annotation=impedance_annotation, + encoding=encoding, preload=preload, verbose=verbose, ) diff --git a/mne/io/ant/tests/test_ant.py b/mne/io/ant/tests/test_ant.py index e51c40cfde6..8c8530d400d 100644 --- a/mne/io/ant/tests/test_ant.py +++ b/mne/io/ant/tests/test_ant.py @@ -17,7 +17,7 @@ from mne.io import BaseRaw, read_raw, read_raw_ant, read_raw_brainvision from mne.io.ant.ant import RawANT -pytest.importorskip("antio", minversion="0.4.0") +pytest.importorskip("antio", minversion="0.5.0") data_path = testing.data_path(download=False) / "antio" diff --git a/mne/io/array/__init__.py b/mne/io/array/__init__.py index aea21ef42ce..ad53f7c817f 100644 --- a/mne/io/array/__init__.py +++ b/mne/io/array/__init__.py @@ -4,4 +4,4 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from .array import RawArray +from ._array import RawArray diff --git a/mne/io/array/array.py b/mne/io/array/_array.py similarity index 100% rename from mne/io/array/array.py rename to mne/io/array/_array.py diff --git a/mne/io/artemis123/tests/test_artemis123.py b/mne/io/artemis123/tests/test_artemis123.py index 039108eb915..610f32ba5da 100644 --- a/mne/io/artemis123/tests/test_artemis123.py +++ b/mne/io/artemis123/tests/test_artemis123.py @@ -35,9 +35,9 @@ def _assert_trans(actual, desired, dist_tol=0.017, angle_tol=5.0): angle = np.rad2deg(_angle_between_quats(quat_est, quat)) dist = np.linalg.norm(trans - trans_est) - assert ( - dist <= dist_tol - ), f"{1000 * dist:0.3f} > {1000 * dist_tol:0.3f} mm translation" + assert dist <= dist_tol, ( + f"{1000 * dist:0.3f} > {1000 * dist_tol:0.3f} mm translation" + ) assert angle <= angle_tol, f"{angle:0.3f} > {angle_tol:0.3f}° rotation" diff --git a/mne/io/base.py b/mne/io/base.py index 5580b88ea25..280330367f7 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -1013,8 +1013,7 @@ def get_data( if n_rejected > 0: if reject_by_annotation == "omit": msg = ( - "Omitting {} of {} ({:.2%}) samples, retaining {}" - " ({:.2%}) samples." + "Omitting {} of {} ({:.2%}) samples, retaining {} ({:.2%}) samples." ) logger.info( msg.format( @@ -1500,6 +1499,71 @@ def resample( ) return self, events + @verbose + def rescale(self, scalings, *, verbose=None): + """Rescale channels. + + .. warning:: + MNE-Python assumes data are stored in SI base units. This function should + typically only be used to fix an incorrect scaling factor in the data to get + it to be in SI base units, otherwise unintended problems (e.g., incorrect + source imaging results) and analysis errors can occur. + + Parameters + ---------- + scalings : int | float | dict + The scaling factor(s) by which to multiply the data. If a float, the same + scaling factor is applied to all channels (this works only if all channels + are of the same type). If a dict, the keys must be valid channel types and + the values the scaling factors to apply to the corresponding channels. + %(verbose)s + + Returns + ------- + raw : Raw + The raw object with rescaled data (modified in-place). + + Examples + -------- + A common use case for EEG data is to convert from µV to V, since many EEG + systems store data in µV, but MNE-Python expects the data to be in V. Therefore, + the data needs to be rescaled by a factor of 1e-6. To rescale all channels from + µV to V, you can do:: + + >>> raw.rescale(1e-6) # doctest: +SKIP + + Note that the previous example only works if all channels are of the same type. + If there are multiple channel types, you can pass a dict with the individual + scaling factors. For example, to rescale only EEG channels, you can do:: + + >>> raw.rescale({"eeg": 1e-6}) # doctest: +SKIP + """ + _validate_type(scalings, (int, float, dict), "scalings") + _check_preload(self, "raw.rescale") + + channel_types = self.get_channel_types(unique=True) + + if isinstance(scalings, int | float): + if len(channel_types) == 1: + self.apply_function(lambda x: x * scalings, channel_wise=False) + else: + raise ValueError( + "If scalings is a scalar, all channels must be of the same type. " + "Consider passing a dict instead." + ) + else: + for ch_type in scalings.keys(): + if ch_type not in channel_types: + raise ValueError( + f'Channel type "{ch_type}" is not present in the Raw file.' + ) + for ch_type, ch_scale in scalings.items(): + self.apply_function( + lambda x: x * ch_scale, picks=ch_type, channel_wise=False + ) + + return self + @verbose def crop(self, tmin=0.0, tmax=None, include_tmax=True, *, verbose=None): """Crop raw data file. @@ -2092,7 +2156,7 @@ def append(self, raws, preload=None): for edge_samp in edge_samps: onset = _sync_onset(self, edge_samp / self.info["sfreq"], True) logger.debug( - f"Marking edge at {edge_samp} samples " f"(maps to {onset:0.3f} sec)" + f"Marking edge at {edge_samp} samples (maps to {onset:0.3f} sec)" ) self.annotations.append(onset, 0.0, "BAD boundary") self.annotations.append(onset, 0.0, "EDGE boundary") @@ -3122,7 +3186,7 @@ def concatenate_raws( @fill_doc -def match_channel_orders(insts=None, copy=True, *, raws=None): +def match_channel_orders(insts, copy=True): """Ensure consistent channel order across instances (Raw, Epochs, or Evoked). Parameters @@ -3131,9 +3195,6 @@ def match_channel_orders(insts=None, copy=True, *, raws=None): List of :class:`~mne.io.Raw`, :class:`~mne.Epochs`, or :class:`~mne.Evoked` instances to order. %(copy_df)s - raws : list - This parameter is deprecated and will be removed in mne version 1.9. - Please use ``insts`` instead. Returns ------- @@ -3141,20 +3202,6 @@ def match_channel_orders(insts=None, copy=True, *, raws=None): List of instances (Raw, Epochs, or Evoked) with channel orders matched according to the order they had in the first item in the ``insts`` list. """ - # XXX: remove "raws" parameter and logic below with MNE version 1.9 - # and remove default parameter value of insts - if raws is not None: - warn( - "The ``raws`` parameter is deprecated and will be removed in version " - "1.9. Use the ``insts`` parameter to suppress this warning.", - DeprecationWarning, - ) - insts = raws - elif insts is None: - # both insts and raws is None - raise ValueError( - "You need to pass a list of Raw, Epochs, or Evoked to ``insts``." - ) insts = deepcopy(insts) if copy else insts ch_order = insts[0].ch_names for inst in insts[1:]: diff --git a/mne/io/cnt/cnt.py b/mne/io/cnt/cnt.py index 16c074269bf..da91ee59f9e 100644 --- a/mne/io/cnt/cnt.py +++ b/mne/io/cnt/cnt.py @@ -14,7 +14,7 @@ from ..._fiff.utils import _create_chs, _find_channels, _mult_cal_one, read_str from ...annotations import Annotations from ...channels.layout import _topo_to_sphere -from ...utils import _check_option, _validate_type, fill_doc, warn +from ...utils import _check_option, _explain_exception, _validate_type, fill_doc, warn from ..base import BaseRaw from ._utils import ( CNTEventType3, @@ -150,7 +150,22 @@ def _update_bad_span_onset(accept_reject, onset, duration, description): np.array([e.KeyPad_Accept for e in my_events]) ) - description = np.array([str(e.StimType) for e in my_events]) + # Check to see if there are any button presses + description = [] + for event in my_events: + # Extract the 4-bit fields + # Upper nibble (4 bits) currently not used + # accept = (event.KeyPad_Accept[0] & 0xF0) >> 4 + # Lower nibble (4 bits) keypad button press + keypad = event.KeyPad_Accept[0] & 0x0F + if str(keypad) != "0": + description.append(f"KeyPad Response {keypad}") + elif event.KeyBoard != 0: + description.append(f"Keyboard Response {event.KeyBoard}") + else: + description.append(str(event.StimType)) + + description = np.array(description) onset, duration, description = _update_bad_span_onset( accept_reject, onset / sfreq, duration, description @@ -532,7 +547,8 @@ def __init__( ) except Exception: raise RuntimeError( - "Could not read header from *.cnt file. mne.io.read_raw_cnt " + f"{_explain_exception()}\n" + "WARNING: mne.io.read_raw_cnt " "supports Neuroscan CNT files only. If this file is an ANT Neuro CNT, " "please use mne.io.read_raw_ant instead." ) diff --git a/mne/io/cnt/tests/test_cnt.py b/mne/io/cnt/tests/test_cnt.py index c098b58e6f3..f98253b1317 100644 --- a/mne/io/cnt/tests/test_cnt.py +++ b/mne/io/cnt/tests/test_cnt.py @@ -57,7 +57,8 @@ def test_auto_data(): third = pytest.warns(RuntimeWarning, match="Omitted 6 annot") with first, second, third: raw = read_raw_cnt(input_fname=fname_bad_spans) - + # Test that responses are read properly + assert "KeyPad Response 1" in raw.annotations.description assert raw.info["bads"] == ["F8"] with _no_parse, pytest.warns(RuntimeWarning, match="number of bytes"): diff --git a/mne/io/ctf/ctf.py b/mne/io/ctf/ctf.py index 44a4e39adf6..971ac51c2f6 100644 --- a/mne/io/ctf/ctf.py +++ b/mne/io/ctf/ctf.py @@ -267,7 +267,7 @@ def _get_sample_info(fname, res4, system_clock): fid.seek(offset, 0) this_data = np.fromfile(fid, ">i4", res4["nsamp"]) if len(this_data) != res4["nsamp"]: - raise RuntimeError(f"Cannot read data for trial {t+1}.") + raise RuntimeError(f"Cannot read data for trial {t + 1}.") end = np.where(this_data == 0)[0] if len(end) > 0: n_samp = samp_offset + end[0] diff --git a/mne/io/ctf/info.py b/mne/io/ctf/info.py index 1b96d8bd88f..685a20792d3 100644 --- a/mne/io/ctf/info.py +++ b/mne/io/ctf/info.py @@ -50,8 +50,7 @@ def _pick_isotrak_and_hpi_coils(res4, coils, t): if p["coord_frame"] == FIFF.FIFFV_MNE_COORD_CTF_DEVICE: if t is None or t["t_ctf_dev_dev"] is None: raise RuntimeError( - "No coordinate transformation " - "available for HPI coil locations" + "No coordinate transformation available for HPI coil locations" ) d = dict( kind=kind, diff --git a/mne/io/ctf/tests/test_ctf.py b/mne/io/ctf/tests/test_ctf.py index 4a5dd846655..448ea90baba 100644 --- a/mne/io/ctf/tests/test_ctf.py +++ b/mne/io/ctf/tests/test_ctf.py @@ -243,9 +243,9 @@ def test_read_ctf(tmp_path): # Make sure all digitization points are in the MNE head coord frame for p in raw.info["dig"]: - assert ( - p["coord_frame"] == FIFF.FIFFV_COORD_HEAD - ), "dig points must be in FIFF.FIFFV_COORD_HEAD" + assert p["coord_frame"] == FIFF.FIFFV_COORD_HEAD, ( + "dig points must be in FIFF.FIFFV_COORD_HEAD" + ) if fname.endswith("catch-alp-good-f.ds"): # omit points from .pos file with raw.info._unlock(): diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index bb79c46f24a..09ac24f753e 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -436,21 +436,24 @@ def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, cals, ones[orig_idx, smp_read : smp_read + len(one_i)] = one_i n_smp_read[orig_idx] += len(one_i) + # resample channels with lower sample frequency # skip if no data was requested, ie. only annotations were read - if sum(n_smp_read) > 0: + if any(n_smp_read) > 0: # expected number of samples, equals maximum sfreq smp_exp = data.shape[-1] - assert max(n_smp_read) == smp_exp # resample data after loading all chunks to prevent edge artifacts resampled = False + for i, smp_read in enumerate(n_smp_read): # nothing read, nothing to resample if smp_read == 0: continue # upsample if n_samples is lower than from highest sfreq if smp_read != smp_exp: - assert (ones[i, smp_read:] == 0).all() # sanity check + # sanity check that we read exactly how much we expected + assert (ones[i, smp_read:] == 0).all() + ones[i, :] = resample( ones[i, :smp_read].astype(np.float64), smp_exp, @@ -628,7 +631,7 @@ def _get_info( if len(chs_without_types): msg = ( "Could not determine channel type of the following channels, " - f'they will be set as EEG:\n{", ".join(chs_without_types)}' + f"they will be set as EEG:\n{', '.join(chs_without_types)}" ) logger.info(msg) @@ -712,8 +715,8 @@ def _get_info( if info["highpass"] > info["lowpass"]: warn( - f'Highpass cutoff frequency {info["highpass"]} is greater ' - f'than lowpass cutoff frequency {info["lowpass"]}, ' + f"Highpass cutoff frequency {info['highpass']} is greater " + f"than lowpass cutoff frequency {info['lowpass']}, " "setting values to 0 and Nyquist." ) info["highpass"] = 0.0 diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py index b4f0ab33fa5..ce671ca7e81 100644 --- a/mne/io/edf/tests/test_edf.py +++ b/mne/io/edf/tests/test_edf.py @@ -259,6 +259,24 @@ def test_edf_different_sfreqs(stim_channel): assert_allclose(times1, times2) +@testing.requires_testing_data +@pytest.mark.parametrize("stim_channel", (None, False, "auto")) +def test_edf_different_sfreqs_nopreload(stim_channel): + """Test loading smaller sfreq channels without preloading.""" + # load without preloading, then load a channel that has smaller sfreq + # as other channels, produced an error, see mne-python/issues/12897 + + for i in range(1, 13): + raw = read_raw_edf(input_fname=edf_reduced, verbose="error", preload=False) + + # this should work for channels of all sfreq, even if larger sfreqs + # are present in the file + x1 = raw.get_data(picks=[f"A{i}"], return_times=False) + # load next ch, this is sometimes with a higher sometimes a lower sfreq + x2 = raw.get_data([f"A{i + 1}"], return_times=False) + assert x1.shape == x2.shape + + def test_edf_data_broken(tmp_path): """Test edf files.""" raw = _test_raw_reader( diff --git a/mne/io/egi/egimff.py b/mne/io/egi/egimff.py index b2f08020e15..c3a10fb72cd 100644 --- a/mne/io/egi/egimff.py +++ b/mne/io/egi/egimff.py @@ -106,7 +106,7 @@ def _read_mff_header(filepath): if bad: raise RuntimeError( "EGI epoch first/last samps could not be parsed:\n" - f'{list(epochs["first_samps"])}\n{list(epochs["last_samps"])}' + f"{list(epochs['first_samps'])}\n{list(epochs['last_samps'])}" ) summaryinfo.update(epochs) # index which samples in raw are actually readable from disk (i.e., not diff --git a/mne/io/egi/tests/test_egi.py b/mne/io/egi/tests/test_egi.py index d91b4fbd264..584714d9c8a 100644 --- a/mne/io/egi/tests/test_egi.py +++ b/mne/io/egi/tests/test_egi.py @@ -57,9 +57,6 @@ egi_pause_w1337_events = None egi_pause_w1337_skips = [(21956000.0, 40444000.0), (60936000.0, 89332000.0)] -# TODO: Remove once complete deprecation / FutureWarning about events_as_annonations -pytestmark = pytest.mark.filterwarnings("ignore:.*events_as_annotation.*:FutureWarning") - @requires_testing_data @pytest.mark.parametrize( diff --git a/mne/io/fieldtrip/fieldtrip.py b/mne/io/fieldtrip/fieldtrip.py index 5d94d3e0a80..c8521722003 100644 --- a/mne/io/fieldtrip/fieldtrip.py +++ b/mne/io/fieldtrip/fieldtrip.py @@ -7,7 +7,7 @@ from ...epochs import EpochsArray from ...evoked import EvokedArray from ...utils import _check_fname, _import_pymatreader_funcs -from ..array.array import RawArray +from ..array._array import RawArray from .utils import ( _create_event_metadata, _create_events, diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index 3e89ee335c9..1ae0cc52901 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -479,14 +479,6 @@ def test_concatenate_raws_order(): match_channel_orders(insts=raws, copy=True) raw_concat = concatenate_raws(raws) - # XXX: remove in version 1.9 - with pytest.warns(DeprecationWarning, match="``raws`` parameter is deprecated"): - match_channel_orders(raws=raws) - - # XXX: remove in version 1.9 - with pytest.raises(ValueError, match="need to pass a list"): - match_channel_orders() - # Now passes because all raws have the same order match_channel_orders(insts=raws, copy=False) raw_concat = concatenate_raws(raws) diff --git a/mne/io/fil/tests/test_fil.py b/mne/io/fil/tests/test_fil.py index 06d3d924319..df15dd13353 100644 --- a/mne/io/fil/tests/test_fil.py +++ b/mne/io/fil/tests/test_fil.py @@ -87,9 +87,9 @@ def _fil_megmag(raw_test, raw_mat): mat_list = raw_mat["label"] mat_inds = _match_str(test_list, mat_list) - assert len(mat_inds) == len( - test_inds - ), "Number of magnetometer channels in RAW does not match .mat file!" + assert len(mat_inds) == len(test_inds), ( + "Number of magnetometer channels in RAW does not match .mat file!" + ) a = raw_test._data[test_inds, :] b = raw_mat["trial"][mat_inds, :] * 1e-15 # fT to T @@ -106,9 +106,9 @@ def _fil_stim(raw_test, raw_mat): mat_list = raw_mat["label"] mat_inds = _match_str(test_list, mat_list) - assert len(mat_inds) == len( - test_inds - ), "Number of stim channels in RAW does not match .mat file!" + assert len(mat_inds) == len(test_inds), ( + "Number of stim channels in RAW does not match .mat file!" + ) a = raw_test._data[test_inds, :] b = raw_mat["trial"][mat_inds, :] # fT to T @@ -122,9 +122,9 @@ def _fil_sensorpos(raw_test, raw_mat): grad_list = raw_mat["coil_label"] grad_inds = _match_str(test_list, grad_list) - assert len(grad_inds) == len( - test_inds - ), "Number of channels with position data in RAW does not match .mat file!" + assert len(grad_inds) == len(test_inds), ( + "Number of channels with position data in RAW does not match .mat file!" + ) mat_pos = raw_mat["coil_pos"][grad_inds, :] mat_ori = raw_mat["coil_ori"][grad_inds, :] diff --git a/mne/io/neuralynx/neuralynx.py b/mne/io/neuralynx/neuralynx.py index 2b9bed80ae8..56ff9fa4adb 100644 --- a/mne/io/neuralynx/neuralynx.py +++ b/mne/io/neuralynx/neuralynx.py @@ -77,7 +77,7 @@ def read_raw_neuralynx( ) -# Helper for neo deprecation of exclude_filename -> exclude_filenames in 0.13.2 +# Helper for neo change of exclude_filename -> exclude_filenames in 0.13.2 def _exclude_kwarg(exclude_fnames): from neo.io import NeuralynxIO diff --git a/mne/io/neuralynx/tests/test_neuralynx.py b/mne/io/neuralynx/tests/test_neuralynx.py index ea5cdbccdfb..18578ef4ab7 100644 --- a/mne/io/neuralynx/tests/test_neuralynx.py +++ b/mne/io/neuralynx/tests/test_neuralynx.py @@ -143,9 +143,9 @@ def test_neuralynx(): assert raw.info["meas_date"] == meas_date_utc, "meas_date not set correctly" # test that channel selection worked - assert ( - raw.ch_names == expected_chan_names - ), "labels in raw.ch_names don't match expected channel names" + assert raw.ch_names == expected_chan_names, ( + "labels in raw.ch_names don't match expected channel names" + ) mne_y = raw.get_data() # in V @@ -216,9 +216,9 @@ def test_neuralynx_gaps(): n_expected_gaps = 3 n_expected_missing_samples = 130 assert len(raw.annotations) == n_expected_gaps, "Wrong number of gaps detected" - assert ( - (mne_y[0, :] == 0).sum() == n_expected_missing_samples - ), "Number of true and inferred missing samples differ" + assert (mne_y[0, :] == 0).sum() == n_expected_missing_samples, ( + "Number of true and inferred missing samples differ" + ) # read in .mat files containing original gaps matchans = ["LAHC1_3_gaps.mat", "LAHC2_3_gaps.mat"] diff --git a/mne/io/nirx/nirx.py b/mne/io/nirx/nirx.py index 53a812e7a21..5d9b79b57cc 100644 --- a/mne/io/nirx/nirx.py +++ b/mne/io/nirx/nirx.py @@ -210,7 +210,7 @@ def __init__(self, fname, saturated, *, preload=False, encoding=None, verbose=No ): warn( "Only import of data from NIRScout devices have been " - f'thoroughly tested. You are using a {hdr["GeneralInfo"]["Device"]}' + f"thoroughly tested. You are using a {hdr['GeneralInfo']['Device']}" " device." ) diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index be87a34526a..8f773533ae4 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -533,7 +533,7 @@ def _test_raw_crop(reader, t_prop, kwargs): n_samp = 50 # crop to this number of samples (per instance) crop_t = n_samp / raw_1.info["sfreq"] t_start = t_prop * crop_t # also crop to some fraction into the first inst - extra = f' t_start={t_start}, preload={kwargs.get("preload", False)}' + extra = f" t_start={t_start}, preload={kwargs.get('preload', False)}" stop = (n_samp - 1) / raw_1.info["sfreq"] raw_1.crop(0, stop) assert len(raw_1.times) == 50 @@ -1063,3 +1063,20 @@ def test_last_samp(): raw = read_raw_fif(raw_fname).crop(0, 0.1).load_data() last_data = raw._data[:, [-1]] assert_array_equal(raw[:, -1][0], last_data) + + +def test_rescale(): + """Test rescaling channels.""" + raw = read_raw_fif(raw_fname, preload=True) # multiple channel types + + with pytest.raises(ValueError, match="If scalings is a scalar, all channels"): + raw.rescale(2) # need to use dict + + orig = raw.get_data(picks="eeg") + raw.rescale({"eeg": 2}) # need to use dict + assert_allclose(raw.get_data(picks="eeg"), orig * 2) + + raw.pick("mag") # only a single channel type "mag" + orig = raw.get_data() + raw.rescale(4) # a scalar works + assert_allclose(raw.get_data(), orig * 4) diff --git a/mne/label.py b/mne/label.py index f68144106c3..02bf9dc09c0 100644 --- a/mne/label.py +++ b/mne/label.py @@ -264,8 +264,7 @@ def __init__( if not (len(vertices) == len(values) == len(pos)): raise ValueError( - "vertices, values and pos need to have same " - "length (number of vertices)" + "vertices, values and pos need to have same length (number of vertices)" ) # name @@ -416,7 +415,7 @@ def __sub__(self, other): else: keep = np.arange(len(self.vertices)) - name = f'{self.name or "unnamed"} - {other.name or "unnamed"}' + name = f"{self.name or 'unnamed'} - {other.name or 'unnamed'}" return Label( self.vertices[keep], self.pos[keep], @@ -976,8 +975,7 @@ def _get_label_src(label, src): src = _ensure_src(src) if src.kind != "surface": raise RuntimeError( - "Cannot operate on SourceSpaces that are not " - f"surface type, got {src.kind}" + f"Cannot operate on SourceSpaces that are not surface type, got {src.kind}" ) if label.hemi == "lh": hemi_src = src[0] @@ -1585,8 +1583,7 @@ def stc_to_label( vertno = np.where(src[hemi_idx]["inuse"])[0] if not len(np.setdiff1d(this_vertno, vertno)) == 0: raise RuntimeError( - "stc contains vertices not present " - "in source space, did you morph?" + "stc contains vertices not present in source space, did you morph?" ) tmp = np.zeros((len(vertno), this_data.shape[1])) this_vertno_idx = np.searchsorted(vertno, this_vertno) @@ -2151,8 +2148,7 @@ def _read_annot(fname): cands = _read_annot_cands(dir_name) if len(cands) == 0: raise OSError( - f"No such file {fname}, no candidate parcellations " - "found in directory" + f"No such file {fname}, no candidate parcellations found in directory" ) else: raise OSError( diff --git a/mne/minimum_norm/inverse.py b/mne/minimum_norm/inverse.py index e5129a4822f..7c789503ac1 100644 --- a/mne/minimum_norm/inverse.py +++ b/mne/minimum_norm/inverse.py @@ -673,7 +673,7 @@ def prepare_inverse_operator( inv["eigen_leads"]["data"] = sqrt(scale) * inv["eigen_leads"]["data"] logger.info( - " Scaled noise and source covariance from nave = %d to" " nave = %d", + " Scaled noise and source covariance from nave = %d to nave = %d", inv["nave"], nave, ) @@ -2011,7 +2011,7 @@ def make_inverse_operator( logger.info( f" scaling factor to adjust the trace = {trace_GRGT:g} " f"(nchan = {eigen_fields.shape[0]} " - f'nzero = {(noise_cov["eig"] <= 0).sum()})' + f"nzero = {(noise_cov['eig'] <= 0).sum()})" ) # MNE-ify everything for output diff --git a/mne/minimum_norm/tests/test_inverse.py b/mne/minimum_norm/tests/test_inverse.py index aa3f8294027..5b5c941a9ac 100644 --- a/mne/minimum_norm/tests/test_inverse.py +++ b/mne/minimum_norm/tests/test_inverse.py @@ -130,8 +130,7 @@ def _compare(a, b): for k, v in a.items(): if k not in b and k not in skip_types: raise ValueError( - "First one had one second one didn't:\n" - f"{k} not in {b.keys()}" + f"First one had one second one didn't:\n{k} not in {b.keys()}" ) if k not in skip_types: last_keys.pop() diff --git a/mne/morph.py b/mne/morph.py index 9c475bff1e9..a8278731f3c 100644 --- a/mne/morph.py +++ b/mne/morph.py @@ -200,8 +200,7 @@ def compute_source_morph( if kind not in "surface" and xhemi: raise ValueError( - "Inter-hemispheric morphing can only be used " - "with surface source estimates." + "Inter-hemispheric morphing can only be used with surface source estimates." ) if sparse and kind != "surface": raise ValueError("Only surface source estimates can compute a sparse morph.") @@ -1301,8 +1300,7 @@ def grade_to_vertices(subject, grade, subjects_dir=None, n_jobs=None, verbose=No if isinstance(grade, list): if not len(grade) == 2: raise ValueError( - "grade as a list must have two elements " - "(arrays of output vertices)" + "grade as a list must have two elements (arrays of output vertices)" ) vertices = grade else: @@ -1385,8 +1383,7 @@ def _surf_upsampling_mat(idx_from, e, smooth): smooth = _ensure_int(smooth, "smoothing steps") if smooth <= 0: # == 0 is handled in a shortcut above raise ValueError( - "The number of smoothing operations has to be at least 0, got " - f"{smooth}" + f"The number of smoothing operations has to be at least 0, got {smooth}" ) smooth = smooth - 1 # idx will gradually expand from idx_from -> np.arange(n_tot) diff --git a/mne/preprocessing/__init__.pyi b/mne/preprocessing/__init__.pyi index 54f1c825c13..c54685dba34 100644 --- a/mne/preprocessing/__init__.pyi +++ b/mne/preprocessing/__init__.pyi @@ -44,6 +44,7 @@ __all__ = [ "realign_raw", "regress_artifact", "write_fine_calibration", + "apply_pca_obs", ] from . import eyetracking, ieeg, nirs from ._annotate_amplitude import annotate_amplitude @@ -56,6 +57,7 @@ from ._fine_cal import ( write_fine_calibration, ) from ._lof import find_bad_channels_lof +from ._pca_obs import apply_pca_obs from ._peak_finder import peak_finder from ._regress import EOGRegression, read_eog_regression, regress_artifact from .artifact_detection import ( diff --git a/mne/preprocessing/_annotate_amplitude.py b/mne/preprocessing/_annotate_amplitude.py index 943c20c0ba2..0cd5676e703 100644 --- a/mne/preprocessing/_annotate_amplitude.py +++ b/mne/preprocessing/_annotate_amplitude.py @@ -126,7 +126,7 @@ def annotate_amplitude( for ch_type, picks_of_type in _picks_by_type(raw.info, exclude="bads") if np.intersect1d(picks_of_type, picks_, assume_unique=True).size != 0 } - del picks_ # re-using this variable name in for loop + del picks_ # reusing this variable name in for loop # skip BAD_acq_skip sections onsets, ends = _annotations_starts_stops(raw, "bad_acq_skip", invert=True) diff --git a/mne/preprocessing/_fine_cal.py b/mne/preprocessing/_fine_cal.py index 41d20539ce0..06041cd7f8e 100644 --- a/mne/preprocessing/_fine_cal.py +++ b/mne/preprocessing/_fine_cal.py @@ -156,11 +156,12 @@ def compute_fine_calibration( # 1. Rotate surface normals using magnetometer information (if present) # cals = np.ones(len(info["ch_names"])) - time_idxs = raw.time_as_index(np.arange(0.0, raw.times[-1], t_window)) - if len(time_idxs) <= 1: - time_idxs = np.array([0, len(raw.times)], int) - else: - time_idxs[-1] = len(raw.times) + end = len(raw.times) + 1 + time_idxs = np.arange(0, end, int(round(t_window * raw.info["sfreq"]))) + if len(time_idxs) == 1: + time_idxs = np.concatenate([time_idxs, [end]]) + if time_idxs[-1] != end: + time_idxs[-1] = end count = 0 locs = np.array([ch["loc"] for ch in info["chs"]]) zs = locs[mag_picks, -3:].copy() @@ -388,9 +389,11 @@ def _adjust_mag_normals(info, data, origin, ext_order, *, angle_limit, err_limit each_err = _data_err(data, S_tot, cals, axis=-1)[picks_mag] n_bad = (each_err > err_limit).sum() if n_bad: + bad_max = np.argmax(each_err) reason.append( f"{n_bad} residual{_pl(n_bad)} > {err_limit:0.1f}% " - f"(max: {each_err.max():0.2f}%)" + f"(max: {each_err[bad_max]:0.2f}% @ " + f"{info['ch_names'][picks_mag[bad_max]]})" ) reason = ", ".join(reason) if reason: @@ -398,7 +401,7 @@ def _adjust_mag_normals(info, data, origin, ext_order, *, angle_limit, err_limit good = not bool(reason) assert np.allclose(np.linalg.norm(zs, axis=1), 1.0) logger.info(f" Fit mismatch {first_err:0.2f}→{last_err:0.2f}%") - logger.info(f' Data segment {"" if good else "un"}usable{reason}') + logger.info(f" Data segment {'' if good else 'un'}usable{reason}") # Reformat zs and cals to be the n_mags (including bads) assert zs.shape == (len(data), 3) assert cals.shape == (len(data), 1) diff --git a/mne/preprocessing/_pca_obs.py b/mne/preprocessing/_pca_obs.py new file mode 100755 index 00000000000..be226a73889 --- /dev/null +++ b/mne/preprocessing/_pca_obs.py @@ -0,0 +1,333 @@ +"""Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import math + +import numpy as np +from scipy.interpolate import PchipInterpolator as pchip +from scipy.signal import detrend + +from ..io.fiff.raw import Raw +from ..utils import _PCA, _validate_type, logger, verbose + + +@verbose +def apply_pca_obs( + raw: Raw, + picks: list[str], + *, + qrs_times: np.ndarray, + n_components: int = 4, + n_jobs: int | None = None, + copy: bool = True, + verbose: bool | str | int | None = None, +) -> Raw: + """ + Apply the PCA-OBS algorithm to picks of a Raw object. + + Uses the optimal basis set (OBS) algorithm from :footcite:`NiazyEtAl2005`. + + Parameters + ---------- + raw : instance of Raw + The raw data to process. + %(picks_all_data_noref)s + qrs_times : ndarray, shape (n_peaks,) + Array of times in the Raw data of detected R-peaks in ECG channel. + n_components : int + Number of PCA components to use to form the OBS (default 4). + %(n_jobs)s + copy : bool + If False, modify the Raw instance in-place. + If True (default), copy the raw instance before processing. + %(verbose)s + + Returns + ------- + raw : instance of Raw + The modified raw instance. + + Notes + ----- + .. versionadded:: 1.10 + + References + ---------- + .. footbibliography:: + """ + # sanity checks + _validate_type(qrs_times, np.ndarray, "qrs_times") + if len(qrs_times.shape) > 1: + raise ValueError("qrs_times must be a 1d array") + if qrs_times.dtype not in [int, float]: + raise ValueError("qrs_times must be an array of either integers or floats") + if np.any(qrs_times < 0): + raise ValueError("qrs_times must be strictly positive") + if np.any(qrs_times >= raw.times[-1]): + logger.warning("some out of bound qrs_times will be ignored..") + + if copy: + raw = raw.copy() + + raw.apply_function( + _pca_obs, + picks=picks, + n_jobs=n_jobs, + # args sent to PCA_OBS, convert times to indices + qrs=raw.time_as_index(qrs_times), + n_components=n_components, + ) + + return raw + + +def _pca_obs( + data: np.ndarray, + qrs: np.ndarray, + n_components: int, +) -> np.ndarray: + """Algorithm to remove heart artefact from EEG data (array of length n_times).""" + # set to baseline + data = data - np.mean(data) + + # Allocate memory for artifact which will be subtracted from the data + fitted_art = np.zeros(data.shape) + + # Extract QRS event indexes which are within out data timeframe + peak_idx = qrs[qrs < len(data)] + peak_count = len(peak_idx) + + ################################################################## + # Preparatory work - reserving memory, configure sizes, de-trend # + ################################################################## + # define peak range based on RR + mRR = np.median(np.diff(peak_idx)) + peak_range = round(mRR / 2) # Rounds to an integer + mid_p = peak_range + 1 + n_samples_fit = round( + peak_range / 8 + ) # sample fit for interpolation between fitted artifact windows + + # make sure array is long enough for PArange (if not cut off last ECG peak) + # NOTE: Here we previously checked for the last part of the window to be big enough. + while peak_idx[peak_count - 1] + peak_range > len(data): + peak_count = peak_count - 1 # reduce number of QRS complexes detected + + # build PCA matrix(heart-beat-epochs x window-length) + pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time] + # picking out heartbeat epochs + for p in range(1, peak_count): + pcamat[p - 1, :] = data[peak_idx[p] - peak_range : peak_idx[p] + peak_range + 1] + + # detrending matrix(twice) + pcamat = detrend( + pcamat, type="constant", axis=1 + ) # [epoch x time] - detrended along the epoch + mean_effect: np.ndarray = np.mean( + pcamat, axis=0 + ) # [1 x time], contains the mean over all epochs + dpcamat = detrend(pcamat, type="constant", axis=1) # [time x epoch] + + ############################ + # Perform PCA with sklearn # + ############################ + # run PCA, perform singular value decomposition (SVD) + pca = _PCA() + pca.fit(dpcamat) + factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) + + # define selected number of components using profile likelihood + + ##################################### + # Make template of the ECG artefact # + ##################################### + mean_effect = mean_effect.reshape(-1, 1) + pca_template = np.c_[mean_effect, factor_loadings[:, :n_components]] + + ################ + # Data Fitting # + ################ + window_start_idx = [] + window_end_idx = [] + post_idx_next_peak = None + + for p in range(peak_count): + # if the current peak doesn't have enough data in the + # start of the peak_range, skip fitting the artifact + if peak_idx[p] - peak_range < 0: + continue + + # Deals with start portion of data + if p == 0: + pre_range = peak_range + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if post_range > peak_range: + post_range = peak_range + + fitted_art, post_idx_next_peak = _fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_next_peak, + n_samples_fit=n_samples_fit, + ) + # Appending to list instead of using counter + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + + # Deals with last edge of data + elif p == peak_count - 1: + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = peak_range + if pre_range > peak_range: + pre_range = peak_range + fitted_art, _ = _fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_next_peak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + + # Deals with middle portion of data + else: + # ---------------- Processing of central data - -------------------- + # cycle through peak artifacts identified by peakplot + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if pre_range >= peak_range: + pre_range = peak_range + if post_range > peak_range: + post_range = peak_range + + a_template = pca_template[ + mid_p - peak_range - 1 : mid_p + peak_range + 1, : + ] + fitted_art, post_idx_next_peak = _fit_ecg_template( + data=data, + pca_template=a_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_next_peak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + + # Actually subtract the artefact, return needs to be the same shape as input data + data -= fitted_art + return data + + +def _fit_ecg_template( + data: np.ndarray, + pca_template: np.ndarray, + a_peak_idx: int, + peak_range: int, + pre_range: int, + post_range: int, + mid_p: float, + fitted_art: np.ndarray, + post_idx_previous_peak: int | None, + n_samples_fit: int, +) -> tuple[np.ndarray, int]: + """ + Fits the heartbeat artefact found in the data. + + Returns the fitted artefact and the index of the next peak. + + Parameters + ---------- + data (ndarray): Data from the raw signal (n_channels, n_times) + pca_template (ndarray): Mean heartbeat and first N (default 4) + principal components of the heartbeat matrix + a_peak_idx (int): Sample index of current R-peak + peak_range (int): Half the median RR-interval + pre_range (int): Number of samples to fit before the R-peak + post_range (int): Number of samples to fit after the R-peak + mid_p (float): Sample index marking middle of the median RR interval + in the signal. Used to extract relevant part of PCA_template. + fitted_art (ndarray): The computed heartbeat artefact computed to + remove from the data + post_idx_previous_peak (optional int): Sample index of previous R-peak + n_samples_fit (int): Sample fit for interpolation in fitted artifact + windows. Helps reduce sharp edges at end of fitted heartbeat events + + Returns + ------- + tuple[np.ndarray, int]: the fitted artifact and the next peak index + """ + # post_idx_next_peak is passed in in PCA_OBS, used here as post_idx_previous_peak + # Then next_peak is returned at the end and the process repeats + # select window of template + template = pca_template[mid_p - peak_range - 1 : mid_p + peak_range + 1, :] + + # select window of data and detrend it + slice_ = data[a_peak_idx - peak_range : a_peak_idx + peak_range + 1] + + detrended_data = detrend(slice_, type="constant") + + # maps data on template and then maps it again back to the sensor space + least_square = np.linalg.lstsq(template, detrended_data, rcond=None) + pad_fit = np.dot(template, least_square[0]) + + # fit artifact + fitted_art[a_peak_idx - pre_range - 1 : a_peak_idx + post_range] = pad_fit[ + mid_p - pre_range - 1 : mid_p + post_range + ].T + + # if last peak, return + if post_idx_previous_peak is None: + return fitted_art, a_peak_idx + post_range + + # interpolate time between peaks + intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx - pre_range]).astype( + int + ) # interpolation window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window + # endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit + # that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in + # x_interpol + # points to be interpolated in pt - the gap between the endpoints of the window + x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) + # Entire range of x values in this step (taking some + # number of samples before and after the window) + x_fit = np.concatenate( + [ + np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), + np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1), + ] + ) + y_fit = fitted_art[x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # make fitted artefact in the desired range equal to the completed fit above + fitted_art[post_idx_previous_peak : a_peak_idx - pre_range + 1] = y_interpol + + return fitted_art, a_peak_idx + post_range diff --git a/mne/preprocessing/artifact_detection.py b/mne/preprocessing/artifact_detection.py index 0a4c8b6a24d..8674d6e22b3 100644 --- a/mne/preprocessing/artifact_detection.py +++ b/mne/preprocessing/artifact_detection.py @@ -213,7 +213,7 @@ def annotate_movement( onsets, offsets = hp_ts[onsets], hp_ts[offsets] bad_pct = 100 * (offsets - onsets).sum() / t_tot logger.info( - "Omitting %5.1f%% (%3d segments): " "ω >= %5.1f°/s (max: %0.1f°/s)", + "Omitting %5.1f%% (%3d segments): ω >= %5.1f°/s (max: %0.1f°/s)", bad_pct, len(onsets), rotation_velocity_limit, @@ -233,7 +233,7 @@ def annotate_movement( onsets, offsets = hp_ts[onsets], hp_ts[offsets] bad_pct = 100 * (offsets - onsets).sum() / t_tot logger.info( - "Omitting %5.1f%% (%3d segments): " "v >= %5.4fm/s (max: %5.4fm/s)", + "Omitting %5.1f%% (%3d segments): v >= %5.4fm/s (max: %5.4fm/s)", bad_pct, len(onsets), translation_velocity_limit, @@ -286,7 +286,7 @@ def annotate_movement( onsets, offsets = hp_ts[onsets], hp_ts[offsets] bad_pct = 100 * (offsets - onsets).sum() / t_tot logger.info( - "Omitting %5.1f%% (%3d segments): " "disp >= %5.4fm (max: %5.4fm)", + "Omitting %5.1f%% (%3d segments): disp >= %5.4fm (max: %5.4fm)", bad_pct, len(onsets), mean_distance_limit, @@ -539,7 +539,7 @@ def annotate_break( if ignore: logger.info( f"Ignoring annotations with descriptions starting " - f'with: {", ".join(ignore)}' + f"with: {', '.join(ignore)}" ) else: annotations = annotations_from_events( diff --git a/mne/preprocessing/eog.py b/mne/preprocessing/eog.py index 20e5481f89c..13b6f2ef672 100644 --- a/mne/preprocessing/eog.py +++ b/mne/preprocessing/eog.py @@ -213,12 +213,12 @@ def _get_eog_channel_index(ch_name, inst): if not_found: raise ValueError( f"The specified EOG channel{_pl(not_found)} " - f'cannot be found: {", ".join(not_found)}' + f"cannot be found: {', '.join(not_found)}" ) eog_inds = pick_channels(inst.ch_names, include=ch_names) - logger.info(f'Using EOG channel{_pl(ch_names)}: {", ".join(ch_names)}') + logger.info(f"Using EOG channel{_pl(ch_names)}: {', '.join(ch_names)}") return eog_inds diff --git a/mne/preprocessing/hfc.py b/mne/preprocessing/hfc.py index f8a65510a9a..41bf6bbd232 100644 --- a/mne/preprocessing/hfc.py +++ b/mne/preprocessing/hfc.py @@ -68,8 +68,7 @@ def compute_proj_hfc( n_chs = len(coils[5]) if n_chs != info["nchan"]: raise ValueError( - f'Only {n_chs}/{info["nchan"]} picks could be interpreted ' - "as MEG channels." + f"Only {n_chs}/{info['nchan']} picks could be interpreted as MEG channels." ) S = _sss_basis(exp, coils) del coils diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 3ea11e0531e..f35fe24c1ee 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -560,7 +560,7 @@ def __repr__(self): """ICA fit information.""" infos = self._get_infos_for_repr() - s = f'{infos.fit_on or "no"} decomposition, method: {infos.fit_method}' + s = f"{infos.fit_on or 'no'} decomposition, method: {infos.fit_method}" if infos.fit_on is not None: s += ( @@ -568,8 +568,8 @@ def __repr__(self): f"{infos.fit_n_samples} samples), " f"{infos.fit_n_components} ICA components " f"({infos.fit_n_pca_components} PCA components available), " - f'channel types: {", ".join(infos.ch_types)}, ' - f'{len(infos.excludes) or "no"} sources marked for exclusion' + f"channel types: {', '.join(infos.ch_types)}, " + f"{len(infos.excludes) or 'no'} sources marked for exclusion" ) return f"" @@ -698,7 +698,7 @@ def fit( warn( f"The following parameters passed to ICA.fit() will be " f"ignored, as they only affect raw data (and it appears " - f'you passed epochs): {", ".join(ignored_params)}' + f"you passed epochs): {', '.join(ignored_params)}" ) picks = _picks_to_idx( @@ -875,7 +875,7 @@ def _do_proj(self, data, log_suffix=""): logger.info( f" Applying projection operator with {nproj} " f"vector{_pl(nproj)}" - f'{" " if log_suffix else ""}{log_suffix}' + f"{' ' if log_suffix else ''}{log_suffix}" ) if self.noise_cov is None: # otherwise it's in pre_whitener_ data = proj @ data @@ -1162,7 +1162,7 @@ def get_explained_variance_ratio(self, inst, *, components=None, ch_type=None): raise ValueError( f"You requested operation on the channel type " f'"{ch_type}", but only the following channel types are ' - f'supported: {", ".join(allowed_ch_types)}' + f"supported: {', '.join(allowed_ch_types)}" ) del ch_type @@ -2393,8 +2393,7 @@ def _pick_sources(self, data, include, exclude, n_pca_components): unmixing = np.dot(unmixing, pca_components) logger.info( - f" Projecting back using {_n_pca_comp} " - f"PCA component{_pl(_n_pca_comp)}" + f" Projecting back using {_n_pca_comp} PCA component{_pl(_n_pca_comp)}" ) mixing = np.eye(_n_pca_comp) mixing[: self.n_components_, : self.n_components_] = self.mixing_matrix_ @@ -3368,8 +3367,7 @@ def corrmap( is_subject = False else: raise ValueError( - "`template` must be a length-2 tuple or an array the " - "size of the ICA maps." + "`template` must be a length-2 tuple or an array the size of the ICA maps." ) template_fig, labelled_ics = None, None diff --git a/mne/preprocessing/ieeg/_volume.py b/mne/preprocessing/ieeg/_volume.py index b4997b2e3f8..af2dcf4328b 100644 --- a/mne/preprocessing/ieeg/_volume.py +++ b/mne/preprocessing/ieeg/_volume.py @@ -109,7 +109,7 @@ def _warn_missing_chs(info, dig_image, after_warp=False): if missing_ch: warn( f"Channel{_pl(missing_ch)} " - f'{", ".join(repr(ch) for ch in missing_ch)} not assigned ' + f"{', '.join(repr(ch) for ch in missing_ch)} not assigned " "voxels " + (f" after applying {after_warp}" if after_warp else "") ) diff --git a/mne/preprocessing/infomax_.py b/mne/preprocessing/infomax_.py index f0722ce5267..b445ac7116c 100644 --- a/mne/preprocessing/infomax_.py +++ b/mne/preprocessing/infomax_.py @@ -320,8 +320,7 @@ def infomax( if l_rate > min_l_rate: if verbose: logger.info( - f"... lowering learning rate to {l_rate:g}" - "\n... re-starting..." + f"... lowering learning rate to {l_rate:g}\n... re-starting..." ) else: raise ValueError( diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py index 144961f6f16..aaa3f3d536c 100644 --- a/mne/preprocessing/maxwell.py +++ b/mne/preprocessing/maxwell.py @@ -10,7 +10,7 @@ import numpy as np from scipy import linalg -from scipy.special import lpmv, sph_harm +from scipy.special import lpmv from .. import __version__ from .._fiff.compensator import make_compensator @@ -25,7 +25,7 @@ from ..annotations import _annotations_starts_stops from ..bem import _check_origin from ..channels.channels import _get_T1T2_mag_inds, fix_mag_coil_types -from ..fixes import _safe_svd, bincount +from ..fixes import _safe_svd, bincount, sph_harm_y from ..forward import _concatenate_coils, _create_meg_coils, _prep_meg_channels from ..io import BaseRaw, RawArray from ..surface import _normalize_vectors @@ -454,7 +454,7 @@ def _prep_maxwell_filter( # we purposefully stay away from shorthand notation in both and use # explicit terms (like 'azimuth' and 'polar') to avoid confusion. # See mathworld.wolfram.com/SphericalHarmonic.html for more discussion. - # Our code follows the same standard that ``scipy`` uses for ``sph_harm``. + # Our code follows the same standard that ``scipy`` uses for ``sph_harm_y``. # triage inputs ASAP to avoid late-thrown errors _validate_type(raw, BaseRaw, "raw") @@ -544,7 +544,7 @@ def _prep_maxwell_filter( extended_proj_.append(proj["data"]["data"][:, idx]) extended_proj = np.concatenate(extended_proj_) logger.info( - " Extending external SSS basis using %d projection " "vectors", + " Extending external SSS basis using %d projection vectors", len(extended_proj), ) @@ -603,8 +603,8 @@ def _prep_maxwell_filter( dist = np.sqrt(np.sum(_sq(diff))) if dist > 25.0: warn( - f'Head position change is over 25 mm ' - f'({", ".join(f"{x:0.1f}" for x in diff)}) = {dist:0.1f} mm' + f"Head position change is over 25 mm " + f"({', '.join(f'{x:0.1f}' for x in diff)}) = {dist:0.1f} mm" ) # Reconstruct raw file object with spatiotemporal processed data @@ -1652,7 +1652,7 @@ def _sss_basis_basic(exp, coils, mag_scale=100.0, method="standard"): S_in_out = list() grads_in_out = list() # Same spherical harmonic is used for both internal and external - sph = sph_harm(order, degree, az, pol) + sph = sph_harm_y(degree, order, pol, az) sph_norm = _sph_harm_norm(order, degree) # Compute complex gradient for all integration points # in spherical coordinates (Eq. 6). The gradient for rad, az, pol @@ -2746,7 +2746,7 @@ def find_bad_channels_maxwell( freq_loc = "below" if raw.info["lowpass"] < h_freq else "equal to" msg = ( f"The input data has already been low-pass filtered with a " - f'{raw.info["lowpass"]} Hz cutoff frequency, which is ' + f"{raw.info['lowpass']} Hz cutoff frequency, which is " f"{freq_loc} the requested cutoff of {h_freq} Hz. Not " f"applying low-pass filter." ) diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index 92a2e55b9fb..c17cf31110c 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -76,7 +76,7 @@ def beer_lambert_law(raw, ppf=6.0): for ki, kind in zip((ii, jj), ("hbo", "hbr")): ch = raw.info["chs"][ki] ch.update(coil_type=coil_dict[kind], unit=FIFF.FIFF_UNIT_MOL) - new_name = f'{ch["ch_name"].split(" ")[0]} {kind}' + new_name = f"{ch['ch_name'].split(' ')[0]} {kind}" rename[ch["ch_name"]] = new_name raw.rename_channels(rename) diff --git a/mne/preprocessing/stim.py b/mne/preprocessing/stim.py index 7db1bab4c85..a823820988b 100644 --- a/mne/preprocessing/stim.py +++ b/mne/preprocessing/stim.py @@ -11,7 +11,7 @@ from ..event import find_events from ..evoked import Evoked from ..io import BaseRaw -from ..utils import _check_option, _check_preload, fill_doc +from ..utils import _check_option, _check_preload, _validate_type, fill_doc def _get_window(start, end): @@ -20,7 +20,9 @@ def _get_window(start, end): return window -def _fix_artifact(data, window, picks, first_samp, last_samp, mode): +def _fix_artifact( + data, window, picks, first_samp, last_samp, base_tmin, base_tmax, mode +): """Modify original data by using parameter data.""" if mode == "linear": x = np.array([first_samp, last_samp]) @@ -32,6 +34,10 @@ def _fix_artifact(data, window, picks, first_samp, last_samp, mode): data[picks, first_samp:last_samp] = ( data[picks, first_samp:last_samp] * window[np.newaxis, :] ) + if mode == "constant": + data[picks, first_samp:last_samp] = data[picks, base_tmin:base_tmax].mean( + axis=1 + )[:, None] @fill_doc @@ -41,6 +47,8 @@ def fix_stim_artifact( event_id=None, tmin=0.0, tmax=0.01, + *, + baseline=None, mode="linear", stim_channel=None, picks=None, @@ -63,10 +71,23 @@ def fix_stim_artifact( Start time of the interpolation window in seconds. tmax : float End time of the interpolation window in seconds. - mode : 'linear' | 'window' + baseline : None | tuple, shape (2,) + The baseline to use when ``mode='constant'``, in which case it + must be non-None. + + .. versionadded:: 1.8 + mode : 'linear' | 'window' | 'constant' Way to fill the artifacted time interval. - 'linear' does linear interpolation - 'window' applies a (1 - hanning) window. + + ``"linear"`` + Does linear interpolation. + ``"window"`` + Applies a ``(1 - hanning)`` window. + ``"constant"`` + Uses baseline average. baseline parameter must be provided. + + .. versionchanged:: 1.8 + Added the ``"constant"`` mode. stim_channel : str | None Stim channel to use. %(picks_all_data)s @@ -76,9 +97,22 @@ def fix_stim_artifact( inst : instance of Raw or Evoked or Epochs Instance with modified data. """ - _check_option("mode", mode, ["linear", "window"]) + _check_option("mode", mode, ["linear", "window", "constant"]) s_start = int(np.ceil(inst.info["sfreq"] * tmin)) s_end = int(np.ceil(inst.info["sfreq"] * tmax)) + if mode == "constant": + _validate_type( + baseline, (tuple, list), "baseline", extra="when mode='constant'" + ) + _check_option("len(baseline)", len(baseline), [2]) + for bi, b in enumerate(baseline): + _validate_type( + b, "numeric", f"baseline[{bi}]", extra="when mode='constant'" + ) + b_start = int(np.ceil(inst.info["sfreq"] * baseline[0])) + b_end = int(np.ceil(inst.info["sfreq"] * baseline[1])) + else: + b_start = b_end = np.nan if (mode == "window") and (s_end - s_start) < 4: raise ValueError( 'Time range is too short. Use a larger interval or set mode to "linear".' @@ -104,7 +138,11 @@ def fix_stim_artifact( for event_idx in event_start: first_samp = int(event_idx) - inst.first_samp + s_start last_samp = int(event_idx) - inst.first_samp + s_end - _fix_artifact(data, window, picks, first_samp, last_samp, mode) + base_t1 = int(event_idx) - inst.first_samp + b_start + base_t2 = int(event_idx) - inst.first_samp + b_end + _fix_artifact( + data, window, picks, first_samp, last_samp, base_t1, base_t2, mode + ) elif isinstance(inst, BaseEpochs): if inst.reject is not None: raise RuntimeError( @@ -114,14 +152,23 @@ def fix_stim_artifact( first_samp = s_start - e_start last_samp = s_end - e_start data = inst._data + base_t1 = b_start - e_start + base_t2 = b_end - e_start for epoch in data: - _fix_artifact(epoch, window, picks, first_samp, last_samp, mode) + _fix_artifact( + epoch, window, picks, first_samp, last_samp, base_t1, base_t2, mode + ) elif isinstance(inst, Evoked): first_samp = s_start - inst.first last_samp = s_end - inst.first data = inst.data - _fix_artifact(data, window, picks, first_samp, last_samp, mode) + base_t1 = b_start - inst.first + base_t2 = b_end - inst.first + + _fix_artifact( + data, window, picks, first_samp, last_samp, base_t1, base_t2, mode + ) else: raise TypeError(f"Not a Raw or Epochs or Evoked (got {type(inst)}).") diff --git a/mne/preprocessing/tests/test_fine_cal.py b/mne/preprocessing/tests/test_fine_cal.py index b25b824bae0..8b45208e848 100644 --- a/mne/preprocessing/tests/test_fine_cal.py +++ b/mne/preprocessing/tests/test_fine_cal.py @@ -20,7 +20,7 @@ ) from mne.preprocessing.tests.test_maxwell import _assert_shielding from mne.transforms import _angle_dist_between_rigid -from mne.utils import object_diff +from mne.utils import catch_logging, object_diff # Define fine calibration filepaths data_path = testing.data_path(download=False) @@ -231,8 +231,8 @@ def test_fine_cal_systems(system, tmp_path): err_limit = 6000 n_ref = 28 corrs = (0.19, 0.41, 0.49) - sfs = [0.5, 0.7, 0.9, 1.5] - corr_tol = 0.45 + sfs = [0.5, 0.7, 0.9, 1.55] + corr_tol = 0.55 elif system == "fil": raw = read_raw_fil(fil_fname, verbose="error") raw.info["bads"] = [f"G2-{a}-{b}" for a in ("MW", "DS", "DT") for b in "YZ"] @@ -289,3 +289,15 @@ def test_fine_cal_systems(system, tmp_path): got_corrs = np.corrcoef([raw_data, raw_sss_data, raw_sss_cal_data]) got_corrs = got_corrs[np.triu_indices(3, 1)] assert_allclose(got_corrs, corrs, atol=corr_tol) + if system == "fil": + with catch_logging(verbose=True) as log: + compute_fine_calibration( + raw.copy().crop(0, 0.12).pick(raw.ch_names[:12]), + t_window=0.06, # 2 segments + angle_limit=angle_limit, + err_limit=err_limit, + ext_order=2, + verbose=True, + ) + log = log.getvalue() + assert "(averaging over 2 time intervals)" in log, log diff --git a/mne/preprocessing/tests/test_maxwell.py b/mne/preprocessing/tests/test_maxwell.py index 8107b1bcb2d..2f279440d1c 100644 --- a/mne/preprocessing/tests/test_maxwell.py +++ b/mne/preprocessing/tests/test_maxwell.py @@ -12,7 +12,6 @@ import pytest from numpy.testing import assert_allclose, assert_array_equal from scipy import sparse -from scipy.special import sph_harm import mne from mne import compute_raw_covariance, concatenate_raws, pick_info, pick_types @@ -20,6 +19,7 @@ from mne.annotations import _annotations_starts_stops from mne.chpi import filter_chpi, read_head_pos from mne.datasets import testing +from mne.fixes import sph_harm_y from mne.forward import _prep_meg_channels, use_coil_def from mne.io import ( BaseRaw, @@ -486,9 +486,9 @@ def test_spherical_conversions(): az, pol = np.meshgrid(np.linspace(0, 2 * np.pi, 30), np.linspace(0, np.pi, 20)) for degree in range(1, int_order): for order in range(0, degree + 1): - sph = sph_harm(order, degree, az, pol) + sph = sph_harm_y(degree, order, pol, az) # ensure that we satisfy the conjugation property - assert_allclose(_sh_negate(sph, order), sph_harm(-order, degree, az, pol)) + assert_allclose(_sh_negate(sph, order), sph_harm_y(degree, -order, pol, az)) # ensure our conversion functions work sph_real_pos = _sh_complex_to_real(sph, order) sph_real_neg = _sh_complex_to_real(sph, -order) @@ -1067,9 +1067,9 @@ def _assert_shielding(raw_sss, erm_power, min_factor, max_factor=np.inf, meg="ma sss_power = raw_sss[picks][0].ravel() sss_power = np.sqrt(np.sum(sss_power * sss_power)) factor = erm_power / sss_power - assert ( - min_factor <= factor < max_factor - ), f"Shielding factor not {min_factor:0.3f} <= {factor:0.3f} < {max_factor:0.3f}" + assert min_factor <= factor < max_factor, ( + f"Shielding factor not {min_factor:0.3f} <= {factor:0.3f} < {max_factor:0.3f}" + ) @buggy_mkl_svd diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py new file mode 100644 index 00000000000..ee2568a2080 --- /dev/null +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -0,0 +1,107 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from pathlib import Path + +import numpy as np +import pytest + +from mne.io import read_raw_fif +from mne.io.fiff.raw import Raw +from mne.preprocessing import apply_pca_obs + +data_path = Path(__file__).parents[2] / "io" / "tests" / "data" +raw_fname = data_path / "test_raw.fif" + + +@pytest.fixture() +def short_raw_data(): + """Create a short, picked raw instance.""" + return read_raw_fif(raw_fname, preload=True) + + +def test_heart_artifact_removal(short_raw_data: Raw): + """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" + pd = pytest.importorskip("pandas") + + # copy the original raw. heart artifact is removed in-place + orig_df: pd.DataFrame = short_raw_data.to_data_frame().copy(deep=True) + + # fake some random qrs events in the window of the raw data + # remove first and last samples and cast to integer for indexing + ecg_event_times = np.linspace(0, orig_df["time"].iloc[-1], 20)[1:-1] + + # perform heart artifact removal + short_raw_data = apply_pca_obs( + raw=short_raw_data, picks=["eeg"], qrs_times=ecg_event_times, n_jobs=1 + ) + + # compare processed df to original df + removed_heart_artifact_df: pd.DataFrame = short_raw_data.to_data_frame() + + # ensure all column names remain the same + pd.testing.assert_index_equal( + orig_df.columns, + removed_heart_artifact_df.columns, + ) + + # ensure every column starting with EEG has been altered + altered_cols = [c for c in orig_df.columns if c.startswith("EEG")] + for col in altered_cols: + with pytest.raises( + AssertionError + ): # make sure that error is raised when we check equal + pd.testing.assert_series_equal( + orig_df[col], + removed_heart_artifact_df[col], + ) + + # ensure every column not starting with EEG has not been altered + unaltered_cols = [c for c in orig_df.columns if not c.startswith("EEG")] + pd.testing.assert_frame_equal( + orig_df[unaltered_cols], + removed_heart_artifact_df[unaltered_cols], + ) + + +# test that various nonsensical inputs raise the proper errors +@pytest.mark.parametrize( + ("picks", "qrs_times", "error", "exception"), + [ + ( + ["eeg"], + np.array([[0, 1], [2, 3]]), + "qrs_times must be a 1d array", + ValueError, + ), + ( + ["eeg"], + [2, 3, 4], + "qrs_times must be an instance of ndarray, got instead.", + TypeError, + ), + ( + ["eeg"], + np.array([None, "foo", 2]), + "qrs_times must be an array of either integers or floats", + ValueError, + ), + ( + ["eeg"], + np.array([-1, 0, 3]), + "qrs_times must be strictly positive", + ValueError, + ), + ], +) +def test_pca_obs_bad_input( + short_raw_data: Raw, + picks: list[str], + qrs_times: np.ndarray, + error: str, + exception: type[Exception], +): + """Test if bad input data raises the proper errors in the function sanity checks.""" + with pytest.raises(exception, match=error): + apply_pca_obs(raw=short_raw_data, picks=picks, qrs_times=qrs_times) diff --git a/mne/preprocessing/tests/test_stim.py b/mne/preprocessing/tests/test_stim.py index 7ae1c4418b4..2d463a4af5a 100644 --- a/mne/preprocessing/tests/test_stim.py +++ b/mne/preprocessing/tests/test_stim.py @@ -55,6 +55,18 @@ def test_fix_stim_artifact(): data_from_epochs_fix = epochs.get_data(copy=False)[:, :, tmin_samp:tmax_samp] assert not np.all(data_from_epochs_fix != 0) + baseline = (-0.1, -0.05) + epochs = fix_stim_artifact( + epochs, tmin=tmin, tmax=tmax, baseline=baseline, mode="constant" + ) + b_start = int(np.ceil(epochs.info["sfreq"] * baseline[0])) + b_end = int(np.ceil(epochs.info["sfreq"] * baseline[1])) + base_t1 = b_start - e_start + base_t2 = b_end - e_start + baseline_mean = epochs.get_data()[:, :, base_t1:base_t2].mean(axis=2)[0][0] + data = epochs.get_data()[:, :, tmin_samp:tmax_samp] + assert data[0][0][0] == baseline_mean + # use window before stimulus in raw event_idx = np.where(events[:, 2] == 1)[0][0] tmin, tmax = -0.045, -0.015 @@ -81,8 +93,22 @@ def test_fix_stim_artifact(): raw, events, event_id=1, tmin=tmin, tmax=tmax, mode="window" ) data, times = raw[:, (tidx + tmin_samp) : (tidx + tmax_samp)] + assert np.all(data) == 0.0 + raw = fix_stim_artifact( + raw, + events, + event_id=1, + tmin=tmin, + tmax=tmax, + baseline=baseline, + mode="constant", + ) + data, times = raw[:, (tidx + tmin_samp) : (tidx + tmax_samp)] + baseline_mean, _ = raw[:, (tidx + b_start) : (tidx + b_end)] + assert baseline_mean.mean(axis=1)[0] == data[0][0] + # get epochs from raw with fixed data tmin, tmax, event_id = -0.2, 0.5, 1 epochs = Epochs( @@ -117,3 +143,12 @@ def test_fix_stim_artifact(): evoked = fix_stim_artifact(evoked, tmin=tmin, tmax=tmax, mode="window") data = evoked.data[:, tmin_samp:tmax_samp] assert np.all(data) == 0.0 + + evoked = fix_stim_artifact( + evoked, tmin=tmin, tmax=tmax, baseline=baseline, mode="constant" + ) + base_t1 = int(baseline[0] * evoked.info["sfreq"]) - evoked.first + base_t2 = int(baseline[1] * evoked.info["sfreq"]) - evoked.first + data = evoked.data[:, tmin_samp:tmax_samp] + baseline_mean = evoked.data[:, base_t1:base_t2].mean(axis=1)[0] + assert data[0][0] == baseline_mean diff --git a/mne/preprocessing/xdawn.py b/mne/preprocessing/xdawn.py index 0b1132761b1..606b49370df 100644 --- a/mne/preprocessing/xdawn.py +++ b/mne/preprocessing/xdawn.py @@ -198,8 +198,7 @@ def _fit_xdawn( evals, evecs = linalg.eigh(evo_cov, signal_cov) except np.linalg.LinAlgError as exp: raise ValueError( - "Could not compute eigenvalues, ensure " - f"proper regularization ({exp})" + f"Could not compute eigenvalues, ensure proper regularization ({exp})" ) evecs = evecs[:, np.argsort(evals)[::-1]] # sort eigenvectors evecs /= np.apply_along_axis(np.linalg.norm, 0, evecs) diff --git a/mne/report/report.py b/mne/report/report.py index 7cab9774c8e..852feebc638 100644 --- a/mne/report/report.py +++ b/mne/report/report.py @@ -324,7 +324,7 @@ def _check_tags(tags) -> tuple[str]: raise TypeError( f"All tags must be strings without spaces or special characters, " f"but got the following instead: " - f'{", ".join([str(tag) for tag in bad_tags])}' + f"{', '.join([str(tag) for tag in bad_tags])}" ) # Check for invalid characters @@ -338,7 +338,7 @@ def _check_tags(tags) -> tuple[str]: if bad_tags: raise ValueError( f"The following tags contained invalid characters: " - f'{", ".join(repr(tag) for tag in bad_tags)}' + f"{', '.join(repr(tag) for tag in bad_tags)}" ) return tags @@ -429,8 +429,7 @@ def _fig_to_img( output = BytesIO() dpi = fig.get_dpi() logger.debug( - f"Saving figure with dimension {fig.get_size_inches()} inches with " - f"{dpi} dpi" + f"Saving figure with dimension {fig.get_size_inches()} inches with {dpi} dpi" ) # https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html @@ -913,7 +912,7 @@ def __repr__(self): if len(titles) > 0: titles = [f" {t}" for t in titles] # indent tr = max(len(s), 50) # trim to larger of opening str and 50 - titles = [f"{t[:tr - 2]} …" if len(t) > tr else t for t in titles] + titles = [f"{t[: tr - 2]} …" if len(t) > tr else t for t in titles] # then trim to the max length of all of these tr = max(len(title) for title in titles) tr = max(tr, len(s)) @@ -2761,9 +2760,7 @@ def _init_render(self, verbose=None): if inc_fname.endswith(".js"): include.append( - f'" + f'' ) elif inc_fname.endswith(".css"): include.append(f'') @@ -3415,6 +3412,7 @@ def _add_raw( init_kwargs.setdefault("fmax", fmax) plot_kwargs.setdefault("show", False) with warnings.catch_warnings(): + # SciPy warning about too short a data segment given the window size warnings.simplefilter(action="ignore", category=FutureWarning) fig = raw.compute_psd(**init_kwargs).plot(**plot_kwargs) self._add_figure( @@ -3648,7 +3646,7 @@ def _add_evoked_joint( ) ) - title = f'Time course ({_handle_default("titles")[ch_type]})' + title = f"Time course ({_handle_default('titles')[ch_type]})" self._add_figure( fig=fig, title=title, @@ -4120,7 +4118,7 @@ def _add_epochs( assert "eeg" in ch_type title_start = "ERP image" - title = f'{title_start} ({_handle_default("titles")[ch_type]})' + title = f"{title_start} ({_handle_default('titles')[ch_type]})" self._add_figure( fig=fig, diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 1c0d6c1bc72..deeb3a43ede 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -764,6 +764,7 @@ def plot( transparent=True, alpha=1.0, time_viewer="auto", + *, subjects_dir=None, figure=None, views="auto", @@ -1387,8 +1388,7 @@ def transform(self, func, idx=None, tmin=None, tmax=None, copy=False): ] else: raise ValueError( - "copy must be True if transformed data has " - "more than 2 dimensions" + "copy must be True if transformed data has more than 2 dimensions" ) else: # return new or overwritten stc @@ -2256,6 +2256,7 @@ def plot( vector_alpha=1.0, scale_factor=None, time_viewer="auto", + *, subjects_dir=None, figure=None, views="lateral", @@ -2267,6 +2268,7 @@ def plot( foreground=None, initial_time=None, time_unit="s", + title=None, show_traces="auto", src=None, volume_options=1.0, @@ -2299,6 +2301,7 @@ def plot( foreground=foreground, initial_time=initial_time, time_unit=time_unit, + title=title, show_traces=show_traces, src=src, volume_options=volume_options, @@ -2767,6 +2770,7 @@ def plot_3d( vector_alpha=1.0, scale_factor=None, time_viewer="auto", + *, subjects_dir=None, figure=None, views="axial", @@ -2778,6 +2782,7 @@ def plot_3d( foreground=None, initial_time=None, time_unit="s", + title=None, show_traces="auto", src=None, volume_options=1.0, @@ -2810,6 +2815,7 @@ def plot_3d( foreground=foreground, initial_time=initial_time, time_unit=time_unit, + title=title, show_traces=show_traces, src=src, volume_options=volume_options, @@ -3626,7 +3632,7 @@ def _volume_labels(src, labels, mri_resolution): ] nnz = sum(len(v) != 0 for v in vertices) logger.info( - "%d/%d atlas regions had at least one vertex " "in the source space", + "%d/%d atlas regions had at least one vertex in the source space", nnz, len(out_labels), ) @@ -3918,22 +3924,6 @@ def stc_near_sensors( _validate_type(src, (None, SourceSpaces), "src") _check_option("mode", mode, ("sum", "single", "nearest", "weighted")) if surface == "auto": - if src is not None: - pial_fname = op.join(subjects_dir, subject, "surf", "lh.pial") - pial_rr = read_surface(pial_fname)[0] - src_surf_is_pial = ( - op.isfile(pial_fname) - and src[0]["rr"].shape == pial_rr.shape - and np.allclose(src[0]["rr"], pial_rr) - ) - if not src_surf_is_pial: - warn( - "In version 1.8, ``surface='auto'`` will be the default " - "which will use the surface in ``src`` instead of the " - "pial surface when ``src != None``. Pass ``surface='pial'`` " - "or ``surface=None`` to suppress this warning", - DeprecationWarning, - ) surface = "pial" if src is None or src.kind == "surface" else None # create a copy of Evoked using ecog, seeg and dbs @@ -4015,7 +4005,7 @@ def stc_near_sensors( min_dist = pdist(pos).min() * 1000 logger.info( - f' Minimum {"projected " if project else ""}intra-sensor distance: ' + f" Minimum {'projected ' if project else ''}intra-sensor distance: " f"{min_dist:0.1f} mm" ) @@ -4043,7 +4033,7 @@ def stc_near_sensors( if len(missing): warn( f"Channel{_pl(missing)} missing in STC: " - f'{", ".join(evoked.ch_names[mi] for mi in missing)}' + f"{', '.join(evoked.ch_names[mi] for mi in missing)}" ) nz_data = w @ evoked.data diff --git a/mne/source_space/_source_space.py b/mne/source_space/_source_space.py index f5e8b76a1fa..d64989961cf 100644 --- a/mne/source_space/_source_space.py +++ b/mne/source_space/_source_space.py @@ -743,7 +743,7 @@ def export_volume( # generate use warnings for clipping if n_diff > 0: warn( - f'{n_diff} {src["type"]} vertices lay outside of volume ' + f"{n_diff} {src['type']} vertices lay outside of volume " f"space. Consider using a larger volume space." ) # get surface id or use default value @@ -1546,7 +1546,7 @@ def setup_source_space( # pre-load ico/oct surf (once) for speed, if necessary if stype not in ("spacing", "all"): logger.info( - f'Doing the {dict(ico="icosa", oct="octa")[stype]}hedral vertex picking...' + f"Doing the {dict(ico='icosa', oct='octa')[stype]}hedral vertex picking..." ) for hemi, surf in zip(["lh", "rh"], surfs): logger.info(f"Loading {surf}...") @@ -2916,8 +2916,7 @@ def _get_vertex_map_nn( raise RuntimeError(f"vertex {one} would be used multiple times.") one = one[0] logger.info( - "Source space vertex moved from %d to %d because of " - "double occupation.", + "Source space vertex moved from %d to %d because of double occupation.", was, one, ) @@ -3167,8 +3166,7 @@ def _compare_source_spaces(src0, src1, mode="exact", nearest=True, dist_tol=1.5e assert_array_equal( s["vertno"], np.where(s["inuse"])[0], - f'src{ii}[{si}]["vertno"] != ' - f'np.where(src{ii}[{si}]["inuse"])[0]', + f'src{ii}[{si}]["vertno"] != np.where(src{ii}[{si}]["inuse"])[0]', ) assert_equal(len(s0["vertno"]), len(s1["vertno"])) agreement = np.mean(s0["inuse"] == s1["inuse"]) diff --git a/mne/surface.py b/mne/surface.py index 21432e7edfd..9e24147a080 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -214,7 +214,7 @@ def get_meg_helmet_surf(info, trans=None, *, verbose=None): ] ) logger.info( - "Getting helmet for system %s (derived from %d MEG " "channel locations)", + "Getting helmet for system %s (derived from %d MEG channel locations)", system, len(rr), ) @@ -733,7 +733,7 @@ def __init__(self, surf, *, mode="old", verbose=None): else: self._init_old() logger.debug( - f'Setting up {mode} interior check for {len(self.surf["rr"])} ' + f"Setting up {mode} interior check for {len(self.surf['rr'])} " f"points took {(time.time() - t0) * 1000:0.1f} ms" ) @@ -761,8 +761,7 @@ def _init_pyvista(self): def __call__(self, rr, n_jobs=None, verbose=None): n_orig = len(rr) logger.info( - f"Checking surface interior status for " - f'{n_orig} point{_pl(n_orig, " ")}...' + f"Checking surface interior status for {n_orig} point{_pl(n_orig, ' ')}..." ) t0 = time.time() if self.mode == "pyvista": @@ -770,7 +769,7 @@ def __call__(self, rr, n_jobs=None, verbose=None): else: inside = self._call_old(rr, n_jobs) n = inside.sum() - logger.info(f' Total {n}/{n_orig} point{_pl(n, " ")} inside the surface') + logger.info(f" Total {n}/{n_orig} point{_pl(n, ' ')} inside the surface") logger.info(f"Interior check completed in {(time.time() - t0) * 1000:0.1f} ms") return inside @@ -792,7 +791,7 @@ def _call_old(self, rr, n_jobs): n = (in_mask).sum() n_pad = str(n).rjust(prec) logger.info( - f' Found {n_pad}/{n_orig} point{_pl(n, " ")} ' + f" Found {n_pad}/{n_orig} point{_pl(n, ' ')} " f"inside an interior sphere of radius " f"{1000 * self.inner_r:6.1f} mm" ) @@ -801,7 +800,7 @@ def _call_old(self, rr, n_jobs): n = (out_mask).sum() n_pad = str(n).rjust(prec) logger.info( - f' Found {n_pad}/{n_orig} point{_pl(n, " ")} ' + f" Found {n_pad}/{n_orig} point{_pl(n, ' ')} " f"outside an exterior sphere of radius " f"{1000 * self.outer_r:6.1f} mm" ) @@ -818,7 +817,7 @@ def _call_old(self, rr, n_jobs): n_pad = str(n).rjust(prec) check_pad = str(len(del_outside)).rjust(prec) logger.info( - f' Found {n_pad}/{check_pad} point{_pl(n, " ")} outside using ' + f" Found {n_pad}/{check_pad} point{_pl(n, ' ')} outside using " "surface Qhull" ) @@ -828,7 +827,7 @@ def _call_old(self, rr, n_jobs): n_pad = str(n).rjust(prec) check_pad = str(len(solid_outside)).rjust(prec) logger.info( - f' Found {n_pad}/{check_pad} point{_pl(n, " ")} outside using ' + f" Found {n_pad}/{check_pad} point{_pl(n, ' ')} outside using " "solid angles" ) inside[idx[solid_outside]] = False diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 6b1356ae107..4d0db170e2a 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -1450,8 +1450,7 @@ def test_repr(): # long annotation repr (> 79 characters, will be shortened) r = repr(Annotations(range(14), [0] * 14, list("abcdefghijklmn"))) assert r == ( - "" + "" ) # empty Annotations diff --git a/mne/tests/test_dipole.py b/mne/tests/test_dipole.py index e93d4031646..f230eaa4256 100644 --- a/mne/tests/test_dipole.py +++ b/mne/tests/test_dipole.py @@ -214,9 +214,9 @@ def test_dipole_fitting(tmp_path): # Sanity check: do our residuals have less power than orig data? data_rms = np.sqrt(np.sum(evoked.data**2, axis=0)) resi_rms = np.sqrt(np.sum(residual.data**2, axis=0)) - assert ( - data_rms > resi_rms * 0.95 - ).all(), f"{(data_rms / resi_rms).min()} (factor: {0.95})" + assert (data_rms > resi_rms * 0.95).all(), ( + f"{(data_rms / resi_rms).min()} (factor: {0.95})" + ) # Compare to original points transform_surface_to(fwd["src"][0], "head", fwd["mri_head_t"]) diff --git a/mne/tests/test_docstring_parameters.py b/mne/tests/test_docstring_parameters.py index d32e62a454e..64f80f50b74 100644 --- a/mne/tests/test_docstring_parameters.py +++ b/mne/tests/test_docstring_parameters.py @@ -76,6 +76,7 @@ def _func_name(func, cls=None): error_ignores = { # These we do not live by: "GL01", # Docstring should start in the line immediately after the quotes + "GL02", # Closing quotes on own line (doesn't work on Python 3.13 anyway) "EX01", "EX02", # examples failed (we test them separately) "ES01", # no extended summary @@ -221,8 +222,7 @@ def test_tabs(): continue source = inspect.getsource(mod) assert "\t" not in source, ( - f'"{modname}" has tabs, please remove them ' - "or add it to the ignore list" + f'"{modname}" has tabs, please remove them or add it to the ignore list' ) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index f7558bf1cea..88f2d9cdc13 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -479,12 +479,12 @@ def test_average_movements(): def _assert_drop_log_types(drop_log): __tracebackhide__ = True assert isinstance(drop_log, tuple), "drop_log should be tuple" - assert all( - isinstance(log, tuple) for log in drop_log - ), "drop_log[ii] should be tuple" - assert all( - isinstance(s, str) for log in drop_log for s in log - ), "drop_log[ii][jj] should be str" + assert all(isinstance(log, tuple) for log in drop_log), ( + "drop_log[ii] should be tuple" + ) + assert all(isinstance(s, str) for log in drop_log for s in log), ( + "drop_log[ii][jj] should be str" + ) def test_reject(): diff --git a/mne/tests/test_filter.py b/mne/tests/test_filter.py index e259ececbce..537f1930f45 100644 --- a/mne/tests/test_filter.py +++ b/mne/tests/test_filter.py @@ -90,9 +90,9 @@ def test_estimate_ringing(): (0.0001, (30000, 60000)), ): # 37993 n_ring = estimate_ringing_samples(butter(3, thresh, output=kind)) - assert ( - lims[0] <= n_ring <= lims[1] - ), f"{kind} {thresh}: {lims[0]} <= {n_ring} <= {lims[1]}" + assert lims[0] <= n_ring <= lims[1], ( + f"{kind} {thresh}: {lims[0]} <= {n_ring} <= {lims[1]}" + ) with pytest.warns(RuntimeWarning, match="properly estimate"): assert estimate_ringing_samples(butter(4, 0.00001)) == 100000 diff --git a/mne/tests/test_parallel.py b/mne/tests/test_parallel.py index f72f3281a59..a780f32f911 100644 --- a/mne/tests/test_parallel.py +++ b/mne/tests/test_parallel.py @@ -26,7 +26,7 @@ def test_parallel_func(n_jobs): """Test Parallel wrapping.""" joblib = pytest.importorskip("joblib") if os.getenv("MNE_FORCE_SERIAL", "").lower() in ("true", "1"): - pytest.skip("MNE_FORCE_SERIAL cannot be set") + pytest.skip("MNE_FORCE_SERIAL is set") def fun(x): return x * 2 diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index 9d24ad1d7bb..e4fa5a36b25 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -1711,8 +1711,7 @@ def test_stc_near_sensors(tmp_path): for s in src: transform_surface_to(s, "head", trans, copy=False) assert src[0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD - with pytest.warns(DeprecationWarning, match="instead of the pial"): - stc_src = stc_near_sensors(evoked, src=src, **kwargs) + stc_src = stc_near_sensors(evoked, src=src, **kwargs) assert len(stc_src.data) == 7928 with pytest.warns(RuntimeWarning, match="not included"): # some removed stc_src_full = compute_source_morph( diff --git a/mne/time_frequency/__init__.pyi b/mne/time_frequency/__init__.pyi index 0faeb7263d8..6b53c39a98b 100644 --- a/mne/time_frequency/__init__.pyi +++ b/mne/time_frequency/__init__.pyi @@ -11,6 +11,8 @@ __all__ = [ "RawTFRArray", "Spectrum", "SpectrumArray", + "combine_spectrum", + "combine_tfr", "csd_array_fourier", "csd_array_morlet", "csd_array_multitaper", @@ -61,6 +63,7 @@ from .spectrum import ( EpochsSpectrumArray, Spectrum, SpectrumArray, + combine_spectrum, read_spectrum, ) from .tfr import ( @@ -71,6 +74,7 @@ from .tfr import ( EpochsTFRArray, RawTFR, RawTFRArray, + combine_tfr, fwhm, morlet, read_tfrs, diff --git a/mne/time_frequency/_stft.py b/mne/time_frequency/_stft.py index 8fb80b43fcc..a6b6f23fff7 100644 --- a/mne/time_frequency/_stft.py +++ b/mne/time_frequency/_stft.py @@ -59,8 +59,7 @@ def stft(x, wsize, tstep=None, verbose=None): if (wsize % tstep) or (tstep % 2): raise ValueError( - "The step size must be a multiple of 2 and a " - "divider of the window length." + "The step size must be a multiple of 2 and a divider of the window length." ) if tstep > wsize / 2: diff --git a/mne/time_frequency/csd.py b/mne/time_frequency/csd.py index c858dd52e57..4ddaa0ac6a3 100644 --- a/mne/time_frequency/csd.py +++ b/mne/time_frequency/csd.py @@ -224,8 +224,7 @@ def sum(self, fmin=None, fmax=None): """ if self._is_sum: raise RuntimeError( - "This CSD matrix already represents a mean or " - "sum across frequencies." + "This CSD matrix already represents a mean or sum across frequencies." ) # Deal with the various ways in which fmin and fmax can be specified @@ -1372,7 +1371,7 @@ def _execute_csd_function( logger.info("[done]") if ch_names is None: - ch_names = [f"SERIES{i+1:03}" for i in range(n_channels)] + ch_names = [f"SERIES{i + 1:03}" for i in range(n_channels)] return CrossSpectralDensity( csds_mean, diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 73a3308685d..1c1a3baf238 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -63,7 +63,14 @@ def dpss_windows(N, half_nbw, Kmax, *, sym=True, norm=None, low_bias=True): ---------- .. footbibliography:: """ - dpss, eigvals = sp_dpss(N, half_nbw, Kmax, sym=sym, norm=norm, return_ratios=True) + # TODO VERSION can be removed with SciPy 1.16 is min, + # workaround for https://github.com/scipy/scipy/pull/22344 + if N <= 1: + dpss, eigvals = np.ones((1, 1)), np.ones(1) + else: + dpss, eigvals = sp_dpss( + N, half_nbw, Kmax, sym=sym, norm=norm, return_ratios=True + ) if low_bias: idx = eigvals > 0.9 if not idx.any(): @@ -471,6 +478,7 @@ def tfr_array_multitaper( output="complex", n_jobs=None, *, + return_weights=False, verbose=None, ): """Compute Time-Frequency Representation (TFR) using DPSS tapers. @@ -504,6 +512,11 @@ def tfr_array_multitaper( coherence across trials. %(n_jobs)s The parallelization is implemented across channels. + return_weights : bool, default False + If True, return the taper weights. Only applies if ``output='complex'`` or + ``'phase'``. + + .. versionadded:: 1.10.0 %(verbose)s Returns @@ -520,6 +533,9 @@ def tfr_array_multitaper( If ``output`` is ``'avg_power_itc'``, the real values in ``out`` contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and + ``return_weights=True``. See Also -------- @@ -550,6 +566,7 @@ def tfr_array_multitaper( use_fft=use_fft, decim=decim, output=output, + return_weights=return_weights, n_jobs=n_jobs, verbose=verbose, ) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 23317a6c9d3..03a57010061 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -311,7 +311,7 @@ def __init__( if np.isfinite(fmax) and (fmax > self.sfreq / 2): raise ValueError( f"Requested fmax ({fmax} Hz) must not exceed ½ the sampling " - f'frequency of the data ({0.5 * inst.info["sfreq"]} Hz).' + f"frequency of the data ({0.5 * inst.info['sfreq']} Hz)." ) # method self._inst_type = type(inst) @@ -419,7 +419,9 @@ def _repr_html_(self, caption=None): inst_type_str = _get_instance_type_string(self) units = [f"{ch_type}: {unit}" for ch_type, unit in self.units().items()] t = _get_html_template("repr", "spectrum.html.jinja") - t = t.render(spectrum=self, inst_type=inst_type_str, units=units) + t = t.render( + inst=self, computed_from=inst_type_str, units=units, filenames=None + ) return t def _check_values(self): @@ -440,7 +442,7 @@ def _check_values(self): if bad_value.any(): chs = np.array(self.ch_names)[bad_value].tolist() s = _pl(bad_value.sum()) - warn(f'Zero value in spectrum for channel{s} {", ".join(chs)}', UserWarning) + warn(f"Zero value in spectrum for channel{s} {', '.join(chs)}", UserWarning) def _returns_complex_tapers(self, **method_kw): return self.method == "multitaper" and method_kw.get("output") == "complex" @@ -944,7 +946,7 @@ def save(self, fname, *, overwrite=False, verbose=None): check_fname(fname, "spectrum", (".h5", ".hdf5")) fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) out = self.__getstate__() - write_hdf5(fname, out, overwrite=overwrite, title="mnepython") + write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace") @verbose def to_data_frame( @@ -1534,7 +1536,7 @@ def average(self, method="mean"): state["nave"] = state["data"].shape[0] state["data"] = method(state["data"]) state["dims"] = state["dims"][1:] - state["data_type"] = f'Averaged {state["data_type"]}' + state["data_type"] = f"Averaged {state['data_type']}" defaults = dict( method=None, fmin=None, @@ -1641,6 +1643,74 @@ def __init__( ) +def combine_spectrum(all_spectrum, weights="nave"): + """Merge spectral data by weighted addition. + + Create a new :class:`mne.time_frequency.Spectrum` instance, using a combination of + the supplied instances as its data. By default, the mean (weighted by trials) is + used. Subtraction can be performed by passing negative weights (e.g., ``[1, -1]``). + Data must have the same channels and the same frequencies. + + Parameters + ---------- + all_spectrum : list of Spectrum + The Spectrum objects. + weights : list of float | str + The weights to apply to the data of each :class:`~mne.time_frequency.Spectrum` + instance, or a string describing the weighting strategy to apply: 'nave' + computes sum-to-one weights proportional to each object’s nave attribute; + 'equal' weights each :class:`~mne.time_frequency.Spectrum` by + ``1 / len(all_spectrum)``. + + Returns + ------- + spectrum : Spectrum + The new spectral data. + + Notes + ----- + .. versionadded:: 1.10.0 + """ + spectrum = all_spectrum[0].copy() + if isinstance(weights, str): + if weights not in ("nave", "equal"): + raise ValueError('Weights must be a list of float, or "nave" or "equal"') + if weights == "nave": + for s_ in all_spectrum: + if s_.nave is None: + raise ValueError(f"The 'nave' attribute is not specified for {s_}") + weights = np.array([e.nave for e in all_spectrum], float) + weights /= weights.sum() + else: # == 'equal' + weights = [1.0 / len(all_spectrum)] * len(all_spectrum) + weights = np.array(weights, float) + if weights.ndim != 1 or weights.size != len(all_spectrum): + raise ValueError("Weights must be the same size as all_spectrum") + + ch_names = spectrum.ch_names + for s_ in all_spectrum[1:]: + assert s_.ch_names == ch_names, ( + f"{spectrum} and {s_} do not contain the same channels" + ) + assert np.max(np.abs(s_.freqs - spectrum.freqs)) < 1e-7, ( + f"{spectrum} and {s_} do not contain the same frequencies" + ) + + # use union of bad channels + bads = list( + set(spectrum.info["bads"]).union(*(s_.info["bads"] for s_ in all_spectrum[1:])) + ) + spectrum.info["bads"] = bads + + # combine spectral data + spectrum._data = sum(w * s_.data for w, s_ in zip(weights, all_spectrum)) + if spectrum.nave is not None: + spectrum._nave = max( + int(1.0 / sum(w**2 / s_.nave for w, s_ in zip(weights, all_spectrum))), 1 + ) + return spectrum + + def read_spectrum(fname): """Load a :class:`mne.time_frequency.Spectrum` object from disk. @@ -1663,7 +1733,7 @@ def read_spectrum(fname): _validate_type(fname, "path-like", "fname") fname = _check_fname(fname=fname, overwrite="read", must_exist=False) # read it in - hdf5_dict = read_hdf5(fname, title="mnepython") + hdf5_dict = read_hdf5(fname, title="mnepython", slash="replace") defaults = dict( method=None, fmin=None, diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 840aab3bcbd..927c22360c5 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -10,11 +10,15 @@ from matplotlib.colors import same_color from numpy.testing import assert_allclose, assert_array_equal -from mne import Annotations, create_info, make_fixed_length_epochs +from mne import Annotations, BaseEpochs, create_info, make_fixed_length_epochs from mne.io import RawArray from mne.time_frequency import read_spectrum from mne.time_frequency.multitaper import _psd_from_mt -from mne.time_frequency.spectrum import EpochsSpectrumArray, SpectrumArray +from mne.time_frequency.spectrum import ( + EpochsSpectrumArray, + SpectrumArray, + combine_spectrum, +) from mne.utils import _record_warnings @@ -163,13 +167,19 @@ def _get_inst(inst, request, *, evoked=None, average_tfr=None): return request.getfixturevalue(inst) -@pytest.mark.parametrize("inst", ("raw", "epochs", "evoked")) +@pytest.mark.parametrize("inst", ("raw", "epochs_full", "evoked")) def test_spectrum_io(inst, tmp_path, request, evoked): """Test save/load of spectrum objects.""" pytest.importorskip("h5io") fname = tmp_path / f"{inst}-spectrum.h5" inst = _get_inst(inst, request, evoked=evoked) + if isinstance(inst, BaseEpochs): + # fake HED-like tags (https://mne.discourse.group/t/10634) + inst.events[-2:, -1] = 2 + inst.event_id = {"foo/bar": 1, "foo/qux": 2} orig = inst.compute_psd() + if isinstance(inst, BaseEpochs): + orig = orig["foo"] orig.save(fname) loaded = read_spectrum(fname) assert orig == loaded @@ -184,6 +194,55 @@ def test_spectrum_copy(raw_spectrum): assert raw_spectrum.freqs is not None +@pytest.mark.parametrize("weights", ["nave", "equal", [1, -1]]) +def test_combine_spectrum(raw_spectrum, weights): + """Test `combine_spectrum()` works.""" + spectrum1 = raw_spectrum.copy() + spectrum2 = raw_spectrum.copy() + if weights == "nave": + spectrum1._nave = 1 + spectrum2._nave = 2 + spectrum2._data *= 2 + new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights) + assert_allclose(new_spectrum.data, spectrum1.data * (5 / 3)) + elif weights == "equal": + spectrum2._data *= 2 + new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights) + assert_allclose(new_spectrum.data, spectrum1.data * 1.5) + else: + new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights) + assert_allclose(new_spectrum.data, 0) + + +def test_combine_spectrum_error_catch(raw_spectrum): + """Test `combine_spectrum()` catches errors.""" + # Test bad weights + with pytest.raises( + ValueError, match='Weights must be a list of float, or "nave" or "equal"' + ): + combine_spectrum([raw_spectrum, raw_spectrum], weights="foo") + with pytest.raises( + ValueError, match="Weights must be the same size as all_spectrum" + ): + combine_spectrum([raw_spectrum, raw_spectrum], weights=[1, 1, 1]) + + # Test bad nave + with pytest.raises(ValueError, match="The 'nave' attribute is not specified"): + combine_spectrum([raw_spectrum, raw_spectrum], weights="nave") + + # Test inconsistent channels + raw_spectrum2 = raw_spectrum.copy() + raw_spectrum2.drop_channels(raw_spectrum2.ch_names[0]) + with pytest.raises(AssertionError, match=".* do not contain the same channels"): + combine_spectrum([raw_spectrum, raw_spectrum2], weights="equal") + + # Test inconsistent frequencies + raw_spectrum2 = raw_spectrum.copy() + raw_spectrum2._freqs = raw_spectrum2._freqs + 1 + with pytest.raises(AssertionError, match=".* do not contain the same frequencies"): + combine_spectrum([raw_spectrum, raw_spectrum2], weights="equal") + + def test_spectrum_reject_by_annot(raw): """Test rejecting by annotation. @@ -627,3 +686,29 @@ def test_plot_spectrum_array_with_bads(): spectrum.get_data(exclude=()), spectrum.info, spectrum.freqs ) spectrum2.plot(spatial_colors=False) + + +@pytest.mark.parametrize("dB", (False, True)) +@pytest.mark.parametrize("amplitude", (False, True)) +def test_plot_spectrum_dB(raw_spectrum, dB, amplitude): + """Test that we properly handle amplitude/power and dB.""" + idx = 7 + power = 3 + freqs = np.linspace(1, 100, 100) + data = np.full((1, freqs.size), np.finfo(float).tiny) + data[0, idx] = power + info = create_info(ch_names=["delta"], sfreq=1000, ch_types="eeg") + psd = SpectrumArray(data=data, info=info, freqs=freqs) + with pytest.warns(RuntimeWarning, match="Channel locations not available"): + fig = psd.plot(dB=dB, amplitude=amplitude) + trace = list( + filter(lambda x: len(x.get_data()[0]) == len(freqs), fig.axes[0].lines) + )[0] + got = trace.get_data()[1][idx] + want = power * 1e12 # scaling for EEG (V → μV), squared + if amplitude: + want = np.sqrt(want) + if dB: + want = (20 if amplitude else 10) * np.log10(want) + + assert want == got, f"expected {want}, got {got}" diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index cd3a97ab90a..6adb4e361e1 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -255,20 +255,25 @@ def test_tfr_morlet(): # computed within the method. assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data) - # test that averaging power across tapers when multitaper with + # test that aggregating power across tapers when multitaper with # output='complex' gives the same as output='power' epoch_data = epochs.get_data() multitaper_power = tfr_array_multitaper( epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="power" ) - multitaper_complex = tfr_array_multitaper( - epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="complex" + multitaper_complex, weights = tfr_array_multitaper( + epoch_data, + epochs.info["sfreq"], + freqs, + n_cycles, + output="complex", + return_weights=True, ) - taper_dim = 2 - power_from_complex = (multitaper_complex * multitaper_complex.conj()).real.mean( - axis=taper_dim - ) + weights = np.expand_dims(weights, axis=(0, 1, -1)) # match shape of complex data + tfr = weights * multitaper_complex + tfr = (tfr * tfr.conj()).real.sum(axis=2) + power_from_complex = tfr * (2 / (weights * weights.conj()).real.sum(axis=2)) assert_allclose(power_from_complex, multitaper_power) print(itc) # test repr @@ -432,17 +437,21 @@ def test_tfr_morlet(): def test_dpsswavelet(): """Test DPSS tapers.""" freqs = np.arange(5, 25, 3) - Ws = _make_dpss( - 1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True + Ws, weights = _make_dpss( + 1000, + freqs=freqs, + n_cycles=freqs / 2.0, + time_bandwidth=4.0, + zero_mean=True, + return_weights=True, ) - assert len(Ws) == 3 # 3 tapers expected + assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected + assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs) # Check that zero mean is true assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5 - assert len(Ws[0]) == len(freqs) # As many wavelets as asked for - @pytest.mark.slowtest def test_tfr_multitaper(): @@ -664,6 +673,17 @@ def test_tfr_io(inst, average_tfr, request, tmp_path): with tfr.info._unlock(): tfr.info["meas_date"] = want assert tfr_loaded == tfr + # test with taper dimension and weights + n_tapers = 3 # anything >= 1 should do + weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs + state = tfr.__getstate__() + state["data"] = np.repeat(np.expand_dims(tfr.data, 2), n_tapers, axis=2) # add dim + state["weights"] = weights # add weights + state["dims"] = ("epoch", "channel", "taper", "freq", "time") # update dims + tfr = EpochsTFR(inst=state) + tfr.save(fname, overwrite=True) + tfr_loaded = read_tfrs(fname) + assert tfr_loaded == tfr # test overwrite with pytest.raises(OSError, match="Destination file exists."): tfr.save(fname, overwrite=False) @@ -722,17 +742,31 @@ def test_average_tfr_init(full_evoked): AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace) -def test_epochstfr_init_errors(epochs_tfr): - """Test __init__ for EpochsTFR.""" - state = epochs_tfr.__getstate__() - with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"): - EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0])) +@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr")) +def test_tfr_init_errors(inst, request, average_tfr): + """Test __init__ for {Raw,Epochs,Average}TFR.""" + # Load data + inst = _get_inst(inst, request, average_tfr=average_tfr) + state = inst.__getstate__() + # Prepare for TFRArray object instantiation + inst_name = inst.__class__.__name__ + class_mapping = dict(RawTFR=RawTFR, EpochsTFR=EpochsTFR, AverageTFR=AverageTFR) + ndims_mapping = dict( + RawTFR=("3D or 4D"), EpochsTFR=("4D or 5D"), AverageTFR=("3D or 4D") + ) + TFR = class_mapping[inst_name] + allowed_ndims = ndims_mapping[inst_name] + # Check errors caught + with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): + TFR(inst=state | dict(data=inst.data[..., 0])) + with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): + TFR(inst=state | dict(data=np.expand_dims(inst.data, axis=(0, 1)))) with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"): - EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1])) + TFR(inst=state | dict(data=inst.data[..., :-1, :, :])) with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"): - EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1])) + TFR(inst=state | dict(times=inst.times[:-1])) with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"): - EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1])) + TFR(inst=state | dict(freqs=inst.freqs[:-1])) @pytest.mark.parametrize( @@ -830,6 +864,25 @@ def test_plot(): plt.close("all") +@pytest.mark.parametrize("output", ("complex", "phase")) +def test_plot_multitaper_complex_phase(output): + """Test TFR plotting of data with a taper dimension.""" + # Create example data with a taper dimension + n_chans, n_tapers, n_freqs, n_times = (3, 4, 2, 3) + data = np.random.rand(n_chans, n_tapers, n_freqs, n_times) + if output == "complex": + data = data + np.random.rand(*data.shape) * 1j # add imaginary data + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.random.rand(n_tapers, n_freqs) + info = mne.create_info(n_chans, 1000.0, "eeg") + tfr = AverageTFRArray( + info=info, data=data, times=times, freqs=freqs, weights=weights + ) + # Check that plotting works + tfr.plot() + + @pytest.mark.parametrize( "timefreqs,title,combine", ( @@ -1154,6 +1207,15 @@ def test_averaging_epochsTFR(): ): power.average(method=np.mean) + # Check it doesn't run for taper spectra + tapered = epochs.compute_tfr( + method="multitaper", freqs=freqs, n_cycles=n_cycles, output="complex" + ) + with pytest.raises( + NotImplementedError, match=r"Averaging multitaper tapers .* is not supported." + ): + tapered.average() + def test_averaging_freqsandtimes_epochsTFR(): """Test that EpochsTFR averaging freqs methods work.""" @@ -1218,8 +1280,8 @@ def test_averaging_freqsandtimes_epochsTFR(): avgpower = power.average(method=lambda x: np.mean(x, axis=2), **kwargs) -@pytest.mark.parametrize("n_drop", (0, 2)) -def test_epochstfr_getitem(epochs_full, n_drop): +@pytest.mark.parametrize("n_drop, as_tfr_array", ((0, False), (0, True), (2, False))) +def test_epochstfr_getitem(epochs_full, n_drop, as_tfr_array): """Test EpochsTFR.__getitem__().""" pd = pytest.importorskip("pandas") from pandas.testing import assert_frame_equal @@ -1227,16 +1289,20 @@ def test_epochstfr_getitem(epochs_full, n_drop): epochs_full.metadata = pd.DataFrame(dict(foo=list("aaaabbb"), bar=np.arange(7))) epochs_full.drop(np.arange(n_drop)) tfr = epochs_full.compute_tfr(method="morlet", freqs=freqs_linspace) - # check that various attributes are preserved - assert_frame_equal(tfr.metadata, epochs_full.metadata) - assert epochs_full.drop_log == tfr.drop_log - for attr in ("events", "selection", "times"): - assert_array_equal(getattr(epochs_full, attr), getattr(tfr, attr)) - # test pandas query - foo_a = tfr["foo == 'a'"] - bar_3 = tfr["bar <= 3"] - assert foo_a == bar_3 - assert foo_a.shape[0] == 4 - n_drop + if not as_tfr_array: # check that various attributes are preserved + assert_frame_equal(tfr.metadata, epochs_full.metadata) + assert epochs_full.drop_log == tfr.drop_log + for attr in ("events", "selection", "times"): + assert_array_equal(getattr(epochs_full, attr), getattr(tfr, attr)) + # test pandas query + foo_a = tfr["foo == 'a'"] + bar_3 = tfr["bar <= 3"] + assert foo_a == bar_3 + assert foo_a.shape[0] == 4 - n_drop + else: # repackage to check __getitem__ also works with unspecified events, etc... + tfr = EpochsTFRArray( + info=tfr.info, data=tfr.data, times=tfr.times, freqs=tfr.freqs + ) # test integer and slice subset_ints = tfr[[0, 1, 2]] subset_slice = tfr[:3] @@ -1254,12 +1320,15 @@ def test_to_data_frame(): ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] n_picks = len(ch_names) ch_types = ["eeg"] * n_picks + n_tapers = 2 n_freqs = 5 n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) + data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) + times = np.arange(n_times) srate = 1000.0 - freqs = np.arange(5) + freqs = np.arange(n_freqs) + tapers = np.arange(n_tapers) + weights = np.ones((n_tapers, n_freqs)) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 5 + n_epos) @@ -1272,6 +1341,7 @@ def test_to_data_frame(): freqs=freqs, events=events, event_id=event_id, + weights=weights, ) # test index checking with pytest.raises(ValueError, match="options. Valid index options are"): @@ -1283,10 +1353,21 @@ def test_to_data_frame(): # test wide format df_wide = tfr.to_data_frame() assert all(np.isin(tfr.ch_names, df_wide.columns)) - assert all(np.isin(["time", "condition", "freq", "epoch"], df_wide.columns)) + assert all( + np.isin(["time", "condition", "freq", "epoch", "taper"], df_wide.columns) + ) # test long format df_long = tfr.to_data_frame(long_format=True) - expected = ("condition", "epoch", "freq", "time", "channel", "ch_type", "value") + expected = ( + "condition", + "epoch", + "freq", + "time", + "channel", + "ch_type", + "value", + "taper", + ) assert set(expected) == set(df_long.columns) assert set(tfr.ch_names) == set(df_long["channel"]) assert len(df_long) == tfr.data.size @@ -1294,21 +1375,29 @@ def test_to_data_frame(): df_long = tfr.to_data_frame(long_format=True, index=["freq"]) del df_wide, df_long # test whether data is in correct shape - df = tfr.to_data_frame(index=["condition", "epoch", "freq", "time"]) + df = tfr.to_data_frame(index=["condition", "epoch", "taper", "freq", "time"]) data = tfr.data assert_array_equal(df.values[:, 0], data[:, 0, :, :].reshape(1, -1).squeeze()) # compare arbitrary observation: assert ( - df.loc[("he", slice(None), freqs[1], times[2]), ch_names[3]].iat[0] - == data[1, 3, 1, 2] + df.loc[("he", slice(None), tapers[1], freqs[1], times[2]), ch_names[3]].iat[0] + == data[1, 3, 1, 1, 2] ) # Check also for AverageTFR: + # (remove taper dimension before averaging) + state = tfr.__getstate__() + state["data"] = state["data"][:, :, 0] + state["dims"] = ("epoch", "channel", "freq", "time") + state["weights"] = None + tfr = EpochsTFR(inst=state) tfr = tfr.average() with pytest.raises(ValueError, match="options. Valid index options are"): tfr.to_data_frame(index=["epoch", "condition"]) with pytest.raises(ValueError, match='"epoch" is not a valid option'): tfr.to_data_frame(index="epoch") + with pytest.raises(ValueError, match='"taper" is not a valid option'): + tfr.to_data_frame(index="taper") with pytest.raises(TypeError, match="index must be `None` or a string "): tfr.to_data_frame(index=np.arange(400)) # test wide format @@ -1344,11 +1433,13 @@ def test_to_data_frame_index(index): ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] n_picks = len(ch_names) ch_types = ["eeg"] * n_picks + n_tapers = 2 n_freqs = 5 n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) - freqs = np.arange(5) + data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.ones((n_tapers, n_freqs)) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 8) @@ -1361,6 +1452,7 @@ def test_to_data_frame_index(index): freqs=freqs, events=events, event_id=event_id, + weights=weights, ) df = tfr.to_data_frame(picks=[0, 2, 3], index=index) # test index order/hierarchy preservation @@ -1368,7 +1460,7 @@ def test_to_data_frame_index(index): index = [index] assert list(df.index.names) == index # test that non-indexed data were present as columns - non_index = list(set(["condition", "time", "freq", "epoch"]) - set(index)) + non_index = list(set(["condition", "time", "freq", "taper", "epoch"]) - set(index)) if len(non_index): assert all(np.isin(non_index, df.columns)) @@ -1534,7 +1626,8 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc): def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output): """Test Epochs.compute_tfr(output="complex"/"phase").""" tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output) - assert len(tfr.shape) == 5 + assert len(tfr.shape) == 5 # epoch x channel x taper x freq x time + assert tfr.weights.shape == tfr.shape[2:4] # check weights and coeffs shapes match @pytest.mark.parametrize("copy", (False, True)) @@ -1546,6 +1639,42 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy): assert avgs[0].comment == str(epochs_tfr.events[0, -1]) +@pytest.mark.parametrize("obj_type", ("raw", "epochs", "evoked")) +def test_tfrarray_tapered_spectra(obj_type): + """Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra.""" + # Create example data with a taper dimension + n_epochs, n_chans, n_tapers, n_freqs, n_times = (5, 3, 4, 2, 6) + data_shape = (n_chans, n_tapers, n_freqs, n_times) + if obj_type == "epochs": + data_shape = (n_epochs,) + data_shape + data = np.random.rand(*data_shape) + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.random.rand(n_tapers, n_freqs) + info = mne.create_info(n_chans, 1000.0, "eeg") + # Prepare for TFRArray object instantiation + defaults = dict(info=info, data=data, times=times, freqs=freqs) + class_mapping = dict(raw=RawTFRArray, epochs=EpochsTFRArray, evoked=AverageTFRArray) + TFRArray = class_mapping[obj_type] + # Check TFRArray instantiation runs with good data + TFRArray(**defaults, weights=weights) + # Check taper dimension but no weights caught + with pytest.raises( + ValueError, match="Taper dimension in data, but no weights found." + ): + TFRArray(**defaults) + # Check mismatching n_taper in weights caught + with pytest.raises( + ValueError, match=r"Taper axis .* doesn't match weights attribute" + ): + TFRArray(**defaults, weights=weights[:-1]) + # Check mismatching n_freq in weights caught + with pytest.raises( + ValueError, match=r"Frequency axis .* doesn't match weights attribute" + ): + TFRArray(**defaults, weights=weights[:, :-1]) + + def test_tfr_proj(epochs): """Test `compute_tfr(proj=True)`.""" epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True) @@ -1727,3 +1856,52 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request): assert re.match( rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title() ) + + +@pytest.mark.parametrize("output", ("complex", "phase")) +def test_tfr_topo_plotting_multitaper_complex_phase(output, evoked): + """Test plot_joint/topo/topomap() for data with a taper dimension.""" + # Compute TFR with taper dimension + tfr = evoked.compute_tfr( + method="multitaper", freqs=freqs_linspace, n_cycles=4, output=output + ) + # Check that plotting works + tfr.plot_joint(topomap_args=dict(res=8, contours=0, sensors=False)) # for speed + tfr.plot_topo() + tfr.plot_topomap() + + +def test_combine_tfr_error_catch(average_tfr): + """Test combine_tfr() catches errors.""" + # check unrecognised weights string caught + with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'): + combine_tfr([average_tfr, average_tfr], weights="foo") + # check bad weights size caught + with pytest.raises(ValueError, match="Weights must be the same size as all_tfr"): + combine_tfr([average_tfr, average_tfr], weights=[1, 1, 1]) + # check different channel names caught + state = average_tfr.__getstate__() + new_info = average_tfr.info.copy() + average_tfr_bad = AverageTFR( + inst=state | dict(info=new_info.rename_channels({new_info.ch_names[0]: "foo"})) + ) + with pytest.raises(AssertionError, match=".* do not contain the same channels"): + combine_tfr([average_tfr, average_tfr_bad]) + # check different times caught + average_tfr_bad = AverageTFR(inst=state | dict(times=average_tfr.times + 1)) + with pytest.raises( + AssertionError, match=".* do not contain the same time instants" + ): + combine_tfr([average_tfr, average_tfr_bad]) + # check taper dim caught + n_tapers = 3 # anything >= 1 should do + weights = np.ones((n_tapers, average_tfr.shape[1])) # tapers x freqs + state["data"] = np.repeat(np.expand_dims(average_tfr.data, 1), n_tapers, axis=1) + state["weights"] = weights + state["dims"] = ("channel", "taper", "freq", "time") + average_tfr_taper = AverageTFR(inst=state) + with pytest.raises( + NotImplementedError, + match="Aggregating multitaper tapers across TFR datasets is not supported.", + ): + combine_tfr([average_tfr_taper, average_tfr_taper]) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index eaf173092bb..f4a01e87895 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -264,8 +264,11 @@ def _make_dpss( ------- Ws : list of array The wavelets time series. + Cs : list of array + The concentration weights. Only returned if return_weights=True. """ Ws = list() + Cs = list() freqs = np.array(freqs) if np.any(freqs <= 0): @@ -281,6 +284,7 @@ def _make_dpss( for m in range(n_taps): Wm = list() + Cm = list() for k, f in enumerate(freqs): if len(n_cycles) != 1: this_n_cycles = n_cycles[k] @@ -302,12 +306,15 @@ def _make_dpss( real_offset = Wk.mean() Wk -= real_offset Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) + Ck = np.sqrt(conc[m]) Wm.append(Wk) + Cm.append(Ck) Ws.append(Wm) + Cs.append(Cm) if return_weights: - return Ws, conc + return Ws, Cs return Ws @@ -428,6 +435,7 @@ def _compute_tfr( use_fft=True, decim=1, output="complex", + return_weights=False, n_jobs=None, *, verbose=None, @@ -478,7 +486,9 @@ def _compute_tfr( * 'itc' : inter-trial coherence. * 'avg_power_itc' : average of single trial power and inter-trial coherence across trials. - + return_weights : bool, default False + Whether to return the taper weights. Only applies if method='multitaper' and + output='complex' or 'phase'. %(n_jobs)s The number of epochs to process at the same time. The parallelization is implemented across channels. @@ -495,6 +505,9 @@ def _compute_tfr( n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the real values in the ``output`` contain average power' and the imaginary values contain the ITC: ``out = avg_power + i * itc``. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if method='multitaper', output='complex' or + 'phase', and return_weights=True. """ # Check data epoch_data = np.asarray(epoch_data) @@ -516,6 +529,9 @@ def _compute_tfr( decim, output, ) + return_weights = ( + return_weights and method == "multitaper" and output in ["complex", "phase"] + ) decim = _ensure_slice(decim) if (freqs > sfreq / 2.0).any(): @@ -529,21 +545,25 @@ def _compute_tfr( if method == "morlet": W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean) Ws = [W] # to have same dimensionality as the 'multitaper' case + weights = None # no tapers for Morlet estimates elif method == "multitaper": - Ws = _make_dpss( + Ws, weights = _make_dpss( sfreq, freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, zero_mean=zero_mean, + return_weights=True, # required for converting complex → power ) + weights = np.asarray(weights) # Check wavelets if len(Ws[0][0]) > epoch_data.shape[2]: raise ValueError( "At least one of the wavelets is longer than the " - "signal. Use a longer signal or shorter wavelets." + f"signal ({len(Ws[0][0])} > {epoch_data.shape[2]} samples). " + "Use a longer signal or shorter wavelets." ) # Initialize output @@ -560,7 +580,7 @@ def _compute_tfr( if ("avg_" in output) or ("itc" in output): out = np.empty((n_chans, n_freqs, n_times), dtype) elif output in ["complex", "phase"] and method == "multitaper": - out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) + out = np.empty((n_chans, n_epochs, n_tapers, n_freqs, n_times), dtype) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) @@ -571,7 +591,7 @@ def _compute_tfr( # Parallelization is applied across channels. tfrs = parallel( - my_cwt(channel, Ws, output, use_fft, "same", decim, method) + my_cwt(channel, Ws, output, use_fft, "same", decim, weights) for channel in epoch_data.transpose(1, 0, 2) ) @@ -581,10 +601,10 @@ def _compute_tfr( if ("avg_" not in output) and ("itc" not in output): # This is to enforce that the first dimension is for epochs - if output in ["complex", "phase"] and method == "multitaper": - out = out.transpose(2, 0, 1, 3, 4) - else: - out = out.transpose(1, 0, 2, 3) + out = np.moveaxis(out, 1, 0) + + if return_weights: + return out, weights return out @@ -598,8 +618,7 @@ def _check_tfr_param( freqs = np.asarray(freqs, dtype=float) if freqs.ndim != 1: raise ValueError( - f"freqs must be of shape (n_freqs,), got {np.array(freqs.shape)} " - "instead." + f"freqs must be of shape (n_freqs,), got {np.array(freqs.shape)} instead." ) # Check sfreq @@ -658,7 +677,7 @@ def _check_tfr_param( return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim -def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): +def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, weights=None): """Aux. function to _compute_tfr. Loops time-frequency transform across wavelets and epochs. @@ -685,9 +704,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): See numpy.convolve. decim : slice The decimation slice: e.g. power[:, decim] - method : str | None - Used only for multitapering to create tapers dimension in the output - if ``output in ['complex', 'phase']``. + weights : array, shape (n_tapers, n_wavelets) | None + Concentration weights for each taper in the wavelets, if present. """ # Set output type dtype = np.float64 @@ -701,10 +719,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): n_freqs = len(Ws[0]) if ("avg_" in output) or ("itc" in output): tfrs = np.zeros((n_freqs, n_times), dtype=dtype) - elif output in ["complex", "phase"] and method == "multitaper": - tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype) + elif output in ["complex", "phase"] and weights is not None: + tfrs = np.zeros((n_epochs, n_tapers, n_freqs, n_times), dtype=dtype) else: tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype) + if weights is not None: + weights = np.expand_dims(weights, axis=-1) # add singleton time dimension # Loops across tapers. for taper_idx, W in enumerate(Ws): @@ -719,6 +739,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): # Loop across epochs for epoch_idx, tfr in enumerate(coefs): # Transform complex values + if output not in ["complex", "phase"] and weights is not None: + tfr = weights[taper_idx] * tfr # weight each taper estimate if output in ["power", "avg_power"]: tfr = (tfr * tfr.conj()).real # power elif output == "phase": @@ -734,8 +756,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): # Stack or add if ("avg_" in output) or ("itc" in output): tfrs += tfr - elif output in ["complex", "phase"] and method == "multitaper": - tfrs[taper_idx, epoch_idx] += tfr + elif output in ["complex", "phase"] and weights is not None: + tfrs[epoch_idx, taper_idx] += tfr else: tfrs[epoch_idx] += tfr @@ -749,9 +771,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): if ("avg_" in output) or ("itc" in output): tfrs /= n_epochs - # Normalization by number of taper - if n_tapers > 1 and output not in ["complex", "phase"]: - tfrs /= n_tapers + # Normalization by taper weights + if n_tapers > 1 and output not in ["complex", "phase", "itc"]: + if "avg_" not in output: # add singleton epochs dimension to weights + weights = np.expand_dims(weights, axis=0) + tfrs.real *= 2 / (weights * weights.conj()).real.sum(axis=-3) + if output == "avg_power_itc": # weight itc by the number of tapers + tfrs.imag = tfrs.imag / n_tapers + return tfrs @@ -1184,12 +1211,9 @@ def __init__( classname = "EpochsTFR" # end TODO raise ValueError( - f'{classname} got unsupported parameter value{_pl(problem)} ' - f'{" and ".join(problem)}.' + f"{classname} got unsupported parameter value{_pl(problem)} " + f"{' and '.join(problem)}." ) - # shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release) - if method == "morlet": - method_kw.setdefault("zero_mean", True) # check method valid_methods = ["morlet", "multitaper"] if isinstance(inst, BaseEpochs): @@ -1203,6 +1227,9 @@ def __init__( method_kw.setdefault("output", "power") self._freqs = np.asarray(freqs, dtype=np.float64) del freqs + # always store weights for per-taper outputs + if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]: + method_kw["return_weights"] = True # check validity of kwargs manually to save compute time if any are invalid tfr_funcs = dict( morlet=tfr_array_morlet, @@ -1224,6 +1251,7 @@ def __init__( self._method = method self._inst_type = type(inst) self._baseline = None + self._weights = None self.preload = True # needed for __getitem__, never False for TFRs # self._dims may also get updated by child classes self._dims = ["channel", "freq", "time"] @@ -1382,6 +1410,7 @@ def __getstate__(self): info=self.info, baseline=self._baseline, decim=self._decim, + weights=self._weights, ) def __setstate__(self, state): @@ -1392,7 +1421,6 @@ def __setstate__(self, state): defaults = dict( method="unknown", - dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :], baseline=None, decim=1, data_type="TFR", @@ -1410,12 +1438,13 @@ def __setstate__(self, state): self._decim = defaults["decim"] self.preload = True self._set_times(self._raw_times) + self._weights = state.get("weights") # objs saved before #12910 won't have # Handle instance type. Prior to gh-11282, Raw was not a possibility so if # `inst_type_str` is missing it must be Epochs or Evoked unknown_class = Epochs if "epoch" in self._dims else Evoked inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class) self._inst_type = inst_types[defaults["inst_type_str"]] - # sanity check data/freqs/times/info agreement + # sanity check data/freqs/times/info/weights agreement self._check_state() def __repr__(self): @@ -1468,18 +1497,29 @@ def _check_compatibility(self, other): raise RuntimeError(msg.format(problem, extra)) def _check_state(self): - """Check data/freqs/times/info agreement during __setstate__.""" + """Check data/freqs/times/info/weights agreement during __setstate__.""" msg = "{} axis of data ({}) doesn't match {} attribute ({})" n_chan_info = len(self.info["chs"]) n_chan = self._data.shape[self._dims.index("channel")] n_freq = self._data.shape[self._dims.index("freq")] n_time = self._data.shape[self._dims.index("time")] + n_taper = ( + self._data.shape[self._dims.index("taper")] + if "taper" in self._dims + else None + ) + if n_taper is not None and self._weights is None: + raise ValueError("Taper dimension in data, but no weights found.") if n_chan_info != n_chan: msg = msg.format("Channel", n_chan, "info", n_chan_info) elif n_freq != len(self.freqs): msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) elif n_time != len(self.times): msg = msg.format("Time", n_time, "times", self.times.size) + elif n_taper is not None and n_taper != self._weights.shape[0]: + msg = msg.format("Taper", n_taper, "weights", self._weights.shape[0]) + elif n_taper is not None and n_freq != self._weights.shape[1]: + msg = msg.format("Frequency", n_freq, "weights", self._weights.shape[1]) else: return raise ValueError(msg) @@ -1499,7 +1539,7 @@ def _check_values(self, negative_ok=False): s = _pl(negative_values.sum()) warn( f"Negative value in time-frequency decomposition for channel{s} " - f'{", ".join(chs)}', + f"{', '.join(chs)}", UserWarning, ) @@ -1516,6 +1556,10 @@ def _compute_tfr(self, data, n_jobs, verbose): if self.method == "stockwell": self._data, self._itc, freqs = result assert np.array_equal(self._freqs, freqs) + elif self.method == "multitaper" and self._tfr_func.keywords.get( + "output", "" + ) in ["complex", "phase"]: + self._data, self._weights = result elif self._tfr_func.keywords.get("output", "").endswith("_itc"): self._data, self._itc = result.real, result.imag else: @@ -1616,6 +1660,7 @@ def _onselect( fmax=fmax, baseline=baseline, mode=mode, + taper_weights=self.weights, verbose=verbose, ) # average over times and freqs @@ -1694,6 +1739,11 @@ def times(self): """The time points present in the data (in seconds).""" return self._times_readonly + @property + def weights(self): + """The weights used for each taper in the time-frequency estimates.""" + return self._weights + @fill_doc def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): """Crop data to a given time interval in place. @@ -1788,6 +1838,7 @@ def get_data( tmax=None, return_times=False, return_freqs=False, + return_tapers=False, ): """Get time-frequency data in NumPy array format. @@ -1803,6 +1854,10 @@ def get_data( return_freqs : bool Whether to return the frequency bin values for the requested frequency range. Default is ``False``. + return_tapers : bool + Whether to return the taper numbers. Default is ``False``. + + .. versionadded:: 1.10.0 Returns ------- @@ -1814,6 +1869,9 @@ def get_data( freqs : array The frequency values for the requested data range. Only returned if ``return_freqs`` is ``True``. + tapers : array | None + The taper numbers. Only returned if ``return_tapers`` is ``True``. Will be + ``None`` if a taper dimension is not present in the data. Notes ----- @@ -1851,7 +1909,13 @@ def get_data( if return_freqs: freqs = self._freqs[fmin_idx:fmax_idx] out.append(freqs) - if not return_times and not return_freqs: + if return_tapers: + if "taper" in self._dims: + tapers = np.arange(self.shape[self._dims.index("taper")]) + else: + tapers = None + out.append(tapers) + if not return_times and not return_freqs and not return_tapers: return out[0] return tuple(out) @@ -1963,6 +2027,7 @@ def plot( baseline=baseline, mode=mode, dB=dB, + taper_weights=self.weights, verbose=verbose, ) # shape @@ -1973,6 +2038,9 @@ def plot( want_shape[ch_axis] = len(idx_picks) if combine is None else 1 want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping + want_shape = [ + n for dim, n in zip(self._dims, want_shape) if dim != "taper" + ] # tapers must be aggregated over by now want_shape = tuple(want_shape) # combine combine_was_none = combine is None @@ -2316,6 +2384,7 @@ def plot_joint( fmax=_fmax, baseline=baseline, mode=mode, + taper_weights=self.weights, verbose=verbose, ) _data = _data.mean(axis=(-1, -2)) # avg over times and freqs @@ -2464,23 +2533,23 @@ def plot_topo( info, data = _prepare_picks(info, data, picks, axis=0) del picks - # TODO this is the only remaining call to _preproc_tfr; should be refactored - # (to use _prep_data_for_plot?) - data, times, freqs, vmin, vmax = _preproc_tfr( + # baseline, crop, convert complex to power, aggregate tapers, and dB scaling + data, times, freqs = _prep_data_for_plot( data, times, freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - dB, - info["sfreq"], + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + dB=dB, + taper_weights=self.weights, + verbose=verbose, ) + # get vlims + vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) if layout is None: from mne import find_layout @@ -2627,21 +2696,21 @@ def to_data_frame( ): """Export data in tabular structure as a pandas DataFrame. - Channels are converted to columns in the DataFrame. By default, - additional columns ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` (epoch event description) are added, unless ``index`` - is not ``None`` (in which case the columns specified in ``index`` will - be used to form the DataFrame's index instead). ``'epoch'``, and - ``'condition'`` are not supported for ``AverageTFR``. + Channels are converted to columns in the DataFrame. By default, additional + columns ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, and ``'condition'`` + (epoch event description) are added, unless ``index`` is not ``None`` (in which + case the columns specified in ``index`` will be used to form the DataFrame's + index instead). ``'epoch'``, and ``'condition'`` are not supported for + ``AverageTFR``. ``'taper'`` is only supported when a taper dimensions is + present, such as for complex or phase multitaper data. Parameters ---------- %(picks_all)s %(index_df_epo)s - Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'`` - for ``AverageTFR``. - Defaults to ``None``. + Valid string values are ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, + and ``'condition'`` for ``EpochsTFR`` and ``'time'``, ``'freq'``, and + ``'taper'`` for ``AverageTFR``. Defaults to ``None``. %(long_format_df_epo)s %(time_format_df)s @@ -2654,42 +2723,58 @@ def to_data_frame( """ # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa + # triage for Epoch-derived or unaggregated spectra + from_epo = isinstance(self, EpochsTFR) + unagg_mt = "taper" in self._dims # arg checking valid_index_args = ["time", "freq"] - if isinstance(self, EpochsTFR): + if from_epo: valid_index_args.extend(["epoch", "condition"]) + if unagg_mt: + valid_index_args.append("taper") valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) time_format = _check_time_format(time_format, valid_time_formats) # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) - data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) - axis = self._dims.index("channel") - if not isinstance(self, EpochsTFR): + data, times, freqs, tapers = self.get_data( + picks, return_times=True, return_freqs=True, return_tapers=True + ) + ch_axis = self._dims.index("channel") + if not from_epo: data = data[np.newaxis] # add singleton "epochs" axis - axis += 1 - n_epochs, n_picks, n_freqs, n_times = data.shape - # reshape to (epochs*freqs*times) x signals - data = np.moveaxis(data, axis, -1) - data = data.reshape(n_epochs * n_freqs * n_times, n_picks) + ch_axis += 1 + if not unagg_mt: + data = np.expand_dims(data, -3) # add singleton "tapers" axis + n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape + # reshape to (epochs*tapers*freqs*times) x signals + data = np.moveaxis(data, ch_axis, -1) + data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks) # prepare extra columns / multiindex mindex = list() + default_index = list() times = _convert_times(times, time_format, self.info["meas_date"]) - times = np.tile(times, n_epochs * n_freqs) - freqs = np.tile(np.repeat(freqs, n_times), n_epochs) + times = np.tile(times, n_epochs * n_freqs * n_tapers) + freqs = np.tile(np.repeat(freqs, n_times), n_epochs * n_tapers) mindex.append(("time", times)) mindex.append(("freq", freqs)) - if isinstance(self, EpochsTFR): - mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) + if from_epo: + mindex.append( + ("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers)) + ) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) + mindex.append( + ("condition", np.repeat(conditions, n_times * n_freqs * n_tapers)) + ) + default_index.extend(["condition", "epoch"]) + if unagg_mt: + tapers = np.repeat(np.tile(tapers, n_epochs), n_freqs * n_times) + mindex.append(("taper", tapers)) + default_index.append("taper") + default_index.extend(["freq", "time"]) assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) # build DataFrame - if isinstance(self, EpochsTFR): - default_index = ["condition", "epoch", "freq", "time"] - else: - default_index = ["freq", "time"] df = _build_data_frame( self, data, picks, long_format, mindex, index, default_index=default_index ) @@ -2736,6 +2821,7 @@ class AverageTFR(BaseTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2852,6 +2938,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack AverageTFR from serialized format.""" + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._comment = state.get("comment", "") self._nave = state.get("nave", 1) @@ -2895,6 +2990,7 @@ class AverageTFRArray(AverageTFR): The number of averaged TFRs. %(comment_averagetfr_attr)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -2907,6 +3003,7 @@ class AverageTFRArray(AverageTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2917,12 +3014,22 @@ class AverageTFRArray(AverageTFR): """ def __init__( - self, info, data, times, freqs, *, nave=None, comment=None, method=None + self, + info, + data, + times, + freqs, + *, + nave=None, + comment=None, + method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - for name, optional in dict(nave=nave, comment=comment, method=method).items(): - if optional is not None: - state[name] = optional + optional = dict(nave=nave, comment=comment, method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) @@ -2946,48 +3053,6 @@ class EpochsTFR(BaseTFR, GetEpochsMixin): %(picks_good_data_noref)s %(proj_psd)s %(decim_tfr)s - %(events_epochstfr)s - - .. deprecated:: 1.7 - Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use - :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. - %(event_id_epochstfr)s - - .. deprecated:: 1.7 - Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use - :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. - selection : array - List of indices of selected events (not dropped or ignored etc.). For - example, if the original event array had 4 events and the second event - has been dropped, this attribute would be np.array([0, 2, 3]). - - .. deprecated:: 1.7 - Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use - :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. - drop_log : tuple of tuple - A tuple of the same length as the event array used to initialize the - ``EpochsTFR`` object. If the i-th original event is still part of the - selection, drop_log[i] will be an empty tuple; otherwise it will be - a tuple of the reasons the event is not longer in the selection, e.g.: - - - ``'IGNORED'`` - If it isn't part of the current subset defined by the user - - ``'NO_DATA'`` or ``'TOO_SHORT'`` - If epoch didn't contain enough data names of channels that - exceeded the amplitude threshold - - ``'EQUALIZED_COUNTS'`` - See :meth:`~mne.Epochs.equalize_event_counts` - - ``'USER'`` - For user-defined reasons (see :meth:`~mne.Epochs.drop`). - - .. deprecated:: 1.7 - Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use - :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. - %(metadata_epochstfr)s - - .. deprecated:: 1.7 - Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use - :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. %(n_jobs)s %(verbose)s %(method_kw_tfr)s @@ -3007,6 +3072,7 @@ class EpochsTFR(BaseTFR, GetEpochsMixin): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3031,11 +3097,6 @@ def __init__( picks=None, proj=False, decim=1, - events=None, - event_id=None, - selection=None, - drop_log=None, - metadata=None, n_jobs=None, verbose=None, **method_kw, @@ -3091,8 +3152,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack EpochsTFR from serialized format.""" - if state["data"].ndim != 4: - raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.") + if state["data"].ndim not in [4, 5]: + raise ValueError( + f"EpochsTFR data should be 4D or 5D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("epoch", "channel") + if state["data"].ndim == 5: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._metadata = state.get("metadata", None) n_epochs = self.shape[0] @@ -3105,8 +3173,14 @@ def __setstate__(self, state): ).squeeze(axis=0) self.events = state.get("events", _ensure_events(fake_events)) self.event_id = state.get("event_id", _check_event_id(None, self.events)) - self.drop_log = state.get("drop_log", tuple()) self.selection = state.get("selection", np.arange(n_epochs)) + self.drop_log = state.get( + "drop_log", + tuple( + () if k in self.selection else ("IGNORED",) + for k in range(max(len(self.events), max(self.selection) + 1)) + ), + ) self._bad_dropped = True # always true, need for `equalize_event_counts()` def __next__(self, return_event_id=False): @@ -3196,7 +3270,16 @@ def average(self, method="mean", *, dim="epochs", copy=False): See discussion here: https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 + + Averaging is not supported for data containing a taper dimension. """ + if "taper" in self._dims: + raise NotImplementedError( + "Averaging multitaper tapers across epochs, frequencies, or times is " + "not supported. If averaging across epochs, consider averaging the " + "epochs before computing the complex/phase spectrum." + ) + _check_option("dim", dim, ("epochs", "freqs", "times")) axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural @@ -3568,6 +3651,7 @@ class EpochsTFRArray(EpochsTFR): %(selection)s %(drop_log)s %(metadata_epochstfr)s + %(weights_tfr_array)s Attributes ---------- @@ -3584,6 +3668,7 @@ class EpochsTFRArray(EpochsTFR): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3606,6 +3691,7 @@ def __init__( selection=None, drop_log=None, metadata=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) optional = dict( @@ -3616,6 +3702,7 @@ def __init__( selection=selection, drop_log=drop_log, metadata=metadata, + weights=weights, ) for name, value in optional.items(): if value is not None: @@ -3658,6 +3745,7 @@ class RawTFR(BaseTFR): method : str The method used to compute the spectra (``'morlet'``, ``'multitaper'`` or ``'stockwell'``). + %(weights_tfr_attr)s See Also -------- @@ -3707,6 +3795,19 @@ def __init__( **method_kw, ) + def __setstate__(self, state): + """Unpack RawTFR from serialized format.""" + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") + super().__setstate__(state) + def __getitem__(self, item): """Get RawTFR data. @@ -3772,6 +3873,7 @@ class RawTFRArray(RawTFR): %(times)s %(freqs_tfr_array)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -3782,6 +3884,7 @@ class RawTFRArray(RawTFR): %(method_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3799,20 +3902,23 @@ def __init__( freqs, *, method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - if method is not None: - state["method"] = method + optional = dict(method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) def combine_tfr(all_tfr, weights="nave"): """Merge AverageTFR data by weighted addition. - Create a new AverageTFR instance, using a combination of the supplied - instances as its data. By default, the mean (weighted by trials) is used. - Subtraction can be performed by passing negative weights (e.g., [1, -1]). - Data must have the same channels and the same time instants. + Create a new :class:`mne.time_frequency.AverageTFR` instance, using a combination of + the supplied instances as its data. By default, the mean (weighted by trials) is + used. Subtraction can be performed by passing negative weights (e.g., [1, -1]). Data + must have the same channels and the same time instants. Parameters ---------- @@ -3830,8 +3936,16 @@ def combine_tfr(all_tfr, weights="nave"): Notes ----- + Aggregating multitaper TFR datasets with a taper dimension such as for complex or + phase data is not supported. + .. versionadded:: 0.11.0 """ + if any("taper" in tfr._dims for tfr in all_tfr): + raise NotImplementedError( + "Aggregating multitaper tapers across TFR datasets is not supported." + ) + tfr = all_tfr[0].copy() if isinstance(weights, str): if weights not in ("nave", "equal"): @@ -3847,10 +3961,10 @@ def combine_tfr(all_tfr, weights="nave"): ch_names = tfr.ch_names for t_ in all_tfr[1:]: - assert t_.ch_names == ch_names, ValueError( + assert t_.ch_names == ch_names, ( f"{tfr} and {t_} do not contain the same channels" ) - assert np.max(np.abs(t_.times - tfr.times)) < 1e-7, ValueError( + assert np.max(np.abs(t_.times - tfr.times)) < 1e-7, ( f"{tfr} and {t_} do not contain the same time instants" ) @@ -3905,62 +4019,6 @@ def _centered(arr, newsize): return arr[tuple(myslice)] -def _preproc_tfr( - data, - times, - freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - dB, - sfreq, - copy=None, -): - """Aux Function to prepare tfr computation.""" - if copy is None: - copy = baseline is not None - data = rescale(data, times, baseline, mode, copy=copy) - - if np.iscomplexobj(data): - # complex amplitude → real power (for plotting); if data are - # real-valued they should already be power - data = (data * data.conj()).real - - # crop time - itmin, itmax = None, None - idx = np.where(_time_mask(times, tmin, tmax, sfreq=sfreq))[0] - if tmin is not None: - itmin = idx[0] - if tmax is not None: - itmax = idx[-1] + 1 - - times = times[itmin:itmax] - - # crop freqs - ifmin, ifmax = None, None - idx = np.where(_time_mask(freqs, fmin, fmax, sfreq=sfreq))[0] - if fmin is not None: - ifmin = idx[0] - if fmax is not None: - ifmax = idx[-1] + 1 - - freqs = freqs[ifmin:ifmax] - - # crop data - data = data[:, ifmin:ifmax, itmin:itmax] - - if dB: - data = 10 * np.log10(data) - - vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) - return data, times, freqs, vmin, vmax - - def _ensure_slice(decim): """Aux function checking the decim parameter.""" _validate_type(decim, ("int-like", slice), "decim") @@ -4105,7 +4163,7 @@ def _read_multiple_tfrs(tfr_data, condition=None, *, verbose=None): if len(out) == 0: raise ValueError( f'Cannot find condition "{condition}" in this file. ' - f'The file contains conditions {", ".join(keys)}' + f"The file contains conditions {', '.join(keys)}" ) if len(out) == 1: out = out[0] @@ -4195,6 +4253,7 @@ def _prep_data_for_plot( baseline=None, mode=None, dB=False, + taper_weights=None, verbose=None, ): # baseline @@ -4208,9 +4267,43 @@ def _prep_data_for_plot( freqs = freqs[freq_mask] # crop data data = data[..., freq_mask, :][..., time_mask] - # complex amplitude → real power; real-valued data is already power (or ITC) + # handle unaggregated multitaper (complex or phase multitaper data) + if taper_weights is not None: # assumes a taper dimension + logger.info("Aggregating multitaper estimates before plotting...") + if np.iscomplexobj(data): # complex coefficients → power + data = _tfr_from_mt(data, taper_weights) + else: # tapered phase data → weighted phase data + # channels, tapers, freqs, time + assert data.ndim == 4 + # weights as a function of (tapers, freqs) + assert taper_weights.ndim == 2 + data = (data * taper_weights[np.newaxis, :, :, np.newaxis]).mean(axis=1) + # handle remaining complex amplitude → real power if np.iscomplexobj(data): data = (data * data.conj()).real if dB: data = 10 * np.log10(data) return data, times, freqs + + +def _tfr_from_mt(x_mt, weights): + """Aggregate complex multitaper coefficients over tapers and convert to power. + + Parameters + ---------- + x_mt : array, shape (n_channels, n_tapers, n_freqs, n_times) + The complex-valued multitaper coefficients. + weights : array, shape (n_tapers, n_freqs) + The weights to use to combine the tapered estimates. + + Returns + ------- + tfr : array, shape (n_channels, n_freqs, n_times) + The time-frequency power estimates. + """ + weights = weights[np.newaxis, :, :, np.newaxis] # add singleton channel & time dims + tfr = weights * x_mt + tfr *= tfr.conj() + tfr = tfr.real.sum(axis=1) + tfr *= 2 / (weights * weights.conj()).real.sum(axis=1) + return tfr diff --git a/mne/transforms.py b/mne/transforms.py index c85c31964b6..7072ea25124 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -12,14 +12,13 @@ import numpy as np from scipy import linalg from scipy.spatial.distance import cdist -from scipy.special import sph_harm from ._fiff.constants import FIFF from ._fiff.open import fiff_open from ._fiff.tag import read_tag from ._fiff.write import start_and_end_file, write_coord_trans from .defaults import _handle_default -from .fixes import _get_img_fdata, jit +from .fixes import _get_img_fdata, jit, sph_harm_y from .utils import ( _check_fname, _check_option, @@ -926,7 +925,7 @@ def _compute_sph_harm(order, az, pol): # _deg_ord_idx(0, 0) = -1 so we're actually okay to use it here for degree in range(order + 1): for order_ in range(degree + 1): - sph = sph_harm(order_, degree, az, pol) + sph = sph_harm_y(degree, order_, pol, az) out[:, _deg_ord_idx(degree, order_)] = _sh_complex_to_real(sph, order_) if order_ > 0: out[:, _deg_ord_idx(degree, -order_)] = _sh_complex_to_real( diff --git a/mne/utils/_logging.py b/mne/utils/_logging.py index 68963feaf61..f4d19655bbf 100644 --- a/mne/utils/_logging.py +++ b/mne/utils/_logging.py @@ -511,7 +511,7 @@ def _frame_info(n): except KeyError: # in our verbose dec pass else: - infos.append(f'{name.lstrip("mne.")}:{frame.f_lineno}') + infos.append(f"{name.lstrip('mne.')}:{frame.f_lineno}") frame = frame.f_back if frame is None: break diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index 323b530a641..63e0d1036b9 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -179,9 +179,9 @@ def assert_and_remove_boundary_annot(annotations, n=1): annotations.delete(idx) -def assert_object_equal(a, b, *, err_msg="Object mismatch"): +def assert_object_equal(a, b, *, err_msg="Object mismatch", allclose=False): """Assert two objects are equal.""" - d = object_diff(a, b) + d = object_diff(a, b, allclose=allclose) assert d == "", f"{err_msg}\n{d}" diff --git a/mne/utils/check.py b/mne/utils/check.py index 973fa33fe79..085c51b6996 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -317,8 +317,7 @@ def _check_subject( _validate_type(second, "str", "subject input") if first is not None and first != second: raise ValueError( - f"{first_kind} ({repr(first)}) did not match " - f"{second_kind} ({second})" + f"{first_kind} ({repr(first)}) did not match {second_kind} ({second})" ) return second elif first is not None: @@ -385,7 +384,7 @@ def _check_compensation_grade(info1, info2, name1, name2="data", ch_names=None): ) -def _soft_import(name, purpose, strict=True): +def _soft_import(name, purpose, strict=True, *, min_version=None): """Import soft dependencies, providing informative errors on failure. Parameters @@ -398,11 +397,6 @@ def _soft_import(name, purpose, strict=True): strict : bool Whether to raise an error if module import fails. """ - - # so that error msg lines are aligned - def indent(x): - return x.rjust(len(x) + 14) - # Mapping import namespaces to their pypi package name pip_name = dict( sklearn="scikit-learn", @@ -415,27 +409,31 @@ def indent(x): pyvista="pyvistaqt", ).get(name, name) + got_version = None try: mod = import_module(name) - return mod except (ImportError, ModuleNotFoundError): - if strict: - raise RuntimeError( - f"For {purpose} to work, the {name} module is needed, " - + "but it could not be imported.\n" - + "\n".join( - ( - indent( - "use the following installation method " - "appropriate for your environment:" - ), - indent(f"'pip install {pip_name}'"), - indent(f"'conda install -c conda-forge {pip_name}'"), - ) - ) - ) - else: - return False + mod = False + else: + have, got_version = check_version( + name, + min_version=min_version, + return_version=True, + ) + if not have: + mod = False + if mod is False and strict: + extra = "" if min_version is None else f">={min_version}" + if got_version is not None: + extra += f" (found version {got_version})" + raise RuntimeError( + f"For {purpose} to work, the module {name}{extra} is needed, " + "but it could not be imported. Use the following installation method " + "appropriate for your environment:\n\n" + f" pip install {pip_name}\n" + f" conda install -c conda-forge {pip_name}" + ) + return mod def _check_pandas_installed(strict=True): @@ -1072,8 +1070,7 @@ def _check_sphere(sphere, info=None, sphere_units="m"): del ch_pos["FPz"] elif "Fpz" not in ch_pos and "Oz" in ch_pos: logger.info( - "Approximating Fpz location by mirroring Oz along " - "the X and Y axes." + "Approximating Fpz location by mirroring Oz along the X and Y axes." ) # This assumes Fpz and Oz have the same Z coordinate ch_pos["Fpz"] = ch_pos["Oz"] * [-1, -1, 1] @@ -1083,7 +1080,7 @@ def _check_sphere(sphere, info=None, sphere_units="m"): msg = ( f'sphere="eeglab" requires digitization points of ' f"the following electrode locations in the data: " - f'{", ".join(horizon_ch_names)}, but could not find: ' + f"{', '.join(horizon_ch_names)}, but could not find: " f"{ch_name}" ) if ch_name == "Fpz": @@ -1264,8 +1261,7 @@ def _to_rgb(*args, name="color", alpha=False): except ValueError: args = args[0] if len(args) == 1 else args raise ValueError( - f'Invalid RGB{"A" if alpha else ""} argument(s) for {name}: ' - f"{repr(args)}" + f"Invalid RGB{'A' if alpha else ''} argument(s) for {name}: {repr(args)}" ) from None @@ -1289,5 +1285,5 @@ def _check_method_kwargs(func, kwargs, msg=None): if msg is None: msg = f'function "{func}"' raise TypeError( - f'Got unexpected keyword argument{s} {", ".join(invalid_kw)} for {msg}.' + f"Got unexpected keyword argument{s} {', '.join(invalid_kw)} for {msg}." ) diff --git a/mne/utils/config.py b/mne/utils/config.py index a817886c3f0..c28373fcb93 100644 --- a/mne/utils/config.py +++ b/mne/utils/config.py @@ -185,8 +185,7 @@ def set_memmap_min_size(memmap_min_size): "triggers automated memory mapping, e.g., 1M or 0.5G" ), "MNE_REPR_HTML": ( - "bool, represent some of our objects with rich HTML in a notebook " - "environment" + "bool, represent some of our objects with rich HTML in a notebook environment" ), "MNE_SKIP_NETWORK_TESTS": ( "bool, used in a test decorator (@requires_good_network) to skip " @@ -203,8 +202,7 @@ def set_memmap_min_size(memmap_min_size): ), "MNE_USE_CUDA": "bool, use GPU for filtering/resampling", "MNE_USE_NUMBA": ( - "bool, use Numba just-in-time compiler for some of our intensive " - "computations" + "bool, use Numba just-in-time compiler for some of our intensive computations" ), "SUBJECTS_DIR": "path-like, directory of freesurfer MRI files for each subject", } @@ -583,9 +581,9 @@ def _get_numpy_libs(): for pool in pools: if pool["internal_api"] in ("openblas", "mkl"): return ( - f'{rename[pool["internal_api"]]} ' - f'{pool["version"]} with ' - f'{pool["num_threads"]} thread{_pl(pool["num_threads"])}' + f"{rename[pool['internal_api']]} " + f"{pool['version']} with " + f"{pool['num_threads']} thread{_pl(pool['num_threads'])}" ) return bad_lib @@ -874,7 +872,7 @@ def sys_info( pre = "│ " else: pre = " | " - out(f'\n{pre}{" " * ljust}{op.dirname(mod.__file__)}') + out(f"\n{pre}{' ' * ljust}{op.dirname(mod.__file__)}") out("\n") if not mne_version_good: diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 5d6851e3a0d..8ae2fcd36df 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1502,19 +1502,22 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["export_fmt_support_epochs"] = """\ Supported formats: - - EEGLAB (``.set``, uses :mod:`eeglabio`) + +- EEGLAB (``.set``, uses :mod:`eeglabio`) """ docdict["export_fmt_support_evoked"] = """\ Supported formats: - - MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) + +- MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) """ docdict["export_fmt_support_raw"] = """\ Supported formats: - - BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) - - EEGLAB (``.set``, uses :mod:`eeglabio`) - - EDF (``.edf``, uses `edfio `_) + +- BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) +- EEGLAB (``.set``, uses :mod:`eeglabio`) +- EDF (``.edf``, uses `edfio `_) """ # noqa: E501 docdict["export_warning"] = """\ @@ -4664,6 +4667,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): The title of the generated figure. If ``None`` (default), no title is displayed. """ + +docdict["title_stc"] = """ +title : str | None + Title for the figure window. If ``None``, the subject name will be used. +""" + docdict["title_tfr_plot"] = """ title : str | 'auto' | None Title for the plot. If ``"auto"``, will use the channel name (if ``combine`` is @@ -5016,6 +5025,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): solution. """ +docdict["weights_tfr_array"] = """ +weights : array, shape (n_tapers, n_freqs) | None + The weights for each taper. Must be provided if ``data`` has a taper dimension, such + as for complex or phase multitaper data. + + .. versionadded:: 1.10.0 +""" +docdict["weights_tfr_attr"] = """ +weights : array, shape (n_tapers, n_freqs) | None + The weights used for each taper in the time-frequency estimates. +""" + docdict["window_psd"] = """\ window : str | float | tuple Windowing function to use. See :func:`scipy.signal.get_window`. @@ -5117,6 +5138,8 @@ def copy_doc(source): This is useful when inheriting from a class and overloading a method. This decorator can be used to copy the docstring of the original method. + Docstrings are processed by :func:`python:inspect.cleandoc` before being used. + Parameters ---------- source : function @@ -5139,7 +5162,8 @@ def copy_doc(source): ... ''' this gets appended''' ... pass >>> print(B.m1.__doc__) - Docstring for m1 this gets appended + Docstring for m1 + this gets appended """ def wrapper(func): @@ -5147,7 +5171,7 @@ def wrapper(func): raise ValueError("Cannot copy docstring: docstring was empty.") doc = source.__doc__ if func.__doc__ is not None: - doc += func.__doc__ + doc += f"\n{inspect.cleandoc(func.__doc__)}" func.__doc__ = doc return func @@ -5166,6 +5190,10 @@ def copy_function_doc_to_method_doc(source): function. This pattern is prevalent in for example the plotting functions of MNE. + Docstrings are parsed by :func:`python:inspect.cleandoc` before being used. + If indentation and newlines are important, make the first line ``.``, and the dot + will be removed and all following lines dedented jointly. + Parameters ---------- source : function @@ -5201,7 +5229,8 @@ def copy_function_doc_to_method_doc(source): >>> class A: ... @copy_function_doc_to_method_doc(plot_function) ... def plot(self, a, b): - ... ''' + ... '''. + ... ... Notes ... ----- ... .. versionadded:: 0.13.0 @@ -5210,26 +5239,31 @@ def copy_function_doc_to_method_doc(source): >>> print(A.plot.__doc__) Docstring for plotting function. - Parameters - ---------- - a : int - Some parameter - b : int - Some parameter - - Notes - ----- - .. versionadded:: 0.13.0 + Parameters + ---------- + a : int + Some parameter + b : int + Some parameter + Notes + ----- + .. versionadded:: 0.13.0 """ # noqa: D410, D411, D214, D215 def wrapper(func): - doc = source.__doc__.split("\n") + # Work with cleandoc'ed sources (py3.13-compat) + doc = inspect.cleandoc(source.__doc__).split("\n") + if func.__doc__ is not None: + func_doc = inspect.cleandoc(func.__doc__) + if func_doc[:2] == ".\n": + func_doc = func_doc[2:] + func_doc = f"\n{func_doc}" + else: + func_doc = "" + if len(doc) == 1: - doc = doc[0] - if func.__doc__ is not None: - doc += func.__doc__ - func.__doc__ = doc + func.__doc__ = f"{doc[0]}{func_doc}" return func # Find parameter block @@ -5277,7 +5311,7 @@ def wrapper(func): break else: # End of docstring reached - first_parameter_end = line + first_parameter_end = line + 1 first_parameter = parameter_block # Copy the docstring, but remove the first parameter @@ -5286,9 +5320,7 @@ def wrapper(func): + "\n" + "\n".join(doc[first_parameter_end:]) ) - if func.__doc__ is not None: - doc += func.__doc__ - func.__doc__ = doc + func.__doc__ = f"{doc}{func_doc}" return func return wrapper diff --git a/mne/utils/linalg.py b/mne/utils/linalg.py index 4e4c9ba23c5..9382aad50f2 100644 --- a/mne/utils/linalg.py +++ b/mne/utils/linalg.py @@ -190,7 +190,7 @@ def _sym_mat_pow(A, power, rcond=1e-7, reduce_rank=False, return_s=False): return out -# SciPy deprecation of pinv + pinvh rcond (never worked properly anyway) +# SciPy pinv + pinvh rcond never worked properly anyway so we roll our own def pinvh(a, rtol=None): """Compute a pseudo-inverse of a Hermitian matrix. diff --git a/mne/utils/misc.py b/mne/utils/misc.py index bb3e3ee5cab..343761aee24 100644 --- a/mne/utils/misc.py +++ b/mne/utils/misc.py @@ -379,7 +379,7 @@ def _assert_no_instances(cls, when=""): check = False if check: if cls.__name__ == "Brain": - ref.append(f'Brain._cleaned = {getattr(obj, "_cleaned", None)}') + ref.append(f"Brain._cleaned = {getattr(obj, '_cleaned', None)}") rr = gc.get_referrers(obj) count = 0 for r in rr: diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index c287fb42305..11ba0ecb487 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -35,7 +35,7 @@ check_random_state, ) from .docs import fill_doc -from .misc import _empty_hash +from .misc import _empty_hash, _pl def split_list(v, n, idx=False): @@ -479,7 +479,8 @@ def _time_mask( extra = "" if include_tmax else "when include_tmax=False " raise ValueError( f"No samples remain when using tmin={orig_tmin} and tmax={orig_tmax} " - f"{extra}(original time bounds are [{times[0]}, {times[-1]}])" + f"{extra}(original time bounds are [{times[0]}, {times[-1]}] containing " + f"{len(times)} sample{_pl(times)})" ) return mask @@ -515,55 +516,65 @@ def _freq_mask(freqs, sfreq, fmin=None, fmax=None, raise_error=True): def grand_average(all_inst, interpolate_bads=True, drop_bads=True): - """Make grand average of a list of Evoked or AverageTFR data. + """Make grand average of a list of Evoked, AverageTFR, or Spectrum data. - For :class:`mne.Evoked` data, the function interpolates bad channels based - on the ``interpolate_bads`` parameter. If ``interpolate_bads`` is True, - the grand average file will contain good channels and the bad channels - interpolated from the good MEG/EEG channels. - For :class:`mne.time_frequency.AverageTFR` data, the function takes the - subset of channels not marked as bad in any of the instances. + For :class:`mne.Evoked` data, the function interpolates bad channels based on the + ``interpolate_bads`` parameter. If ``interpolate_bads`` is True, the grand average + file will contain good channels and the bad channels interpolated from the good + MEG/EEG channels. + For :class:`mne.time_frequency.AverageTFR` and :class:`mne.time_frequency.Spectrum` + data, the function takes the subset of channels not marked as bad in any of the + instances. - The ``grand_average.nave`` attribute will be equal to the number - of evoked datasets used to calculate the grand average. + The ``grand_average.nave`` attribute will be equal to the number of datasets used to + calculate the grand average. - .. note:: A grand average evoked should not be used for source - localization. + .. note:: A grand average evoked should not be used for source localization. Parameters ---------- - all_inst : list of Evoked or AverageTFR - The evoked datasets. + all_inst : list of Evoked, AverageTFR or Spectrum + The datasets. + + .. versionchanged:: 1.10.0 + Added support for :class:`~mne.time_frequency.Spectrum` objects. + interpolate_bads : bool If True, bad MEG and EEG channels are interpolated. Ignored for - AverageTFR. + :class:`~mne.time_frequency.AverageTFR` and + :class:`~mne.time_frequency.Spectrum` data. drop_bads : bool - If True, drop all bad channels marked as bad in any data set. - If neither interpolate_bads nor drop_bads is True, in the output file, - every channel marked as bad in at least one of the input files will be - marked as bad, but no interpolation or dropping will be performed. + If True, drop all bad channels marked as bad in any data set. If neither + ``interpolate_bads`` nor ``drop_bads`` is `True`, in the output file, every + channel marked as bad in at least one of the input files will be marked as bad, + but no interpolation or dropping will be performed. Returns ------- - grand_average : Evoked | AverageTFR + grand_average : Evoked | AverageTFR | Spectrum The grand average data. Same type as input. Notes ----- + Aggregating multitaper TFR datasets with a taper dimension such as for complex or + phase data is not supported. + .. versionadded:: 0.11.0 """ # check if all elements in the given list are evoked data from ..channels.channels import equalize_channels from ..evoked import Evoked - from ..time_frequency import AverageTFR + from ..time_frequency import AverageTFR, Spectrum if not all_inst: - raise ValueError("Please pass a list of Evoked or AverageTFR objects.") + raise ValueError( + "Please pass a list of Evoked, AverageTFR, or Spectrum objects." + ) elif len(all_inst) == 1: warn("Only a single dataset was passed to mne.grand_average().") inst_type = type(all_inst[0]) - _validate_type(all_inst[0], (Evoked, AverageTFR), "All elements") + _validate_type(all_inst[0], (Evoked, AverageTFR, Spectrum), "All elements") for inst in all_inst: _validate_type(inst, inst_type, "All elements", "of the same type") @@ -578,6 +589,8 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): for inst in all_inst ] from ..evoked import combine_evoked as combine + elif isinstance(all_inst[0], Spectrum): + from ..time_frequency.spectrum import combine_spectrum as combine else: # isinstance(all_inst[0], AverageTFR): from ..time_frequency.tfr import combine_tfr as combine @@ -588,9 +601,9 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): inst.drop_channels(bads) equalize_channels(all_inst, copy=False) - # make grand_average object using combine_[evoked/tfr] + # make grand_average object using combine_[evoked/tfr/spectrum] grand_average = combine(all_inst, weights="equal") - # change the grand_average.nave to the number of Evokeds + # change the grand_average.nave to the number of datasets grand_average.nave = len(all_inst) # change comment field grand_average.comment = f"Grand average (n = {grand_average.nave})" @@ -859,6 +872,9 @@ def fit_transform(self, X, y=None): return U + def fit(self, X): + self._fit(X) + def _fit(self, X): if self.n_components is None: n_components = min(X.shape) diff --git a/mne/utils/tests/test_check.py b/mne/utils/tests/test_check.py index 4b519198e39..2cfea767aa8 100644 --- a/mne/utils/tests/test_check.py +++ b/mne/utils/tests/test_check.py @@ -28,6 +28,7 @@ _path_like, _record_warnings, _safe_input, + _soft_import, _suggest, _validate_type, catch_logging, @@ -372,3 +373,9 @@ def test_check_sphere_verbose(): _check_sphere("auto", info) with mne.use_log_level("error"): _check_sphere("auto", info) + + +def test_soft_import(): + """Test _soft_import.""" + with pytest.raises(RuntimeError, match=r".* the module mne>=999 \(found version.*"): + _soft_import("mne", "testing", min_version="999") diff --git a/mne/utils/tests/test_config.py b/mne/utils/tests/test_config.py index e7611081b55..4426eae0fc7 100644 --- a/mne/utils/tests/test_config.py +++ b/mne/utils/tests/test_config.py @@ -129,7 +129,7 @@ def test_sys_info_complete(): pyproject = tomllib.loads(pyproject.read_text("utf-8")) deps = pyproject["project"]["optional-dependencies"]["test_extra"] for dep in deps: - dep = dep.split("[")[0].split(">")[0] + dep = dep.split("[")[0].split(">")[0].strip() assert f" {dep}" in out, f"Missing in dev config: {dep}" diff --git a/mne/utils/tests/test_docs.py b/mne/utils/tests/test_docs.py index 64ebc2c6916..ea355820f57 100644 --- a/mne/utils/tests/test_docs.py +++ b/mne/utils/tests/test_docs.py @@ -122,7 +122,7 @@ def m1(): def test_copy_function_doc_to_method_doc(): - """Test decorator for re-using function docstring as method docstrings.""" + """Test decorator for reusing function docstring as method docstrings.""" def f1(obj, a, b, c): """Docstring for f1. @@ -195,28 +195,29 @@ def method_f3(self): assert ( A.method_f1.__doc__ - == """Docstring for f1. - - Parameters - ---------- - a : int - Parameter a - b : int - Parameter b - """ + == """\ +Docstring for f1. + +Parameters +---------- +a : int + Parameter a +b : int + Parameter b""" ) assert ( A.method_f2.__doc__ - == """Docstring for f2. + == """\ +Docstring for f2. - Returns - ------- - nothing. - method_f3 own docstring""" +Returns +------- +nothing. +method_f3 own docstring""" ) - assert A.method_f3.__doc__ == "Docstring for f3.\n\n " + assert A.method_f3.__doc__ == "Docstring for f3.\n\n" pytest.raises(ValueError, copy_function_doc_to_method_doc(f5), A.method_f1) diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index eb51e75a349..4a893b7c017 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -2354,6 +2354,7 @@ def plot_source_estimates( transparent=True, alpha=1.0, time_viewer="auto", + *, subjects_dir=None, figure=None, views="auto", @@ -2463,8 +2464,7 @@ def plot_source_estimates( Defaults to 'oct6'. .. versionadded:: 0.15.0 - title : str | None - Title for the figure. If None, the subject name will be used. + %(title_stc)s .. versionadded:: 0.17.0 %(show_traces)s @@ -2543,6 +2543,7 @@ def plot_source_estimates( view_layout=view_layout, add_data_kwargs=add_data_kwargs, brain_kwargs=brain_kwargs, + title=title, **kwargs, ) @@ -2578,6 +2579,7 @@ def _plot_stc( view_layout, add_data_kwargs, brain_kwargs, + title, ): from ..source_estimate import _BaseVolSourceEstimate from .backends.renderer import _get_3d_backend, get_brain_class @@ -2620,7 +2622,9 @@ def _plot_stc( if overlay_alpha == 0: smoothing_steps = 1 # Disable smoothing to save time. - title = subject if len(hemis) > 1 else f"{subject} - {hemis[0]}" + sub_info = subject if len(hemis) > 1 else f"{subject} - {hemis[0]}" + title = title if title is not None else sub_info + kwargs = { "subject": subject, "hemi": hemi, @@ -2917,11 +2921,9 @@ def _plot_and_correct(*, params, cut_coords): params["axes"].clear() if params.get("fig_anat") is not None and plot_kwargs["colorbar"]: params["fig_anat"]._cbar.ax.clear() - with warnings.catch_warnings(record=True): # nilearn bug; ax recreated - warnings.simplefilter("ignore", DeprecationWarning) - params["fig_anat"] = nil_func( - params["img_idx"], cut_coords=cut_coords, **plot_kwargs - ) + params["fig_anat"] = nil_func( + params["img_idx"], cut_coords=cut_coords, **plot_kwargs + ) params["fig_anat"]._cbar.outline.set_visible(False) for key in "xyz": params.update({"ax_" + key: params["fig_anat"].axes[key].ax}) @@ -3253,6 +3255,7 @@ def plot_vector_source_estimates( vector_alpha=1.0, scale_factor=None, time_viewer="auto", + *, subjects_dir=None, figure=None, views="lateral", @@ -3264,6 +3267,7 @@ def plot_vector_source_estimates( foreground=None, initial_time=None, time_unit="s", + title=None, show_traces="auto", src=None, volume_options=1.0, @@ -3341,6 +3345,9 @@ def plot_vector_source_estimates( time_unit : 's' | 'ms' Whether time is represented in seconds ("s", default) or milliseconds ("ms"). + %(title_stc)s + + .. versionadded:: 1.9 %(show_traces)s %(src_volume_options)s %(view_layout)s @@ -3387,6 +3394,7 @@ def plot_vector_source_estimates( cortex=cortex, foreground=foreground, size=size, + title=title, scale_factor=scale_factor, show_traces=show_traces, src=src, diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 247c0840858..778700c99a7 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -4072,28 +4072,28 @@ def _update_monotonic(lims, fmin, fmid, fmax): if fmin is not None: lims["fmin"] = fmin if lims["fmax"] < fmin: - logger.debug(f' Bumping fmax = {lims["fmax"]} to {fmin}') + logger.debug(f" Bumping fmax = {lims['fmax']} to {fmin}") lims["fmax"] = fmin if lims["fmid"] < fmin: - logger.debug(f' Bumping fmid = {lims["fmid"]} to {fmin}') + logger.debug(f" Bumping fmid = {lims['fmid']} to {fmin}") lims["fmid"] = fmin assert lims["fmin"] <= lims["fmid"] <= lims["fmax"] if fmid is not None: lims["fmid"] = fmid if lims["fmin"] > fmid: - logger.debug(f' Bumping fmin = {lims["fmin"]} to {fmid}') + logger.debug(f" Bumping fmin = {lims['fmin']} to {fmid}") lims["fmin"] = fmid if lims["fmax"] < fmid: - logger.debug(f' Bumping fmax = {lims["fmax"]} to {fmid}') + logger.debug(f" Bumping fmax = {lims['fmax']} to {fmid}") lims["fmax"] = fmid assert lims["fmin"] <= lims["fmid"] <= lims["fmax"] if fmax is not None: lims["fmax"] = fmax if lims["fmin"] > fmax: - logger.debug(f' Bumping fmin = {lims["fmin"]} to {fmax}') + logger.debug(f" Bumping fmin = {lims['fmin']} to {fmax}") lims["fmin"] = fmax if lims["fmid"] > fmax: - logger.debug(f' Bumping fmid = {lims["fmid"]} to {fmax}') + logger.debug(f" Bumping fmid = {lims['fmid']} to {fmax}") lims["fmid"] = fmax assert lims["fmin"] <= lims["fmid"] <= lims["fmax"] diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 2a1c943250b..5d092c21713 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -850,14 +850,6 @@ def tiny(tmp_path): def test_brain_screenshot(renderer_interactive_pyvistaqt, tmp_path, brain_gc): """Test time viewer screenshot.""" # This is broken on Conda + GHA for some reason - from qtpy import API_NAME - - if ( - os.getenv("CONDA_PREFIX", "") != "" - and os.getenv("GITHUB_ACTIONS", "") == "true" - or API_NAME.lower() == "pyside6" - ): - pytest.skip("Test is unreliable on GitHub Actions conda runs and pyside6") tiny_brain, ratio = tiny(tmp_path) img_nv = tiny_brain.screenshot(time_viewer=False) want = (_TINY_SIZE[1] * ratio, _TINY_SIZE[0] * ratio, 3) @@ -875,9 +867,9 @@ def _assert_brain_range(brain, rng): for key, mesh in layerer._overlays.items(): if key == "curv": continue - assert ( - mesh._rng == rng - ), f"_layered_meshes[{repr(hemi)}][{repr(key)}]._rng != {rng}" + assert mesh._rng == rng, ( + f"_layered_meshes[{repr(hemi)}][{repr(key)}]._rng != {rng}" + ) @testing.requires_testing_data @@ -1245,9 +1237,9 @@ def test_brain_scraper(renderer_interactive_pyvistaqt, brain_gc, tmp_path): w = img.shape[1] w0 = size[0] # On Linux+conda we get a width of 624, similar tweak in test_brain_init above - assert np.isclose(w, w0, atol=30) or np.isclose( - w, w0 * 2, atol=30 - ), f"w ∉ {{{w0}, {2 * w0}}}" # HiDPI + assert np.isclose(w, w0, atol=30) or np.isclose(w, w0 * 2, atol=30), ( + f"w ∉ {{{w0}, {2 * w0}}}" + ) # HiDPI @testing.requires_testing_data diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index b63d2a395e2..f492c4b7fde 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -500,11 +500,11 @@ def _create_ch_location_fig(self, pick): show=False, ) # highlight desired channel & disable interactivity - inds = np.isin(fig.lasso.ch_names, [ch_name]) + fig.lasso.selection_inds = np.isin(fig.lasso.names, [ch_name]) fig.lasso.disconnect() - fig.lasso.alpha_other = 0.3 + fig.lasso.alpha_nonselected = 0.3 fig.lasso.linewidth_selected = 3 - fig.lasso.style_sensors(inds) + fig.lasso.style_objects() return fig diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index 2e552bd4012..3987b641dff 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -1536,7 +1536,7 @@ def _update_selection(self): def _update_highlighted_sensors(self): """Update the sensor plot to show what is selected.""" inds = np.isin( - self.mne.fig_selection.lasso.ch_names, self.mne.ch_names[self.mne.picks] + self.mne.fig_selection.lasso.names, self.mne.ch_names[self.mne.picks] ).nonzero()[0] self.mne.fig_selection.lasso.select_many(inds) diff --git a/mne/viz/_proj.py b/mne/viz/_proj.py index 5d21afb0594..6e0cb9a4143 100644 --- a/mne/viz/_proj.py +++ b/mne/viz/_proj.py @@ -90,8 +90,7 @@ def plot_projs_joint( missing = (~used.astype(bool)).sum() if missing: warn( - f"{missing} projector{_pl(missing)} had no channel names " - "present in epochs" + f"{missing} projector{_pl(missing)} had no channel names present in epochs" ) del projs ch_types = list(proj_by_type) # reduce to number we actually need diff --git a/mne/viz/backends/_abstract.py b/mne/viz/backends/_abstract.py index c847da4486b..81b21cbef3c 100644 --- a/mne/viz/backends/_abstract.py +++ b/mne/viz/backends/_abstract.py @@ -991,7 +991,7 @@ def _set_size(self, width=None, height=None): # ------------------------------------ # Non-object-based Widget Abstractions # ------------------------------------ -# These are planned to be deprecated in favor of the simpler, object- +# These are planned to be removed in favor of the simpler, object- # oriented abstractions above when time allows. diff --git a/mne/viz/backends/_notebook.py b/mne/viz/backends/_notebook.py index 38ff6784bc5..2a6a20278f7 100644 --- a/mne/viz/backends/_notebook.py +++ b/mne/viz/backends/_notebook.py @@ -18,7 +18,7 @@ Button, Checkbox, Dropdown, - # non-object-based-abstraction-only widgets, deprecate + # non-object-based-abstraction-only widgets, remove FloatSlider, GridBox, HBox, @@ -806,7 +806,7 @@ def show(self): # ------------------------------------ # Non-object-based Widget Abstractions # ------------------------------------ -# These are planned to be deprecated in favor of the simpler, object- +# These are planned to be removed in favor of the simpler, object- # oriented abstractions above when time allows. diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index 85060837729..ee5b62404d3 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -837,7 +837,7 @@ def _toggle_antialias(self): """Enable it everywhere except on systems with problematic OpenGL.""" # MESA can't seem to handle MSAA and depth peeling simultaneously, see # https://github.com/pyvista/pyvista/issues/4867 - bad_system = _is_mesa(self.plotter) + bad_system = _is_osmesa(self.plotter) for plotter in self._all_plotters: if bad_system or not self.antialias: plotter.disable_anti_aliasing() @@ -1096,9 +1096,7 @@ def _3d_to_2d(plotter, xyz): def _close_all(): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - close_all() + close_all() _FIGURES.clear() @@ -1319,10 +1317,11 @@ def _disabled_depth_peeling(): depth_peeling["enabled"] = depth_peeling_enabled -def _is_mesa(plotter): +def _is_osmesa(plotter): # MESA (could use GPUInfo / _get_gpu_info here, but it takes # > 700 ms to make a new window + report capabilities!) # CircleCI's is: "Mesa 20.0.8 via llvmpipe (LLVM 10.0.0, 256 bits)" + # and a working Nouveau is: "Mesa 24.2.3-1ubuntu1 via NVE6" if platform.system() == "Darwin": # segfaults on macOS sometimes return False gpu_info_full = plotter.ren_win.ReportCapabilities() @@ -1331,8 +1330,8 @@ def _is_mesa(plotter): gpu_info_full, ) gpu_info = " ".join(gpu_info).lower() - is_mesa = "mesa" in gpu_info.split() - if is_mesa: + is_osmesa = "mesa" in gpu_info.split() + if is_osmesa: # Try to warn if it's ancient version = re.findall("mesa ([0-9.]+)[ -].*", gpu_info) or re.findall( "OpenGL version string: .* Mesa ([0-9.]+)\n", gpu_info_full @@ -1345,7 +1344,8 @@ def _is_mesa(plotter): "surface rendering, consider upgrading to 18.3.6 or " "later." ) - return is_mesa + is_osmesa = "via llvmpipe" in gpu_info + return is_osmesa class _SafeBackgroundPlotter(BackgroundPlotter): diff --git a/mne/viz/backends/_qt.py b/mne/viz/backends/_qt.py index b1cd43788ef..78b02a6d05b 100644 --- a/mne/viz/backends/_qt.py +++ b/mne/viz/backends/_qt.py @@ -25,7 +25,7 @@ QObject, Qt, QTimer, - # non-object-based-abstraction-only, deprecate + # non-object-based-abstraction-only, remove Signal, ) from qtpy.QtGui import QCursor, QIcon, QKeyEvent @@ -33,7 +33,7 @@ QButtonGroup, QCheckBox, QComboBox, - # non-object-based-abstraction-only, deprecate + # non-object-based-abstraction-only, remove QDockWidget, QDoubleSpinBox, QFileDialog, @@ -102,7 +102,6 @@ _check_3d_figure, # noqa: F401 _close_3d_figure, # noqa: F401 _close_all, # noqa: F401 - _is_mesa, # noqa: F401 _PyVistaRenderer, _set_3d_title, # noqa: F401 _set_3d_view, # noqa: F401 @@ -843,7 +842,7 @@ def _clean(self): # ------------------------------------ # Non-object-based Widget Abstractions # ------------------------------------ -# These are planned to be deprecated in favor of the simpler, object- +# These are planned to be removed in favor of the simpler, object- # oriented abstractions above when time allows. diff --git a/mne/viz/backends/_utils.py b/mne/viz/backends/_utils.py index c415d83e456..467f5cb15e7 100644 --- a/mne/viz/backends/_utils.py +++ b/mne/viz/backends/_utils.py @@ -317,8 +317,7 @@ def _qt_get_stylesheet(theme): file = open(theme) except OSError: warn( - "Requested theme file not found, will use light instead: " - f"{repr(theme)}" + f"Requested theme file not found, will use light instead: {repr(theme)}" ) else: with file as fid: diff --git a/mne/viz/backends/tests/_utils.py b/mne/viz/backends/tests/_utils.py deleted file mode 100644 index 5e668a4dc75..00000000000 --- a/mne/viz/backends/tests/_utils.py +++ /dev/null @@ -1,44 +0,0 @@ -# Authors: The MNE-Python contributors. -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - -import warnings - -import pytest - - -def has_pyvista(): - """Check that PyVista is installed.""" - try: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - import pyvista # noqa: F401 - return True - except ImportError: - return False - - -def has_pyvistaqt(): - """Check that PyVistaQt is installed.""" - try: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - import pyvistaqt # noqa: F401 - return True - except ImportError: - return False - - -def has_imageio_ffmpeg(): - """Check if imageio-ffmpeg is installed.""" - try: - import imageio_ffmpeg # noqa: F401 - - return True - except ImportError: - return False - - -skips_if_not_pyvistaqt = pytest.mark.skipif( - not has_pyvistaqt(), reason="requires pyvistaqt" -) diff --git a/mne/viz/backends/tests/test_abstract.py b/mne/viz/backends/tests/test_abstract.py index bfb3663146a..365f4b70fd9 100644 --- a/mne/viz/backends/tests/test_abstract.py +++ b/mne/viz/backends/tests/test_abstract.py @@ -5,7 +5,6 @@ from pathlib import Path from mne.viz.backends.renderer import _get_backend -from mne.viz.backends.tests._utils import skips_if_not_pyvistaqt def _do_widget_tests(backend): @@ -106,7 +105,6 @@ def callback(x=None): window._close() -@skips_if_not_pyvistaqt def test_widget_abstraction_pyvistaqt(renderer_pyvistaqt): """Test the GUI widgets abstraction.""" backend = _get_backend() diff --git a/mne/viz/backends/tests/test_renderer.py b/mne/viz/backends/tests/test_renderer.py index a52942c804b..88506bae210 100644 --- a/mne/viz/backends/tests/test_renderer.py +++ b/mne/viz/backends/tests/test_renderer.py @@ -13,18 +13,19 @@ from mne.viz import Figure3D, get_3d_backend, set_3d_backend from mne.viz.backends._utils import ALLOWED_QUIVER_MODES from mne.viz.backends.renderer import _get_renderer -from mne.viz.backends.tests._utils import skips_if_not_pyvistaqt @pytest.mark.parametrize( "backend", [ - pytest.param("pyvistaqt", marks=skips_if_not_pyvistaqt), + pytest.param("pyvistaqt"), pytest.param("foo", marks=pytest.mark.xfail(raises=ValueError)), ], ) def test_backend_environment_setup(backend, monkeypatch): """Test set up 3d backend based on env.""" + if backend == "pyvistaqt": + pytest.importorskip("pyvistaqt") monkeypatch.setenv("MNE_3D_BACKEND", backend) monkeypatch.setattr("mne.viz.backends.renderer.MNE_3D_BACKEND", None) assert os.environ["MNE_3D_BACKEND"] == backend # just double-check @@ -217,16 +218,21 @@ def fail(x): def test_3d_warning(renderer_pyvistaqt, monkeypatch): """Test that warnings are emitted for old Mesa.""" fig = renderer_pyvistaqt.create_3d_figure((800, 600)) - _is_mesa = renderer_pyvistaqt.backend._is_mesa + from mne.viz.backends._pyvista import _is_osmesa + plotter = fig.plotter - good = "OpenGL renderer string: OpenGL 3.3 (Core Profile) Mesa 20.0.8 via llvmpipe (LLVM 10.0.0, 256 bits)\n" # noqa - bad = "OpenGL renderer string: OpenGL 3.3 (Core Profile) Mesa 18.3.4 via llvmpipe (LLVM 7.0, 256 bits)\n" # noqa + pre = "OpenGL renderer string: " + good = f"{pre}OpenGL 3.3 (Core Profile) Mesa 20.0.8 via llvmpipe (LLVM 10.0.0, 256 bits)\n" # noqa + bad = f"{pre}OpenGL 3.3 (Core Profile) Mesa 18.3.4 via llvmpipe (LLVM 7.0, 256 bits)\n" # noqa monkeypatch.setattr(platform, "system", lambda: "Linux") # avoid short-circuit monkeypatch.setattr(plotter.ren_win, "ReportCapabilities", lambda: good) - assert _is_mesa(plotter) + assert _is_osmesa(plotter) monkeypatch.setattr(plotter.ren_win, "ReportCapabilities", lambda: bad) with pytest.warns(RuntimeWarning, match=r"18\.3\.4 is too old"): - assert _is_mesa(plotter) - non = "OpenGL 4.1 Metal - 76.3 via Apple M1 Pro\n" + assert _is_osmesa(plotter) + non = f"{pre}OpenGL 4.1 Metal - 76.3 via Apple M1 Pro\n" + monkeypatch.setattr(plotter.ren_win, "ReportCapabilities", lambda: non) + assert not _is_osmesa(plotter) + non = f"{pre}OpenGL 4.5 (Core Profile) Mesa 24.2.3-1ubuntu1 via NVE6\n" monkeypatch.setattr(plotter.ren_win, "ReportCapabilities", lambda: non) - assert not _is_mesa(plotter) + assert not _is_osmesa(plotter) diff --git a/mne/viz/circle.py b/mne/viz/circle.py index fdcbd5a26bb..67a47c0d5fd 100644 --- a/mne/viz/circle.py +++ b/mne/viz/circle.py @@ -6,6 +6,7 @@ from functools import partial from itertools import cycle +from types import SimpleNamespace import numpy as np @@ -371,6 +372,7 @@ def _plot_connectivity_circle( cb_yticks = plt.getp(cb.ax.axes, "yticklabels") cb.ax.tick_params(labelsize=fontsize_colorbar) plt.setp(cb_yticks, color=textcolor) + fig.mne = SimpleNamespace(colorbar=cb) # Add callback for interaction if interactive: diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 1e4b1dfb4c6..96ee0684e6e 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -27,6 +27,7 @@ _clean_names, _is_numeric, _pl, + _time_mask, _to_rgb, _validate_type, fill_doc, @@ -1152,6 +1153,7 @@ def plot_evoked_topo( background_color="w", noise_cov=None, exclude="bads", + select=False, show=True, ): """Plot 2D topography of evoked responses. @@ -1217,6 +1219,15 @@ def plot_evoked_topo( exclude : list of str | ``'bads'`` Channels names to exclude from the plot. If ``'bads'``, the bad channels are excluded. By default, exclude is set to ``'bads'``. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 + exclude : list of str | ``'bads'`` + Channels names to exclude from the plot. If ``'bads'``, the + bad channels are excluded. By default, exclude is set to ``'bads'``. show : bool Show figure if True. @@ -1273,10 +1284,11 @@ def plot_evoked_topo( font_color=font_color, merge_channels=merge_grads, legend=legend, + noise_cov=noise_cov, axes=axes, exclude=exclude, + select=select, show=show, - noise_cov=noise_cov, ) @@ -1988,10 +2000,18 @@ def plot_evoked_joint( contours = topomap_args.get("contours", 6) ch_type = ch_types.pop() # set should only contain one element # Since the data has all the ch_types, we get the limits from the plot. - vmin, vmax = ts_ax.get_ylim() + vmin, vmax = (None, None) norm = ch_type == "grad" vmin = 0 if norm else vmin - vmin, vmax = _setup_vmin_vmax(evoked.data, vmin, vmax, norm) + time_idx = [ + np.where( + _time_mask(evoked.times, tmin=t, tmax=None, sfreq=evoked.info["sfreq"]) + )[0][0] + for t in times_sec + ] + scalings = topomap_args["scalings"] if "scalings" in topomap_args else None + scaling = _handle_default("scalings", scalings)[ch_type] + vmin, vmax = _setup_vmin_vmax(evoked.data[:, time_idx] * scaling, vmin, vmax, norm) if not isinstance(contours, list | np.ndarray): locator, contours = _set_contour_locator(vmin, vmax, contours) else: @@ -2009,7 +2029,7 @@ def plot_evoked_joint( from matplotlib import ticker cbar = fig.colorbar(map_ax[0].images[0], ax=map_ax, cax=cbar_ax, shrink=0.8) - cbar.ax.grid(False) # auto-removal deprecated as of 2021/10/05 + cbar.ax.grid(False) if isinstance(contours, list | np.ndarray): cbar.set_ticks(contours) else: diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index 2a93febba4e..cf5a9996216 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -380,45 +380,36 @@ def _configure_dock(self): # Fieldline configuration layout = r._dock_add_group_box("Fieldlines") - if self._show_density: - r._dock_add_label(value="max value", align=True, layout=layout) - - @_auto_weakref - def _callback(vmax, kind, scaling): - self.set_vmax(vmax / scaling, kind=kind) - - for surf_map in self._surf_maps: - if surf_map["map_kind"] == "meg": - scaling = DEFAULTS["scalings"]["grad"] - else: - scaling = DEFAULTS["scalings"]["eeg"] - rng = [0, np.max(np.abs(surf_map["data"])) * scaling] - hlayout = r._dock_add_layout(vertical=False) - - self._widgets[f"vmax_slider_{surf_map['map_kind']}"] = ( - r._dock_add_slider( - name=surf_map["map_kind"].upper(), - value=surf_map["map_vmax"] * scaling, - rng=rng, - callback=partial( - _callback, kind=surf_map["map_kind"], scaling=scaling - ), - double=True, - layout=hlayout, - ) - ) - self._widgets[f"vmax_spin_{surf_map['map_kind']}"] = ( - r._dock_add_spin_box( - name="", - value=surf_map["map_vmax"] * scaling, - rng=rng, - callback=partial( - _callback, kind=surf_map["map_kind"], scaling=scaling - ), - layout=hlayout, - ) - ) - r._layout_add_widget(layout, hlayout) + r._dock_add_label(value="max value", align=True, layout=layout) + + @_auto_weakref + def _callback(vmax, kind, scaling): + self.set_vmax(vmax / scaling, kind=kind) + + for surf_map in self._surf_maps: + if surf_map["map_kind"] == "meg": + scaling = DEFAULTS["scalings"]["grad"] + else: + scaling = DEFAULTS["scalings"]["eeg"] + rng = [0, np.max(np.abs(surf_map["data"])) * scaling] + hlayout = r._dock_add_layout(vertical=False) + + self._widgets[f"vmax_slider_{surf_map['map_kind']}"] = r._dock_add_slider( + name=surf_map["map_kind"].upper(), + value=surf_map["map_vmax"] * scaling, + rng=rng, + callback=partial(_callback, kind=surf_map["map_kind"], scaling=scaling), + double=True, + layout=hlayout, + ) + self._widgets[f"vmax_spin_{surf_map['map_kind']}"] = r._dock_add_spin_box( + name="", + value=surf_map["map_vmax"] * scaling, + rng=rng, + callback=partial(_callback, kind=surf_map["map_kind"], scaling=scaling), + layout=hlayout, + ) + r._layout_add_widget(layout, hlayout) hlayout = r._dock_add_layout(vertical=False) r._dock_add_label( diff --git a/mne/viz/misc.py b/mne/viz/misc.py index ed2636d3961..c83a4dfe717 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -443,7 +443,7 @@ def _plot_mri_contours( if src[0]["coord_frame"] != FIFF.FIFFV_COORD_MRI: raise ValueError( "Source space must be in MRI coordinates, got " - f'{_frame_to_str[src[0]["coord_frame"]]}' + f"{_frame_to_str[src[0]['coord_frame']]}" ) for src_ in src: points = src_["rr"][src_["inuse"].astype(bool)] @@ -708,8 +708,7 @@ def plot_bem( src = read_source_spaces(src) elif src is not None and not isinstance(src, SourceSpaces): raise TypeError( - "src needs to be None, path-like or SourceSpaces instance, " - f"not {repr(src)}" + f"src needs to be None, path-like or SourceSpaces instance, not {repr(src)}" ) if len(surfaces) == 0: diff --git a/mne/viz/montage.py b/mne/viz/montage.py index b7f1724cb87..221cc21f7a0 100644 --- a/mne/viz/montage.py +++ b/mne/viz/montage.py @@ -19,8 +19,7 @@ def plot_montage( montage, *, - scale=None, - scale_factor=None, + scale=1.0, show_names=True, kind="topomap", show=True, @@ -36,9 +35,7 @@ def plot_montage( The montage to visualize. scale : float Determines the scale of the channel points and labels; values < 1 will scale - down, whereas values > 1 will scale up. Default to None, which implies 1. - scale_factor : float - Determines the size of the points. Deprecated, use scale instead. + down, whereas values > 1 will scale up. show_names : bool | list Whether to display all channel names. If a list, only the channel names in the list are shown. Defaults to True. @@ -61,16 +58,7 @@ def plot_montage( from ..channels import DigMontage, make_dig_montage - if scale_factor is not None: - msg = "scale_factor has been deprecated and will be removed. Use scale instead." - if scale is not None: - raise ValueError( - " ".join(["scale and scale_factor cannot be used together.", msg]) - ) - logger.info(msg) - if scale is None: - scale = 1 - + _validate_type(scale, "numeric", "scale") _check_option("kind", kind, ["topomap", "3d"]) _validate_type(montage, DigMontage, item_name="montage") ch_names = montage.ch_names @@ -112,11 +100,7 @@ def plot_montage( axes=axes, ) - if scale_factor is not None: - # scale points - collection = fig.axes[0].collections[0] - collection.set_sizes([scale_factor]) - elif scale is not None: + if scale != 1.0: # scale points collection = fig.axes[0].collections[0] collection.set_sizes([scale * 10]) diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 6f109b9490b..34022d59768 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -893,7 +893,7 @@ def test_plot_alignment_fnirs(renderer, tmp_path): with catch_logging() as log: fig = plot_alignment(info, **kwargs) log = log.getvalue() - assert f'fnirs_cw_amplitude: {info["nchan"]}' in log + assert f"fnirs_cw_amplitude: {info['nchan']}" in log _assert_n_actors(fig, renderer, info["nchan"]) fig = plot_alignment(info, fnirs=["channels", "sources", "detectors"], **kwargs) diff --git a/mne/viz/tests/test_circle.py b/mne/viz/tests/test_circle.py index a26379f6ccc..c5f3719746b 100644 --- a/mne/viz/tests/test_circle.py +++ b/mne/viz/tests/test_circle.py @@ -17,7 +17,10 @@ def test_plot_channel_labels_circle(): fig, axes = plot_channel_labels_circle( dict(brain=["big", "great", "smart"]), colors=dict(big="r", great="y", smart="b"), + colorbar=True, ) + # check that colorbar handle is returned + assert isinstance(fig.mne.colorbar, matplotlib.colorbar.Colorbar) texts = [ child.get_text() for child in axes.get_children() diff --git a/mne/viz/tests/test_evoked.py b/mne/viz/tests/test_evoked.py index 035008dd87f..964acae2b31 100644 --- a/mne/viz/tests/test_evoked.py +++ b/mne/viz/tests/test_evoked.py @@ -339,8 +339,10 @@ def test_plot_evoked_image(): ch_names = evoked.ch_names[3:5] picks = [evoked.ch_names.index(ch) for ch in ch_names] - evoked.plot_image(show_names="all", time_unit="s", picks=picks) - yticklabels = plt.gca().get_yticklabels() + fig = evoked.plot_image(show_names="all", time_unit="s", picks=picks) + fig.canvas.draw_idle() + yticklabels = fig.axes[0].get_yticklabels() + assert len(yticklabels) == len(ch_names) for tick_target, tick_observed in zip(ch_names, yticklabels): assert tick_target in str(tick_observed) evoked.plot_image(show_names=True, time_unit="s") diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index a2565927feb..caa09ae4d07 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -862,16 +862,14 @@ def test_remove_annotations(raw, hide_which, browser_backend): assert len(raw.annotations) == len(hide_which) -def test_merge_annotations(raw, browser_backend): +def test_merge_annotations(raw, pg_backend): """Test merging of annotations in the Qt backend. Let's not bother in figuring out on which sample the _fake_click actually dropped the annotation, especially with the 600.614 Hz weird sampling rate. -> atol = 10 / raw.info["sfreq"] """ - if browser_backend.name == "matplotlib": - pytest.skip("The MPL backend does not support draggable annotations.") - elif not check_version("mne_qt_browser", "0.5.3"): + if not check_version("mne_qt_browser", "0.5.3"): pytest.xfail("mne_qt_browser < 0.5.3 does not merge annotations properly") annot = Annotations( onset=[1, 3, 4, 5, 7, 8], @@ -970,7 +968,7 @@ def test_plot_raw_psd(raw, raw_orig): """Test plotting of raw psds.""" raw_unchanged = raw.copy() spectrum = raw.compute_psd() - # deprecation change handler + # change handler old_defaults = dict(picks="data", exclude="bads") fig = spectrum.plot(average=False, amplitude=False) # normal mode @@ -1090,36 +1088,25 @@ def test_plot_sensors(raw): pytest.raises(TypeError, plot_sensors, raw) # needs to be info pytest.raises(ValueError, plot_sensors, raw.info, kind="sasaasd") plt.close("all") + + # Test lasso selection. fig, sels = raw.plot_sensors("select", show_names=True) ax = fig.axes[0] - - # Click with no sensors - _fake_click(fig, ax, (0.0, 0.0), xform="data") - _fake_click(fig, ax, (0, 0.0), xform="data", kind="release") - assert fig.lasso.selection == [] - - # Lasso with 1 sensor (upper left) - _fake_click(fig, ax, (0, 1), xform="ax") - fig.canvas.draw() - assert fig.lasso.selection == [] - _fake_click(fig, ax, (0.65, 1), xform="ax", kind="motion") - _fake_click(fig, ax, (0.65, 0.7), xform="ax", kind="motion") - _fake_keypress(fig, "control") - _fake_click(fig, ax, (0, 0.7), xform="ax", kind="release", key="control") + # Lasso a single sensor. + _fake_click(fig, ax, (-0.13, 0.13), xform="data") + _fake_click(fig, ax, (-0.11, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.11, 0.06), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.06), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="release") assert fig.lasso.selection == ["MEG 0121"] - # check that point appearance changes - fc = fig.lasso.collection.get_facecolors() - ec = fig.lasso.collection.get_edgecolors() - assert (fc[:, -1] == [0.5, 1.0, 0.5]).all() - assert (ec[:, -1] == [0.25, 1.0, 0.25]).all() - - _fake_click(fig, ax, (0.7, 1), xform="ax", kind="motion", key="control") - xy = ax.collections[0].get_offsets() - _fake_click(fig, ax, xy[2], xform="data", key="control") # single sel + # Add another sensor with a single click. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (-0.1278, 0.0318), xform="data") + _fake_click(fig, ax, (-0.1278, 0.0318), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") assert fig.lasso.selection == ["MEG 0121", "MEG 0131"] - _fake_click(fig, ax, xy[2], xform="data", key="control") # deselect - assert fig.lasso.selection == ["MEG 0121"] plt.close("all") raw.info["dev_head_t"] = None # like empty room diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index 85b4b43dcf8..48d031739b9 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -23,7 +23,7 @@ ) from mne.viz.evoked import _line_plot_onselect from mne.viz.topo import _imshow_tfr, _plot_update_evoked_topo_proj, iter_topography -from mne.viz.utils import _fake_click +from mne.viz.utils import _fake_click, _fake_keypress base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" evoked_fname = base_dir / "test-ave.fif" @@ -231,6 +231,16 @@ def test_plot_topo(): break plt.close("all") + # Test plot_topo with selection of channels enabled. + fig = evoked.plot_topo(select=True) + ax = fig.axes[0] + _fake_click(fig, ax, (0.05, 0.62), xform="data") + _fake_click(fig, ax, (0.2, 0.62), xform="data", kind="motion") + _fake_click(fig, ax, (0.2, 0.7), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.7), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.7), xform="data", kind="release") + assert fig.lasso.selection == ["MEG 0113", "MEG 0112", "MEG 0111"] + def test_plot_topo_nirs(fnirs_evoked): """Test plotting of ERP topography for nirs data.""" @@ -296,6 +306,30 @@ def test_plot_topo_image_epochs(): assert qm_cmap[0] is cmap +def test_plot_topo_select(): + """Test selecting sensors in an ERP topography plot.""" + # Show topography + evoked = _get_epochs().average() + fig = plot_evoked_topo(evoked, select=True) + ax = fig.axes[0] + + # Lasso select 3 out of the 6 sensors. + _fake_click(fig, ax, (0.05, 0.5), xform="data") + _fake_click(fig, ax, (0.2, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.2, 0.6), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.6), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.5), xform="data", kind="release") + assert fig.lasso.selection == ["MEG 0132", "MEG 0133", "MEG 0131"] + + # Add another sensor with a single click. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (0.11, 0.65), xform="data") + _fake_click(fig, ax, (0.21, 0.65), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") + assert fig.lasso.selection == ["MEG 0111", "MEG 0132", "MEG 0133", "MEG 0131"] + + def test_plot_tfr_topo(): """Test plotting of TFR data.""" epochs = _get_epochs() diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index afa9341c00e..b87d0d39f89 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -44,7 +44,7 @@ compute_bridged_electrodes, compute_current_source_density, ) -from mne.time_frequency.tfr import AverageTFRArray +from mne.time_frequency.tfr import AverageTFR, AverageTFRArray from mne.viz import plot_evoked_topomap, plot_projs_topomap, topomap from mne.viz.tests.test_raw import _proj_status from mne.viz.topomap import ( @@ -610,6 +610,29 @@ def test_plot_tfr_topomap(): ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 ) + # test data with taper dimension (real) + data = np.expand_dims(data, axis=1) + weights = np.random.rand(1, n_freqs) + tfr = AverageTFRArray( + info=info, + data=data, + times=times, + freqs=np.arange(n_freqs), + nave=nave, + weights=weights, + ) + tfr.plot_topomap( + ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 + ) + # test data with taper dimension (complex) + state = tfr.__getstate__() + tfr = AverageTFR(inst=state | dict(data=data * (1 + 1j))) + tfr.plot_topomap( + ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 + ) + # remove taper dim before proceeding + data = data[:, 0] + # test real numbers tfr = AverageTFRArray( info=info, data=data, times=times, freqs=np.arange(n_freqs), nave=nave diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index 59e2976e464..55dc0f1e65c 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -16,6 +16,7 @@ from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap from mne.viz.ui_events import ColormapRange, link, subscribe from mne.viz.utils import ( + SelectFromCollection, _compute_scalings, _fake_click, _fake_keypress, @@ -274,3 +275,71 @@ def callback(event): cmap_new1 = fig.axes[0].CB.mappable.get_cmap().name cmap_new2 = fig2.axes[0].CB.mappable.get_cmap().name assert cmap_new1 == cmap_new2 == cmap_want != cmap_old + + +def test_select_from_collection(): + """Test the lasso selector for matplotlib figures.""" + fig, ax = plt.subplots() + collection = ax.scatter([1, 2, 2, 1], [1, 1, 0, 0], color="black", edgecolor="red") + ax.set_xlim(-1, 4) + ax.set_ylim(-1, 2) + lasso = SelectFromCollection(ax, collection, names=["A", "B", "C", "D"]) + assert lasso.selection == [] + + # Make a selection with no patches inside of it. + _fake_click(fig, ax, (0, 0), xform="data") + _fake_click(fig, ax, (0.5, 0), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1), xform="data", kind="release") + assert lasso.selection == [] + + # Doing a single click on a patch should not select it. + _fake_click(fig, ax, (1, 1), xform="data") + assert lasso.selection == [] + + # Make a selection with two patches in it. + _fake_click(fig, ax, (0, 0.5), xform="data") + _fake_click(fig, ax, (3, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (3, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 0.5), xform="data", kind="release") + assert lasso.selection == ["A", "B"] + + # Use Control key to lasso an additional patch. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (0.5, -0.5), xform="data") + _fake_click(fig, ax, (1.5, -0.5), xform="data", kind="motion") + _fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 0.5), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") + assert lasso.selection == ["A", "B", "D"] + + # Use CTRL+SHIFT to remove a patch. + _fake_keypress(fig, "ctrl+shift") + _fake_click(fig, ax, (0.5, 0.5), xform="data") + _fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (1.5, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1.5), xform="data", kind="release") + _fake_keypress(fig, "ctrl+shift", kind="release") + assert lasso.selection == ["B", "D"] + + # Check that the two selected patches have a different appearance. + fc = lasso.collection.get_facecolors() + ec = lasso.collection.get_edgecolors() + assert (fc[:, -1] == [0.5, 1.0, 0.5, 1.0]).all() + assert (ec[:, -1] == [0.25, 1.0, 0.25, 1.0]).all() + + # Test adding and removing single channels. + lasso.select_one(2) # should not do anything without modifier keys + assert lasso.selection == ["B", "D"] + _fake_keypress(fig, "control") + lasso.select_one(2) # add to selection + _fake_keypress(fig, "control", kind="release") + assert lasso.selection == ["B", "C", "D"] + _fake_keypress(fig, "ctrl+shift") + lasso.select_one(1) # remove from selection + assert lasso.selection == ["C", "D"] + _fake_keypress(fig, "ctrl+shift", kind="release") diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 3364a455aed..5c43d4de48e 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -13,8 +13,10 @@ from .._fiff.pick import _picks_to_idx, channel_type, pick_types from ..defaults import _handle_default from ..utils import Bunch, _check_option, _clean_names, _is_numeric, _to_rgb, fill_doc +from .ui_events import ChannelsSelect, publish, subscribe from .utils import ( DraggableColorbar, + SelectFromCollection, _check_cov, _check_delayed_ssp, _draw_proj_checkbox, @@ -37,6 +39,7 @@ def iter_topography( axis_spinecolor="k", layout_scale=None, legend=False, + select=False, ): """Create iterator over channel positions. @@ -72,6 +75,12 @@ def iter_topography( If True, an additional axis is created in the bottom right corner that can be used to, e.g., construct a legend. The index of this axis will be -1. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 Returns ------- @@ -93,6 +102,7 @@ def iter_topography( axis_spinecolor, layout_scale, legend=legend, + select=select, ) @@ -128,6 +138,7 @@ def _iter_topography( img=False, axes=None, legend=False, + select=False, ): """Iterate over topography. @@ -193,8 +204,11 @@ def format_coord_multiaxis(x, y, ch_name=None): under_ax.set(xlim=[0, 1], ylim=[0, 1]) axs = list() + + shown_ch_names = [] for idx, name in iter_ch: ch_idx = ch_names.index(name) + shown_ch_names.append(name) if not unified: # old, slow way ax = plt.axes(pos[idx]) ax.patch.set_facecolor(axis_facecolor) @@ -226,24 +240,48 @@ def format_coord_multiaxis(x, y, ch_name=None): if unified: under_ax._mne_axs = axs # Create a PolyCollection for the axis backgrounds + sel_pos = pos[[i[0] for i in iter_ch]] verts = np.transpose( [ - pos[:, :2], - pos[:, :2] + pos[:, 2:] * [1, 0], - pos[:, :2] + pos[:, 2:], - pos[:, :2] + pos[:, 2:] * [0, 1], + sel_pos[:, :2], + sel_pos[:, :2] + sel_pos[:, 2:] * [1, 0], + sel_pos[:, :2] + sel_pos[:, 2:], + sel_pos[:, :2] + sel_pos[:, 2:] * [0, 1], ], [1, 0, 2], ) - if not img: - under_ax.add_collection( - collections.PolyCollection( - verts, - facecolor=axis_facecolor, - edgecolor=axis_spinecolor, - linewidth=1.0, + if not img: # Not needed for image plots. + collection = collections.PolyCollection( + verts, + facecolor=axis_facecolor, + edgecolor=axis_spinecolor, + linewidth=1.0, + ) + under_ax.add_collection(collection) + + if select: + # Configure the lasso-selection tool + fig.lasso = SelectFromCollection( + ax=under_ax, + collection=collection, + names=shown_ch_names, + alpha_nonselected=0, + alpha_selected=1, + linewidth_nonselected=0, + linewidth_selected=0.7, ) - ) # Not needed for image plots. + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + + def on_channels_select(event): + selection_inds = np.flatnonzero( + np.isin(shown_ch_names, event.ch_names) + ) + fig.lasso.select_many(selection_inds) + + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) for ax in axs: yield ax, ax._mne_ch_idx @@ -270,6 +308,7 @@ def _plot_topo( unified=False, img=False, axes=None, + select=False, ): """Plot on sensor layout.""" import matplotlib.pyplot as plt @@ -322,6 +361,7 @@ def _plot_topo( unified=unified, img=img, axes=axes, + select=select, ) for ax, ch_idx in my_topo_plot: @@ -340,8 +380,17 @@ def _plot_topo( def _plot_topo_onpick(event, show_func): """Onpick callback that shows a single channel in a new figure.""" - # make sure that the swipe gesture in OS-X doesn't open many figures orig_ax = event.inaxes + fig = orig_ax.figure + + # If we are doing lasso select, allow it to handle the click instead. + if hasattr(fig, "lasso") and event.key in ["control", "ctrl+shift"]: + return + + # make sure that the swipe gesture in OS-X doesn't open many figures + if fig.canvas._key in ["shift", "alt"]: + return + import matplotlib.pyplot as plt try: @@ -838,9 +887,10 @@ def _plot_evoked_topo( merge_channels=False, legend=True, axes=None, + noise_cov=None, exclude="bads", + select=False, show=True, - noise_cov=None, ): """Plot 2D topography of evoked responses. @@ -912,6 +962,10 @@ def _plot_evoked_topo( exclude : list of str | 'bads' Channels names to exclude from being shown. If 'bads', the bad channels are excluded. By default, exclude is set to 'bads'. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. show : bool Show figure if True. @@ -1091,6 +1145,7 @@ def _plot_evoked_topo( y_label=y_label, unified=True, axes=axes, + select=select, ) add_background_image(fig, fig_background) @@ -1098,7 +1153,10 @@ def _plot_evoked_topo( if legend is not False: legend_loc = 0 if legend is True else legend labels = [e.comment if e.comment else "Unknown" for e in evoked] - handles = fig.axes[0].lines[: len(evoked)] + if select: + handles = fig.axes[0].lines[1 : len(evoked) + 1] + else: + handles = fig.axes[0].lines[: len(evoked)] legend = plt.legend( labels=labels, handles=handles, loc=legend_loc, prop={"size": 10} ) @@ -1157,6 +1215,7 @@ def plot_topo_image_epochs( fig_facecolor="k", fig_background=None, font_color="w", + select=False, show=True, ): """Plot Event Related Potential / Fields image on topographies. @@ -1204,6 +1263,12 @@ def plot_topo_image_epochs( :func:`matplotlib.pyplot.imshow`. Defaults to ``None``. font_color : color The color of tick labels in the colorbar. Defaults to white. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 show : bool Whether to show the figure. Defaults to ``True``. @@ -1293,6 +1358,7 @@ def plot_topo_image_epochs( y_label="Epoch", unified=True, img=True, + select=select, ) add_background_image(fig, fig_background) plt_show(show) diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index 147919a9c9d..bb180a3f299 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -910,8 +910,7 @@ def _get_pos_outlines(info, picks, sphere, to_sphere=True): orig_sphere = sphere sphere, clip_origin = _adjust_meg_sphere(sphere, info, ch_type) logger.debug( - "Generating pos outlines with sphere " - f"{sphere} from {orig_sphere} for {ch_type}" + f"Generating pos outlines with sphere {sphere} from {orig_sphere} for {ch_type}" ) pos = _find_topomap_coords( info, picks, ignore_overlap=True, to_sphere=to_sphere, sphere=sphere @@ -1262,7 +1261,7 @@ def _plot_topomap( if len(data) != len(pos): raise ValueError( "Data and pos need to be of same length. Got data of " - f"length {len(data)}, pos of length { len(pos)}" + f"length {len(data)}, pos of length {len(pos)}" ) norm = min(data) >= 0 @@ -1409,8 +1408,7 @@ def _plot_ica_topomap( sphere = _check_sphere(sphere, ica.info) if not isinstance(axes, Axes): raise ValueError( - "axis has to be an instance of matplotlib Axes, " - f"got {type(axes)} instead." + f"axis has to be an instance of matplotlib Axes, got {type(axes)} instead." ) ch_type = _get_plot_ch_type(ica, ch_type, allow_ref_meg=ica.allow_ref_meg) if ch_type == "ref_meg": @@ -1882,7 +1880,7 @@ def plot_tfr_topomap( tfr, ch_type, sphere=sphere ) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) - data = tfr.data[picks, :, :] + data = tfr.data[picks] # merging grads before rescaling makes ERDs visible if merge_channels: @@ -1890,6 +1888,18 @@ def plot_tfr_topomap( data = rescale(data, tfr.times, baseline, mode, copy=True) + # handle unaggregated multitaper (complex or phase multitaper data) + if tfr.weights is not None: # assumes a taper dimension + logger.info("Aggregating multitaper estimates before plotting...") + weights = tfr.weights[np.newaxis, :, :, np.newaxis] # add channel & time dims + data = weights * data + if np.iscomplexobj(data): # complex coefficients → power + data *= data.conj() + data = data.real.sum(axis=1) + data *= 2 / (weights * weights.conj()).real.sum(axis=1) + else: # tapered phase data → weighted phase data + data = data.mean(axis=1) + # handle remaining complex amplitude → real power if np.iscomplexobj(data): data = np.sqrt((data * data.conj()).real) @@ -2104,6 +2114,22 @@ def plot_evoked_topomap( :ref:`gridspec ` interface to adjust the colorbar size yourself. + The defaults for ``contours`` and ``vlim`` are handled as follows: + + * When neither ``vlim`` nor a list of ``contours`` is passed, MNE sets + ``vlim`` at ± the maximum absolute value of the data and then chooses + contours within those bounds. + + * When ``vlim`` but not a list of ``contours`` is passed, MNE chooses + contours to be within the ``vlim``. + + * When a list of ``contours`` but not ``vlim`` is passed, MNE chooses + ``vlim`` to encompass the ``contours`` and the maximum absolute value of the + data. + + * When both a list of ``contours`` and ``vlim`` are passed, MNE uses them + as-is. + When ``time=="interactive"``, the figure will publish and subscribe to the following UI events: @@ -2179,8 +2205,7 @@ def plot_evoked_topomap( space = 1 / (2.0 * evoked.info["sfreq"]) if max(times) > max(evoked.times) + space or min(times) < min(evoked.times) - space: raise ValueError( - f"Times should be between {evoked.times[0]:0.3} and " - f"{evoked.times[-1]:0.3}." + f"Times should be between {evoked.times[0]:0.3} and {evoked.times[-1]:0.3}." ) # create axes want_axes = n_times + int(colorbar) @@ -2287,11 +2312,17 @@ def plot_evoked_topomap( _vlim = [ _setup_vmin_vmax(data[:, i], *vlim, norm=merge_channels) for i in range(n_times) ] - _vlim = (np.min(_vlim), np.max(_vlim)) + _vlim = [np.min(_vlim), np.max(_vlim)] cmap = _setup_cmap(cmap, n_axes=n_times, norm=_vlim[0] >= 0) # set up contours if not isinstance(contours, list | np.ndarray): _, contours = _set_contour_locator(*_vlim, contours) + else: + if vlim[0] is None and np.any(contours < _vlim[0]): + _vlim[0] = contours[0] + if vlim[1] is None and np.any(contours > _vlim[1]): + _vlim[1] = contours[-1] + # prepare for main loop over times kwargs = dict( sensors=sensors, @@ -2779,8 +2810,7 @@ def plot_psds_topomap( # convert legacy list-of-tuple input to a dict bands = {band[-1]: band[:-1] for band in bands} logger.info( - "converting legacy list-of-tuples input to a dict for the " - "`bands` parameter" + "converting legacy list-of-tuples input to a dict for the `bands` parameter" ) # upconvert single freqs to band upper/lower edges as needed bin_spacing = np.diff(freqs)[0] @@ -3340,6 +3370,7 @@ def _set_contour_locator(vmin, vmax, contours): # correct number of bins is equal to contours + 1. locator = ticker.MaxNLocator(nbins=contours + 1) contours = locator.tick_values(vmin, vmax) + contours = contours[1:-1] return locator, contours @@ -3456,11 +3487,9 @@ def _trigradient(x, y, z): """Take gradients of z on a mesh.""" from matplotlib.tri import CubicTriInterpolator, Triangulation - with warnings.catch_warnings(): # catch matplotlib warnings - warnings.filterwarnings("ignore", category=DeprecationWarning) - tri = Triangulation(x, y) - tci = CubicTriInterpolator(tri, z) - dx, dy = tci.gradient(tri.x, tri.y) + tri = Triangulation(x, y) + tci = CubicTriInterpolator(tri, z) + dx, dy = tci.gradient(tri.x, tri.y) return dx, dy diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index 256d5741ad3..b8b3fe29a4d 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -212,6 +212,26 @@ class Contours(UIEvent): contours: list[str] +@dataclass +@fill_doc +class ChannelsSelect(UIEvent): + """Indicates that the user has selected one or more channels. + + Parameters + ---------- + ch_names : list of str + The names of the channels that were selected. + + Attributes + ---------- + %(ui_event_name_source)s + ch_names : list of str + The names of the channels that were selected. + """ + + ch_names: list[str] + + def _get_event_channel(fig): """Get the event channel associated with a figure. diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 675b89b2852..f9d64c49ec8 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -58,7 +58,7 @@ warn, ) from ..utils.misc import _identity_function -from .ui_events import ColormapRange, publish, subscribe +from .ui_events import ChannelsSelect, ColormapRange, publish, subscribe _channel_type_prettyprint = { "eeg": "EEG channel", @@ -807,12 +807,12 @@ def _fake_click(fig, ax, point, xform="ax", button=1, kind="press", key=None): ) -def _fake_keypress(fig, key): +def _fake_keypress(fig, key, kind="press"): from matplotlib import backend_bases fig.canvas.callbacks.process( - "key_press_event", - backend_bases.KeyEvent(name="key_press_event", canvas=fig.canvas, key=key), + f"key_{kind}_event", + backend_bases.KeyEvent(name=f"key_{kind}_event", canvas=fig.canvas, key=key), ) @@ -952,7 +952,7 @@ def plot_sensors( Whether to plot the sensors as 3d, topomap or as an interactive sensor selection dialog. Available options ``'topomap'``, ``'3d'``, ``'select'``. If ``'select'``, a set of channels can be selected - interactively by using lasso selector or clicking while holding control + interactively by using lasso selector or clicking while holding the control key. The selected channels are returned along with the figure instance. Defaults to ``'topomap'``. ch_type : None | str @@ -1163,10 +1163,10 @@ def _onpick_sensor(event, fig, ax, pos, ch_names, show_names): if event.mouseevent.inaxes != ax: return - if event.mouseevent.key == "control" and fig.lasso is not None: + if fig.lasso is not None and event.mouseevent.key in ["control", "ctrl+shift"]: + # Add the sensor to the selection instead of showing its name. for ind in event.ind: fig.lasso.select_one(ind) - return if show_names: return # channel names already visible @@ -1272,7 +1272,17 @@ def _plot_sensors_2d( lw=linewidth, ) if kind == "select": - fig.lasso = SelectFromCollection(ax, pts, ch_names) + fig.lasso = SelectFromCollection(ax, pts, names=ch_names) + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + + def on_channels_select(event): + selection_inds = np.flatnonzero(np.isin(ch_names, event.ch_names)) + fig.lasso.select_many(selection_inds) + + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) else: fig.lasso = None @@ -1595,11 +1605,14 @@ def _update(self): class SelectFromCollection: - """Select channels from a matplotlib collection using ``LassoSelector``. + """Select objects from a matplotlib collection using ``LassoSelector``. - Selected channels are saved in the ``selection`` attribute. This tool - highlights selected points by fading other points out (i.e., reducing their - alpha values). + The names of the selected objects are saved in the ``selection`` attribute. + This tool highlights selected objects by fading other objects out (i.e., + reducing their alpha values). + + Holding down the Control key will add to the current selection, and holding down + Control+Shift will remove from the current selection. Parameters ---------- @@ -1607,112 +1620,144 @@ class SelectFromCollection: Axes to interact with. collection : instance of matplotlib collection Collection you want to select from. - alpha_other : 0 <= float <= 1 - To highlight a selection, this tool sets all selected points to an - alpha value of 1 and non-selected points to ``alpha_other``. - Defaults to 0.3. - linewidth_other : float - Linewidth to use for non-selected sensors. Default is 1. + names : list of str + The names of the object. The selection is returned as a subset of these names. + alpha_selected : float + Alpha for selected objects (0=tranparant, 1=opaque). + alpha_nonselected : float + Alpha for non-selected objects (0=tranparant, 1=opaque). + linewidth_selected : float + Linewidth for the borders of selected objects. + linewidth_nonselected : float + Linewidth for the borders of non-selected objects. Notes ----- - This tool selects collection objects based on their *origins* - (i.e., ``offsets``). Calls all callbacks in self.callbacks when selection - is ready. + This tool selects collection objects which bounding boxes intersect with a lasso + path. Calls all callbacks in self.callbacks when selection is ready. """ def __init__( self, ax, collection, - ch_names, - alpha_other=0.5, - linewidth_other=0.5, + *, + names, alpha_selected=1, + alpha_nonselected=0.5, linewidth_selected=1, + linewidth_nonselected=0.5, + verbose=None, ): from matplotlib.widgets import LassoSelector + self.fig = ax.figure self.canvas = ax.figure.canvas self.collection = collection - self.ch_names = ch_names - self.alpha_other = alpha_other - self.linewidth_other = linewidth_other + self.names = names self.alpha_selected = alpha_selected + self.alpha_nonselected = alpha_nonselected self.linewidth_selected = linewidth_selected + self.linewidth_nonselected = linewidth_nonselected - self.xys = collection.get_offsets() - self.Npts = len(self.xys) + from matplotlib.collections import PolyCollection + from matplotlib.path import Path - # Ensure that we have separate colors for each object + if isinstance(collection, PolyCollection): + self.paths = collection.get_paths() + else: + self.paths = [Path([point]) for point in collection.get_offsets()] + self.Npts = len(self.paths) + if self.Npts != len(names): + raise ValueError( + f"Number of names ({len(names)}) does not match the number of objects " + f"in the collection ({self.Npts})." + ) + + # Ensure that we have colors for each object. self.fc = collection.get_facecolors() self.ec = collection.get_edgecolors() - self.lw = collection.get_linewidths() if len(self.fc) == 0: raise ValueError("Collection must have a facecolor") elif len(self.fc) == 1: self.fc = np.tile(self.fc, self.Npts).reshape(self.Npts, -1) + if len(self.ec) == 0: + self.ec = np.zeros((self.Npts, 4)) # all black + elif len(self.ec) == 1: self.ec = np.tile(self.ec, self.Npts).reshape(self.Npts, -1) - self.fc[:, -1] = self.alpha_other # deselect in the beginning - self.ec[:, -1] = self.alpha_other - self.lw = np.full(self.Npts, self.linewidth_other) + self.lw = np.full(self.Npts, float(self.linewidth_nonselected)) + # Initialize the lasso selector self.lasso = LassoSelector( ax, onselect=self.on_select, props=dict(color="red", linewidth=0.5) ) self.selection = list() + self.selection_inds = np.array([], dtype="int") self.callbacks = list() + # Deselect everything in the beginning. + self.style_objects() + + # For backwards compatibility + @property + def ch_names(self): + return self.names + + def notify(self): + """Notify listeners that a selection has been made.""" + logger.info(f"Selected channels: {self.selection}") + for callback in self.callbacks: + callback() + def on_select(self, verts): """Select a subset from the collection.""" from matplotlib.path import Path - if len(verts) <= 3: # Seems to be a good way to exclude single clicks. + # Don't respond to single clicks without extra keys being hold down. + # Figures like plot_evoked_topo want to do something else with them. + if len(verts) <= 3 and self.canvas._key not in ["control", "ctrl+shift"]: return path = Path(verts) - inds = np.nonzero([path.contains_point(xy) for xy in self.xys])[0] + inds = np.nonzero([path.intersects_path(p) for p in self.paths])[0] if self.canvas._key == "control": # Appending selection. - sels = [np.where(self.ch_names == c)[0][0] for c in self.selection] - inters = set(inds) - set(sels) - inds = list(inters.union(set(sels) - set(inds))) - - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) + self.selection_inds = np.union1d(self.selection_inds, inds).astype("int") + elif self.canvas._key == "ctrl+shift": + self.selection_inds = np.setdiff1d(self.selection_inds, inds).astype("int") + else: + self.selection_inds = inds + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() self.notify() def select_one(self, ind): """Select or deselect one sensor.""" - ch_name = self.ch_names[ind] - if ch_name in self.selection: - sel_ind = self.selection.index(ch_name) - self.selection.pop(sel_ind) + if self.canvas._key == "control": + self.selection_inds = np.union1d(self.selection_inds, [ind]) + elif self.canvas._key == "ctrl+shift": + self.selection_inds = np.setdiff1d(self.selection_inds, [ind]) else: - self.selection.append(ch_name) - inds = np.isin(self.ch_names, self.selection).nonzero()[0] - self.style_sensors(inds) + return # don't notify() + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() self.notify() - def notify(self): - """Notify listeners that a selection has been made.""" - for callback in self.callbacks: - callback() - def select_many(self, inds): """Select many sensors using indices (for predefined selections).""" - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) + self.selection_inds = inds + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() - def style_sensors(self, inds): + def style_objects(self): """Style selected sensors as "active".""" # reset - self.fc[:, -1] = self.alpha_other - self.ec[:, -1] = self.alpha_other / 2 - self.lw[:] = self.linewidth_other + self.fc[:, -1] = self.alpha_nonselected + self.ec[:, -1] = self.alpha_nonselected / 2 + self.lw[:] = self.linewidth_nonselected # style sensors at `inds` - self.fc[inds, -1] = self.alpha_selected - self.ec[inds, -1] = self.alpha_selected - self.lw[inds] = self.linewidth_selected + self.fc[self.selection_inds, -1] = self.alpha_selected + self.ec[self.selection_inds, -1] = self.alpha_selected + self.lw[self.selection_inds] = self.linewidth_selected self.collection.set_facecolors(self.fc) self.collection.set_edgecolors(self.ec) self.collection.set_linewidths(self.lw) @@ -2356,7 +2401,7 @@ def _gfp(data): except KeyError: raise ValueError( f'"combine" must be None, a callable, or one of "{", ".join(valid)}"; ' - f'got {combine}' + f"got {combine}" ) return combine @@ -2390,14 +2435,16 @@ def _convert_psds( np.sqrt(psds, out=psds) psds *= scaling ylabel = rf"$\mathrm{{{unit}/\sqrt{{Hz}}}}$" + coef = 20 else: psds *= scaling * scaling if "/" in unit: unit = f"({unit})" ylabel = rf"$\mathrm{{{unit}²/Hz}}$" + coef = 10 if dB: np.log10(np.maximum(psds, np.finfo(float).tiny), out=psds) - psds *= 10 + psds *= coef ylabel = r"$\mathrm{dB}\ $" + ylabel ylabel = "Power (" + ylabel if estimate == "power" else "Amplitude (" + ylabel ylabel += ")" diff --git a/pyproject.toml b/pyproject.toml index 4199e8cd8e0..f20c495a2bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,11 +23,11 @@ dependencies = [ "decorator", "jinja2", "lazy_loader >= 0.3", - "matplotlib >= 3.6", - "numpy >= 1.23,<3", + "matplotlib >= 3.7", # 2023/02/13 + "numpy >= 1.25,<3", # 2023/06/17 "packaging", "pooch >= 1.5", - "scipy >= 1.9", + "scipy >= 1.11", # 2023/06/25 "tqdm", ] description = "MNE-Python project for MEG and EEG data analysis." @@ -63,6 +63,7 @@ doc = [ "mne-gui-addons", "neo", "numpydoc", + "openneuro-py", "psutil", "pydata_sphinx_theme >= 0.15.2", "pygments >= 2.13", @@ -88,7 +89,7 @@ full = ["mne[full-no-qt]", "PyQt6 != 6.6.0", "PyQt6-Qt6 != 6.6.0, != 6.7.0"] # We also offter two more variants: mne[full-qt6] (which is equivalent to mne[full]), # and mne[full-pyside6], which will install PySide6 instead of PyQt6. full-no-qt = [ - "antio >= 0.4.0", + "antio >= 0.5.0", "darkdetect", "defusedxml", "dipy", @@ -111,7 +112,7 @@ full-no-qt = [ "nilearn", "numba", "openmeeg >= 2.5.5", - "pandas", + "pandas >= 2.0", # 2023/04/03 "pillow", # for `Brain.save_image` and `mne.Report` "pyarrow", # only needed to avoid a deprecation warning in pandas "pybv", @@ -269,7 +270,6 @@ addopts = """--durations=20 --doctest-modules -rfEXs --cov-report= --tb=short \ --ignore=mne/gui/_*.py --ignore=mne/icons --ignore=tools \ --ignore=mne/report/js_and_css \ --color=yes --capture=sys""" -junit_family = "xunit2" [tool.rstcheck] ignore_directives = [ diff --git a/tools/azure_dependencies.sh b/tools/azure_dependencies.sh index 28d0b16f91c..8880e6478fa 100755 --- a/tools/azure_dependencies.sh +++ b/tools/azure_dependencies.sh @@ -5,10 +5,11 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) STD_ARGS="--progress-bar off --upgrade" python -m pip install $STD_ARGS pip setuptools wheel if [ "${TEST_MODE}" == "pip" ]; then - python -m pip install $STD_ARGS --only-binary="numba,llvmlite,numpy,scipy,vtk,dipy" -e .[test,full] + python -m pip install $STD_ARGS --only-binary="numba,llvmlite,numpy,scipy,vtk,dipy,openmeeg" -e .[test,full] elif [ "${TEST_MODE}" == "pip-pre" ]; then ${SCRIPT_DIR}/install_pre_requirements.sh python -m pip install $STD_ARGS --pre -e .[test_extra] + echo "##vso[task.setvariable variable=MNE_TEST_ALLOW_SKIP].*(Requires (spm|brainstorm) dataset|Requires MNE-C|CUDA not|Numba not| on Windows|MNE_FORCE_SERIAL|PySide6 causes segfaults).*" else echo "Unknown run type ${TEST_MODE}" exit 1 diff --git a/tools/circleci_dependencies.sh b/tools/circleci_dependencies.sh index bd9ec823871..b306bb528f4 100755 --- a/tools/circleci_dependencies.sh +++ b/tools/circleci_dependencies.sh @@ -1,12 +1,6 @@ #!/bin/bash -ef python -m pip install --upgrade "pip!=20.3.0" build -# This can be removed once dipy > 1.9.0 is released -python -m pip install --upgrade --progress-bar off \ - numpy scipy h5py -python -m pip install --pre --progress-bar off \ - --extra-index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" \ - "dipy>1.9" python -m pip install --upgrade --progress-bar off \ --only-binary "numpy,dipy,scipy,matplotlib,pandas,statsmodels" \ -ve .[full,test,doc] "numpy>=2" \ @@ -17,6 +11,6 @@ python -m pip install --upgrade --progress-bar off \ alphaCSC autoreject bycycle conpy emd fooof meggie \ mne-ari mne-bids-pipeline mne-faster mne-features \ mne-icalabel mne-lsl mne-microstates mne-nirs mne-rsa \ - neurodsp neurokit2 niseq nitime openneuro-py pactools \ + neurodsp neurokit2 niseq nitime pactools \ plotly pycrostates pyprep pyriemann python-picard sesameeg \ - sleepecg tensorpac yasa meegkit + sleepecg tensorpac yasa meegkit eeg_positions wfdb diff --git a/tools/dev/Makefile b/tools/dev/Makefile index e2f61735a7e..9128d758289 100644 --- a/tools/dev/Makefile +++ b/tools/dev/Makefile @@ -25,3 +25,6 @@ feature-requests-12-to-2-months-old.json: clean: @rm bug-reports-12-to-2-months-old.json || true + +dep: + cd ../../ && git grep -iI "\(deprecat\|futurewarning\)" -- ':!*.js' ':!*/conftest.py' ':!*/docs.py' ':!*/test_docs.py' ':!*/utils/__init__.pyi' mne/ diff --git a/tools/dev/ensure_headers.py b/tools/dev/ensure_headers.py index b5b425b5900..a4095d82b42 100644 --- a/tools/dev/ensure_headers.py +++ b/tools/dev/ensure_headers.py @@ -156,15 +156,15 @@ def _ensure_copyright(lines, path): lines[insert] = COPYRIGHT_LINE else: lines.insert(insert, COPYRIGHT_LINE) - assert ( - lines.count(COPYRIGHT_LINE) == 1 - ), f"{lines.count(COPYRIGHT_LINE)=} for {path=}" + assert lines.count(COPYRIGHT_LINE) == 1, ( + f"{lines.count(COPYRIGHT_LINE)=} for {path=}" + ) def _ensure_blank(lines, path): - assert ( - lines.count(COPYRIGHT_LINE) == 1 - ), f"{lines.count(COPYRIGHT_LINE)=} for {path=}" + assert lines.count(COPYRIGHT_LINE) == 1, ( + f"{lines.count(COPYRIGHT_LINE)=} for {path=}" + ) insert = lines.index(COPYRIGHT_LINE) + 1 if lines[insert].strip(): # actually has content lines.insert(insert, "") diff --git a/tools/dev/update_credit_json.py b/tools/dev/update_credit_json.py index e131fd3f33c..de96c040604 100644 --- a/tools/dev/update_credit_json.py +++ b/tools/dev/update_credit_json.py @@ -20,7 +20,9 @@ g = Github(auth=auth, per_page=100) out_path = Path(__file__).parents[2] / "doc" / "sphinxext" / "prs" out_path.mkdir(exist_ok=True) -oldest_pr = 6915 # can update this when the oldest open PR changes to speed things up +# manually update this when the oldest open PR changes to speed things up +# (don't need to look any farther back than this) +oldest_pr = 9176 # JSON formatting json_kwargs = dict(indent=2, ensure_ascii=False, sort_keys=False) diff --git a/tools/environment_old.yml b/tools/environment_old.yml index 4515f9cd611..46826bdbcf4 100644 --- a/tools/environment_old.yml +++ b/tools/environment_old.yml @@ -1,17 +1,18 @@ +# THIS FILE IS AUTO-GENERATED BY tools/hooks/update_environment_file.py AND WILL BE OVERWRITTEN name: mne channels: - conda-forge dependencies: - - python=3.10 - - numpy=1.24 - - scipy=1.10 - - matplotlib=3.6 - - pandas=1.5.2 - - scikit-learn=1.2 - - nibabel # whichever one works + - python =3.10 + - numpy =1.25 + - scipy =1.11 + - matplotlib =3.7 + - pandas =2.0 + - scikit-learn + - nibabel - tqdm - - pooch + - pooch =1.5 - decorator - packaging - jinja2 - - lazy_loader + - lazy_loader =0.3 diff --git a/tools/get_minimal_commands.sh b/tools/get_minimal_commands.sh index 4e28fdf9e7b..8190f331075 100755 --- a/tools/get_minimal_commands.sh +++ b/tools/get_minimal_commands.sh @@ -11,7 +11,7 @@ export MNE_ROOT="${PWD}/minimal_cmds" export PATH=${MNE_ROOT}/bin:$PATH if [ "${GITHUB_ACTIONS}" == "true" ]; then echo "Setting MNE_ROOT for GHA" - echo "MNE_ROOT=${MNE_ROOT}" >> $GITHUB_ENV; + echo "MNE_ROOT=${MNE_ROOT}" | tee -a $GITHUB_ENV; echo "${MNE_ROOT}/bin" >> $GITHUB_PATH; elif [ "${AZURE_CI}" == "true" ]; then echo "Setting MNE_ROOT for Azure" @@ -33,9 +33,9 @@ if [[ "${CI_OS_NAME}" != "macos"* ]]; then export NEUROMAG2FT_ROOT="${PWD}/minimal_cmds/bin" export FREESURFER_HOME="${MNE_ROOT}" if [ "${GITHUB_ACTIONS}" == "true" ]; then - echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" >> "$GITHUB_ENV"; - echo "NEUROMAG2FT_ROOT=${NEUROMAG2FT_ROOT}" >> "$GITHUB_ENV"; - echo "FREESURFER_HOME=${FREESURFER_HOME}" >> "$GITHUB_ENV"; + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" | tee -a "$GITHUB_ENV"; + echo "NEUROMAG2FT_ROOT=${NEUROMAG2FT_ROOT}" | tee -a "$GITHUB_ENV"; + echo "FREESURFER_HOME=${FREESURFER_HOME}" | tee -a "$GITHUB_ENV"; fi; if [ "${AZURE_CI}" == "true" ]; then echo "##vso[task.setvariable variable=LD_LIBRARY_PATH]${LD_LIBRARY_PATH}" @@ -57,7 +57,7 @@ else export DYLD_LIBRARY_PATH=${MNE_ROOT}/lib:$DYLD_LIBRARY_PATH if [ "${GITHUB_ACTIONS}" == "true" ]; then echo "Setting variables for GHA" - echo "DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH}" >> "$GITHUB_ENV"; + echo "DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH}" | tee -a "$GITHUB_ENV"; set -x wget https://github.com/XQuartz/XQuartz/releases/download/XQuartz-2.7.11/XQuartz-2.7.11.dmg sudo hdiutil attach XQuartz-2.7.11.dmg diff --git a/tools/get_testing_version.sh b/tools/get_testing_version.sh index 44ff28addb4..aaf703dbddd 100755 --- a/tools/get_testing_version.sh +++ b/tools/get_testing_version.sh @@ -6,7 +6,7 @@ TESTING_VERSION=`grep -o "testing=\"[0-9.]\+\"" mne/datasets/config.py | cut -d # This can be incremented to start fresh when the cache misbehaves, e.g.: # TESTING_VERSION=${TESTING_VERSION}-1 if [ ! -z $GITHUB_ENV ]; then - echo "TESTING_VERSION="$TESTING_VERSION >> $GITHUB_ENV + echo "TESTING_VERSION="$TESTING_VERSION | tee -a $GITHUB_ENV elif [ ! -z $AZURE_CI ]; then echo "##vso[task.setvariable variable=testing_version]$TESTING_VERSION" elif [ ! -z $CIRCLECI ]; then diff --git a/tools/github_actions_dependencies.sh b/tools/github_actions_dependencies.sh index 1b74ce14e99..d47d9070f8b 100755 --- a/tools/github_actions_dependencies.sh +++ b/tools/github_actions_dependencies.sh @@ -20,6 +20,11 @@ if [ ! -z "$CONDA_ENV" ]; then INSTALL_KIND="test" STD_ARGS="--progress-bar off" fi +elif [[ "${MNE_CI_KIND}" == "pip" ]]; then + # Only used for 3.13 at the moment, just get test deps plus a few extras + # that we know are available + INSTALL_ARGS="nibabel scikit-learn numpydoc PySide6 mne-qt-browser pandas h5io mffpy defusedxml numba" + INSTALL_KIND="test" else test "${MNE_CI_KIND}" == "pip-pre" STD_ARGS="$STD_ARGS --pre" diff --git a/tools/github_actions_env_vars.sh b/tools/github_actions_env_vars.sh index 8b9fc8560e6..9f424ae5f48 100755 --- a/tools/github_actions_env_vars.sh +++ b/tools/github_actions_env_vars.sh @@ -4,23 +4,31 @@ set -eo pipefail -x # old and minimal use conda if [[ "$MNE_CI_KIND" == "pip"* ]]; then echo "Setting pip env vars for $MNE_CI_KIND" - echo "MNE_QT_BACKEND=PyQt6" >> $GITHUB_ENV - # We should test an eager import somewhere, might as well be here - echo "EAGER_IMPORT=true" >> $GITHUB_ENV + if [[ "$MNE_CI_KIND" == "pip-pre" ]]; then + echo "MNE_QT_BACKEND=PyQt6" | tee -a $GITHUB_ENV + # We should test an eager import somewhere, might as well be here + echo "EAGER_IMPORT=true" | tee -a $GITHUB_ENV + # Make sure nothing unexpected is skipped + echo "MNE_TEST_ALLOW_SKIP=.*(Requires (spm|brainstorm) dataset|CUDA not|Numba not|PySide6 causes segfaults).*" | tee -a $GITHUB_ENV + else + echo "MNE_QT_BACKEND=PySide6" | tee -a $GITHUB_ENV + fi else # conda-like echo "Setting conda env vars for $MNE_CI_KIND" if [[ "$MNE_CI_KIND" == "old" ]]; then - echo "CONDA_ENV=tools/environment_old.yml" >> $GITHUB_ENV - echo "MNE_IGNORE_WARNINGS_IN_TESTS=true" >> $GITHUB_ENV - echo "MNE_SKIP_NETWORK_TESTS=1" >> $GITHUB_ENV - echo "MNE_QT_BACKEND=PyQt5" >> $GITHUB_ENV + echo "CONDA_ENV=tools/environment_old.yml" | tee -a $GITHUB_ENV + echo "MNE_IGNORE_WARNINGS_IN_TESTS=true" | tee -a $GITHUB_ENV + echo "MNE_SKIP_NETWORK_TESTS=1" | tee -a $GITHUB_ENV + echo "MNE_QT_BACKEND=PyQt5" | tee -a $GITHUB_ENV elif [[ "$MNE_CI_KIND" == "minimal" ]]; then - echo "CONDA_ENV=tools/environment_minimal.yml" >> $GITHUB_ENV - echo "MNE_QT_BACKEND=PySide6" >> $GITHUB_ENV + echo "CONDA_ENV=tools/environment_minimal.yml" | tee -a $GITHUB_ENV + echo "MNE_QT_BACKEND=PySide6" | tee -a $GITHUB_ENV else # conda, mamba (use warning level for completeness) - echo "CONDA_ENV=environment.yml" >> $GITHUB_ENV - echo "MNE_LOGGING_LEVEL=warning" >> $GITHUB_ENV - echo "MNE_QT_BACKEND=PySide6" >> $GITHUB_ENV + echo "CONDA_ENV=environment.yml" | tee -a $GITHUB_ENV + echo "MNE_LOGGING_LEVEL=warning" | tee -a $GITHUB_ENV + echo "MNE_QT_BACKEND=PySide6" | tee -a $GITHUB_ENV + # TODO: Also need "|unreliable on GitHub Actions conda" on macOS, but omit for now to make sure the failure actually shows up + echo "MNE_TEST_ALLOW_SKIP=.*(Requires (spm|brainstorm) dataset|CUDA not|PySide6 causes segfaults|Accelerate|Flakey verbose behavior).*" | tee -a $GITHUB_ENV fi fi set +x diff --git a/tools/github_actions_test.sh b/tools/github_actions_test.sh index 4cdd202223f..4fe8756bd50 100755 --- a/tools/github_actions_test.sh +++ b/tools/github_actions_test.sh @@ -13,19 +13,25 @@ else USE_DIRS="mne/" fi JUNIT_PATH="junit-results.xml" -if [[ ! -z "$CONDA_ENV" ]] && [[ "${RUNNER_OS}" != "Windows" ]]; then - JUNIT_PATH="$(pwd)/${JUNIT_PATH}" +if [[ ! -z "$CONDA_ENV" ]] && [[ "${RUNNER_OS}" != "Windows" ]] && [[ "${MNE_CI_KIND}" != "minimal" ]] && [[ "${MNE_CI_KIND}" != "old" ]]; then + PROJ_PATH="$(pwd)" + JUNIT_PATH="$PROJ_PATH/${JUNIT_PATH}" # Use the installed version after adding all (excluded) test files - cd .. + cd ~ # so that "import mne" doesn't just import the checked-out data INSTALL_PATH=$(python -c "import mne, pathlib; print(str(pathlib.Path(mne.__file__).parents[1]))") - echo "Copying tests from $(pwd)/mne-python/mne/ to ${INSTALL_PATH}/mne/" - echo "::group::rsync" - rsync -a --partial --progress --prune-empty-dirs --exclude="*.pyc" --include="**/" --include="**/tests/*" --include="**/tests/data/**" --exclude="**" ./mne-python/mne/ ${INSTALL_PATH}/mne/ + echo "Copying tests from ${PROJ_PATH}/mne-python/mne/ to ${INSTALL_PATH}/mne/" + echo "::group::rsync mne" + rsync -a --partial --progress --prune-empty-dirs --exclude="*.pyc" --include="**/" --include="**/tests/*" --include="**/tests/data/**" --exclude="**" ${PROJ_PATH}/mne/ ${INSTALL_PATH}/mne/ echo "::endgroup::" + echo "::group::rsync doc" + mkdir -p ${INSTALL_PATH}/doc/ + rsync -a --partial --progress --prune-empty-dirs --include="**/" --include="**/api/*" --exclude="**" ${PROJ_PATH}/doc/ ${INSTALL_PATH}/doc/ + test -f ${INSTALL_PATH}/doc/api/reading_raw_data.rst cd $INSTALL_PATH - echo "Executing from $(pwd)" + cp -av $PROJ_PATH/pyproject.toml . + echo "::endgroup::" fi set -x -pytest -m "${CONDITION}" --tb=short --cov=mne --cov-report xml --color=yes --junit-xml=$JUNIT_PATH -vv ${USE_DIRS} -set +x +pytest -m "${CONDITION}" --cov=mne --cov-report xml --color=yes --continue-on-collection-errors --junit-xml=$JUNIT_PATH -vv ${USE_DIRS} +echo "Exited with code $?" diff --git a/tools/hooks/sync_dependencies.py b/tools/hooks/sync_dependencies.py index 0878a5f56eb..1ff6d7f8712 100755 --- a/tools/hooks/sync_dependencies.py +++ b/tools/hooks/sync_dependencies.py @@ -4,7 +4,12 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import difflib import re + +# NB here we use metadata from the latest stable release because this goes in our +# README, which should apply to the latest release (rather than dev). +# For oldest supported dev dependencies, see update_environment_file.py. from importlib.metadata import metadata from pathlib import Path @@ -77,7 +82,6 @@ def _prettify_pin(pin): f"- `{key} <{url}>`__{core_deps_pins[key.lower()]}" for key, url in CORE_DEPS_URLS.items() ] -core_deps_rst = "\n" + "\n".join(core_deps_bullets) + "\n" # rewrite the README file lines = README_PATH.read_text("utf-8").splitlines() @@ -87,9 +91,14 @@ def _prettify_pin(pin): if line.strip() == BEGIN: skip = True out_lines.append(line) - out_lines.append(core_deps_rst) + out_lines.extend(["", *core_deps_bullets, ""]) if line.strip() == END: skip = False if not skip: out_lines.append(line) -README_PATH.write_text("\n".join(out_lines) + "\n", encoding="utf-8") +new = "\n".join(out_lines) + "\n" +old = README_PATH.read_text("utf-8") +if new != old: + diff = "\n".join(difflib.unified_diff(old.splitlines(), new.splitlines())) + print(f"Updating {README_PATH} with diff:\n{diff}") + README_PATH.write_text(new, encoding="utf-8") diff --git a/tools/hooks/update_environment_file.py b/tools/hooks/update_environment_file.py index bfedca39d1a..0b5380a16b5 100755 --- a/tools/hooks/update_environment_file.py +++ b/tools/hooks/update_environment_file.py @@ -22,7 +22,7 @@ deps |= set(section_deps) recursive_deps = set(d for d in deps if d.startswith("mne[")) deps -= recursive_deps -deps |= {"pip"} +deps |= {"pip", "mamba", "nomkl"} def remove_spaces(version_spec): @@ -56,9 +56,9 @@ def split_dep(dep): # `environment.yaml` breaks the solver if package_name == "PySide6": version_spec = version_spec.replace("!=6.7.0,", "") - # openmeeg 2.5.12=*_2 is broken, so pin to 2.5.12=*_1 - if package_name == "openmeeg": - version_spec = "=2.5.12=*_1" + elif package_name == "vtk": + # TODO VERSION remove once we support VTK 9.4 + version_spec = "=9.3.1=qt_*" # rstrip output line in case `version_spec` == "" line = f" - {package_name} {version_spec}".rstrip() # use pip for packages needing e.g. `platform_system` or `python_version` triaging @@ -80,7 +80,7 @@ def split_dep(dep): pip_section = pip_section if len(pip_deps) else "" # prepare the env file env = f"""\ -# THIS FILE IS AUTO-GENERATED BY {'/'.join(Path(__file__).parts[-3:])} AND WILL BE OVERWRITTEN +# THIS FILE IS AUTO-GENERATED BY {"/".join(Path(__file__).parts[-3:])} AND WILL BE OVERWRITTEN name: mne channels: - conda-forge diff --git a/tools/install_pre_requirements.sh b/tools/install_pre_requirements.sh index 4f9f0165fe2..c717b1b477b 100755 --- a/tools/install_pre_requirements.sh +++ b/tools/install_pre_requirements.sh @@ -15,15 +15,22 @@ echo "PyQt6 and scientific-python-nightly-wheels dependencies" python -m pip install $STD_ARGS pip setuptools packaging \ threadpoolctl cycler fonttools kiwisolver pyparsing pillow python-dateutil \ patsy pytz tzdata nibabel tqdm trx-python joblib numexpr "$QT_BINDING" \ - py-cpuinfo blosc2 + py-cpuinfo blosc2 hatchling echo "NumPy/SciPy/pandas etc." python -m pip uninstall -yq numpy +python -m pip install --upgrade matplotlib # TODO: Until https://github.com/matplotlib/matplotlib/pull/29427 lands python -m pip install $STD_ARGS --only-binary ":all:" --default-timeout=60 \ --index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" \ "numpy>=2.1.0.dev0" "scikit-learn>=1.6.dev0" "scipy>=1.15.0.dev0" \ - "statsmodels>=0.15.0.dev0" "pandas>=3.0.0.dev0" "matplotlib>=3.10.0.dev0" \ + "pandas>=3.0.0.dev0" \ "h5py>=3.12.1" "dipy>=1.10.0.dev0" "pyarrow>=19.0.0.dev0" "tables>=3.10.2.dev0" +# statsmodels requires formulaic@main so we need to use --extra-index-url +echo "statsmodels" +python -m pip install $STD_ARGS --only-binary ":all:" \ + --extra-index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" \ + "statsmodels>=0.15.0.dev0" + # No Numba because it forces an old NumPy version echo "pymatreader" @@ -42,7 +49,7 @@ python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https:/ python -c "import vtk" echo "PyVista" -python -m pip install $STD_ARGS "git+https://github.com/pyvista/pyvista" +python -m pip install $STD_ARGS "git+https://github.com/pyvista/pyvista" trame trame-vtk trame-vuetify jupyter ipyevents ipympl echo "picard" python -m pip install $STD_ARGS git+https://github.com/pierreablin/picard @@ -51,7 +58,7 @@ echo "pyvistaqt" pip install $STD_ARGS git+https://github.com/pyvista/pyvistaqt echo "imageio-ffmpeg, xlrd, mffpy" -pip install $STD_ARGS imageio-ffmpeg xlrd mffpy traitlets pybv eeglabio +pip install $STD_ARGS imageio-ffmpeg xlrd mffpy traitlets pybv eeglabio defusedxml antio echo "mne-qt-browser" pip install $STD_ARGS git+https://github.com/mne-tools/mne-qt-browser @@ -70,13 +77,11 @@ echo "edfio" # https://github.com/mne-tools/mne-python/pull/12609#issuecomment-2115639369 GIT_CLONE_PROTECTION_ACTIVE=false pip install $STD_ARGS git+https://github.com/the-siesta-group/edfio -if [[ "${PLATFORM}" == "Linux" ]]; then - echo "h5io" - pip install $STD_ARGS git+https://github.com/h5io/h5io +echo "h5io" +pip install $STD_ARGS git+https://github.com/h5io/h5io - echo "pysnirf2" - pip install $STD_ARGS git+https://github.com/BUNPC/pysnirf2 -fi +echo "pysnirf2" +pip install $STD_ARGS git+https://github.com/BUNPC/pysnirf2 # Make sure we're on a NumPy 2.0 variant echo "Checking NumPy version" diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index ce5e95124d4..9d0e215ee80 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -33,8 +33,6 @@ captions_new comments_new items_new -has_imageio_ffmpeg -has_pyvista f4 set_channel_types_eyetrack _use_test_3d_backend @@ -43,6 +41,8 @@ # Decoding _._more_tags +_.multi_class +_.preserves_dtype deep # Backward compat or rarely used diff --git a/tutorials/forward/20_source_alignment.py b/tutorials/forward/20_source_alignment.py index dd26f610907..c8cf981dce9 100644 --- a/tutorials/forward/20_source_alignment.py +++ b/tutorials/forward/20_source_alignment.py @@ -115,11 +115,11 @@ mne.viz.set_3d_view(fig, 45, 90, distance=0.6, focalpoint=(0.0, 0.0, 0.0)) print( "Distance from head origin to MEG origin: " - f"{1000 * np.linalg.norm(raw.info["dev_head_t"]["trans"][:3, 3]):.1f} mm" + f"{1000 * np.linalg.norm(raw.info['dev_head_t']['trans'][:3, 3]):.1f} mm" ) print( "Distance from head origin to MRI origin: " - f"{1000 * np.linalg.norm(trans["trans"][:3, 3]):.1f} mm" + f"{1000 * np.linalg.norm(trans['trans'][:3, 3]):.1f} mm" ) dists = mne.dig_mri_distances(raw.info, trans, "sample", subjects_dir=subjects_dir) print( diff --git a/tutorials/forward/30_forward.py b/tutorials/forward/30_forward.py index 6c55d0bfe3c..72731982962 100644 --- a/tutorials/forward/30_forward.py +++ b/tutorials/forward/30_forward.py @@ -255,7 +255,7 @@ # or ``inv['src']`` so that this removal is adequately accounted for. print(f"Before: {src}") -print(f'After: {fwd["src"]}') +print(f"After: {fwd['src']}") # %% # We can explore the content of ``fwd`` to access the numpy array that contains diff --git a/tutorials/intro/15_inplace.py b/tutorials/intro/15_inplace.py index 0c68843d4c8..01e8c1f7eb0 100644 --- a/tutorials/intro/15_inplace.py +++ b/tutorials/intro/15_inplace.py @@ -60,9 +60,9 @@ # Another group of methods where data is modified in-place are the # channel-picking methods. For example: -print(f'original data had {original_raw.info["nchan"]} channels.') +print(f"original data had {original_raw.info['nchan']} channels.") original_raw.pick("eeg") # selects only the EEG channels -print(f'after picking, it has {original_raw.info["nchan"]} channels.') +print(f"after picking, it has {original_raw.info['nchan']} channels.") # %% diff --git a/tutorials/intro/70_report.py b/tutorials/intro/70_report.py index d23bd54aebf..fe87c0f3a44 100644 --- a/tutorials/intro/70_report.py +++ b/tutorials/intro/70_report.py @@ -7,17 +7,16 @@ :class:`mne.Report` is a way to create interactive HTML summaries of your data. These reports can show many different visualizations for one or multiple participants. -A common use case is creating diagnostic summaries to check data -quality at different stages in the processing pipeline. The report can show -things like plots of data before and after each preprocessing step, epoch -rejection statistics, MRI slices with overlaid BEM shells, all the way up to -plots of estimated cortical activity. - -Compared to a Jupyter notebook, :class:`mne.Report` is easier to deploy, as the -HTML pages it generates are self-contained and do not require a running Python -environment. However, it is less flexible as you can't change code and re-run -something directly within the browser. This tutorial covers the basics of -building a report. As usual, we will start by importing the modules and data we need: +A common use case is creating diagnostic summaries to check data quality at different +stages in the processing pipeline. The report can show things like plots of data before +and after each preprocessing step, epoch rejection statistics, MRI slices with overlaid +BEM shells, all the way up to plots of estimated cortical activity. + +Compared to a Jupyter notebook, :class:`mne.Report` is easier to deploy, as the HTML +pages it generates are self-contained and do not require a running Python environment. +However, it is less flexible as you can't change code and re-run something directly +within the browser. This tutorial covers the basics of building a report. As usual, +we will start by importing the modules and data we need: """ # Authors: The MNE-Python contributors. diff --git a/tutorials/inverse/20_dipole_fit.py b/tutorials/inverse/20_dipole_fit.py index 2b640aa8fc2..e72e76dd0fd 100644 --- a/tutorials/inverse/20_dipole_fit.py +++ b/tutorials/inverse/20_dipole_fit.py @@ -87,6 +87,7 @@ # %% # Calculate and visualise magnetic field predicted by dipole with maximum GOF # and compare to the measured data, highlighting the ipsilateral (right) source + fwd, stc = make_forward_dipole(dip, fname_bem, evoked.info, fname_trans) pred_evoked = simulate_evoked(fwd, stc, evoked.info, cov=None, nave=np.inf) diff --git a/tutorials/preprocessing/40_artifact_correction_ica.py b/tutorials/preprocessing/40_artifact_correction_ica.py index 5eeb7b79d64..257b1f85051 100644 --- a/tutorials/preprocessing/40_artifact_correction_ica.py +++ b/tutorials/preprocessing/40_artifact_correction_ica.py @@ -291,8 +291,7 @@ # This time, print as percentage. ratio_percent = round(100 * explained_var_ratio["eeg"]) print( - f"Fraction of variance in EEG signal explained by first component: " - f"{ratio_percent}%" + f"Fraction of variance in EEG signal explained by first component: {ratio_percent}%" ) # %% diff --git a/tutorials/preprocessing/45_projectors_background.py b/tutorials/preprocessing/45_projectors_background.py index 128229e516a..3c83d49d8c3 100644 --- a/tutorials/preprocessing/45_projectors_background.py +++ b/tutorials/preprocessing/45_projectors_background.py @@ -488,7 +488,7 @@ def setup_3d_axes(): # for this recommendation: # # 1. It is computationally cheaper to apply projectors to data *after* the -# data have been reducted to just the segments of interest (the epochs) +# data have been reduced to just the segments of interest (the epochs) # # 2. If you are applying amplitude-based rejection criteria to epochs, it is # preferable to reject based on the signal *after* projectors have been diff --git a/tutorials/preprocessing/50_artifact_correction_ssp.py b/tutorials/preprocessing/50_artifact_correction_ssp.py index 57be25803d5..28dee357f9a 100644 --- a/tutorials/preprocessing/50_artifact_correction_ssp.py +++ b/tutorials/preprocessing/50_artifact_correction_ssp.py @@ -390,6 +390,13 @@ # # See the documentation of each function for further details. # +# .. note:: +# In situations only limited electrodes are available for analysis, removing the +# cardiac artefact using techniques which rely on the availability of spatial +# information (such as SSP) may not be possible. In these instances, it may be of +# use to consider algorithms which require information only regarding heartbeat +# instances in the time domain, such as :func:`mne.preprocessing.apply_pca_obs`. +# # # Repairing EOG artifacts with SSP # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -520,7 +527,7 @@ evoked_eeg.plot(proj=proj, axes=ax, spatial_colors=True) parts = ax.get_title().split("(") ylabel = ( - f'{parts[0]} ({ax.get_ylabel()})\n{parts[1].replace(")", "")}' + f"{parts[0]} ({ax.get_ylabel()})\n{parts[1].replace(')', '')}" if pi == 0 else "" ) @@ -535,6 +542,7 @@ # reduced the amplitude of our signals in sensor space, but that it should not # bias the amplitudes in source space. # +# # References # ^^^^^^^^^^ # diff --git a/tutorials/time-freq/10_spectrum_class.py b/tutorials/time-freq/10_spectrum_class.py index 62f6103bfdb..9ba463e18b5 100644 --- a/tutorials/time-freq/10_spectrum_class.py +++ b/tutorials/time-freq/10_spectrum_class.py @@ -35,11 +35,10 @@ raw.compute_psd() # %% -# By default, the spectral estimation method will be the -# :footcite:t:`Welch1967` method for continuous data, and the multitaper -# method :footcite:`Slepian1978` for epoched or averaged data. This default can -# be overridden by passing ``method='welch'`` or ``method='multitaper'`` to the -# :meth:`~mne.io.Raw.compute_psd` method. +# By default, the spectral estimation method will be the :footcite:t:`Welch1967` method +# for continuous data, and the multitaper method :footcite:`Slepian1978` for epoched or +# averaged data. This default can be overridden by passing ``method='welch'`` or +# ``method='multitaper'`` to the :meth:`~mne.io.Raw.compute_psd` method. # # There are many other options available as well; for example we can compute a # spectrum from a given span of times, for a chosen frequency range, and for a diff --git a/tutorials/time-freq/50_ssvep.py b/tutorials/time-freq/50_ssvep.py index a0d130f1d35..a625a001d9e 100644 --- a/tutorials/time-freq/50_ssvep.py +++ b/tutorials/time-freq/50_ssvep.py @@ -641,7 +641,7 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1): ].mean(axis=1) fig, ax = plt.subplots(1) -ax.boxplot(window_snrs, tick_labels=window_lengths, vert=True) +ax.boxplot(window_snrs, tick_labels=window_lengths, orientation="vertical") ax.set( title="Effect of trial duration on 12 Hz SNR", ylabel="Average SNR",