From 6b67aba7548989635a8c6ba96196237593aa076d Mon Sep 17 00:00:00 2001 From: Danny Farrell <16297104+danpf@users.noreply.github.com> Date: Fri, 15 Dec 2023 13:24:26 -0800 Subject: [PATCH] Add Python pybind11 bindings + Upgrade ci/cd This PR adds python bindings as well as updates the backend build process of the mmtf-cpp library significantly. These improvements are mainly from a convenience perspective and include: - Removing all build-based submodules - Moving to cmake fetchcontent build - Simplify CMakeLists with better linking procedures - Upgrade msgpack-c - Upgrade catch2 - Move to github actions for ci/cd - Use cibuildwheel for wheel cd Pybind11 library: The pybind11 library utilizes the c++ code of mmtf-cpp in order to build an extremely fast cpp layer underneath the python interface. You have to keep in mind that moving between c++ and python is slow, but this is still much faster than the previously existing python library. see this example: time to load a single mmtf file 1000x cpp bare 0.29s this library 0.44s python og 4.34s --- .github/workflows/cpp.yml | 62 +++ .github/workflows/emscripten.yml | 48 +++ .github/workflows/pip.yml | 39 ++ .github/workflows/wheels.yml | 74 ++++ .gitignore | 169 ++++++++ .gitmodules | 10 +- .travis.yml | 68 ---- CHANGELOG.md | 30 ++ CMakeLists.txt | 45 ++- README.md | 49 ++- appveyor.yml | 40 -- ci/build_and_run_tests.sh | 15 +- ci/setup-appveyor.ps1 | 18 - ci/setup-travis.sh | 44 --- ci/travis-test-example.sh | 22 -- examples/CMakeLists.txt | 8 +- include/mmtf/structure_data.hpp | 2 +- pyproject.toml | 91 +++++ src/python/CMakeLists.txt | 21 + src/python/bindings.cpp | 515 +++++++++++++++++++++++++ src/python/mmtf_cppy/__init__.py | 38 ++ src/python/mmtf_cppy/structure_data.py | 495 ++++++++++++++++++++++++ src/python/tests/tests.py | 196 ++++++++++ submodules/Catch2 | 1 - submodules/mmtf_spec | 2 +- submodules/msgpack-c | 1 - tests/CMakeLists.txt | 18 +- tests/mmtf_tests.cpp | 8 +- 28 files changed, 1882 insertions(+), 247 deletions(-) create mode 100644 .github/workflows/cpp.yml create mode 100644 .github/workflows/emscripten.yml create mode 100644 .github/workflows/pip.yml create mode 100644 .github/workflows/wheels.yml delete mode 100644 .travis.yml delete mode 100644 appveyor.yml delete mode 100644 ci/setup-appveyor.ps1 delete mode 100644 ci/setup-travis.sh delete mode 100644 ci/travis-test-example.sh create mode 100644 pyproject.toml create mode 100644 src/python/CMakeLists.txt create mode 100644 src/python/bindings.cpp create mode 100644 src/python/mmtf_cppy/__init__.py create mode 100644 src/python/mmtf_cppy/structure_data.py create mode 100644 src/python/tests/tests.py delete mode 160000 submodules/Catch2 delete mode 160000 submodules/msgpack-c diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml new file mode 100644 index 0000000..fa5a3eb --- /dev/null +++ b/.github/workflows/cpp.yml @@ -0,0 +1,62 @@ +--- +name: cpp +'on': + workflow_dispatch: null + pull_request: null + push: + branches: + - master +concurrency: + group: '${{ github.workflow }}-${{ github.ref }}' + cancel-in-progress: true +jobs: + build: + name: Build and test cpp + strategy: + fail-fast: false + matrix: + include: + - os: ubuntu-22.04 + cc: gcc + cxx: g++ + - os: ubuntu-22.04 + cc: gcc + cxx: g++ + env_list: EMSCRIPTEN=ON + - os: ubuntu-22.04 + cc: clang + cxx: clang++ + - os: ubuntu-22.04 + cc: gcc + cxx: g++ + cmake_args: "-DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS='-march=native'" + - os: macos-latest + cc: clang + cxx: clang++ + - os: windows-latest + cc: '' + cxx: '' + runs-on: '${{ matrix.os }}' + env: + CC: '${{ matrix.cc }}' + CXX: '${{ matrix.cxx }}' + CMAKE_ARGS: '${{ matrix.cmake_args }}' + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + - name: Set environment list variables + run: | + env_vars="${{ matrix.env_list }}" + for var in $env_vars; do + echo "$var" >> $GITHUB_ENV + done + if: matrix.os != 'windows-latest' + - name: Setup cmake + uses: jwlawson/actions-setup-cmake@v1.13 + with: + cmake-version: 3.16.x + - uses: seanmiddleditch/gha-setup-ninja@master + - name: build and test + run: ./ci/build_and_run_tests.sh diff --git a/.github/workflows/emscripten.yml b/.github/workflows/emscripten.yml new file mode 100644 index 0000000..fe5f679 --- /dev/null +++ b/.github/workflows/emscripten.yml @@ -0,0 +1,48 @@ +--- +name: WASM +'on': + workflow_dispatch: null + push: + branches: + - master +concurrency: + group: '${{ github.workflow }}-${{ github.ref }}' + cancel-in-progress: true +jobs: + build-wasm-emscripten: + name: Pyodide + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: '3.11' + - name: Install pyodide-build + run: pip install pyodide-build==0.23.4 + - name: Compute emsdk version + id: compute-emsdk-version + run: | + pyodide xbuildenv install --download + EMSCRIPTEN_VERSION=$(pyodide config get emscripten_version) + echo "emsdk-version=$EMSCRIPTEN_VERSION" >> $GITHUB_OUTPUT + - uses: mymindstorm/setup-emsdk@v12 + with: + version: '${{ steps.compute-emsdk-version.outputs.emsdk-version }}' + actions-cache-folder: emsdk-cache + - name: Build + run: CFLAGS=-fexceptions LDFLAGS=-fexceptions pyodide build + - uses: actions/upload-artifact@v3 + with: + path: dist/*.whl + - uses: actions/setup-node@v4 + with: + node-version: 18 + - name: Set up Pyodide virtual environment + run: | + pyodide venv .venv-pyodide + .venv-pyodide/bin/pip install $(echo -n dist/*.whl) + - name: Test + run: .venv-pyodide/bin/python -m unittest src/python/tests/tests.py diff --git a/.github/workflows/pip.yml b/.github/workflows/pip.yml new file mode 100644 index 0000000..936e306 --- /dev/null +++ b/.github/workflows/pip.yml @@ -0,0 +1,39 @@ +--- +name: Pip +'on': + workflow_dispatch: null + pull_request: null + push: + branches: + - master +concurrency: + group: '${{ github.workflow }}-${{ github.ref }}' + cancel-in-progress: true +jobs: + build: + name: Build with Pip + runs-on: '${{ matrix.platform }}' + strategy: + fail-fast: false + matrix: + platform: + - windows-latest + - macos-latest + - ubuntu-latest + python-version: + - '3.8' + - '3.11' + - '3.12' + - pypy-3.8 + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: '${{ matrix.python-version }}' + - name: Build and install + run: pip install --verbose . + - name: Test + run: python src/python/tests/tests.py diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml new file mode 100644 index 0000000..b2b7b23 --- /dev/null +++ b/.github/workflows/wheels.yml @@ -0,0 +1,74 @@ +--- +name: Wheels +'on': + workflow_dispatch: null + pull_request: null + push: + branches: + - master + release: + types: + - published +env: + FORCE_COLOR: 3 +concurrency: + group: '${{ github.workflow }}-${{ github.ref }}' + cancel-in-progress: true +jobs: + build_sdist: + name: Build SDist + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: true + - name: Build SDist + run: pipx run build --sdist + - name: Check metadata + run: pipx run twine check dist/* + - uses: actions/upload-artifact@v3 + with: + path: dist/*.tar.gz + build_wheels: + name: 'Wheels on ${{ matrix.os }}' + runs-on: '${{ matrix.os }}' + strategy: + fail-fast: false + matrix: + os: + - ubuntu-latest + - macos-latest + - windows-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: true + - uses: pypa/cibuildwheel@v2.16.2 + env: + CIBW_ARCHS_MACOS: universal2 + CIBW_ARCHS_WINDOWS: auto ARM64 + CMAKE_GENERATOR: '${{ env.CMAKE_GENERATOR }}' + - name: Verify clean directory + run: git diff --exit-code + shell: bash + - uses: actions/upload-artifact@v3 + with: + path: wheelhouse/*.whl + upload_all: + name: Upload if release + needs: + - build_wheels + - build_sdist + runs-on: ubuntu-latest + if: github.event_name == 'release' && github.event.action == 'published' + steps: + - uses: actions/setup-python@v4 + with: + python-version: 3.x + - uses: actions/download-artifact@v3 + with: + name: artifact + path: dist + - uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: '${{ secrets.pypi_password }}' diff --git a/.gitignore b/.gitignore index a8de19a..38cf470 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,172 @@ build/* docs/html/* examples/out/* examples/out_json_ref/* + +# python eggs +src/python/*.egg-info +**/__pycache__ +**/*.pyc + + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ +**/_version.py + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + diff --git a/.gitmodules b/.gitmodules index 0bc74f1..12a258f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,3 @@ -[submodule "Catch2"] - path = submodules/Catch2 - url = https://github.com/catchorg/Catch2 -[submodule "msgpack-c"] - path = submodules/msgpack-c - url = https://github.com/msgpack/msgpack-c -[submodule "mmtf_spec"] +[submodule "submodules/mmtf_spec"] path = submodules/mmtf_spec - url = https://github.com/rcsb/mmtf + url = https://github.com/rcsb/mmtf.git diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 144daea..0000000 --- a/.travis.yml +++ /dev/null @@ -1,68 +0,0 @@ -language: cpp -sudo: false -dist: trusty - -linux64_addons: - addons: &linux64 - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - g++-4.8 - -linux32_addons: - addons: &linux32 - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - g++-4.8 - - g++-4.8-multilib - - linux-libc-dev:i386 - - libc6-dev-i386 - -linux64_cpp17addons: - addons: &linux64cpp17 - apt: - sources: - - ubuntu-toolchain-r-test - -# Set empty values for allow_failures to work -env: TEST_COMMAND=$TRAVIS_BUILD_DIR/ci/build_and_run_tests.sh - -matrix: - fast_finish: true - include: - - os: linux - env: EMSCRIPTEN=ON TEST_COMMAND=$TRAVIS_BUILD_DIR/ci/build_and_run_tests.sh - addons: *linux64 - - os: linux - compiler: clang - addons: *linux64 - - os: linux - compiler: gcc - env: ARCH=x86 CMAKE_EXTRA=-DHAVE_LIBM=/lib32/libm.so.6 TEST_COMMAND=$TRAVIS_BUILD_DIR/ci/build_and_run_tests.sh - addons: *linux32 - - os: osx - compiler: clang - - os: linux - compiler: gcc - env: CMAKE_EXTRA="-DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS='-march=native'" TEST_COMMAND=$TRAVIS_BUILD_DIR/ci/build_and_run_tests.sh - addons: *linux64cpp17 - dist: bionic - - os: linux - compiler: gcc - addons: *linux64cpp17 - dist: bionic - - -before_install: - # Setting environement - - cd $TRAVIS_BUILD_DIR - - source ci/setup-travis.sh - - $CC --version - - $CXX --version - -script: - - echo $TEST_COMMAND - - (eval "$TEST_COMMAND") diff --git a/CHANGELOG.md b/CHANGELOG.md index f4da72e..f728368 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,36 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/), and this project adheres to [Semantic Versioning](https://semver.org/). + +## [Released] +## v1.1.1 - 2023-12-15 +Add Python pybind11 bindings + Upgrade ci/cd +The project can be found on pypi as mmtf_cppy + +### Added +This PR adds python bindings as well as updates the backend build +process of the mmtf-cpp library significantly. These improvements are +mainly from a convenience perspective and include: +- Removing all build-based submodules +- Moving to cmake fetchcontent build +- Simplify CMakeLists with better linking procedures +- Upgrade msgpack-c +- Upgrade catch2 +- Move to github actions for ci/cd +- Use cibuildwheel for wheel cd + +Pybind11 library: +The pybind11 library utilizes the c++ code of mmtf-cpp in order to build +an extremely fast cpp layer underneath the python interface. You have +to keep in mind that moving between c++ and python is slow, but this is +still much faster than the previously existing python library. see this +example: + +time to load a single mmtf file 1000x +cpp bare 0.29s +this library 0.44s +python og 4.34s + ## [Unreleased] ## v1.1.0 - 2022-10-03 diff --git a/CMakeLists.txt b/CMakeLists.txt index 7acffed..0ada8c9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,23 @@ +cmake_minimum_required(VERSION 3.15...3.26 FATAL_ERROR) -cmake_minimum_required(VERSION 3.5 FATAL_ERROR) -project(mmtf-cpp VERSION 1.0.0 LANGUAGES CXX) +# Version based on repo tag +execute_process( + COMMAND git describe --exact-match --tags + OUTPUT_VARIABLE GIT_VERSION + ERROR_QUIET +) +string(STRIP "${GIT_VERSION}" GIT_VERSION) +string(REGEX REPLACE "^v" "" GIT_VERSION "${GIT_VERSION}") +if ("${GIT_VERSION}" STREQUAL "") + set(GIT_VERSION "0.0.0") +endif() +project(mmtf-cpp VERSION ${GIT_VERSION} LANGUAGES CXX) +message("Using git to tag version as: ${GIT_VERSION}") -option(mmtf_build_local "Use the submodule dependencies for building" OFF) option(mmtf_build_examples "Build the examples" OFF) +SET(MSGPACK_USE_BOOST OFF CACHE BOOL "msgpack-c uses boost by default, lets keep that off") +# option(MSGPACK_USE_BOOST "msgpack-c uses boost by default, lets keep that off" OFF) add_library(MMTFcpp INTERFACE) target_compile_features(MMTFcpp INTERFACE cxx_auto_type) @@ -13,23 +26,23 @@ target_include_directories(MMTFcpp INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/include ) -if (mmtf_build_local) - # use header only - set(MSGPACKC_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/submodules/msgpack-c/include) - add_library(msgpackc INTERFACE) - target_include_directories(msgpackc INTERFACE ${MSGPACKC_INCLUDE_DIR}) - if (BUILD_TESTS) - set(CATCH_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/submodules/Catch2/single_include) - add_library(Catch INTERFACE) - target_include_directories(Catch INTERFACE ${CATCH_INCLUDE_DIR}) - endif() +include(FetchContent) +FetchContent_Declare( + msgpack-cxx + GIT_REPOSITORY https://github.com/msgpack/msgpack-c.git + GIT_TAG cpp-6.1.0) +FetchContent_MakeAvailable(msgpack-cxx) + +if(WIN32) + target_link_libraries(MMTFcpp INTERFACE ws2_32) endif() -if (NOT TARGET msgpackc) - find_package(msgpack) +target_link_libraries(MMTFcpp INTERFACE msgpack-cxx) + +if (build_py) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src/python) endif() -target_link_libraries(MMTFcpp INTERFACE msgpackc) if (BUILD_TESTS) enable_testing() diff --git a/README.md b/README.md index 5679e52..4fa1df8 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,53 @@ Here, `` and `` are the paths to the For your more complicated projects, a `CMakeLists.txt` is included for you. + +### Python bindings + +The C++ MMTF library now can build python bindings using pybind11. To use them +you must have A) a c++11 compatible compiler and B) python >= 3.6 + +to install, it is as simple as `pip install .` + +(in the future possible `pip install mmtf-cpp`) + +```python +from mmtf_cppy import StructureData +import numpy as np +import math + + +def rotation_matrix(axis, theta): + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by theta radians. + from https://stackoverflow.com/a/6802723 + """ + axis = np.asarray(axis) + axis = axis / math.sqrt(np.dot(axis, axis)) + a = math.cos(theta / 2.0) + b, c, d = -axis * math.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + + +theta = 1.2 +axis = [0, 0, 1] + +sd = StructureData("my_favorite_structure.mmtf") +sd.atomProperties["pymol_colorList"] = [1 if x % 2 == 0 else 5 for x in sd.xCoordList] +xyz = np.column_stack((sd.xCoordList, sd.yCoordList, sd.zCoordList)) +xyz_rot = rotation_matrix(axis, theta).dot(xyz.T).T +sd.xCoordList, sd.yCoordList, sd.zCoordList = np.hsplit(xyz_rot, 3) +sd.write_to_file("my_favorite_structure_rot.mmtf") + +``` + + + ## Installation You can also perform a system wide installation with `cmake` and `ninja` (or `make`). To do so: @@ -72,7 +119,7 @@ To build the tests + examples we recommend using the following lines: git submodule update --init --recursive mkdir build cd build -cmake -G Ninja -DBUILD_TESTS=ON -Dmmtf_build_local=ON -Dmmtf_build_examples=ON .. +cmake -G Ninja -DBUILD_TESTS=ON -Dmmtf_build_examples=ON .. ninja chmod +x ./tests/mmtf_tests ./tests/mmtf_tests diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index 3210ecf..0000000 --- a/appveyor.yml +++ /dev/null @@ -1,40 +0,0 @@ -version: "{build}" - -os: Visual Studio 2015 - -environment: - matrix: - - generator: MinGW Makefiles - CXX_PATH: 'C:\mingw-w64\x86_64-6.3.0-posix-seh-rt_v5-rev1\mingw64\bin' - ARCH: x64 - - generator: MinGW Makefiles - CXX_PATH: 'C:\mingw-w64\i686-6.3.0-posix-dwarf-rt_v5-rev1\mingw32\bin' - ARCH: x86 - - generator: Visual Studio 14 2015 Win64 - ARCH: x64 - - generator: Visual Studio 14 2015 - ARCH: x86 - -clone_folder: c:\mmtf-cpp - -# Uncomment the following lines to enable remote desktop access to Appveyor -# after a failed build. -# init: -# - ps: iex ((new-object net.webclient).DownloadString('https://raw.githubusercontent.com/appveyor/ci/master/scripts/enable-rdp.ps1')) -# on_failure: -# - ps: $blockRdp = $true; iex ((new-object net.webclient).DownloadString('https://raw.githubusercontent.com/appveyor/ci/master/scripts/enable-rdp.ps1')) - -install: - - git submodule update --init --recursive - - ps: . .\ci\setup-appveyor.ps1 - -build_script: - - cd C:\mmtf-cpp - - mkdir build - - cd build - - ps: echo $env:CMAKE_ARGUMENTS - - cmake %CMAKE_ARGUMENTS% .. - - cmake --build . --config Debug -- %BUILD_ARGUMENTS% - -test_script: - - ctest --build-config Debug --timeout 300 --output-on-failure diff --git a/ci/build_and_run_tests.sh b/ci/build_and_run_tests.sh index de8e67c..b2cbeb9 100755 --- a/ci/build_and_run_tests.sh +++ b/ci/build_and_run_tests.sh @@ -1,8 +1,11 @@ set -e -cd $TRAVIS_BUILD_DIR mkdir build && cd build -$CMAKE_CONFIGURE cmake $CMAKE_ARGS $CMAKE_EXTRA .. -make -j2 -ctest -j2 --output-on-failure -bash $TRAVIS_BUILD_DIR/ci/travis-test-example.sh -cd $TRAVIS_BUILD_DIR +cmake $CMAKE_ARGS -G Ninja -Dmmtf_build_examples=ON -DBUILD_TESTS=ON -DCMAKE_BUILD_TYPE=debug .. +ninja +./tests/mmtf_tests +./tests/multi_cpp_test +./examples/mmtf_demo ../submodules/mmtf_spec/test-suite/mmtf/173D.mmtf +./examples/traverse ../submodules/mmtf_spec/test-suite/mmtf/173D.mmtf +./examples/traverse ../submodules/mmtf_spec/test-suite/mmtf/173D.mmtf json +./examples/traverse ../submodules/mmtf_spec/test-suite/mmtf/173D.mmtf print +./examples/print_as_pdb ../submodules/mmtf_spec/test-suite/mmtf/173D.mmtf diff --git a/ci/setup-appveyor.ps1 b/ci/setup-appveyor.ps1 deleted file mode 100644 index 569e838..0000000 --- a/ci/setup-appveyor.ps1 +++ /dev/null @@ -1,18 +0,0 @@ -if ("$env:CXX_PATH" -ne "") { - $env:PATH += ";$env:CXX_PATH" -} - -$env:CMAKE_ARGUMENTS = "-G `"$env:generator`"" -$env:CMAKE_ARGUMENTS += " -Dmmtf_build_local=ON" -$env:CMAKE_ARGUMENTS += " -DBUILD_TESTS=ON" - -if ($env:generator -Match "Visual Studio") { - $env:BUILD_ARGUMENTS="/verbosity:minimal /m:2" -} else { - $env:BUILD_ARGUMENTS="-j2" -} - -if ($env:generator -eq "MinGW Makefiles") { - # Remove sh.exe from git in the PATH for MinGW to work - $env:PATH = ($env:PATH.Split(';') | Where-Object { $_ -ne 'C:\Program Files\Git\usr\bin' }) -join ';' -} diff --git a/ci/setup-travis.sh b/ci/setup-travis.sh deleted file mode 100644 index 0fb5b25..0000000 --- a/ci/setup-travis.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash - -export CMAKE_ARGS="-DCMAKE_BUILD_TYPE=debug -DBUILD_TESTS=ON -Dmmtf_build_local=ON" - -if [[ "$EMSCRIPTEN" == "ON" ]]; then - # Install a Travis compatible emscripten SDK - wget https://github.com/chemfiles/emscripten-sdk/archive/master.tar.gz - tar xf master.tar.gz - ./emscripten-sdk-master/emsdk activate - source ./emscripten-sdk-master/emsdk_env.sh - - export CMAKE_CONFIGURE='emcmake' - export CMAKE_ARGS="$CMAKE_ARGS -DTEST_RUNNER=node -DCMAKE_BUILD_TYPE=release" - - # Install a modern cmake - cd $HOME - wget https://cmake.org/files/v3.9/cmake-3.9.3-Linux-x86_64.tar.gz - tar xf cmake-3.9.3-Linux-x86_64.tar.gz - export PATH=$HOME/cmake-3.9.3-Linux-x86_64/bin:$PATH - - export CC=emcc - export CXX=em++ - - return -fi - -if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then - if [[ "$TRAVIS_DIST" == "trusty" ]]; then - if [[ "$CC" == "gcc" ]]; then - export CC=gcc-4.8 - export CXX=g++-4.8 - fi - fi - if [[ "$TRAVIS_DIST" == "bionic" ]]; then - if [[ "$CC" == "gcc" ]]; then - export CC=gcc-7 - export CXX=g++-7 - fi - fi -fi - -if [[ "$ARCH" == "x86" ]]; then - export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CXX_FLAGS=-m32 -DCMAKE_C_FLAGS=-m32" -fi diff --git a/ci/travis-test-example.sh b/ci/travis-test-example.sh deleted file mode 100644 index 6772ac3..0000000 --- a/ci/travis-test-example.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -# Test compilation of example code using C++03 if possible -# -> only expected to work in Linux or Mac with CXX set to g++ or clang++ -# -> expects TRAVIS_BUILD_DIR, EMSCRIPTEN, CXX to be set - -# abort on error and exit with proper exit code -set -e -# test example -cd $TRAVIS_BUILD_DIR/examples -if [ -z "$EMSCRIPTEN" ]; then - # Compile with C++03 forced - $CXX -I"../submodules/msgpack-c/include" -I"../include" -std=c++03 -O2 \ - -o read_and_write read_and_write.cpp - ./read_and_write ../submodules/mmtf_spec/test-suite/mmtf/3NJW.mmtf test.mmtf -else - # Cannot do C++03 here and need to embed input file for running it with node - cp ../submodules/mmtf_spec/test-suite/mmtf/3NJW.mmtf . - $CXX -I"../submodules/msgpack-c/include" -I"../include" -O2 \ - -o read_and_write.js read_and_write.cpp --embed-file 3NJW.mmtf - node read_and_write.js 3NJW.mmtf test.mmtf -fi diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 10eb2b8..d14fcb0 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,4 +1,3 @@ - cmake_minimum_required(VERSION 3.5 FATAL_ERROR) SET(executables mmtf_demo traverse print_as_pdb tableexport read_and_write) @@ -6,10 +5,5 @@ SET(executables mmtf_demo traverse print_as_pdb tableexport read_and_write) foreach(exe ${executables}) add_executable(${exe} ${exe}.cpp) target_compile_features(${exe} PRIVATE cxx_auto_type) - if(WIN32) - target_link_libraries(${exe} MMTFcpp ws2_32) - else() - target_link_libraries(${exe} MMTFcpp) - endif() + target_link_libraries(${exe} MMTFcpp) endforeach(exe) - diff --git a/include/mmtf/structure_data.hpp b/include/mmtf/structure_data.hpp index 0d37e32..c16cfe6 100644 --- a/include/mmtf/structure_data.hpp +++ b/include/mmtf/structure_data.hpp @@ -163,7 +163,7 @@ struct StructureData { std::string title; std::string depositionDate; std::string releaseDate; - std::vector > ncsOperatorList; + std::vector> ncsOperatorList; std::vector bioAssemblyList; std::vector entityList; std::vector experimentalMethods; diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5347ccd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,91 @@ +[project] +name = "mmtf_cppy" +description="A minimal example package (with pybind11)" +readme = "README.md" +url = "https://github.com/rcsb/mmtf-cpp" +authors = [ + { name = "Danny Farrell", email = "16297104+danpf@users.noreply.github.com" }, +] + +requires-python = ">=3.8" +classifiers = [ + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +dependencies = [ + "numpy", + "msgpack" +] +dynamic = ["version"] + +[build-system] +requires = ["scikit-build-core>=0.3.3", "pybind11"] +build-backend = "scikit_build_core.build" + +[tool.scikit-build] +wheel.expand-macos-universal-tags = true +cmake.minimum-version = "3.15" +cmake.build-type = "Release" +# cmake.build-type = "Debug" +cmake.source-dir = "." +wheel.packages = ["src/python/mmtf_cppy"] +# Uncomment during development +build-dir = "build/{wheel_tag}" +metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" +sdist.include = ["src/python/mmtf_cppy/_version.py"] + +[tool.scikit-build.cmake.define] +build_py = "ON" +MSGPACK_USE_BOOST = "OFF" + +[tool.setuptools_scm] +write_to = "src/python/mmtf_cppy/_version.py" + +[tool.cibuildwheel] +test-command = "python {project}/src/python/tests/tests.py" +test-skip = ["*universal2:arm64"] +skip = "*-win32 pp*win* *musllinux*" +build-verbosity = 1 + + +[tool.ruff] +src = ["src/python"] + +[tool.ruff.lint] +extend-select = [ + "B", # flake8-bugbear + "I", # isort + "ARG", # flake8-unused-arguments + "C4", # flake8-comprehensions + "EM", # flake8-errmsg + "ICN", # flake8-import-conventions + "G", # flake8-logging-format + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PL", # pylint + "PTH", # flake8-use-pathlib + "RET", # flake8-return + "RUF", # Ruff-specific + "SIM", # flake8-simplify + "T20", # flake8-print + "UP", # pyupgrade + "YTT", # flake8-2020 + "EXE", # flake8-executable + "NPY", # NumPy specific rules + "PD", # pandas-vet +] +ignore = [ + "PLR", # Design related pylint codes + "PT", # flake8-pytest-style +] +isort.required-imports = ["from __future__ import annotations"] + +[tool.ruff.per-file-ignores] +"src/python/tests/**" = ["T20"] + diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt new file mode 100644 index 0000000..83ee2d1 --- /dev/null +++ b/src/python/CMakeLists.txt @@ -0,0 +1,21 @@ +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +FetchContent_Declare( + pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.11.1) +FetchContent_MakeAvailable(pybind11) +find_package(pybind11 CONFIG REQUIRED) + +find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) + +Python_add_library(mmtf_bindings MODULE bindings.cpp WITH_SOABI) + +target_include_directories(mmtf_bindings PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../../include +) +target_link_libraries(mmtf_bindings PRIVATE pybind11::headers MMTFcpp) +target_compile_definitions(mmtf_bindings PRIVATE VERSION_INFO=${SKBUILD_PROJECT_VERSION}) +install(TARGETS mmtf_bindings DESTINATION mmtf_cppy) diff --git a/src/python/bindings.cpp b/src/python/bindings.cpp new file mode 100644 index 0000000..55914d2 --- /dev/null +++ b/src/python/bindings.cpp @@ -0,0 +1,515 @@ + +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace py = pybind11; + +/// CPP -> PY FUNCTIONS + +/* Notes + * We destory original data because it is much faster to apply move + * than it is to copy the data. + */ + +// This destroys the original data +template< typename T > +py::array +array1d_from_vector(std::vector & m) { + if (m.empty()) return py::array_t(); + std::vector* ptr = new std::vector(std::move(m)); + auto capsule = py::capsule(ptr, [](void* p) { + delete reinterpret_cast*>(p); + }); + return py::array_t( + ptr->size(), // shape of array + ptr->data(), // c-style contiguous strides for Sequence + capsule // numpy array references this parent + ); +} + + +template<> +py::array +array1d_from_vector(std::vector & m) { + //if (m.empty()) return py::array_t(); + std::vector* ptr = new std::vector(std::move(m)); + auto capsule = py::capsule(ptr, [](void* p) { + delete reinterpret_cast*>(p); + }); + return py::array( + py::dtype("size()}, // shape of array + {}, + ptr->data(), // c-style contiguous strides for Sequence + capsule // numpy array references this parent + ); +} + +template< > +py::array +array1d_from_vector(std::vector & m) { + return py::array(py::cast(std::move(m))); +} + +template +std::vector +flatten2D(std::vector> const & v) { + std::size_t total_size = 0; + for (auto const & x : v) + total_size += x.size(); + std::vector result; + result.reserve(total_size); + for (auto const & subv : v) + result.insert(result.end(), subv.begin(), subv.end()); + return result; +} + + +// would be nice if this was faster +template< typename T > +py::array +array2D_from_vector(std::vector> const & m) { + if (m.empty()) return py::array_t(); + std::vector* ptr = new std::vector(flatten2D(m)); + auto capsule = py::capsule(ptr, [](void* p) { + delete reinterpret_cast*>(p); + }); + return py::array_t( + {m.size(), m.at(0).size()}, // shape of array + {m.at(0).size()*sizeof(T), sizeof(T)}, // c-style contiguous strides + ptr->data(), + capsule); +} + +// This destroys the original data +py::list +dump_bio_assembly_list(mmtf::StructureData & sd) { + py::object py_ba_class = py::module::import("mmtf_cppy").attr("BioAssembly"); + py::object py_t_class = py::module::import("mmtf_cppy").attr("Transform"); + py::list bal; + for (mmtf::BioAssembly & cba : sd.bioAssemblyList) { + py::list transform_list; + for (mmtf::Transform & trans : cba.transformList) { + std::vector matrix(std::begin(trans.matrix), std::end(trans.matrix)); + transform_list.append( + py_t_class( + array1d_from_vector(trans.chainIndexList), + array1d_from_vector(matrix) + ) + ); + } + bal.append( + py_ba_class( + transform_list, + py::str(cba.name) + ) + ); + } + return bal; +} + +// This destroys the original data +py::list +dump_entity_list(std::vector & cpp_el) { + py::object entity = py::module::import("mmtf_cppy").attr("Entity"); + py::list el; + for (mmtf::Entity & e : cpp_el) { + el.append( + entity( + array1d_from_vector(e.chainIndexList), + e.description, + e.type, + e.sequence) + ); + } + return el; +} + +py::bytes +raw_properties(mmtf::StructureData const & sd) { + std::stringstream bytes; + std::map< std::string, std::map< std::string, msgpack::object > > objs({ + {"bondProperties", sd.bondProperties}, + {"atomProperties", sd.atomProperties}, + {"groupProperties", sd.groupProperties}, + {"chainProperties", sd.chainProperties}, + {"modelProperties", sd.modelProperties}, + {"extraProperties", sd.extraProperties}}); + msgpack::pack(bytes, objs); + return py::bytes(bytes.str()); +} + + +std::vector +make_transformList(py::list const & l) { + std::vector tl; + for (auto const & trans : l) { + mmtf::Transform t; + t.chainIndexList = trans.attr("chainIndexList").cast>(); + py::list pymatrix(trans.attr("matrix")); + std::size_t count(0); + for (auto const & x : pymatrix) { + t.matrix[count] = x.cast(); + ++count; + } + tl.push_back(t); + } + return tl; +} + + +void +set_bioAssemblyList(py::list const & obj, mmtf::StructureData & sd) { + std::vector bioAs; + for (auto const & py_bioAssembly : obj ) { + mmtf::BioAssembly bioA; + bioA.name = py::str(py_bioAssembly.attr("name")); + py::list py_transform_list(py_bioAssembly.attr("transformList")); + std::vector transform_list = make_transformList(py_transform_list); + bioA.transformList = transform_list; + bioAs.push_back(bioA); + } + sd.bioAssemblyList = bioAs; +} + + +void +set_entityList(py::list const & obj, mmtf::StructureData & sd) { + std::vector entities; + for (auto const & py_entity : obj ) { + mmtf::Entity entity; + entity.chainIndexList = py_entity.attr("chainIndexList").cast>(); + entity.description = py_entity.attr("description").cast(); + entity.type = py_entity.attr("type").cast(); + entity.sequence = py_entity.attr("sequence").cast(); + entities.push_back(entity); + } + sd.entityList = entities; +} + + +void +set_groupList(py::list const & obj, mmtf::StructureData & sd) { + std::vector groups; + for (auto const & py_group : obj ) { + mmtf::GroupType group; + group.formalChargeList = py_group.attr("formalChargeList").cast>(); + group.atomNameList = py_group.attr("atomNameList").cast>(); + group.elementList = py_group.attr("elementList").cast>(); + group.bondAtomList = py_group.attr("bondAtomList").cast>(); + group.bondOrderList = py_group.attr("bondOrderList").cast>(); + group.bondResonanceList = py_group.attr("bondResonanceList").cast>(); + group.groupName = py_group.attr("groupName").cast(); + group.singleLetterCode = py_group.attr("singleLetterCode").cast(); + group.chemCompType = py_group.attr("chemCompType").cast(); + groups.push_back(group); + } + sd.groupList = groups; +} + + +// This destroys the original data +py::list +dump_group_list(std::vector & gtl) { + py::object py_gt_class = py::module::import("mmtf_cppy").attr("GroupType"); + py::list gl; + for (mmtf::GroupType & gt : gtl) { + gl.append( + py_gt_class( + array1d_from_vector(gt.formalChargeList), + gt.atomNameList, + gt.elementList, + array1d_from_vector(gt.bondAtomList), + array1d_from_vector(gt.bondOrderList), + array1d_from_vector(gt.bondResonanceList), + gt.groupName, + std::string(1, gt.singleLetterCode), + gt.chemCompType + ) + ); + } + return gl; +} + +template< typename T> +std::vector +py_array_to_vector(py::array_t const & array_in) { + std::vector vec_array(array_in.size()); + std::memcpy(vec_array.data(), array_in.data(), array_in.size()*sizeof(T)); + return vec_array; +} + +template<> +std::vector +py_array_to_vector(py::array_t const & array_in) { + std::string tmpstr(array_in.data(), array_in.size()); + std::vector vec_array(tmpstr.begin(), tmpstr.end()); + return vec_array; +} + +/* This isn't really necessary, but lets make the interface anyway + */ +py::bytes +py_encodeInt8ToByte(py::array_t const & array_in) { + std::vector cpp_vec(mmtf::encodeInt8ToByte(py_array_to_vector(array_in))); + return py::bytes(std::string(cpp_vec.begin(), cpp_vec.end())); +} + +py::bytes +py_encodeFourByteInt(py::array_t const & array_in) { + std::vector const cpp_vec(py_array_to_vector(array_in)); + std::vector encoded(mmtf::encodeFourByteInt(cpp_vec)); + return py::bytes(encoded.data(), encoded.size()); +} + +py::bytes +py_encodeRunLengthChar(py::array_t const & array_in) { + std::vector const cpp_vec(py_array_to_vector(array_in)); + std::vector encoded(mmtf::encodeRunLengthChar(cpp_vec)); + return py::bytes(encoded.data(), encoded.size()); +} + +py::bytes +py_encodeRunLengthDeltaInt(py::array_t const & array_in) { + std::vector const cpp_vec(py_array_to_vector(array_in)); + std::vector encoded(mmtf::encodeRunLengthDeltaInt(cpp_vec)); + return py::bytes(encoded.data(), encoded.size()); +} + +py::bytes +py_encodeDeltaRecursiveFloat(py::array_t const & array_in, int32_t const multiplier = 1000) { + std::vector const cpp_vec(py_array_to_vector(array_in)); + std::vector encoded(mmtf::encodeDeltaRecursiveFloat(cpp_vec, multiplier)); + return py::bytes(encoded.data(), encoded.size()); +} + +py::bytes +py_encodeRunLengthFloat(py::array_t const & array_in, int32_t const multiplier = 1000) { + std::vector const cpp_vec(py_array_to_vector(array_in)); + std::vector encoded(mmtf::encodeRunLengthFloat(cpp_vec, multiplier)); + return py::bytes(encoded.data(), encoded.size()); +} + + + +py::bytes +py_encodeRunLengthInt8(py::array_t const & array_in) { + std::vector const cpp_vec(py_array_to_vector(array_in)); + std::vector encoded(mmtf::encodeRunLengthInt8(cpp_vec)); + return py::bytes(encoded.data(), encoded.size()); +} + +// TODO pyarray POD types to numpy array? seems hard +//py::bytes +//py_encodeStringVector4(py::array_t> const & array_in) { +// using np_str_t = std::array; +// pybind11::array_t cstring_array(vector.size()); +// const char * data = +// np_str_t* array_of_cstr_ptr = reinterpret_cast(cstring_array.request().ptr); +// +// +///* std::vector tobuild; */ +///* std::vector const cpp_vec(py_array_to_vector(array_in)); */ +///* std::vector encoded(mmtf::encodeStringVector(cpp_vec, max_string_size)); */ +///* return py::bytes(encoded.data(), encoded.size()); */ +//} + + +std::vector +char_vector_to_string_vector(std::vector const & cvec) { + std::vector ret(cvec.size()); + for (std::size_t i=0; i > tmp_target; + tmp_object.convert(tmp_target); + sd.bondProperties = tmp_target["bondProperties"]; + sd.atomProperties = tmp_target["atomProperties"]; + sd.groupProperties = tmp_target["groupProperties"]; + sd.chainProperties = tmp_target["chainProperties"]; + sd.modelProperties = tmp_target["modelProperties"]; + sd.extraProperties = tmp_target["extraProperties"]; +} + + +py::array +binary_decode_int32(py::bytes const & bytes) { + using namespace pybind11::literals; + std::string const tmpstr(bytes); + std::vector tmp; + mmtf::BinaryDecoder bd(tmpstr); + bd.decode(tmp); + return array1d_from_vector(tmp); +} + +py::array +binary_decode_int16(py::bytes const & bytes) { + using namespace pybind11::literals; + std::string const tmpstr(bytes); + std::vector tmp; + mmtf::BinaryDecoder bd(tmpstr); + bd.decode(tmp); + return array1d_from_vector(tmp); +} + +py::array +binary_decode_int8(py::bytes const & bytes) { + using namespace pybind11::literals; + std::string const tmpstr(bytes); + std::vector tmp; + mmtf::BinaryDecoder bd(tmpstr); + bd.decode(tmp); + return array1d_from_vector(tmp); +} + + +py::array +binary_decode_char(py::bytes const & bytes) { + using namespace pybind11::literals; + std::string const tmpstr(bytes); + std::vector tmp; + mmtf::BinaryDecoder bd(tmpstr); + bd.decode(tmp); + return array1d_from_vector(tmp); +} + +py::array +binary_decode_float(py::bytes const & bytes) { + using namespace pybind11::literals; + std::string const tmpstr(bytes); + std::vector tmp; + mmtf::BinaryDecoder bd(tmpstr); + bd.decode(tmp); + return array1d_from_vector(tmp); +} + + +PYBIND11_MODULE(mmtf_bindings, m) { + m.def("decode_int32", &binary_decode_int32, "decode array[int32_t]"); + m.def("decode_int16", &binary_decode_int16, "decode array[int16_t]"); + m.def("decode_int8", &binary_decode_int8, "decode array[int8_t]"); + m.def("decode_char", &binary_decode_char, "decode array[char]"); + m.def("decode_float", &binary_decode_float, "decode array[float]"); + // new stuff here + py::class_(m, "CPPStructureData") + .def( pybind11::init( [](){ return new mmtf::StructureData(); } ) ) + .def( pybind11::init( [](mmtf::StructureData const &o){ return new mmtf::StructureData(o); } ) ) + .def_readwrite("mmtfVersion", &mmtf::StructureData::mmtfVersion) + .def_readwrite("mmtfProducer", &mmtf::StructureData::mmtfProducer) + .def("unitCell", [](mmtf::StructureData &m){return array1d_from_vector(m.unitCell);}) + .def_readwrite("unitCell_io", &mmtf::StructureData::unitCell) + .def_readwrite("spaceGroup", &mmtf::StructureData::spaceGroup) + .def_readwrite("structureId", &mmtf::StructureData::structureId) + .def_readwrite("title", &mmtf::StructureData::title) + .def_readwrite("depositionDate", &mmtf::StructureData::depositionDate) + .def_readwrite("releaseDate", &mmtf::StructureData::releaseDate) + .def("ncsOperatorList", [](mmtf::StructureData &m){return array2D_from_vector(m.ncsOperatorList);}) + .def_readwrite("ncsOperatorList_io", &mmtf::StructureData::ncsOperatorList, py::return_value_policy::move) + .def("bioAssemblyList", [](mmtf::StructureData &m){return dump_bio_assembly_list(m);}) + .def_readwrite("bioAssemblyList_io", &mmtf::StructureData::bioAssemblyList) + .def("entityList", [](mmtf::StructureData &m){return dump_entity_list(m.entityList);}) + .def_readwrite("entityList_io", &mmtf::StructureData::entityList) + .def_readwrite("experimentalMethods", &mmtf::StructureData::experimentalMethods) + .def_readwrite("resolution", &mmtf::StructureData::resolution) + .def_readwrite("rFree", &mmtf::StructureData::rFree) + .def_readwrite("rWork", &mmtf::StructureData::rWork) + .def_readwrite("numBonds", &mmtf::StructureData::numBonds) + .def_readwrite("numAtoms", &mmtf::StructureData::numAtoms) + .def_readwrite("numGroups", &mmtf::StructureData::numGroups) + .def_readwrite("numChains", &mmtf::StructureData::numChains) + .def_readwrite("numModels", &mmtf::StructureData::numModels) + .def("groupList", [](mmtf::StructureData &m){return dump_group_list(m.groupList);}) + .def_readwrite("groupList_io", &mmtf::StructureData::groupList) + .def("unitCell", [](mmtf::StructureData &m){return array1d_from_vector(m.unitCell);}) + .def_readwrite("unitCell_io", &mmtf::StructureData::unitCell) + .def("bondAtomList", [](mmtf::StructureData &m){return array1d_from_vector(m.bondAtomList);}) + .def_readwrite("bondAtomList_io", &mmtf::StructureData::bondAtomList) + .def("bondOrderList", [](mmtf::StructureData &m){return array1d_from_vector(m.bondOrderList);}) + .def_readwrite("bondOrderList_io", &mmtf::StructureData::bondOrderList) + .def("bondResonanceList", [](mmtf::StructureData &m){return array1d_from_vector(m.bondResonanceList);}) + .def_readwrite("bondResonanceList_io", &mmtf::StructureData::bondResonanceList) + .def("xCoordList", [](mmtf::StructureData &m){return array1d_from_vector(m.xCoordList);}) + .def_readwrite("xCoordList_io", &mmtf::StructureData::xCoordList) + .def("yCoordList", [](mmtf::StructureData &m){return array1d_from_vector(m.yCoordList);}) + .def_readwrite("yCoordList_io", &mmtf::StructureData::yCoordList) + .def("zCoordList", [](mmtf::StructureData &m){return array1d_from_vector(m.zCoordList);}) + .def_readwrite("zCoordList_io", &mmtf::StructureData::zCoordList) + .def("bFactorList", [](mmtf::StructureData &m){return array1d_from_vector(m.bFactorList);}) + .def_readwrite("bFactorList_io", &mmtf::StructureData::bFactorList) + .def("atomIdList", [](mmtf::StructureData &m){return array1d_from_vector(m.atomIdList);}) + .def_readwrite("atomIdList_io", &mmtf::StructureData::atomIdList) + .def("altLocList", [](mmtf::StructureData &m) { + return array1d_from_vector(m.altLocList); + /* std::vector tmp(char_vector_to_string_vector(m.altLocList)); */ + /* return array1d_from_vector(tmp); */ + }) + .def("set_altLocList", [](mmtf::StructureData &m, py::array_t const & st) { + m.altLocList = std::vector(st.data(), st.data()+st.size()); + }) + .def("occupancyList", [](mmtf::StructureData &m){return array1d_from_vector(m.occupancyList);}) + .def_readwrite("occupancyList_io", &mmtf::StructureData::occupancyList) + .def("groupIdList", [](mmtf::StructureData &m){return array1d_from_vector(m.groupIdList);}) + .def_readwrite("groupIdList_io", &mmtf::StructureData::groupIdList) + .def("groupTypeList", [](mmtf::StructureData &m){return array1d_from_vector(m.groupTypeList);}) + .def_readwrite("groupTypeList_io", &mmtf::StructureData::groupTypeList) + .def("secStructList", [](mmtf::StructureData &m){return array1d_from_vector(m.secStructList);}) + .def_readwrite("secStructList_io", &mmtf::StructureData::secStructList) + .def("insCodeList", [](mmtf::StructureData &m) { + std::vector tmp(char_vector_to_string_vector(m.insCodeList)); + return array1d_from_vector(tmp); + }) + .def("set_insCodeList", [](mmtf::StructureData &m, py::array_t const & st) { + m.insCodeList = std::vector(st.data(), st.data()+st.size()); + }) + .def("sequenceIndexList", [](mmtf::StructureData &m){return array1d_from_vector(m.sequenceIndexList);}) + .def_readwrite("sequenceIndexList_io", &mmtf::StructureData::sequenceIndexList) + .def("chainIdList", [](mmtf::StructureData &m){return array1d_from_vector(m.chainIdList);}) + .def_readwrite("chainIdList_io", &mmtf::StructureData::chainIdList) + .def("chainNameList", [](mmtf::StructureData &m){return array1d_from_vector(m.chainNameList);}) + .def_readwrite("chainNameList_io", &mmtf::StructureData::chainNameList) + .def("groupsPerChain", [](mmtf::StructureData &m){return array1d_from_vector(m.groupsPerChain);}) + .def_readwrite("groupsPerChain_io", &mmtf::StructureData::groupsPerChain) + .def("chainsPerModel", [](mmtf::StructureData &m){return array1d_from_vector(m.chainsPerModel);}) + .def_readwrite("chainsPerModel_io", &mmtf::StructureData::chainsPerModel) + .def("set_properties", [](mmtf::StructureData & sd, py::bytes const & bytes_in){set_properties(sd, bytes_in);}) + .def("raw_properties", [](mmtf::StructureData const &m){return raw_properties(m);}); + + // I think it would be ideal to not pass in the sd, but it is still very + // fast this way. + m.def("set_bioAssemblyList", [](py::list const & i, mmtf::StructureData & sd){return set_bioAssemblyList(i, sd);}); + m.def("set_entityList", [](py::list const & i, mmtf::StructureData & sd){return set_entityList(i, sd);}); + m.def("set_groupList", [](py::list const & i, mmtf::StructureData & sd){return set_groupList(i, sd);}); + m.def("decodeFromFile", &mmtf::decodeFromFile, "decode a mmtf::StructureData from a file"); + m.def("decodeFromBuffer", &mmtf::decodeFromBuffer, "decode a mmtf::StructureData from bytes"); + m.def("encodeToFile", [](mmtf::StructureData const &m, std::string const & fn){mmtf::encodeToFile(m, fn);}); + m.def("encodeToStream", [](mmtf::StructureData const &m){ + std::stringstream ss; + mmtf::encodeToStream(m, ss); + return py::bytes(ss.str()); + }); + // encoders + m.def("encodeInt8ToByte", &py_encodeInt8ToByte); + m.def("encodeFourByteInt", &py_encodeFourByteInt); + m.def("encodeRunLengthChar", &py_encodeRunLengthChar); + m.def("encodeRunLengthDeltaInt", &py_encodeRunLengthDeltaInt); + m.def("encodeDeltaRecursiveFloat", &py_encodeDeltaRecursiveFloat); + m.def("encodeRunLengthFloat", &py_encodeRunLengthFloat); + m.def("encodeRunLengthInt8", &py_encodeRunLengthInt8); + //m.def("encodeStringVector", &py_encodeStringVector); +} diff --git a/src/python/mmtf_cppy/__init__.py b/src/python/mmtf_cppy/__init__.py new file mode 100644 index 0000000..8961490 --- /dev/null +++ b/src/python/mmtf_cppy/__init__.py @@ -0,0 +1,38 @@ +try: + from ._version import version as __version__ + from ._version import version_tuple as version_tuple +except ImportError: + __version__ = "unknown version" + version_tuple = (0, 0, "unknown version") + +from .structure_data import ( + Entity as Entity, + GroupType as GroupType, + Transform as Transform, + BioAssembly as BioAssembly, + StructureData as StructureData, +) + + +from .mmtf_bindings import ( + CPPStructureData as CPPStructureData, + decode_int8 as decode_int8, + encodeRunLengthInt8 as encodeRunLengthInt8, + decodeFromBuffer as decodeFromBuffer, + encodeDeltaRecursiveFloat as encodeDeltaRecursiveFloat, + encodeToFile as encodeToFile, + decodeFromFile as decodeFromFile, + encodeFourByteInt as encodeFourByteInt, + encodeToStream as encodeToStream, + decode_char as decode_char, + encodeInt8ToByte as encodeInt8ToByte, + set_bioAssemblyList as set_bioAssemblyList, + decode_float as decode_float, + encodeRunLengthChar as encodeRunLengthChar, + set_entityList as set_entityList, + decode_int16 as decode_int16, + encodeRunLengthDeltaInt as encodeRunLengthDeltaInt, + set_groupList as set_groupList, + decode_int32 as decode_int32, + encodeRunLengthFloat as encodeRunLengthFloat, +) diff --git a/src/python/mmtf_cppy/structure_data.py b/src/python/mmtf_cppy/structure_data.py new file mode 100644 index 0000000..76a9e02 --- /dev/null +++ b/src/python/mmtf_cppy/structure_data.py @@ -0,0 +1,495 @@ +from os import PathLike +from typing import Dict, List, Union + +import msgpack +import numpy as np + +from . import mmtf_bindings +from .mmtf_bindings import ( + CPPStructureData, + decodeFromBuffer, + decodeFromFile, + encodeToFile, + encodeToStream, +) + + +class Entity: + def __init__(self, chainIndexList: np.ndarray, description: str, type_: str, sequence: str): + self.chainIndexList = chainIndexList + self.description = description + self.type = type_ + self.sequence = sequence + + def __repr__(self): + return ( + f"chainIndexList: {self.chainIndexList}" + f"description: {self.description}" + f"type: {self.type}" + f"sequence: {self.sequence}" + ) + + def __eq__(self, other: "Entity"): + return ( + (self.chainIndexList == other.chainIndexList).all() + and self.description == other.description + and self.type == other.type + and self.sequence == other.sequence + ) + + +class GroupType: + def __init__( + self, + formalChargeList: np.ndarray, + atomNameList: np.ndarray, + elementList: np.ndarray, + bondAtomList: np.ndarray, + bondOrderList: np.ndarray, + bondResonanceList: np.ndarray, + groupName: str, + singleLetterCode: str, + chemCompType: str, + ): + self.formalChargeList = formalChargeList + self.atomNameList = atomNameList + self.elementList = elementList + self.bondAtomList = bondAtomList + self.bondOrderList = bondOrderList + self.bondResonanceList = bondResonanceList + self.groupName = groupName + self.singleLetterCode = singleLetterCode + self.chemCompType = chemCompType + + def __repr__(self): + return ( + f"formalChargeList: {self.formalChargeList}" + f" atomNameList: {self.atomNameList}" + f" elementList: {self.elementList}" + f" bondAtomList: {self.bondAtomList}" + f" bondOrderList: {self.bondOrderList}" + f" bondResonanceList: {self.bondResonanceList}" + f" groupName: {self.groupName}" + f" singleLetterCode: {self.singleLetterCode}" + f" chemCompType: {self.chemCompType}" + ) + + def __eq__(self, other: "GroupType"): + return ( + (self.formalChargeList == other.formalChargeList).all() + and self.atomNameList == other.atomNameList + and self.elementList == other.elementList + and (self.bondAtomList == other.bondAtomList).all() + and (self.bondOrderList == other.bondOrderList).all() + and (self.bondResonanceList == other.bondResonanceList).all() + and self.groupName == other.groupName + and self.singleLetterCode == other.singleLetterCode + and self.chemCompType == other.chemCompType + ) + + +class Transform: + def __init__(self, chainIndexList: np.ndarray, matrix: np.ndarray): + self.chainIndexList = chainIndexList + self.matrix = matrix + + def __eq__(self, other: "Transform"): + return (self.chainIndexList == other.chainIndexList).all() and (self.matrix == other.matrix).all() + + +class BioAssembly: + def __init__(self, transformList: List[Transform], name: str): + self.transformList = transformList + self.name = name + + def __eq__(self, other: "BioAssembly"): + return self.transformList == other.transformList and self.name == other.name + + +def cppSD_from_SD(sd: "StructureData"): + cppsd = CPPStructureData() + cppsd.mmtfVersion = sd.mmtfVersion + cppsd.mmtfProducer = sd.mmtfProducer + cppsd.unitCell_io = sd.unitCell + cppsd.spaceGroup = sd.spaceGroup + cppsd.structureId = sd.structureId + cppsd.title = sd.title + cppsd.depositionDate = sd.depositionDate + cppsd.releaseDate = sd.releaseDate + cppsd.ncsOperatorList_io = sd.ncsOperatorList + mmtf_bindings.set_bioAssemblyList(sd.bioAssemblyList, cppsd) + mmtf_bindings.set_entityList(sd.entityList, cppsd) + cppsd.experimentalMethods = sd.experimentalMethods + cppsd.resolution = sd.resolution + cppsd.rFree = sd.rFree + cppsd.rWork = sd.rWork + cppsd.numBonds = sd.numBonds + cppsd.numAtoms = sd.numAtoms + cppsd.numGroups = sd.numGroups + cppsd.numChains = sd.numChains + cppsd.numModels = sd.numModels + mmtf_bindings.set_groupList(sd.groupList, cppsd) + cppsd.bondAtomList_io = sd.bondAtomList + cppsd.bondOrderList_io = sd.bondOrderList + cppsd.bondResonanceList_io = sd.bondResonanceList + cppsd.xCoordList_io = sd.xCoordList + cppsd.yCoordList_io = sd.yCoordList + cppsd.zCoordList_io = sd.zCoordList + cppsd.bFactorList_io = sd.bFactorList + cppsd.atomIdList_io = sd.atomIdList + assert sd.altLocList is not None + tmp_altLocList = np.array([ord(x) if x else 0 for x in sd.altLocList], dtype=np.int8) + cppsd.set_altLocList(tmp_altLocList) + del tmp_altLocList + cppsd.occupancyList_io = sd.occupancyList + cppsd.groupIdList_io = sd.groupIdList + cppsd.groupTypeList_io = sd.groupTypeList + cppsd.secStructList_io = sd.secStructList + assert sd.insCodeList is not None + tmp_insCodeList = np.array([ord(x) if x else 0 for x in sd.insCodeList], dtype=np.int8) + cppsd.set_insCodeList(np.int8(tmp_insCodeList)) + del tmp_insCodeList + cppsd.sequenceIndexList_io = sd.sequenceIndexList + cppsd.chainIdList_io = sd.chainIdList + cppsd.chainNameList_io = sd.chainNameList + cppsd.groupsPerChain_io = sd.groupsPerChain + cppsd.chainsPerModel_io = sd.chainsPerModel + packed_data = msgpack.packb( + { + "bondProperties": sd.bondProperties, + "atomProperties": sd.atomProperties, + "groupProperties": sd.groupProperties, + "chainProperties": sd.chainProperties, + "modelProperties": sd.modelProperties, + "extraProperties": sd.extraProperties, + }, + use_bin_type=True, + ) + cppsd.set_properties(packed_data) + return cppsd + + +class StructureData: + def __init__( + self, + mmtfVersion=str, + mmtfProducer=str, + unitCell=List[float], + spaceGroup=str, + structureId=str, + title=str, + depositionDate=str, + releaseDate=str, + ncsOperatorList=List[List[float]], + bioAssemblyList=List[BioAssembly], + entityList=List[Entity], + experimentalMethods=str, + resolution=float, + rFree=float, + rWork=float, + numBonds=int, + numAtoms=int, + numGroups=int, + numChains=int, + numModels=int, + groupList=List[GroupType], + bondAtomList=np.ndarray, + bondOrderList=np.ndarray, # type? + bondResonanceList=np.ndarray, + xCoordList=np.ndarray, + yCoordList=np.ndarray, + zCoordList=np.ndarray, + bFactorList=np.ndarray, + atomIdList=np.ndarray, + altLocList=np.ndarray, + occupancyList=np.ndarray, + groupIdList=List[int], + groupTypeList=List[int], + secStructList=List[int], + insCodeList=List[str], + sequenceIndexList=List[int], + chainIdList=List[str], + chainNameList=List[str], + groupsPerChain=List[int], + chainsPerModel=List[int], + bondProperties=Dict[str, bytes], + atomProperties=Dict[str, bytes], + groupProperties=Dict[str, bytes], + chainProperties=Dict[str, bytes], + modelProperties=Dict[str, bytes], + extraProperties=Dict[str, bytes], + ): + """ + Recommended to use `StructureData.init_from_bytes` or `StructureData.init_from_file_name` + + Note: + file and bytes are separated because it will be faster + if you just let c++ handle the file (rather than have + python read the bytes itself and pass them to c++) + """ + self.mmtfVersion = mmtfVersion + self.mmtfProducer = mmtfProducer + self.unitCell = unitCell + self.spaceGroup = spaceGroup + self.structureId = structureId + self.title = title + self.depositionDate = depositionDate + self.releaseDate = releaseDate + self.ncsOperatorList = ncsOperatorList + self.bioAssemblyList = bioAssemblyList + self.entityList = entityList + self.experimentalMethods = experimentalMethods + self.resolution = resolution + self.rFree = rFree + self.rWork = rWork + self.numBonds = numBonds + self.numAtoms = numAtoms + self.numGroups = numGroups + self.numChains = numChains + self.numModels = numModels + self.groupList = groupList + self.bondAtomList = bondAtomList + self.bondOrderList = bondOrderList + self.bondResonanceList = bondResonanceList + self.xCoordList = xCoordList + self.yCoordList = yCoordList + self.zCoordList = zCoordList + self.bFactorList = bFactorList + self.atomIdList = atomIdList + self.altLocList = altLocList + self.occupancyList = occupancyList + self.groupIdList = groupIdList + self.groupTypeList = groupTypeList + self.secStructList = secStructList + self.insCodeList = insCodeList + self.sequenceIndexList = sequenceIndexList + self.chainIdList = chainIdList + self.chainNameList = chainNameList + self.groupsPerChain = groupsPerChain + self.chainsPerModel = chainsPerModel + self.bondProperties = bondProperties + self.atomProperties = atomProperties + self.groupProperties = groupProperties + self.chainProperties = chainProperties + self.modelProperties = modelProperties + self.extraProperties = extraProperties + + @classmethod + def init_from_bytes(cls, file_bytes: bytes) -> "StructureData": + cppsd = CPPStructureData() + decodeFromBuffer(cppsd, file_bytes, len(file_bytes)) + return cls.init_from_cppsd(cppsd) + + @classmethod + def init_from_file_name(cls, file_name: Union[str, PathLike]) -> "StructureData": + cppsd = CPPStructureData() + decodeFromFile(cppsd, str(file_name)) + return cls.init_from_cppsd(cppsd) + + @classmethod + def init_from_cppsd(cls, cppsd: "CPPStructureData") -> "StructureData": + raw_properties = cppsd.raw_properties() + raw_properties = msgpack.unpackb(raw_properties, raw=False) + return cls( + mmtfVersion=cppsd.mmtfVersion, + mmtfProducer=cppsd.mmtfProducer, + unitCell=cppsd.unitCell(), + spaceGroup=cppsd.spaceGroup, + structureId=cppsd.structureId, + title=cppsd.title, + depositionDate=cppsd.depositionDate, + releaseDate=cppsd.releaseDate, + ncsOperatorList=cppsd.ncsOperatorList(), + bioAssemblyList=cppsd.bioAssemblyList(), + entityList=cppsd.entityList(), + experimentalMethods=cppsd.experimentalMethods, + resolution=cppsd.resolution, + rFree=cppsd.rFree, + rWork=cppsd.rWork, + numBonds=cppsd.numBonds, + numAtoms=cppsd.numAtoms, + numGroups=cppsd.numGroups, + numChains=cppsd.numChains, + numModels=cppsd.numModels, + groupList=cppsd.groupList(), + bondAtomList=cppsd.bondAtomList(), + bondOrderList=cppsd.bondOrderList(), + bondResonanceList=cppsd.bondResonanceList(), + xCoordList=cppsd.xCoordList(), + yCoordList=cppsd.yCoordList(), + zCoordList=cppsd.zCoordList(), + bFactorList=cppsd.bFactorList(), + atomIdList=cppsd.atomIdList(), + altLocList=cppsd.altLocList(), + occupancyList=cppsd.occupancyList(), + groupIdList=cppsd.groupIdList(), + groupTypeList=cppsd.groupTypeList(), + secStructList=cppsd.secStructList(), + insCodeList=cppsd.insCodeList(), + sequenceIndexList=cppsd.sequenceIndexList(), + chainIdList=cppsd.chainIdList(), + chainNameList=cppsd.chainNameList(), + groupsPerChain=cppsd.groupsPerChain(), + chainsPerModel=cppsd.chainsPerModel(), + bondProperties=raw_properties["bondProperties"], + atomProperties=raw_properties["atomProperties"], + groupProperties=raw_properties["groupProperties"], + chainProperties=raw_properties["chainProperties"], + modelProperties=raw_properties["modelProperties"], + extraProperties=raw_properties["extraProperties"], + ) + + def write_to_file(self, filename: Union[str, PathLike]): + cppsd = cppSD_from_SD(self) + encodeToFile(cppsd, str(filename)) + + def write_to_bytes(self): + cppsd = cppSD_from_SD(self) + return encodeToStream(cppsd) + + def check_equals(self, other: "StructureData"): + """ + A debugging function, to check and see what parts of each StructureData are not identical + """ + if not (self.mmtfVersion == other.mmtfVersion): + print("NOT self.mmtfVersion == other.mmtfVersion") + if not (self.mmtfProducer == other.mmtfProducer): + print("NOT self.mmtfProducer == other.mmtfProducer") + if not ((self.unitCell == other.unitCell).all()): + print("NOT (self.unitCell == other.unitCell).all()") + if not (self.spaceGroup == other.spaceGroup): + print("NOT self.spaceGroup == other.spaceGroup") + if not (self.structureId == other.structureId): + print("NOT self.structureId == other.structureId") + if not (self.title == other.title): + print("NOT self.title == other.title") + if not (self.depositionDate == other.depositionDate): + print("NOT self.depositionDate == other.depositionDate") + if not (self.releaseDate == other.releaseDate): + print("NOT self.releaseDate == other.releaseDate") + if not ((self.ncsOperatorList == other.ncsOperatorList).all()): + print("NOT (self.ncsOperatorList == other.ncsOperatorList).all()") + if not (self.bioAssemblyList == other.bioAssemblyList): + print("NOT self.bioAssemblyList == other.bioAssemblyList") + if not (self.entityList == other.entityList): + print("NOT self.entityList == other.entityList") + if not (self.experimentalMethods == other.experimentalMethods): + print("NOT self.experimentalMethods == other.experimentalMethods") + if not (self.resolution == other.resolution): + print("NOT self.resolution == other.resolution") + if not (self.rFree == other.rFree): + print("NOT self.rFree == other.rFree") + if not (self.rWork == other.rWork): + print("NOT self.rWork == other.rWork") + if not (self.numBonds == other.numBonds): + print("NOT self.numBonds == other.numBonds") + if not (self.numAtoms == other.numAtoms): + print("NOT self.numAtoms == other.numAtoms") + if not (self.numGroups == other.numGroups): + print("NOT self.numGroups == other.numGroups") + if not (self.numChains == other.numChains): + print("NOT self.numChains == other.numChains") + if not (self.numModels == other.numModels): + print("NOT self.numModels == other.numModels") + if not (self.groupList == other.groupList): + print("NOT self.groupList == other.groupList") + if not ((self.bondAtomList == other.bondAtomList).all()): + print("NOT (self.bondAtomList == other.bondAtomList).all()") + if not ((self.bondOrderList == other.bondOrderList).all()): + print("NOT (self.bondOrderList == other.bondOrderList).all()") + if not ((self.bondResonanceList == other.bondResonanceList).all()): + print("NOT (self.bondResonanceList == other.bondResonanceList).all()") + if not ((self.xCoordList == other.xCoordList).all()): + print("NOT (self.xCoordList == other.xCoordList).all()") + if not ((self.yCoordList == other.yCoordList).all()): + print("NOT (self.yCoordList == other.yCoordList).all()") + if not ((self.zCoordList == other.zCoordList).all()): + print("NOT (self.zCoordList == other.zCoordList).all()") + if not ((self.bFactorList == other.bFactorList).all()): + print("NOT (self.bFactorList == other.bFactorList).all()") + if not ((self.atomIdList == other.atomIdList).all()): + print("NOT (self.atomIdList == other.atomIdList).all()") + if not ((self.altLocList == other.altLocList).all()): + print("NOT (self.altLocList == other.altLocList).all()") + if not ((self.occupancyList == other.occupancyList).all()): + print("NOT (self.occupancyList == other.occupancyList).all()") + if not ((self.groupIdList == other.groupIdList).all()): + print("NOT (self.groupIdList == other.groupIdList).all()") + if not ((self.groupTypeList == other.groupTypeList).all()): + print("NOT (self.groupTypeList == other.groupTypeList).all()") + if not ((self.secStructList == other.secStructList).all()): + print("NOT (self.secStructList == other.secStructList).all()") + if not ((self.insCodeList == other.insCodeList).all()): + print("NOT (self.insCodeList == other.insCodeList).all()") + if not ((self.sequenceIndexList == other.sequenceIndexList).all()): + print("NOT (self.sequenceIndexList == other.sequenceIndexList).all()") + if not ((self.chainIdList == other.chainIdList).all()): + print("NOT (self.chainIdList == other.chainIdList).all()") + if not ((self.chainNameList == other.chainNameList).all()): + print("NOT (self.chainNameList == other.chainNameList).all()") + if not ((self.groupsPerChain == other.groupsPerChain).all()): + print("NOT (self.groupsPerChain == other.groupsPerChain).all()") + if not ((self.chainsPerModel == other.chainsPerModel).all()): + print("NOT (self.chainsPerModel == other.chainsPerModel).all()") + if not (self.bondProperties == other.bondProperties): + print("NOT self.bondProperties == other.bondProperties") + if not (self.atomProperties == other.atomProperties): + print("NOT self.atomProperties == other.atomProperties") + if not (self.groupProperties == other.groupProperties): + print("NOT self.groupProperties == other.groupProperties") + if not (self.chainProperties == other.chainProperties): + print("NOT self.chainProperties == other.chainProperties") + if not (self.modelProperties == other.modelProperties): + print("NOT self.modelProperties == other.modelProperties") + if not (self.extraProperties == other.extraProperties): + print("NOT self.extraProperties == other.extraProperties") + + def __eq__(self, other: "StructureData"): + return ( + self.mmtfVersion == other.mmtfVersion + and self.mmtfProducer == other.mmtfProducer + and (self.unitCell == other.unitCell).all() + and self.spaceGroup == other.spaceGroup + and self.structureId == other.structureId + and self.title == other.title + and self.depositionDate == other.depositionDate + and self.releaseDate == other.releaseDate + and (self.ncsOperatorList == other.ncsOperatorList).all() + and self.bioAssemblyList == other.bioAssemblyList + and self.entityList == other.entityList + and self.experimentalMethods == other.experimentalMethods + and self.resolution == other.resolution + and self.rFree == other.rFree + and self.rWork == other.rWork + and self.numBonds == other.numBonds + and self.numAtoms == other.numAtoms + and self.numGroups == other.numGroups + and self.numChains == other.numChains + and self.numModels == other.numModels + and self.groupList == other.groupList + and (self.bondAtomList == other.bondAtomList).all() + and (self.bondOrderList == other.bondOrderList).all() + and (self.bondResonanceList == other.bondResonanceList).all() + and (self.xCoordList == other.xCoordList).all() + and (self.yCoordList == other.yCoordList).all() + and (self.zCoordList == other.zCoordList).all() + and (self.bFactorList == other.bFactorList).all() + and (self.atomIdList == other.atomIdList).all() + and (self.altLocList == other.altLocList).all() + and (self.occupancyList == other.occupancyList).all() + and (self.groupIdList == other.groupIdList).all() + and (self.groupTypeList == other.groupTypeList).all() + and (self.secStructList == other.secStructList).all() + and (self.insCodeList == other.insCodeList).all() + and (self.sequenceIndexList == other.sequenceIndexList).all() + and (self.chainIdList == other.chainIdList).all() + and (self.chainNameList == other.chainNameList).all() + and (self.groupsPerChain == other.groupsPerChain).all() + and (self.chainsPerModel == other.chainsPerModel).all() + and self.bondProperties == other.bondProperties + and self.atomProperties == other.atomProperties + and self.groupProperties == other.groupProperties + and self.chainProperties == other.chainProperties + and self.modelProperties == other.modelProperties + and self.extraProperties == other.extraProperties + ) diff --git a/src/python/tests/tests.py b/src/python/tests/tests.py new file mode 100644 index 0000000..c6f7e0c --- /dev/null +++ b/src/python/tests/tests.py @@ -0,0 +1,196 @@ +from pathlib import Path +import tempfile +import unittest + +import mmtf_cppy +from mmtf_cppy import StructureData +import numpy as np + +MMTF_SPEC_DIR = Path(__file__).parent / "../../../submodules/mmtf_spec" +EXTRA_TEST_DATA_DIR = Path(__file__).parent / "../../../temporary_test_data" + + +class TestMMTF(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_eq_operator(self): + s1 = StructureData.init_from_file_name(MMTF_SPEC_DIR / "test-suite/mmtf/173D.mmtf") + s2 = StructureData.init_from_file_name(MMTF_SPEC_DIR / "test-suite/mmtf/173D.mmtf") + s3 = StructureData.init_from_file_name(MMTF_SPEC_DIR / "test-suite/mmtf/1AUY.mmtf") + assert s1 == s2 + assert s1 != s3 + + def test_roundtrip(self): + files = [ + MMTF_SPEC_DIR / "test-suite/mmtf/173D.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/1AA6.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/1AUY.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/1BNA.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/1CAG.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/1HTQ.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/1IGT.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/1L2Q.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/1LPV.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/1MSH.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/1O2F.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/1R9V.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/1SKM.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/3NJW.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/3ZYB.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/4CK4.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/4CUP.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/4OPJ.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/4P3R.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/4V5A.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/4Y60.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/5EMG.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/5ESW.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/empty-all0.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/empty-numChains1.mmtf", + MMTF_SPEC_DIR / "test-suite/mmtf/empty-numModels1.mmtf", + EXTRA_TEST_DATA_DIR / "all_canoncial.mmtf", + EXTRA_TEST_DATA_DIR / "1PEF_with_resonance.mmtf", + ] + test_tmp_mmtf_filename = Path(self.temp_dir.name) / "test_mmtf.mmtf" + for filename in files: + s1 = StructureData.init_from_file_name(filename) + s1.write_to_file(test_tmp_mmtf_filename) + s2 = StructureData.init_from_file_name(test_tmp_mmtf_filename) + s1.check_equals(s2) + assert s1 == s2 + + def test_bad_mmtf(self): + doesnt_work = [MMTF_SPEC_DIR / "test-suite/mmtf/empty-mmtfVersion99999999.mmtf"] + for filename in doesnt_work: + with self.assertRaises(RuntimeError): + _ = StructureData.init_from_file_name(filename) + + def test_various_throws(self): + working_mmtf_fn = MMTF_SPEC_DIR / "test-suite/mmtf/173D.mmtf" + temporary_file = Path(self.temp_dir.name) / "wrk.mmtf" + + sd = StructureData.init_from_file_name(working_mmtf_fn) + sd.xCoordList = np.append(sd.xCoordList, 0.334) + with self.assertRaises(RuntimeError): + sd.write_to_file(temporary_file) + + sd = StructureData.init_from_file_name(working_mmtf_fn) + sd.yCoordList = np.append(sd.yCoordList, 0.334) + with self.assertRaises(RuntimeError): + sd.write_to_file(temporary_file) + + sd = StructureData.init_from_file_name(working_mmtf_fn) + sd.zCoordList = np.append(sd.zCoordList, 0.334) + with self.assertRaises(RuntimeError): + sd.write_to_file(temporary_file) + + sd = StructureData.init_from_file_name(working_mmtf_fn) + sd.bFactorList = np.append(sd.bFactorList, 0.334) + with self.assertRaises(RuntimeError): + sd.write_to_file(temporary_file) + + sd = StructureData.init_from_file_name(working_mmtf_fn) + sd.numAtoms = 20 + with self.assertRaises(RuntimeError): + sd.write_to_file(temporary_file) + + sd = StructureData.init_from_file_name(working_mmtf_fn) + sd.chainIdList = np.append(sd.chainIdList, "xsz") + with self.assertRaises(RuntimeError): + sd.write_to_file(temporary_file) + + sd = StructureData.init_from_file_name(working_mmtf_fn) + sd.chainIdList = sd.chainIdList.astype(" c++ vector string + # def test_encodeStringVector(): + # encoded_data = b'\x00\x00\x00\x05\x00\x00\x00\x06\x00\x00\x00\x04B\x00\x00\x00A\x00\x00\x00C\x00\x00\x00A\x00\x00\x00A\x00\x00\x00A\x00\x00\x00' + # decoded_data = np.array(("B", "A", "C", "A", "A", "A")) + # ret = mmtf_cppy.encodeStringVector(decoded_data, 4) + # assert ret == encoded_data + + def test_atomProperties(self): + working_mmtf_fn = MMTF_SPEC_DIR / "test-suite/mmtf/173D.mmtf" + tmp_output_fn = Path(self.temp_dir.name) / "properties.mmtf" + sd = StructureData.init_from_file_name(working_mmtf_fn) + random_data = list(range(256)) + encoded_random_data = mmtf_cppy.encodeRunLengthDeltaInt(list(range(256))) + sd.atomProperties["256_atomColorList"] = random_data + sd.atomProperties["256_atomColorList_encoded"] = encoded_random_data + sd.write_to_file(tmp_output_fn) + sd2 = StructureData.init_from_file_name(tmp_output_fn) + assert sd2.atomProperties["256_atomColorList"] == random_data + assert sd2.atomProperties["256_atomColorList_encoded"] == encoded_random_data + assert (mmtf_cppy.decode_int32(sd2.atomProperties["256_atomColorList_encoded"]) == np.array(random_data)).all() + +if __name__ == "__main__": + unittest.main() diff --git a/submodules/Catch2 b/submodules/Catch2 deleted file mode 160000 index cf4b7ee..0000000 --- a/submodules/Catch2 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cf4b7eead92773932f32c7efd2612e9d27b07557 diff --git a/submodules/mmtf_spec b/submodules/mmtf_spec index 8c88834..e4aaae5 160000 --- a/submodules/mmtf_spec +++ b/submodules/mmtf_spec @@ -1 +1 @@ -Subproject commit 8c8883457e54fb460908a57d801212c56a603aec +Subproject commit e4aaae5d2f273d073e5482db61f74ad93f6e8ab4 diff --git a/submodules/msgpack-c b/submodules/msgpack-c deleted file mode 160000 index 7a98138..0000000 --- a/submodules/msgpack-c +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7a98138f27f27290e680bf8fbf1f8d1b089bf138 diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 77fa418..5c1bf36 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -5,22 +5,24 @@ if(EMSCRIPTEN) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -s TOTAL_MEMORY=150994944 -s DISABLE_EXCEPTION_CATCHING=0") endif() +FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2.git + GIT_TAG v3.4.0) +FetchContent_MakeAvailable(Catch2) + add_executable(mmtf_tests mmtf_tests.cpp) target_compile_features(mmtf_tests PRIVATE cxx_auto_type) -if(WIN32) - target_link_libraries(mmtf_tests Catch msgpackc MMTFcpp ws2_32) +if(EMSCRIPTEN) + target_link_libraries(mmtf_tests Catch2::Catch2 MMTFcpp) else() - target_link_libraries(mmtf_tests Catch msgpackc MMTFcpp) + target_link_libraries(mmtf_tests Catch2::Catch2WithMain MMTFcpp) endif() # test for multi-linking add_executable(multi_cpp_test multi_cpp_test.cpp multi_cpp_test_helper.cpp) target_compile_features(multi_cpp_test PRIVATE cxx_auto_type) -if(WIN32) - target_link_libraries(multi_cpp_test MMTFcpp ws2_32) -else() - target_link_libraries(multi_cpp_test MMTFcpp) -endif() +target_link_libraries(multi_cpp_test MMTFcpp) set(TEST_RUNNER "none" CACHE STRING "External runner for the tests") diff --git a/tests/mmtf_tests.cpp b/tests/mmtf_tests.cpp index d84a392..993262b 100644 --- a/tests/mmtf_tests.cpp +++ b/tests/mmtf_tests.cpp @@ -1,11 +1,9 @@ #ifdef __EMSCRIPTEN__ #define CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS -#define CATCH_CONFIG_RUNNER -#else -#define CATCH_CONFIG_MAIN #endif -#include "catch.hpp" +#include +#include #include #include @@ -17,7 +15,7 @@ template bool approx_equal_vector(const T& a, const T& b, float eps = 0.00001) { if (a.size() != b.size()) return false; for (std::size_t i=0; i < a.size(); ++i) { - if (a[i] != Approx(b[i]).margin(eps)) return false; + if (a[i] != Catch::Approx(b[i]).margin(eps)) return false; } return true; }