diff --git a/.codespell_ignore_list b/.codespell_ignore_list new file mode 100644 index 00000000..75abb4a1 --- /dev/null +++ b/.codespell_ignore_list @@ -0,0 +1,4 @@ +punctuations +Bellow +Coo +Patter diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 812fc3b1..00000000 --- a/.coveragerc +++ /dev/null @@ -1,4 +0,0 @@ -[report] -omit = - */python?.?/* - */site-packages/nose/* diff --git a/.github/environment-ci.yml b/.github/environment-ci.yml new file mode 100644 index 00000000..86a930c4 --- /dev/null +++ b/.github/environment-ci.yml @@ -0,0 +1,18 @@ +name: test +channels: + - conda-forge + - defaults +dependencies: + - pip + - numpy >=1.20.0 + - jsonschema >=4.0.1 + - pandas >=1.2.0 + - mir_eval >=0.8.2 + - matplotlib-base>=3.4.1 + - sortedcontainers >=2.1.0 + - six + - decorator + - pytest + - coverage + - pytest-cov + - pytest-mpl diff --git a/.github/environment-lint.yml b/.github/environment-lint.yml new file mode 100644 index 00000000..ba2e6838 --- /dev/null +++ b/.github/environment-lint.yml @@ -0,0 +1,21 @@ +name: lint +channels: + - conda-forge + - defaults +dependencies: + # required + - pip + - bandit + - codespell + - flake8 + - pytest + - pydocstyle + + # Dependencies for velin + - numpydoc>=1.1.0 + - sphinx>=5.1.0 + - pygments + - black + + - pip: + - velin diff --git a/.github/environment-minimal.yml b/.github/environment-minimal.yml new file mode 100644 index 00000000..8eda0212 --- /dev/null +++ b/.github/environment-minimal.yml @@ -0,0 +1,18 @@ +name: test +channels: + - conda-forge + - defaults +dependencies: + - pip + - numpy ==1.20.0 + - jsonschema ==4.0.1 + - pandas ==1.2.0 + - mir_eval ==0.8.2 + - matplotlib-base ==3.4.1 + - sortedcontainers ==2.1.0 + - six + - decorator + - pytest + - coverage + - pytest-cov + - pytest-mpl diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e8d3f4e9..cce0cb33 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: JSON Annotated Music Specification for Reproducible MIR Research +name: CI Testing on: push: @@ -8,25 +8,91 @@ on: branches: - master +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: True + jobs: - install: - runs-on: ubuntu-latest + test: + name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" + runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: - python-version: - - "3.9" - - "3.10" - - "3.11" - - "3.12" - timeout-minutes: 5 + include: + - os: ubuntu-latest + python-version: "3.9" + channel-priority: "strict" + envfile: ".github/environment-ci.yml" + + - os: ubuntu-latest + python-version: "3.10" + channel-priority: "strict" + envfile: ".github/environment-ci.yml" + + - os: ubuntu-latest + python-version: "3.11" + channel-priority: "strict" + envfile: ".github/environment-ci.yml" + + - os: ubuntu-latest + python-version: "3.12" + channel-priority: "strict" + envfile: ".github/environment-ci.yml" + + - os: ubuntu-latest + python-version: "3.13" + channel-priority: "strict" + envfile: ".github/environment-ci.yml" + + - python-version: "3.13" + os: macos-latest + channel-priority: "strict" + envfile: ".github/environment-ci.yml" + + - python-version: "3.13" + os: windows-latest + envfile: ".github/environment-ci.yml" + + - os: ubuntu-latest + python-version: "3.9" + envfile: ".github/environment-minimal.yml" + channel-priority: "flexible" + name: "Minimal dependencies" + steps: - uses: actions/checkout@v4 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + - name: Create conda environment + uses: conda-incubator/setup-miniconda@v3 with: python-version: ${{ matrix.python-version }} + auto-activate-base: false + channel-priority: ${{ matrix.channel-priority }} + environment-file: ${{ matrix.envfile }} + # Disabling bz2 to get more recent dependencies. + # NOTE: this breaks cache support, so CI will be slower. + use-only-tar-bz2: False # IMPORTANT: This needs to be set for caching to work properly! - name: Install jams - run: pip install -e .[display,tests] - - name: Run tests + shell: bash -l {0} + run: python -m pip install --upgrade-strategy=only-if-needed -e .[display,tests] + - name: Log installed packages for debugging + shell: bash -l {0} run: | - pytest -v --cov-report term-missing --cov jams + python -c "import sys; print(sys.version)" + conda info -a + conda list + + - name: Run tests + shell: bash -l {0} + run: pytest + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: ./coverage.xml + flags: unittests + env_vars: OS,PYTHON + name: codecov-umbrella + fail_ci_if_error: true + verbose: true diff --git a/.github/workflows/lint_python.yml b/.github/workflows/lint_python.yml new file mode 100644 index 00000000..c923da53 --- /dev/null +++ b/.github/workflows/lint_python.yml @@ -0,0 +1,71 @@ +name: lint_python +on: [pull_request, push] +jobs: + lint_python: + name: "Lint and code analysis" + runs-on: ubuntu-latest + strategy: + fail-fast: true + matrix: + include: + - os: ubuntu-latest + python-version: "3.11" + channel-priority: "flexible" + envfile: ".github/environment-lint.yml" + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Cache conda + uses: actions/cache@v4 + env: + CACHE_NUMBER: 0 + with: + path: ~/conda_pkgs_dir + key: ${{ runner.os }}-${{ matrix.python-version }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles( matrix.envfile ) }} + - name: Install conda environmnent + uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: false + python-version: ${{ matrix.python-version }} + add-pip-as-python-dependency: true + auto-activate-base: false + activate-environment: lint + # mamba-version: "*" + channel-priority: ${{ matrix.channel-priority }} + environment-file: ${{ matrix.envfile }} + use-only-tar-bz2: false + + - name: Conda info + shell: bash -l {0} + run: | + conda info -a + conda list + + - name: Spell check package + shell: bash -l {0} + run: codespell --ignore-words .codespell_ignore_list jams + + - name: Security check + shell: bash -l {0} + run: bandit --recursive --skip B101,B110 . + + - name: Style check package + shell: bash -l {0} + run: python -m flake8 jams + + - name: Format check package + shell: bash -l {0} + run: python -m black --check jams + + - name: Format check tests + shell: bash -l {0} + run: python -m black --check tests + + - name: Docstring check + shell: bash -l {0} + run: python -m velin --check jams + + - name: Docstring style check + shell: bash -l {0} + run: python -m pydocstyle jams diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..ebf3637f --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,43 @@ +name: Publish Python 🐍 distributions 📦 to PyPI + +on: + release: + types: [created] + + +jobs: + pypi-publish: + name: Upload release to PyPI + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/jams + permissions: + id-token: write # IMPORTANT: this permission is mandatory for trusted publishing + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: 3.12 + + - name: Install pypa/build + run: >- + python -m + pip install + build + --user + + - name: Build a binary wheel and a source tarball + run: >- + python -m + build + --sdist + --wheel + --outdir dist/ + . + + - name: Publish package distributions to PyPI + if: startsWith(github.ref, 'refs/tags') + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index 4fe082e3..43cab823 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ pip-log.txt # Unit test / coverage reports .coverage +coverage.xml .tox #Translations diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9a9f68e4..cceccffc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -41,48 +41,6 @@ review. This will send an email to the committers. (If any of the above seems like magic to you, then look up the [Git documentation](http://git-scm.com/documentation) on the web.) -It is recommended to check that your contribution complies with the -following rules before submitting a pull request: - -- All public methods should have informative docstrings with sample - usage presented. - -You can also check for common programming errors with the following -tools: - -- Code with good unittest coverage (at least 80%), check with: - - $ pip install nose coverage - $ nosetests --with-coverage --cover-package=jams -w jams/tests/ - -- No pyflakes warnings, check with: - - $ pip install pyflakes - $ pyflakes path/to/module.py - -- No PEP8 warnings, check with: - - $ pip install pep8 - $ pep8 path/to/module.py - -- AutoPEP8 can help you fix some of the easy redundant errors: - - $ pip install autopep8 - $ autopep8 path/to/pep8.py - - -Documentation -------------- - -You can edit the documentation using any text editor and then generate -the HTML output by typing ``make html`` from the docs/ directory. -The resulting HTML files will be placed in _build/html/ and are viewable -in a web browser. See the README file in the doc/ directory for more information. - -For building the documentation, you will need -[sphinx](http://sphinx.pocoo.org/), -[matplotlib](http://matplotlib.sourceforge.net/), and [scikit-learn](http://scikit-learn.org/). - Note ---- diff --git a/README.md b/README.md index 7adbbf0f..15904a9f 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ jams ==== [![PyPI](https://img.shields.io/pypi/v/jams.svg)](https://pypi.python.org/pypi/jams) [![License](https://img.shields.io/pypi/l/jams.svg)](https://github.com/marl/jams/blob/master/LICENSE.md) -[![Build Status](https://travis-ci.org/marl/jams.svg?branch=master)](https://travis-ci.org/marl/jams) -[![Coverage Status](https://coveralls.io/repos/marl/jams/badge.svg?branch=master)](https://coveralls.io/r/marl/jams?branch=master) +[![Tests](https://github.com/marl/jams/actions/workflows/ci.yml/badge.svg)](https://github.com/marl/jams/actions/workflows/ci.yml) +[![codecov](https://codecov.io/gh/marl/jams/graph/badge.svg?token=dRCLnaNTTO)](https://codecov.io/gh/marl/jams) [![Dependency Status](https://dependencyci.com/github/marl/jams/badge)](https://dependencyci.com/github/marl/jams) A JSON Annotated Music Specification for Reproducible MIR Research. @@ -20,7 +20,7 @@ We provide: * The ability to store multiple annotations per file * Schema definitions for a wide range of annotation types (beats, chords, segments, tags, etc.) * Error detection and validation for annotations -* A translation layer to interface with [mir eval](https://craffel.github.io/mir_eval) +* A translation layer to interface with [mir eval](https://mir-eval.readthedocs.io/latest/) for evaluating annotations Why diff --git a/jams/__init__.py b/jams/__init__.py index 1da10ca8..304029c4 100644 --- a/jams/__init__.py +++ b/jams/__init__.py @@ -23,14 +23,21 @@ if sys.version_info < (3, 10): from pkg_resources import resource_filename - for _ in util.find_with_extension(resource_filename(__name__, schema.NS_SCHEMA_DIR), 'json'): + for _ in util.find_with_extension( + resource_filename(__name__, schema.NS_SCHEMA_DIR), "json" + ): schema.add_namespace(_) else: - for ns in chain(*map(lambda p: p.rglob('*.json'), resources.files('jams.schemata.namespaces').iterdir())): + for ns in chain( + *map( + lambda p: p.rglob("*.json"), + resources.files("jams.schemata.namespaces").iterdir(), + ) + ): schema.add_namespace(ns) # Populate local namespaces -if 'JAMS_SCHEMA_DIR' in os.environ: - for ns in util.find_with_extension(os.environ['JAMS_SCHEMA_DIR'], 'json'): +if "JAMS_SCHEMA_DIR" in os.environ: + for ns in util.find_with_extension(os.environ["JAMS_SCHEMA_DIR"], "json"): schema.add_namespace(ns) diff --git a/jams/core.py b/jams/core.py index 2f0160ee..6b9dbcc1 100644 --- a/jams/core.py +++ b/jams/core.py @@ -3,7 +3,7 @@ ------------------ This library provides an interface for reading JAMS into Python, or creating -them programatically. +them programmatically. .. currentmodule:: jams @@ -53,30 +53,37 @@ from .exceptions import JamsError, SchemaError, ParameterError -__all__ = ['load', - 'JObject', 'Sandbox', - 'Annotation', 'Curator', 'AnnotationMetadata', - 'FileMetadata', 'AnnotationArray', 'JAMS', - 'Observation'] +__all__ = [ + "load", + "JObject", + "Sandbox", + "Annotation", + "Curator", + "AnnotationMetadata", + "FileMetadata", + "AnnotationArray", + "JAMS", + "Observation", +] def deprecated(version, version_removed): - '''This is a decorator which can be used to mark functions - as deprecated. + """Decorate a function to be marked as deprecated. - It will result in a warning being emitted when the function is used.''' + It will result in a warning being emitted when the function is used. + """ def __wrapper(func, *args, **kwargs): - '''Warn the user, and then proceed.''' + """Warn the user, and then proceed.""" code = six.get_function_code(func) warnings.warn_explicit( "{:s}.{:s}\n\tDeprecated as of JAMS version {:s}." - "\n\tIt will be removed in JAMS version {:s}." - .format(func.__module__, func.__name__, - version, version_removed), + "\n\tIt will be removed in JAMS version {:s}.".format( + func.__module__, func.__name__, version, version_removed + ), category=DeprecationWarning, filename=code.co_filename, - lineno=code.co_firstlineno + 1 + lineno=code.co_firstlineno + 1, ) return func(*args, **kwargs) @@ -84,8 +91,8 @@ def __wrapper(func, *args, **kwargs): @contextlib.contextmanager -def _open(name_or_fdesc, mode='r', fmt='auto'): - '''An intelligent wrapper for ``open``. +def _open(name_or_fdesc, mode="r", fmt="auto"): + """Provide intelligent wrapping for ``open``. Parameters ---------- @@ -105,26 +112,21 @@ def _open(name_or_fdesc, mode='r', fmt='auto'): Otherwise, use the specified coding. - See Also -------- open gzip.open - ''' - - open_map = {'jams': open, - 'json': open, - 'jamz': gzip.open, - 'gz': gzip.open} + """ + open_map = {"jams": open, "json": open, "jamz": gzip.open, "gz": gzip.open} # If we've been given an open descriptor, do the right thing - if hasattr(name_or_fdesc, 'read') or hasattr(name_or_fdesc, 'write'): + if hasattr(name_or_fdesc, "read") or hasattr(name_or_fdesc, "write"): yield name_or_fdesc elif isinstance(name_or_fdesc, six.string_types): # Infer the opener from the extension - if fmt == 'auto': + if fmt == "auto": _, ext = os.path.splitext(name_or_fdesc) # Pull off the extension separator @@ -136,26 +138,25 @@ def _open(name_or_fdesc, mode='r', fmt='auto'): ext = ext.lower() # Force text mode if we're using gzip - if ext in ['jamz', 'gz'] and 't' not in mode: - mode = '{:s}t'.format(mode) + if ext in ["jamz", "gz"] and "t" not in mode: + mode = "{:s}t".format(mode) with open_map[ext](name_or_fdesc, mode=mode) as fdesc: yield fdesc except KeyError: - raise ParameterError('Unknown JAMS extension ' - 'format: "{:s}"'.format(ext)) + raise ParameterError("Unknown JAMS extension " 'format: "{:s}"'.format(ext)) else: # Don't know how to handle this. Raise a parameter error - raise ParameterError('Invalid filename or ' - 'descriptor: {}'.format(name_or_fdesc)) + raise ParameterError( + "Invalid filename or " "descriptor: {}".format(name_or_fdesc) + ) -def load(path_or_file, validate=True, strict=True, fmt='auto'): +def load(path_or_file, validate=True, strict=True, fmt="auto"): r"""Load a JAMS Annotation from a file. - Parameters ---------- path_or_file : str or file-like @@ -177,25 +178,21 @@ def load(path_or_file, validate=True, strict=True, fmt='auto'): If the input is an open file handle, `jams` encoding is used. - Returns ------- jam : JAMS The loaded JAMS object - Raises ------ SchemaError if `validate == True`, `strict==True`, and validation fails - - See also + See Also -------- JAMS.validate JAMS.save - Examples -------- >>> # Load a jams object from a file name @@ -208,8 +205,7 @@ def load(path_or_file, validate=True, strict=True, fmt='auto'): >>> # No validation at all >>> J = jams.load('data.jams', validate=False) """ - - with _open(path_or_file, mode='r', fmt=fmt) as fdesc: + with _open(path_or_file, mode="r", fmt=fmt) as fdesc: jam = JAMS(**json.load(fdesc)) if validate: @@ -227,12 +223,13 @@ class JObject(object): By setting the `type` attribute to a defined schema entry, only the fields allowed by the schema are permitted as attributes. """ + def __init__(self, **kwargs): - '''Construct a new JObject + """Construct a new JObject Parameters ---------- - kwargs + **kwargs Each keyword argument becomes an attribute with the specified value Examples @@ -242,7 +239,7 @@ def __init__(self, **kwargs): 5 >>> dict(J) {'foo': 5} - ''' + """ super(JObject, self).__init__() for name, value in six.iteritems(kwargs): @@ -250,13 +247,13 @@ def __init__(self, **kwargs): @property def __schema__(self): - '''The schema definition for this JObject, if it exists. + """The schema definition for this JObject, if it exists. Returns ------- schema : dict or None - ''' - return schema.JAMS_SCHEMA['definitions'].get(self.type, None) + """ + return schema.JAMS_SCHEMA["definitions"].get(self.type, None) @property def __json__(self): @@ -267,10 +264,10 @@ def __json__(self): filtered_dict = dict() for k, item in six.iteritems(self.__dict__): - if k.startswith('_'): + if k.startswith("_"): continue - if hasattr(item, '__json__'): + if hasattr(item, "__json__"): filtered_dict[k] = item.__json__ else: filtered_dict[k] = serialize_obj(item) @@ -283,10 +280,11 @@ def __json_init__(cls, **kwargs): return cls(**kwargs) def __eq__(self, other): - return (isinstance(other, self.__class__) and - (self.__dict__ == other.__dict__)) + """Equality operator for JObject""" + return isinstance(other, self.__class__) and (self.__dict__ == other.__dict__) def __nonzero__(self): + """Return True if the JObject has any attributes""" return bool(self.__json__) def __getitem__(self, key): @@ -294,35 +292,37 @@ def __getitem__(self, key): return self.__dict__[key] def __setattr__(self, name, value): + """Set an attribute on the JObject.""" if self.__schema__ is not None: - props = self.__schema__['properties'] + props = self.__schema__["properties"] if name not in props: - raise SchemaError("Attribute {} not in {}" - .format(name, props.keys())) + raise SchemaError("Attribute {} not in {}".format(name, props.keys())) self.__dict__[name] = value def __contains__(self, key): + """Dict-style interface""" return key in self.__dict__ def __len__(self): + """Return the number of attributes in the JObject.""" return len(self.keys()) def __repr__(self): """Render the object alongside its attributes.""" indent = len(self.type) + 2 - jstr = ',\n' + ' ' * indent + jstr = ",\n" + " " * indent props = self._display_properties() - params = jstr.join('{:}={:}'.format(p, summary(self[p], - indent=indent)) - for (p, dp) in props) - return '<{}({:})>'.format(self.type, params) + params = jstr.join( + "{:}={:}".format(p, summary(self[p], indent=indent)) for (p, dp) in props + ) + return "<{}({:})>".format(self.type, params) def _display_properties(self): - '''Returns a list of tuples (key, display_name) - for properties of this object''' - + """Return a list of tuples (key, display_name) + for properties of this object + """ return sorted([(k, k) for k in self.__dict__]) def _repr_html_(self): @@ -330,24 +330,23 @@ def _repr_html_(self): props = self._display_properties() if not props: - return '' + return "" out = '
' - for (prop, dprop) in props: + for prop, dprop in props: content = summary_html(self[prop]) - prop_class = 'default' + prop_class = "default" if not content: - prop_class = 'danger' + prop_class = "danger" out += '
'.format(prop_class) - if (isinstance(self[prop], (JObject, AnnotationArray, dict)) - and content): + if isinstance(self[prop], (JObject, AnnotationArray, dict)) and content: # These classes should have collapses div_id = _get_divid(self[prop]) - out += r'''""" if content: - out += r'''
{1}
-
'''.format(div_id, content) +
""".format( + div_id, content + ) else: - out += r'''
+ out += r"""
{}  {} -
'''.format(dprop, content) - out += '
' - out += '
' + """.format( + dprop, content + ) + out += "" + out += "" return out def __summary__(self): - return '<{}(...)>'.format(self.type) + """Return a summary of the JObject.""" + return "<{}(...)>".format(self.type) def __str__(self): + """Return a JSON string representation of the JObject.""" return json.dumps(self.__json__, indent=2) def dumps(self, **kwargs): - '''Serialize the JObject to a string. + """Serialize the JObject to a string. Parameters ---------- - kwargs + **kwargs Keyword arguments to json.dumps Returns @@ -414,7 +423,7 @@ def dumps(self, **kwargs): >>> J.dumps() '{"foo": 5, "bar": "baz"}' - ''' + """ return json.dumps(self.__json__, **kwargs) def keys(self): @@ -434,11 +443,11 @@ def keys(self): return self.__dict__.keys() def update(self, **kwargs): - '''Update the attributes of a JObject. + """Update the attributes of a JObject. Parameters ---------- - kwargs + **kwargs Keyword arguments of the form `attribute=new_value` Examples @@ -449,18 +458,18 @@ def update(self, **kwargs): >>> J.update(bar='baz') >>> J.dumps() '{"foo": 5, "bar": "baz"}' - ''' + """ for name, value in six.iteritems(kwargs): setattr(self, name, value) @property def type(self): - '''The type (class name) of a derived JObject type''' + """The type (class name) of a derived JObject type""" return self.__class__.__name__ @classmethod def loads(cls, string): - '''De-serialize a JObject + """De-serialize a JObject Parameters ---------- @@ -484,15 +493,15 @@ def loads(cls, string): '{"foo": 5, "bar": "baz"}' >>> jams.JObject.loads(J.dumps()) - ''' + """ return cls.__json_init__(**json.loads(string)) def search(self, **kwargs): - '''Query this object (and its descendants). + """Query this object (and its descendants). Parameters ---------- - kwargs + **kwargs Each `(key, value)` pair encodes a search field in `key` and a target value in `value`. @@ -528,8 +537,7 @@ def search(self, **kwargs): True >>> J.search(foo=lambda x: x > 10) False - ''' - + """ match = False r_query = {} @@ -559,7 +567,7 @@ def search(self, **kwargs): return match def validate(self, strict=True): - '''Validate a JObject against its schema + """Validate a JObject against its schema Parameters ---------- @@ -576,8 +584,7 @@ def validate(self, strict=True): ------ SchemaError If `strict==True` and `jam` fails validation - ''' - + """ valid = True try: @@ -594,24 +601,32 @@ def validate(self, strict=True): return valid -Observation = namedtuple('Observation', - ['time', 'duration', 'value', 'confidence']) -'''Core observation type: (time, duration, value, confidence).''' +Observation = namedtuple("Observation", ["time", "duration", "value", "confidence"]) +"""Core observation type: (time, duration, value, confidence).""" class Sandbox(JObject): """Sandbox (unconstrained) Functionally identical to JObjects, but the class hierarchy might be - confusing if all objects inherit from Sandboxes.""" + confusing if all objects inherit from Sandboxes. + """ + pass class Annotation(JObject): """Annotation base class.""" - def __init__(self, namespace, data=None, annotation_metadata=None, - sandbox=None, time=0, duration=None): + def __init__( + self, + namespace, + data=None, + annotation_metadata=None, + sandbox=None, + time=0, + duration=None, + ): """Create an Annotation. Note that, if an argument is None, an empty Annotation is created in @@ -622,23 +637,17 @@ def __init__(self, namespace, data=None, annotation_metadata=None, ---------- namespace : str The namespace for this annotation - data : dict of lists, list of dicts, or list of Observations Data for the new annotation - annotation_metadata : AnnotationMetadata (or dict), default=None. Metadata corresponding to this Annotation. - sandbox : Sandbox (dict), default=None Miscellaneous information; keep to native datatypes if possible. - time : non-negative number The starting time for this annotation - duration : non-negative number The duration of this annotation """ - super(Annotation, self).__init__() if annotation_metadata is None: @@ -665,22 +674,27 @@ def __init__(self, namespace, data=None, annotation_metadata=None, self.duration = duration def _display_properties(self): - return [('namespace', 'Namespace'), - ('time', 'Time'), - ('duration', 'Duration'), - ('annotation_metadata', 'Annotation metadata'), - ('data', 'Data'), - ('sandbox', 'Sandbox')] + return [ + ("namespace", "Namespace"), + ("time", "Time"), + ("duration", "Duration"), + ("annotation_metadata", "Annotation metadata"), + ("data", "Data"), + ("sandbox", "Sandbox"), + ] def append(self, time=None, duration=None, value=None, confidence=None): - '''Append an observation to the data field + """Append an observation to the data field Parameters ---------- time : float >= 0 + duration : float >= 0 The time and duration of the new observation, in seconds + value + confidence The value and confidence of the new observations. @@ -691,15 +705,18 @@ def append(self, time=None, duration=None, value=None, confidence=None): -------- >>> ann = jams.Annotation(namespace='chord') >>> ann.append(time=3, duration=2, value='E#') - ''' - - self.data.add(Observation(time=float(time), - duration=float(duration), - value=value, - confidence=confidence)) + """ + self.data.add( + Observation( + time=float(time), + duration=float(duration), + value=value, + confidence=confidence, + ) + ) def append_records(self, records): - '''Add observations from row-major storage. + """Add observations from row-major storage. This is primarily useful for deserializing sparsely packed data. @@ -707,7 +724,7 @@ def append_records(self, records): ---------- records : iterable of dicts or Observations Each element of `records` corresponds to one observation. - ''' + """ for obs in records: if isinstance(obs, Observation): self.append(**obs._asdict()) @@ -715,7 +732,7 @@ def append_records(self, records): self.append(**obs) def append_columns(self, columns): - '''Add observations from column-major storage. + """Add observations from column-major storage. This is primarily used for deserializing densely packed data. @@ -725,16 +742,21 @@ def append_columns(self, columns): Keys must be `time, duration, value, confidence`, and each much be a list of equal length. - ''' - self.append_records([dict(time=t, duration=d, value=v, confidence=c) - for (t, d, v, c) - in six.moves.zip(columns['time'], - columns['duration'], - columns['value'], - columns['confidence'])]) + """ + self.append_records( + [ + dict(time=t, duration=d, value=v, confidence=c) + for (t, d, v, c) in six.moves.zip( + columns["time"], + columns["duration"], + columns["value"], + columns["confidence"], + ) + ] + ) def validate(self, strict=True): - '''Validate this annotation object against the JAMS schema, + """Validate this annotation object against the JAMS schema, and its data against the namespace schema. Parameters @@ -758,16 +780,16 @@ def validate(self, strict=True): See Also -------- JObject.validate - ''' - + """ # Get the schema for this annotation ann_schema = schema.namespace_array(self.namespace) valid = True try: - schema.VALIDATOR.validate(self.__json_light__(data=False), - schema.JAMS_SCHEMA) + schema.VALIDATOR.validate( + self.__json_light__(data=False), schema.JAMS_SCHEMA + ) # validate each record in the frame data_ser = [serialize_obj(obs) for obs in self.data] @@ -783,7 +805,7 @@ def validate(self, strict=True): return valid def trim(self, start_time, end_time, strict=False): - ''' + """ Trim the annotation and return as a new `Annotation` object. Trimming will result in the new annotation only containing observations @@ -864,11 +886,10 @@ def trim(self, start_time, end_time, strict=False): >>> ann_trim_strict.to_dataframe() time duration value confidence 0 6 2 three None - ''' + """ # Check for basic start_time and end_time validity if end_time <= start_time: - raise ParameterError( - 'end_time must be greater than start_time.') + raise ParameterError("end_time must be greater than start_time.") # If the annotation does not have a set duration value, we'll assume # trimming is possible (up to the user to ensure this is valid). @@ -878,7 +899,8 @@ def trim(self, start_time, end_time, strict=False): warnings.warn( "Annotation.duration is not defined, cannot check " "for temporal intersection, assuming the annotation " - "is valid between start_time and end_time.") + "is valid between start_time and end_time." + ) else: orig_time = self.time orig_duration = self.duration @@ -888,9 +910,10 @@ def trim(self, start_time, end_time, strict=False): # appropriately. if start_time > (orig_time + orig_duration) or (end_time < orig_time): warnings.warn( - 'Time range defined by [start_time,end_time] does not ' - 'intersect with the time range spanned by this annotation, ' - 'the trimmed annotation will be empty.') + "Time range defined by [start_time,end_time] does not " + "intersect with the time range spanned by this annotation, " + "the trimmed annotation will be empty." + ) trim_start = self.time trim_end = trim_start else: @@ -905,7 +928,8 @@ def trim(self, start_time, end_time, strict=False): annotation_metadata=self.annotation_metadata, sandbox=self.sandbox, time=trim_start, - duration=trim_end - trim_start) + duration=trim_end - trim_start, + ) # Selectively add observations based on their start time / duration # We do this rather than copying and directly manipulating the @@ -917,32 +941,47 @@ def trim(self, start_time, end_time, strict=False): obs_end = obs_start + obs.duration # Special-case here handles duration=0 as a closed interval - if obs_start < trim_end and (obs_end > trim_start or obs_start == obs_end >= trim_start): + if obs_start < trim_end and ( + obs_end > trim_start or obs_start == obs_end >= trim_start + ): new_start = max(obs_start, trim_start) new_end = min(obs_end, trim_end) new_duration = new_end - new_start - if ((not strict) or - (new_start == obs_start and new_end == obs_end)): - ann_trimmed.append(time=new_start, - duration=new_duration, - value=obs.value, - confidence=obs.confidence) + if (not strict) or (new_start == obs_start and new_end == obs_end): + ann_trimmed.append( + time=new_start, + duration=new_duration, + value=obs.value, + confidence=obs.confidence, + ) - if 'trim' not in ann_trimmed.sandbox.keys(): + if "trim" not in ann_trimmed.sandbox.keys(): ann_trimmed.sandbox.update( - trim=[{'start_time': start_time, 'end_time': end_time, - 'trim_start': trim_start, 'trim_end': trim_end}]) + trim=[ + { + "start_time": start_time, + "end_time": end_time, + "trim_start": trim_start, + "trim_end": trim_end, + } + ] + ) else: ann_trimmed.sandbox.trim.append( - {'start_time': start_time, 'end_time': end_time, - 'trim_start': trim_start, 'trim_end': trim_end}) + { + "start_time": start_time, + "end_time": end_time, + "trim_start": trim_start, + "trim_end": trim_end, + } + ) return ann_trimmed def slice(self, start_time, end_time, strict=False): - ''' + """ Slice the annotation and return as a new `Annotation` object. Slicing has the same effect as trimming (see `Annotation.trim`) except @@ -1013,7 +1052,7 @@ def slice(self, start_time, end_time, strict=False): >>> ann_slice_strict.to_dataframe() time duration value confidence 0 1.0 2.0 three None - ''' + """ # start by trimming the annotation sliced_ann = self.trim(start_time, end_time, strict=strict) raw_data = sliced_ann.pop_data() @@ -1027,23 +1066,37 @@ def slice(self, start_time, end_time, strict=False): # duration doesn't change # if obs.time < start_time, # duration shrinks by start_time - obs.time - sliced_ann.append(time=new_time, - duration=obs.duration, - value=obs.value, - confidence=obs.confidence) + sliced_ann.append( + time=new_time, + duration=obs.duration, + value=obs.value, + confidence=obs.confidence, + ) ref_time = sliced_ann.time slice_start = ref_time slice_end = ref_time + sliced_ann.duration - if 'slice' not in sliced_ann.sandbox.keys(): + if "slice" not in sliced_ann.sandbox.keys(): sliced_ann.sandbox.update( - slice=[{'start_time': start_time, 'end_time': end_time, - 'slice_start': slice_start, 'slice_end': slice_end}]) + slice=[ + { + "start_time": start_time, + "end_time": end_time, + "slice_start": slice_start, + "slice_end": slice_end, + } + ] + ) else: sliced_ann.sandbox.slice.append( - {'start_time': start_time, 'end_time': end_time, - 'slice_start': slice_start, 'slice_end': slice_end}) + { + "start_time": start_time, + "end_time": end_time, + "slice_start": slice_start, + "slice_end": slice_end, + } + ) # Update the timing for the sliced annotation sliced_ann.time = max(0, ref_time - start_time) @@ -1051,20 +1104,19 @@ def slice(self, start_time, end_time, strict=False): return sliced_ann def pop_data(self): - '''Replace this observation's data with a fresh container. + """Replace this observation's data with a fresh container. Returns ------- annotation_data : SortedKeyList The original annotation data container - ''' - + """ data = self.data self.data = SortedKeyList(key=self._key) return data def to_interval_values(self): - '''Extract observation data in a `mir_eval`-friendly format. + """Extract observation data in a `mir_eval`-friendly format. Returns ------- @@ -1072,11 +1124,9 @@ def to_interval_values(self): Start- and end-times of all valued intervals `intervals[i, :] = [time[i], time[i] + duration[i]]` - labels : list List view of value field. - ''' - + """ ints, vals = [], [] for obs in self.data: ints.append([obs.time, obs.time + obs.duration]) @@ -1088,16 +1138,15 @@ def to_interval_values(self): return np.array(ints), vals def to_event_values(self): - '''Extract observation data in a `mir_eval`-friendly format. + """Extract observation data in a `mir_eval`-friendly format. Returns ------- times : np.ndarray [shape=(n,), dtype=float] Start-time of all observations - labels : list List view of value field. - ''' + """ ints, vals = [], [] for obs in self.data: ints.append(obs.time) @@ -1106,7 +1155,7 @@ def to_event_values(self): return np.array(ints), vals def to_dataframe(self): - '''Convert this annotation to a pandas dataframe. + """Convert this annotation to a pandas dataframe. Returns ------- @@ -1114,19 +1163,18 @@ def to_dataframe(self): Columns are `time, duration, value, confidence`. Each row is an observation, and rows are sorted by ascending `time`. - ''' - return pd.DataFrame.from_records(list(self.data), - columns=['time', 'duration', - 'value', 'confidence']) + """ + return pd.DataFrame.from_records( + list(self.data), columns=["time", "duration", "value", "confidence"] + ) def to_samples(self, times, confidence=False): - '''Sample the annotation at specified times. + """Sample the annotation at specified times. Parameters ---------- times : np.ndarray, non-negative, ndim=1 The times (in seconds) to sample the annotation - confidence : bool If `True`, return both values and confidences. If `False` (default) only return values. @@ -1136,13 +1184,12 @@ def to_samples(self, times, confidence=False): values : list `values[i]` is a list of observation values for intervals that cover `times[i]`. - confidence : list (optional) `confidence` values corresponding to `values` - ''' + """ times = np.asarray(times) if times.ndim != 1 or np.any(times < 0): - raise ParameterError('times must be 1-dimensional and non-negative') + raise ParameterError("times must be 1-dimensional and non-negative") idx = np.argsort(times) samples = times[idx] @@ -1152,7 +1199,7 @@ def to_samples(self, times, confidence=False): for obs in self.data: start = np.searchsorted(samples, obs.time) - end = np.searchsorted(samples, obs.time + obs.duration, side='right') + end = np.searchsorted(samples, obs.time + obs.duration, side="right") for i in range(start, end): values[idx[i]].append(obs.value) confidences[idx[i]].append(obs.confidence) @@ -1163,21 +1210,22 @@ def to_samples(self, times, confidence=False): return values def __iter__(self): + """Iterate over the observations in this annotation.""" return iter(self.data) def to_html(self, max_rows=None): - '''Render this annotation list in HTML + """Render this annotation list in HTML Returns ------- rendered : str An HTML table containing this annotation's data. - ''' + """ n = len(self.data) div_id = _get_divid(self) - out = r'''
+ out = r"""
'''.format(div_id, self.namespace, n) +
""".format( + div_id, self.namespace, n + ) - out += r'''
-
'''.format(div_id) +
""".format( + div_id + ) - out += r'''
+ out += r"""
{} -
'''.format(self.annotation_metadata._repr_html_()) - out += r'''
+
""".format( + self.annotation_metadata._repr_html_() + ) + out += r"""
{} -
'''.format(self.sandbox._repr_html_()) +
""".format( + self.sandbox._repr_html_() + ) # -- Annotation content starts here - out += r'''
+ out += r"""
@@ -1213,53 +1269,58 @@ def to_html(self, max_rows=None): - '''.format(self.namespace, n) + """.format( + self.namespace, n + ) - out += r'''''' + out += r"""""" if max_rows is None or n <= max_rows: out += self._fmt_rows(0, n) else: - out += self._fmt_rows(0, max_rows//2) - out += r''' + out += self._fmt_rows(0, max_rows // 2) + out += r""" - ''' - out += self._fmt_rows(n-max_rows//2, n) + """ + out += self._fmt_rows(n - max_rows // 2, n) - out += r'''''' + out += r"""""" - out += r'''
value confidence
... ... ... ... ...
''' + out += r"""
""" - out += r'''
''' + out += r"""""" return out def _fmt_rows(self, start, end): - out = '' + out = "" for i, obs in enumerate(self.data[start:end], start): - out += r''' + out += r""" {:d} {:0.3f} {:0.3f} {:} {:} - '''.format(i, - obs.time, - obs.duration, - summary_html(obs.value), - summary_html(obs.confidence)) + """.format( + i, + obs.time, + obs.duration, + summary_html(obs.value), + summary_html(obs.confidence), + ) return out def _repr_html_(self, max_rows=25): - '''Render annotation as HTML. See also: `to_html()`''' + """Render annotation as HTML. See also: `to_html()`""" return self.to_html(max_rows=max_rows) @property def __json__(self): + """Return a JSON-serializable representation of this object.""" return self.__json_light__(data=True) def __json_light__(self, data=True): @@ -1270,15 +1331,15 @@ def __json_light__(self, data=True): filtered_dict = dict() for k, item in six.iteritems(self.__dict__): - if k.startswith('_'): + if k.startswith("_"): continue - elif k == 'data': + elif k == "data": if data: filtered_dict[k] = self.__json_data__ else: filtered_dict[k] = [] - elif hasattr(item, '__json__'): + elif hasattr(item, "__json__"): filtered_dict[k] = item.__json__ else: filtered_dict[k] = item @@ -1304,9 +1365,9 @@ def __json_data__(self): @classmethod def _key(cls, obs): - '''Provides sorting index for Observation objects''' + """Provide sorting index for Observation objects""" if not isinstance(obs, Observation): - raise JamsError('{} must be of type jams.Observation'.format(obs)) + raise JamsError("{} must be of type jams.Observation".format(obs)) return obs.time @@ -1316,15 +1377,15 @@ class Curator(JObject): Container object for curator metadata. """ - def __init__(self, name='', email=''): + + def __init__(self, name="", email=""): """Create a Curator. Parameters ---------- - name: str, default='' + name : str, default='' Common name of the curator. - - email: str, default='' + email : str, default='' An email address corresponding to the curator. """ super(Curator, self).__init__() @@ -1332,7 +1393,7 @@ def __init__(self, name='', email=''): self.email = email def _display_properties(self): - return [('name', 'Name'), ('email', 'Email')] + return [("name", "Name"), ("email", "Email")] class AnnotationMetadata(JObject): @@ -1340,37 +1401,39 @@ class AnnotationMetadata(JObject): Data structure for metadata corresponding to a specific annotation. """ - def __init__(self, curator=None, version='', corpus='', annotator=None, - annotation_tools='', annotation_rules='', validation='', - data_source=''): + + def __init__( + self, + curator=None, + version="", + corpus="", + annotator=None, + annotation_tools="", + annotation_rules="", + validation="", + data_source="", + ): """Create an AnnotationMetadata object. Parameters ---------- - curator: Curator, default=None + curator : Curator, default=None Object documenting a name and email address for the person of correspondence. - - version: string, default='' + version : string, default='' Version of this annotation. - - annotator: dict, default=None + annotator : dict, default=None Sandbox for information about the specific annotator, such as musical experience, skill level, principal instrument, etc. - - corpus: str, default='' + corpus : str, default='' Collection assignment. - - annotation_tools: str, default='' + annotation_tools : str, default='' Description of the tools used to create the annotation. - - annotation_rules: str, default='' + annotation_rules : str, default='' Description of the rules provided to the annotator. - - validation: str, default='' + validation : str, default='' Methods for validating the integrity of the data. - - data_source: str, default='' + data_source : str, default='' Description of where the data originated, e.g. 'Manual Annotation'. """ super(AnnotationMetadata, self).__init__() @@ -1392,40 +1455,45 @@ def __init__(self, curator=None, version='', corpus='', annotator=None, self.data_source = data_source def _display_properties(self): - return [('annotator', 'Annotator'), - ('version', 'Version'), - ('corpus', 'Corpus'), - ('curator', 'Curator'), - ('annotation_tools', 'Annotation tools'), - ('annotation_rules', 'Annotation rules'), - ('data_source', 'Data source'), - ('validation', 'Validation')] + return [ + ("annotator", "Annotator"), + ("version", "Version"), + ("corpus", "Corpus"), + ("curator", "Curator"), + ("annotation_tools", "Annotation tools"), + ("annotation_rules", "Annotation rules"), + ("data_source", "Data source"), + ("validation", "Validation"), + ] class FileMetadata(JObject): """Metadata for a given audio file.""" - def __init__(self, title='', artist='', release='', duration=None, - identifiers=None, jams_version=None): + + def __init__( + self, + title="", + artist="", + release="", + duration=None, + identifiers=None, + jams_version=None, + ): """Create a file-level Metadata object. Parameters ---------- - title: str + title : str Name of the recording. - - artist: str + artist : str Name of the artist / musician. - - release: str + release : str Name of the release - - duration: number >= 0 + duration : number >= 0 Time duration of the file, in seconds. - identifiers : jams.Sandbox Sandbox of identifier keys (eg, musicbrainz ids) - - jams_version: str + jams_version : str Version of the JAMS Schema. """ super(FileMetadata, self).__init__() @@ -1444,12 +1512,14 @@ def __init__(self, title='', artist='', release='', duration=None, self.jams_version = jams_version def _display_properties(self): - return [('artist', 'Artist'), - ('title', 'Title'), - ('release', 'Release'), - ('duration', 'Duration (s)'), - ('jams_version', 'JAMS version'), - ('identifiers', 'Identifiers')] + return [ + ("artist", "Artist"), + ("title", "Title"), + ("release", "Release"), + ("duration", "Duration (s)"), + ("jams_version", "JAMS version"), + ("identifiers", "Identifiers"), + ] class AnnotationArray(list): @@ -1486,13 +1556,14 @@ class AnnotationArray(list): >>> # Retrieve everything after the second salami annotation >>> seg_anns = jam.annotations['segment_salami_.*', 2:] """ + def __init__(self, annotations=None): """Create an AnnotationArray. Parameters ---------- - annotations: list - List of Annotations, or appropriately formated dicts + annotations : list + List of Annotations, or appropriately formatted dicts is consistent with Annotation. """ super(AnnotationArray, self).__init__() @@ -1503,13 +1574,12 @@ def __init__(self, annotations=None): self.extend([Annotation(**obj) for obj in annotations]) def search(self, **kwargs): - '''Filter the annotation array down to only those Annotation + """Filter the annotation array down to only those Annotation objects matching the query. - Parameters ---------- - kwargs : search parameters + **kwargs : search parameters See JObject.search Returns @@ -1520,8 +1590,7 @@ def search(self, **kwargs): See Also -------- JObject.search - ''' - + """ results = AnnotationArray() for annotation in self: @@ -1531,8 +1600,7 @@ def search(self, **kwargs): return results def __getitem__(self, idx): - '''Overloaded getitem for syntactic search sugar''' - + """Overloaded getitem for syntactic search sugar""" # if we have only one argument, it can be an int, slice or query if isinstance(idx, (int, slice)): return list.__getitem__(self, idx) @@ -1540,21 +1608,21 @@ def __getitem__(self, idx): return self.search(namespace=idx) elif isinstance(idx, tuple): return self.search(namespace=idx[0])[idx[1]] - raise IndexError('Invalid index: {}'.format(idx)) + raise IndexError("Invalid index: {}".format(idx)) @property def __json__(self): + """Return a JSON-serializable representation of this object.""" return [item.__json__ for item in self] def trim(self, start_time, end_time, strict=False): - ''' + """ Trim every annotation contained in the annotation array using `Annotation.trim` and return as a new `AnnotationArray`. See `Annotation.trim` for details about trimming. This function does not modify the annotations in the original annotation array. - Parameters ---------- start_time : float @@ -1574,7 +1642,7 @@ def trim(self, start_time, end_time, strict=False): ------- trimmed_array : AnnotationArray An annotation array where every annotation has been trimmed. - ''' + """ trimmed_array = AnnotationArray() for ann in self: trimmed_array.append(ann.trim(start_time, end_time, strict=strict)) @@ -1582,7 +1650,7 @@ def trim(self, start_time, end_time, strict=False): return trimmed_array def slice(self, start_time, end_time, strict=False): - ''' + """ Slice every annotation contained in the annotation array using `Annotation.slice` and return as a new AnnotationArray @@ -1609,7 +1677,7 @@ def slice(self, start_time, end_time, strict=False): ------- sliced_array : AnnotationArray An annotation array where every annotation has been sliced. - ''' + """ sliced_array = AnnotationArray() for ann in self: sliced_array.append(ann.slice(start_time, end_time, strict=strict)) @@ -1617,15 +1685,17 @@ def slice(self, start_time, end_time, strict=False): return sliced_array def __repr__(self): + """Return a string representation of this annotation array.""" n = len(self) if n == 1: - return '[1 annotation]' + return "[1 annotation]" else: - return '[{:d} annotations]'.format(n) + return "[{:d} annotations]".format(n) def _repr_html_(self): - out = '' + """Render this annotation array as HTML.""" + out = "" for ann in self: out += '
{}
'.format(ann._repr_html_()) return out @@ -1641,13 +1711,10 @@ def __init__(self, annotations=None, file_metadata=None, sandbox=None): ---------- annotations : list of Annotations Zero or more Annotation objects - file_metadata : FileMetadata (or dict), default=None Metadata corresponding to the audio file. - sandbox : Sandbox (or dict), default=None Unconstrained global sandbox for additional information. - """ super(JAMS, self).__init__() @@ -1664,15 +1731,18 @@ def __init__(self, annotations=None, file_metadata=None, sandbox=None): self.sandbox = Sandbox(**sandbox) def _display_properties(self): - return [('file_metadata', 'File Metadata'), - ('annotations', 'Annotations'), - ('sandbox', 'Sandbox')] + return [ + ("file_metadata", "File Metadata"), + ("annotations", "Annotations"), + ("sandbox", "Sandbox"), + ] @property def __schema__(self): + """Return the JAMS schema for this object.""" return schema.JAMS_SCHEMA - def add(self, jam, on_conflict='fail'): + def add(self, jam, on_conflict="fail"): """Add the contents of another jam to this object. Note that, by default, this method fails if file_metadata is not @@ -1682,10 +1752,9 @@ def add(self, jam, on_conflict='fail'): Parameters ---------- - jam: JAMS object + jam : JAMS object Object to add to this jam - - on_conflict: str, default='fail' + on_conflict : str, default='fail' Strategy for resolving metadata conflicts; one of ['fail', 'overwrite', or 'ignore']. @@ -1693,31 +1762,32 @@ def add(self, jam, on_conflict='fail'): ------ ParameterError if `on_conflict` is an unknown value - JamsError If a conflict is detected and `on_conflict='fail'` """ - - if on_conflict not in ['overwrite', 'fail', 'ignore']: - raise ParameterError("on_conflict='{}' is not in ['fail', " - "'overwrite', 'ignore'].".format(on_conflict)) + if on_conflict not in ["overwrite", "fail", "ignore"]: + raise ParameterError( + "on_conflict='{}' is not in ['fail', " + "'overwrite', 'ignore'].".format(on_conflict) + ) if not self.file_metadata == jam.file_metadata: - if on_conflict == 'overwrite': + if on_conflict == "overwrite": self.file_metadata = jam.file_metadata - elif on_conflict == 'fail': - raise JamsError("Metadata conflict! " - "Resolve manually or force-overwrite it.") + elif on_conflict == "fail": + raise JamsError( + "Metadata conflict! " "Resolve manually or force-overwrite it." + ) self.annotations.extend(jam.annotations) self.sandbox.update(**jam.sandbox) def search(self, **kwargs): - '''Search a JAMS object for matching objects. + """Search a JAMS object for matching objects. Parameters ---------- - kwargs : keyword arguments + **kwargs : keyword arguments Keyword query Returns @@ -1730,18 +1800,16 @@ def search(self, **kwargs): JObject.search AnnotationArray.search - Examples -------- A simple query to get all beat annotations >>> beats = my_jams.search(namespace='beat') - ''' - + """ return self.annotations.search(**kwargs) - def save(self, path_or_file, strict=True, fmt='auto'): + def save(self, path_or_file, strict=True, fmt="auto"): """Serialize annotation as a JSON formatted stream to file. Parameters @@ -1762,25 +1830,23 @@ def save(self, path_or_file, strict=True, fmt='auto'): If the input is an open file handle, `jams` encoding is used. - Raises ------ SchemaError If `strict == True` and the JAMS object fails schema or namespace validation. - See also + See Also -------- validate """ - self.validate(strict=strict) - with _open(path_or_file, mode='w', fmt=fmt) as fdesc: + with _open(path_or_file, mode="w", fmt=fmt) as fdesc: json.dump(self.__json__, fdesc, indent=2) def validate(self, strict=True): - '''Validate a JAMS object against the schema. + """Validate a JAMS object against the schema. Parameters ---------- @@ -1803,7 +1869,7 @@ def validate(self, strict=True): -------- jsonschema.validate - ''' + """ valid = True try: schema.VALIDATOR.validate(self.__json_light__, schema.JAMS_SCHEMA) @@ -1812,7 +1878,7 @@ def validate(self, strict=True): if isinstance(ann, Annotation): valid &= ann.validate(strict=strict) else: - msg = '{} is not a well-formed JAMS Annotation'.format(ann) + msg = "{} is not a well-formed JAMS Annotation".format(ann) valid = False if strict: raise SchemaError(msg) @@ -1830,7 +1896,7 @@ def validate(self, strict=True): return valid def trim(self, start_time, end_time, strict=False): - ''' + """ Trim all the annotations inside the jam and return as a new `JAMS` object. @@ -1869,42 +1935,46 @@ def trim(self, start_time, end_time, strict=False): The trimmed jam with trimmed annotations, returned as a new JAMS object. - ''' + """ # Make sure duration is set in file metadata if self.file_metadata.duration is None: raise JamsError( - 'Duration must be set (jam.file_metadata.duration) before ' - 'trimming can be performed.') + "Duration must be set (jam.file_metadata.duration) before " + "trimming can be performed." + ) # Make sure start and end times are within the file start/end times - if not (0 <= start_time <= end_time <= float( - self.file_metadata.duration)): + if not (0 <= start_time <= end_time <= float(self.file_metadata.duration)): raise ParameterError( - 'start_time and end_time must be within the original file ' - 'duration ({:f}) and end_time cannot be smaller than ' - 'start_time.'.format(float(self.file_metadata.duration))) + "start_time and end_time must be within the original file " + "duration ({:f}) and end_time cannot be smaller than " + "start_time.".format(float(self.file_metadata.duration)) + ) # Create a new jams - jam_trimmed = JAMS(annotations=None, - file_metadata=self.file_metadata, - sandbox=self.sandbox) + jam_trimmed = JAMS( + annotations=None, file_metadata=self.file_metadata, sandbox=self.sandbox + ) # trim annotations jam_trimmed.annotations = self.annotations.trim( - start_time, end_time, strict=strict) + start_time, end_time, strict=strict + ) # Document jam-level trim in top level sandbox - if 'trim' not in jam_trimmed.sandbox.keys(): + if "trim" not in jam_trimmed.sandbox.keys(): jam_trimmed.sandbox.update( - trim=[{'start_time': start_time, 'end_time': end_time}]) + trim=[{"start_time": start_time, "end_time": end_time}] + ) else: jam_trimmed.sandbox.trim.append( - {'start_time': start_time, 'end_time': end_time}) + {"start_time": start_time, "end_time": end_time} + ) return jam_trimmed def slice(self, start_time, end_time, strict=False): - ''' + """ Slice all the annotations inside the jam and return as a new `JAMS` object. @@ -1945,42 +2015,49 @@ def slice(self, start_time, end_time, strict=False): The sliced jam with sliced annotations, returned as a new JAMS object. - ''' + """ # Make sure duration is set in file metadata if self.file_metadata.duration is None: raise JamsError( - 'Duration must be set (jam.file_metadata.duration) before ' - 'slicing can be performed.') + "Duration must be set (jam.file_metadata.duration) before " + "slicing can be performed." + ) # Make sure start and end times are within the file start/end times - if (start_time < 0 or - start_time > float(self.file_metadata.duration) or - end_time < start_time or - end_time > float(self.file_metadata.duration)): + if ( + start_time < 0 + or start_time > float(self.file_metadata.duration) + or end_time < start_time + or end_time > float(self.file_metadata.duration) + ): raise ParameterError( - 'start_time and end_time must be within the original file ' - 'duration ({:f}) and end_time cannot be smaller than ' - 'start_time.'.format(float(self.file_metadata.duration))) + "start_time and end_time must be within the original file " + "duration ({:f}) and end_time cannot be smaller than " + "start_time.".format(float(self.file_metadata.duration)) + ) # Create a new jams - jam_sliced = JAMS(annotations=None, - file_metadata=self.file_metadata, - sandbox=self.sandbox) + jam_sliced = JAMS( + annotations=None, file_metadata=self.file_metadata, sandbox=self.sandbox + ) # trim annotations jam_sliced.annotations = self.annotations.slice( - start_time, end_time, strict=strict) + start_time, end_time, strict=strict + ) # adjust dutation jam_sliced.file_metadata.duration = end_time - start_time # Document jam-level trim in top level sandbox - if 'slice' not in jam_sliced.sandbox.keys(): + if "slice" not in jam_sliced.sandbox.keys(): jam_sliced.sandbox.update( - slice=[{'start_time': start_time, 'end_time': end_time}]) + slice=[{"start_time": start_time, "end_time": end_time}] + ) else: jam_sliced.sandbox.slice.append( - {'start_time': start_time, 'end_time': end_time}) + {"start_time": start_time, "end_time": end_time} + ) return jam_sliced @@ -1995,10 +2072,10 @@ def __json_light__(self): filtered_dict = dict() for k, item in six.iteritems(self.__dict__): - if k.startswith('_') or k == 'annotations': + if k.startswith("_") or k == "annotations": continue - if hasattr(item, '__json__'): + if hasattr(item, "__json__"): filtered_dict[k] = item.__json__ else: filtered_dict[k] = serialize_obj(item) @@ -2007,18 +2084,15 @@ def __json_light__(self): # -- Helper functions -- # -def query_pop(query, prefix, sep='.'): - '''Pop a prefix from a query string. - +def query_pop(query, prefix, sep="."): + """Pop a prefix from a query string. Parameters ---------- query : str The query string - prefix : str The prefix string to pop, if it exists - sep : str The string to separate fields @@ -2035,8 +2109,7 @@ def query_pop(query, prefix, sep='.'): >>> query_pop('namespace', 'Annotation') 'namespace' - ''' - + """ terms = query.split(sep) if terms[0] == prefix: @@ -2046,13 +2119,12 @@ def query_pop(query, prefix, sep='.'): def match_query(string, query): - '''Test if a string matches a query. + """Test if a string matches a query. Parameters ---------- string : str The string to test - query : string, callable, or object Either a regular expression, callable function, or object. @@ -2066,13 +2138,11 @@ def match_query(string, query): `False` otherwise - ''' - + """ if six.callable(query): return query(string) - elif (isinstance(query, six.string_types) and - isinstance(string, six.string_types)): + elif isinstance(query, six.string_types) and isinstance(string, six.string_types): return re.match(query, string) is not None else: @@ -2080,13 +2150,12 @@ def match_query(string, query): def serialize_obj(obj): - '''Custom serialization functionality for working with advanced data types. + """Serialize advanced data types. - numpy arrays are converted to lists - lists are recursively serialized element-wise - ''' - + """ if isinstance(obj, np.integer): return int(obj) @@ -2106,13 +2175,12 @@ def serialize_obj(obj): def summary(obj, indent=0): - '''Helper function to format repr strings for JObjects and friends. + """Format repr strings for JObjects and friends. Parameters ---------- obj The object to repr - indent : int >= 0 indent each new line by `indent` spaces @@ -2125,32 +2193,34 @@ def summary(obj, indent=0): of the length of the list. Otherwise, `repr(obj)`. - ''' - if hasattr(obj, '__summary__'): + """ + if hasattr(obj, "__summary__"): rep = obj.__summary__() elif isinstance(obj, SortedKeyList): - rep = '<{:d} observations>'.format(len(obj)) + rep = "<{:d} observations>".format(len(obj)) else: rep = repr(obj) - return rep.replace('\n', '\n' + ' ' * indent) + return rep.replace("\n", "\n" + " " * indent) def summary_html(obj): - if hasattr(obj, '_repr_html_'): + if hasattr(obj, "_repr_html_"): return obj._repr_html_() elif isinstance(obj, dict): out = '' for key in obj: - out += r''' + out += r""" - '''.format(key, summary_html(obj[key])) - out += '
{0} {1}
' + """.format( + key, summary_html(obj[key]) + ) + out += "" return out elif isinstance(obj, list): - return ''.join([summary_html(x) for x in obj]) + return "".join([summary_html(x) for x in obj]) else: return str(obj) @@ -2159,10 +2229,10 @@ def summary_html(obj): def _get_divid(obj): - '''Static function to get a unique id for an object. + """Get a unique id for an object. This is used in HTML rendering to ensure unique div ids for each call - to display an object''' - + to display an object + """ global __DIVID_COUNT__ __DIVID_COUNT__ += 1 - return '{}-{}'.format(id(obj), __DIVID_COUNT__) + return "{}-{}".format(id(obj), __DIVID_COUNT__) diff --git a/jams/display.py b/jams/display.py index 69924760..e43a9dfc 100644 --- a/jams/display.py +++ b/jams/display.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -r''' +r""" Display ------- @@ -8,7 +8,7 @@ display display_multi -''' +""" from collections import OrderedDict @@ -30,46 +30,44 @@ def pprint_jobject(obj, **kwargs): - '''Pretty-print a jobject. + """Pretty-print a jobject. Parameters ---------- obj : jams.JObject - - kwargs + **kwargs additional parameters to `json.dumps` Returns ------- string A simplified display of `obj` contents. - ''' - + """ obj_simple = {k: v for k, v in six.iteritems(obj.__json__) if v} string = json.dumps(obj_simple, **kwargs) # Suppress braces and quotes - string = re.sub(r'[{}"]', '', string) + string = re.sub(r'[{}"]', "", string) # Kill trailing commas - string = re.sub(r',\n', '\n', string) + string = re.sub(r",\n", "\n", string) # Kill blank lines - string = re.sub(r'^\s*$', '', string) + string = re.sub(r"^\s*$", "", string) return string def intervals(annotation, **kwargs): - '''Plotting wrapper for labeled intervals''' + """Display annotation as labeled intervals""" times, labels = annotation.to_interval_values() return mir_eval.display.labeled_intervals(times, labels, **kwargs) def hierarchy(annotation, **kwargs): - '''Plotting wrapper for hierarchical segmentations''' + """Display annotation as hierarchical segmentations""" htimes, hlabels = hierarchy_flatten(annotation) htimes = [np.asarray(_) for _ in htimes] @@ -77,31 +75,30 @@ def hierarchy(annotation, **kwargs): def pitch_contour(annotation, **kwargs): - '''Plotting wrapper for pitch contours''' - ax = kwargs.pop('ax', None) + """Display annotation as pitch contours""" + ax = kwargs.pop("ax", None) # If the annotation is empty, we need to construct a new axes ax = mir_eval.display.__get_axes(ax=ax)[0] times, values = annotation.to_interval_values() - indices = np.unique([v['index'] for v in values]) + indices = np.unique([v["index"] for v in values]) for idx in indices: - rows = [i for (i, v) in enumerate(values) if v['index'] == idx] - freqs = np.asarray([values[r]['frequency'] for r in rows]) - unvoiced = ~np.asarray([values[r]['voiced'] for r in rows]) + rows = [i for (i, v) in enumerate(values) if v["index"] == idx] + freqs = np.asarray([values[r]["frequency"] for r in rows]) + unvoiced = ~np.asarray([values[r]["voiced"] for r in rows]) freqs[unvoiced] *= -1 - ax = mir_eval.display.pitch(times[rows, 0], freqs, unvoiced=True, - ax=ax, - **kwargs) + ax = mir_eval.display.pitch( + times[rows, 0], freqs, unvoiced=True, ax=ax, **kwargs + ) return ax def event(annotation, **kwargs): - '''Plotting wrapper for events''' - + """Display annotation as events""" times, values = annotation.to_interval_values() if any(values): @@ -113,18 +110,17 @@ def event(annotation, **kwargs): def beat_position(annotation, **kwargs): - '''Plotting wrapper for beat-position data''' - + """Display annotation as beat-position data""" times, values = annotation.to_interval_values() - labels = [_['position'] for _ in values] + labels = [_["position"] for _ in values] # TODO: plot time signature, measure number return mir_eval.display.events(times, labels=labels, **kwargs) def piano_roll(annotation, **kwargs): - '''Plotting wrapper for piano rolls''' + """Display annotation as piano rolls""" times, midi = annotation.to_interval_values() return mir_eval.display.piano_roll(times, midi=midi, **kwargs) @@ -132,29 +128,27 @@ def piano_roll(annotation, **kwargs): VIZ_MAPPING = OrderedDict() -VIZ_MAPPING['segment_open'] = intervals -VIZ_MAPPING['chord'] = intervals -VIZ_MAPPING['multi_segment'] = hierarchy -VIZ_MAPPING['pitch_contour'] = pitch_contour -VIZ_MAPPING['beat_position'] = beat_position -VIZ_MAPPING['beat'] = event -VIZ_MAPPING['onset'] = event -VIZ_MAPPING['note_midi'] = piano_roll -VIZ_MAPPING['tag_open'] = intervals +VIZ_MAPPING["segment_open"] = intervals +VIZ_MAPPING["chord"] = intervals +VIZ_MAPPING["multi_segment"] = hierarchy +VIZ_MAPPING["pitch_contour"] = pitch_contour +VIZ_MAPPING["beat_position"] = beat_position +VIZ_MAPPING["beat"] = event +VIZ_MAPPING["onset"] = event +VIZ_MAPPING["note_midi"] = piano_roll +VIZ_MAPPING["tag_open"] = intervals def display(annotation, meta=True, **kwargs): - '''Visualize a jams annotation through mir_eval + """Visualize a jams annotation through mir_eval Parameters ---------- annotation : jams.Annotation The annotation to display - meta : bool If `True`, include annotation metadata in the figure - - kwargs + **kwargs Additional keyword arguments to mir_eval.display functions Returns @@ -166,8 +160,7 @@ def display(annotation, meta=True, **kwargs): ------ NamespaceError If the annotation cannot be visualized - ''' - + """ for namespace, func in six.iteritems(VIZ_MAPPING): try: ann = coerce_annotation(annotation, namespace) @@ -179,12 +172,14 @@ def display(annotation, meta=True, **kwargs): if meta: description = pprint_jobject(annotation.annotation_metadata, indent=2) - anchored_box = AnchoredText(description.strip('\n'), - loc=2, - frameon=True, - bbox_to_anchor=(1.02, 1.0), - bbox_transform=axes.transAxes, - borderpad=0.0) + anchored_box = AnchoredText( + description.strip("\n"), + loc=2, + frameon=True, + bbox_to_anchor=(1.02, 1.0), + bbox_transform=axes.transAxes, + borderpad=0.0, + ) axes.add_artist(anchored_box) axes.figure.subplots_adjust(right=0.8) @@ -193,25 +188,25 @@ def display(annotation, meta=True, **kwargs): except NamespaceError: pass - raise NamespaceError('Unable to visualize annotation of namespace="{:s}"' - .format(annotation.namespace)) + raise NamespaceError( + 'Unable to visualize annotation of namespace="{:s}"'.format( + annotation.namespace + ) + ) def display_multi(annotations, fig_kw=None, meta=True, **kwargs): - '''Display multiple annotations with shared axes + """Display multiple annotations with shared axes Parameters ---------- annotations : jams.AnnotationArray A collection of annotations to display - fig_kw : dict Keyword arguments to `plt.figure` - meta : bool If `True`, display annotation metadata for each annotation - - kwargs + **kwargs Additional keyword arguments to the `mir_eval.display` routines Returns @@ -220,14 +215,14 @@ def display_multi(annotations, fig_kw=None, meta=True, **kwargs): The created figure axs List of subplot axes corresponding to each displayed annotation - ''' + """ if fig_kw is None: fig_kw = dict() - fig_kw.setdefault('sharex', True) - fig_kw.setdefault('squeeze', True) + fig_kw.setdefault("sharex", True) + fig_kw.setdefault("squeeze", True) - # Filter down to coercable annotations first + # Filter down to coercible annotations first display_annotations = [] for ann in annotations: for namespace in VIZ_MAPPING: @@ -237,7 +232,7 @@ def display_multi(annotations, fig_kw=None, meta=True, **kwargs): # If there are no displayable annotations, fail here if not len(display_annotations): - raise ParameterError('No displayable annotations found') + raise ParameterError("No displayable annotations found") fig, axs = plt.subplots(nrows=len(display_annotations), ncols=1, **fig_kw) @@ -247,7 +242,7 @@ def display_multi(annotations, fig_kw=None, meta=True, **kwargs): axs = [axs] for ann, ax in zip(display_annotations, axs): - kwargs['ax'] = ax + kwargs["ax"] = ax display(ann, meta=meta, **kwargs) return fig, axs diff --git a/jams/eval.py b/jams/eval.py index eb95d368..be131491 100644 --- a/jams/eval.py +++ b/jams/eval.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # CREATED:2015-02-04 16:39:00 by Brian McFee -r''' +r""" Evaluation ---------- @@ -16,7 +16,7 @@ pattern hierarchy transcription -''' +""" from collections import defaultdict @@ -26,13 +26,21 @@ from .nsconvert import convert -__all__ = ['beat', 'chord', 'melody', 'onset', - 'segment', 'hierarchy', 'tempo', - 'pattern', 'transcription'] +__all__ = [ + "beat", + "chord", + "melody", + "onset", + "segment", + "hierarchy", + "tempo", + "pattern", + "transcription", +] def coerce_annotation(ann, namespace): - '''Validate that the annotation has the correct namespace, + """Validate that the annotation has the correct namespace, and is well-formed. If the annotation is not of the correct namespace, automatic conversion @@ -42,7 +50,6 @@ def coerce_annotation(ann, namespace): ---------- ann : jams.Annotation The annotation object in question - namespace : str The namespace pattern to match `ann` against @@ -55,15 +62,13 @@ def coerce_annotation(ann, namespace): ------ NamespaceError If `ann` does not match the proper namespace - SchemaError If `ann` fails schema validation See Also -------- jams.nsconvert.convert - ''' - + """ ann = convert(ann, namespace) ann.validate(strict=True) @@ -71,7 +76,7 @@ def coerce_annotation(ann, namespace): def beat(ref, est, **kwargs): - r'''Beat tracking evaluation + r"""Beat tracking evaluation Parameters ---------- @@ -79,7 +84,7 @@ def beat(ref, est, **kwargs): Reference annotation object est : jams.Annotation Estimated annotation object - kwargs + **kwargs Additional keyword arguments Returns @@ -101,9 +106,8 @@ def beat(ref, est, **kwargs): >>> ref_ann = ref_jam.search(namespace='beat')[0] >>> est_ann = est_jam.search(namespace='beat')[0] >>> scores = jams.eval.beat(ref_ann, est_ann) - ''' - - namespace = 'beat' + """ + namespace = "beat" ref = coerce_annotation(ref, namespace) est = coerce_annotation(est, namespace) @@ -114,7 +118,7 @@ def beat(ref, est, **kwargs): def onset(ref, est, **kwargs): - r'''Onset evaluation + r"""Onset evaluation Parameters ---------- @@ -122,7 +126,7 @@ def onset(ref, est, **kwargs): Reference annotation object est : jams.Annotation Estimated annotation object - kwargs + **kwargs Additional keyword arguments Returns @@ -144,8 +148,8 @@ def onset(ref, est, **kwargs): >>> ref_ann = ref_jam.search(namespace='onset')[0] >>> est_ann = est_jam.search(namespace='onset')[0] >>> scores = jams.eval.onset(ref_ann, est_ann) - ''' - namespace = 'onset' + """ + namespace = "onset" ref = coerce_annotation(ref, namespace) est = coerce_annotation(est, namespace) @@ -156,7 +160,7 @@ def onset(ref, est, **kwargs): def chord(ref, est, **kwargs): - r'''Chord evaluation + r"""Chord evaluation Parameters ---------- @@ -164,7 +168,7 @@ def chord(ref, est, **kwargs): Reference annotation object est : jams.Annotation Estimated annotation object - kwargs + **kwargs Additional keyword arguments Returns @@ -186,20 +190,20 @@ def chord(ref, est, **kwargs): >>> ref_ann = ref_jam.search(namespace='chord')[0] >>> est_ann = est_jam.search(namespace='chord')[0] >>> scores = jams.eval.chord(ref_ann, est_ann) - ''' - - namespace = 'chord' + """ + namespace = "chord" ref = coerce_annotation(ref, namespace) est = coerce_annotation(est, namespace) ref_interval, ref_value = ref.to_interval_values() est_interval, est_value = est.to_interval_values() - return mir_eval.chord.evaluate(ref_interval, ref_value, - est_interval, est_value, **kwargs) + return mir_eval.chord.evaluate( + ref_interval, ref_value, est_interval, est_value, **kwargs + ) def segment(ref, est, **kwargs): - r'''Segment evaluation + r"""Segment evaluation Parameters ---------- @@ -207,7 +211,7 @@ def segment(ref, est, **kwargs): Reference annotation object est : jams.Annotation Estimated annotation object - kwargs + **kwargs Additional keyword arguments Returns @@ -229,19 +233,20 @@ def segment(ref, est, **kwargs): >>> ref_ann = ref_jam.search(namespace='segment_.*')[0] >>> est_ann = est_jam.search(namespace='segment_.*')[0] >>> scores = jams.eval.segment(ref_ann, est_ann) - ''' - namespace = 'segment_open' + """ + namespace = "segment_open" ref = coerce_annotation(ref, namespace) est = coerce_annotation(est, namespace) ref_interval, ref_value = ref.to_interval_values() est_interval, est_value = est.to_interval_values() - return mir_eval.segment.evaluate(ref_interval, ref_value, - est_interval, est_value, **kwargs) + return mir_eval.segment.evaluate( + ref_interval, ref_value, est_interval, est_value, **kwargs + ) def hierarchy_flatten(annotation): - '''Flatten a multi_segment annotation into mir_eval style. + """Flatten a multi_segment annotation into mir_eval style. Parameters ---------- @@ -252,32 +257,30 @@ def hierarchy_flatten(annotation): ------- hier_intervalss : list A list of lists of intervals, ordered by increasing specificity. - hier_labels : list A list of lists of labels, ordered by increasing specificity. - ''' - + """ intervals, values = annotation.to_interval_values() ordering = dict() for interval, value in zip(intervals, values): - level = value['level'] + level = value["level"] if level not in ordering: ordering[level] = dict(intervals=list(), labels=list()) - ordering[level]['intervals'].append(interval) - ordering[level]['labels'].append(value['label']) + ordering[level]["intervals"].append(interval) + ordering[level]["labels"].append(value["label"]) levels = sorted(list(ordering.keys())) - hier_intervals = [ordering[level]['intervals'] for level in levels] - hier_labels = [ordering[level]['labels'] for level in levels] + hier_intervals = [ordering[level]["intervals"] for level in levels] + hier_labels = [ordering[level]["labels"] for level in levels] return hier_intervals, hier_labels def hierarchy(ref, est, **kwargs): - r'''Multi-level segmentation evaluation + r"""Multi-level segmentation evaluation Parameters ---------- @@ -285,7 +288,7 @@ def hierarchy(ref, est, **kwargs): Reference annotation object est : jams.Annotation Estimated annotation object - kwargs + **kwargs Additional keyword arguments Returns @@ -307,20 +310,20 @@ def hierarchy(ref, est, **kwargs): >>> ref_ann = ref_jam.search(namespace='multi_segment')[0] >>> est_ann = est_jam.search(namespace='multi_segment')[0] >>> scores = jams.eval.hierarchy(ref_ann, est_ann) - ''' - namespace = 'multi_segment' + """ + namespace = "multi_segment" ref = coerce_annotation(ref, namespace) est = coerce_annotation(est, namespace) ref_hier, ref_hier_lab = hierarchy_flatten(ref) est_hier, est_hier_lab = hierarchy_flatten(est) - return mir_eval.hierarchy.evaluate(ref_hier, ref_hier_lab, - est_hier, est_hier_lab, - **kwargs) + return mir_eval.hierarchy.evaluate( + ref_hier, ref_hier_lab, est_hier, est_hier_lab, **kwargs + ) def tempo(ref, est, **kwargs): - r'''Tempo evaluation + r"""Tempo evaluation Parameters ---------- @@ -328,7 +331,7 @@ def tempo(ref, est, **kwargs): Reference annotation object est : jams.Annotation Estimated annotation object - kwargs + **kwargs Additional keyword arguments Returns @@ -350,10 +353,9 @@ def tempo(ref, est, **kwargs): >>> ref_ann = ref_jam.search(namespace='tempo')[0] >>> est_ann = est_jam.search(namespace='tempo')[0] >>> scores = jams.eval.tempo(ref_ann, est_ann) - ''' - - ref = coerce_annotation(ref, 'tempo') - est = coerce_annotation(est, 'tempo') + """ + ref = coerce_annotation(ref, "tempo") + est = coerce_annotation(est, "tempo") ref_tempi = np.asarray([o.value for o in ref]) ref_weight = ref.data[0].confidence @@ -364,7 +366,7 @@ def tempo(ref, est, **kwargs): # melody def melody(ref, est, **kwargs): - r'''Melody extraction evaluation + r"""Melody extraction evaluation Parameters ---------- @@ -372,7 +374,7 @@ def melody(ref, est, **kwargs): Reference annotation object est : jams.Annotation Estimated annotation object - kwargs + **kwargs Additional keyword arguments Returns @@ -394,26 +396,23 @@ def melody(ref, est, **kwargs): >>> ref_ann = ref_jam.search(namespace='pitch_contour')[0] >>> est_ann = est_jam.search(namespace='pitch_contour')[0] >>> scores = jams.eval.melody(ref_ann, est_ann) - ''' - - namespace = 'pitch_contour' + """ + namespace = "pitch_contour" ref = coerce_annotation(ref, namespace) est = coerce_annotation(est, namespace) ref_times, ref_p = ref.to_event_values() est_times, est_p = est.to_event_values() - ref_freq = np.asarray([p['frequency'] * (-1)**(~p['voiced']) for p in ref_p]) - est_freq = np.asarray([p['frequency'] * (-1)**(~p['voiced']) for p in est_p]) + ref_freq = np.asarray([p["frequency"] * (-1) ** (~p["voiced"]) for p in ref_p]) + est_freq = np.asarray([p["frequency"] * (-1) ** (~p["voiced"]) for p in est_p]) - return mir_eval.melody.evaluate(ref_times, ref_freq, - est_times, est_freq, - **kwargs) + return mir_eval.melody.evaluate(ref_times, ref_freq, est_times, est_freq, **kwargs) # pattern detection def pattern_to_mireval(ann): - '''Convert a pattern_jku annotation object to mir_eval format. + """Convert a pattern_jku annotation object to mir_eval format. Parameters ---------- @@ -430,8 +429,7 @@ def pattern_to_mireval(ann): - `patterns[x][y][z]` contains a time-note tuple `(time, midi note)` - ''' - + """ # It's easier to work with dictionaries, since we can't assume # sequential pattern or occurrence identifiers @@ -441,9 +439,9 @@ def pattern_to_mireval(ann): for time, observation in zip(*ann.to_event_values()): - pattern_id = observation['pattern_id'] - occurrence_id = observation['occurrence_id'] - obs = (time, observation['midi_pitch']) + pattern_id = observation["pattern_id"] + occurrence_id = observation["occurrence_id"] + obs = (time, observation["midi_pitch"]) # Push this note observation into the correct pattern/occurrence patterns[pattern_id][occurrence_id].append(obs) @@ -453,7 +451,7 @@ def pattern_to_mireval(ann): def pattern(ref, est, **kwargs): - r'''Pattern detection evaluation + r"""Pattern detection evaluation Parameters ---------- @@ -461,7 +459,7 @@ def pattern(ref, est, **kwargs): Reference annotation object est : jams.Annotation Estimated annotation object - kwargs + **kwargs Additional keyword arguments Returns @@ -483,9 +481,8 @@ def pattern(ref, est, **kwargs): >>> ref_ann = ref_jam.search(namespace='pattern_jku')[0] >>> est_ann = est_jam.search(namespace='pattern_jku')[0] >>> scores = jams.eval.pattern(ref_ann, est_ann) - ''' - - namespace = 'pattern_jku' + """ + namespace = "pattern_jku" ref = coerce_annotation(ref, namespace) est = coerce_annotation(est, namespace) @@ -496,7 +493,7 @@ def pattern(ref, est, **kwargs): def transcription(ref, est, **kwargs): - r'''Note transcription evaluation + r"""Note transcription evaluation Parameters ---------- @@ -504,7 +501,7 @@ def transcription(ref, est, **kwargs): Reference annotation object est : jams.Annotation Estimated annotation object - kwargs + **kwargs Additional keyword arguments Returns @@ -527,16 +524,16 @@ def transcription(ref, est, **kwargs): >>> ref_ann = ref_jam.search(namespace='pitch_contour')[0] >>> est_ann = est_jam.search(namespace='note_hz')[0] >>> scores = jams.eval.transcription(ref_ann, est_ann) - ''' - - namespace = 'pitch_contour' + """ + namespace = "pitch_contour" ref = coerce_annotation(ref, namespace) est = coerce_annotation(est, namespace) ref_intervals, ref_p = ref.to_interval_values() est_intervals, est_p = est.to_interval_values() - ref_pitches = np.asarray([p['frequency'] * (-1)**(~p['voiced']) for p in ref_p]) - est_pitches = np.asarray([p['frequency'] * (-1)**(~p['voiced']) for p in est_p]) + ref_pitches = np.asarray([p["frequency"] * (-1) ** (~p["voiced"]) for p in ref_p]) + est_pitches = np.asarray([p["frequency"] * (-1) ** (~p["voiced"]) for p in est_p]) return mir_eval.transcription.evaluate( - ref_intervals, ref_pitches, est_intervals, est_pitches, **kwargs) + ref_intervals, ref_pitches, est_intervals, est_pitches, **kwargs + ) diff --git a/jams/exceptions.py b/jams/exceptions.py index 8d9aeb70..aca15844 100644 --- a/jams/exceptions.py +++ b/jams/exceptions.py @@ -1,23 +1,27 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -'''Exception classes for JAMS''' +"""Exception classes for JAMS""" class JamsError(Exception): - '''The root JAMS exception class''' + """The root JAMS exception class""" + pass class SchemaError(JamsError): - '''Exceptions relating to schema validation''' + """Exceptions relating to schema validation""" + pass class NamespaceError(JamsError): - '''Exceptions relating to task namespaces''' + """Exceptions relating to task namespaces""" + pass class ParameterError(JamsError): - '''Exceptions relating to function and method parameters''' + """Exceptions relating to function and method parameters""" + pass diff --git a/jams/nsconvert.py b/jams/nsconvert.py index 5a55b30b..9265f0ba 100644 --- a/jams/nsconvert.py +++ b/jams/nsconvert.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- # CREATED:2016-02-16 15:40:04 by Brian McFee -r''' +r""" Namespace conversion -------------------- @@ -9,7 +9,7 @@ :toctree: generated convert -''' +""" import numpy as np @@ -22,22 +22,22 @@ # The structure that handles all conversion mappings __CONVERSION__ = defaultdict(defaultdict) -__all__ = ['convert', 'can_convert'] +__all__ = ["convert", "can_convert"] def _conversion(target, source): - '''A decorator to register namespace conversions. + """Decorate a function to register namespace conversions. - Usage - ----- + Examples + -------- >>> @conversion('tag_open', 'tag_.*') ... def tag_to_open(annotation): ... annotation.namespace = 'tag_open' ... return annotation - ''' + """ def register(func): - '''This decorator registers func as mapping source to target''' + """Register func as mapping source to target""" __CONVERSION__[target][source] = func return func @@ -45,13 +45,12 @@ def register(func): def convert(annotation, target_namespace): - '''Convert a given annotation to the target namespace. + """Convert a given annotation to the target namespace. Parameters ---------- annotation : jams.Annotation An annotation object - target_namespace : str The target namespace @@ -68,7 +67,6 @@ def convert(annotation, target_namespace): ------ SchemaError if the input annotation fails to validate - NamespaceError if no conversion is possible @@ -81,8 +79,7 @@ def convert(annotation, target_namespace): And back to Hz >>> ann_hz2 = jams.convert(ann_midi, 'note_hz') - ''' - + """ # First, validate the input. If this fails, we can't auto-convert. annotation.validate(strict=True) @@ -100,19 +97,19 @@ def convert(annotation, target_namespace): return __CONVERSION__[target_namespace][source](annotation) # No conversion possible - raise NamespaceError('Unable to convert annotation from namespace=' - '"{0}" to "{1}"'.format(annotation.namespace, - target_namespace)) + raise NamespaceError( + "Unable to convert annotation from namespace=" + '"{0}" to "{1}"'.format(annotation.namespace, target_namespace) + ) def can_convert(annotation, target_namespace): - '''Test if an annotation can be mapped to a target namespace + """Test if an annotation can be mapped to a target namespace Parameters ---------- annotation : jams.Annotation An annotation object - target_namespace : str The target namespace @@ -121,11 +118,9 @@ def can_convert(annotation, target_namespace): True if `annotation` can be automatically converted to `target_namespace` - False otherwise - ''' - + """ # If we're already in the target namespace, do nothing if annotation.namespace == target_namespace: return True @@ -138,137 +133,148 @@ def can_convert(annotation, target_namespace): return False -@_conversion('pitch_contour', 'pitch_hz') +@_conversion("pitch_contour", "pitch_hz") def pitch_hz_to_contour(annotation): - '''Convert a pitch_hz annotation to a contour''' - annotation.namespace = 'pitch_contour' + """Convert a pitch_hz annotation to a contour""" + annotation.namespace = "pitch_contour" data = annotation.pop_data() for obs in data: - annotation.append(time=obs.time, duration=obs.duration, - confidence=obs.confidence, - value=dict(index=0, - frequency=np.abs(obs.value), - voiced=obs.value > 0)) + annotation.append( + time=obs.time, + duration=obs.duration, + confidence=obs.confidence, + value=dict(index=0, frequency=np.abs(obs.value), voiced=obs.value > 0), + ) return annotation -@_conversion('pitch_contour', 'pitch_midi') +@_conversion("pitch_contour", "pitch_midi") def pitch_midi_to_contour(annotation): - '''Convert a pitch_hz annotation to a contour''' + """Convert a pitch_hz annotation to a contour""" annotation = pitch_midi_to_hz(annotation) return pitch_hz_to_contour(annotation) -@_conversion('note_hz', 'note_midi') +@_conversion("note_hz", "note_midi") def note_midi_to_hz(annotation): - '''Convert a pitch_midi annotation to pitch_hz''' - - annotation.namespace = 'note_hz' + """Convert a pitch_midi annotation to pitch_hz""" + annotation.namespace = "note_hz" data = annotation.pop_data() for obs in data: - annotation.append(time=obs.time, duration=obs.duration, - confidence=obs.confidence, - value=440 * (2.0**((obs.value - 69.0)/12.0))) + annotation.append( + time=obs.time, + duration=obs.duration, + confidence=obs.confidence, + value=440 * (2.0 ** ((obs.value - 69.0) / 12.0)), + ) return annotation -@_conversion('note_midi', 'note_hz') +@_conversion("note_midi", "note_hz") def note_hz_to_midi(annotation): - '''Convert a pitch_hz annotation to pitch_midi''' - - annotation.namespace = 'note_midi' + """Convert a pitch_hz annotation to pitch_midi""" + annotation.namespace = "note_midi" data = annotation.pop_data() for obs in data: - annotation.append(time=obs.time, duration=obs.duration, - confidence=obs.confidence, - value=12 * (np.log2(obs.value) - np.log2(440.0)) + 69) + annotation.append( + time=obs.time, + duration=obs.duration, + confidence=obs.confidence, + value=12 * (np.log2(obs.value) - np.log2(440.0)) + 69, + ) return annotation -@_conversion('pitch_hz', 'pitch_midi') +@_conversion("pitch_hz", "pitch_midi") def pitch_midi_to_hz(annotation): - '''Convert a pitch_midi annotation to pitch_hz''' - - annotation.namespace = 'pitch_hz' + """Convert a pitch_midi annotation to pitch_hz""" + annotation.namespace = "pitch_hz" data = annotation.pop_data() for obs in data: - annotation.append(time=obs.time, duration=obs.duration, - confidence=obs.confidence, - value=440 * (2.0**((obs.value - 69.0)/12.0))) + annotation.append( + time=obs.time, + duration=obs.duration, + confidence=obs.confidence, + value=440 * (2.0 ** ((obs.value - 69.0) / 12.0)), + ) return annotation -@_conversion('pitch_midi', 'pitch_hz') +@_conversion("pitch_midi", "pitch_hz") def pitch_hz_to_midi(annotation): - '''Convert a pitch_hz annotation to pitch_midi''' - - annotation.namespace = 'pitch_midi' + """Convert a pitch_hz annotation to pitch_midi""" + annotation.namespace = "pitch_midi" data = annotation.pop_data() for obs in data: - annotation.append(time=obs.time, duration=obs.duration, - confidence=obs.confidence, - value=12 * (np.log2(obs.value) - np.log2(440.0)) + 69) + annotation.append( + time=obs.time, + duration=obs.duration, + confidence=obs.confidence, + value=12 * (np.log2(obs.value) - np.log2(440.0)) + 69, + ) return annotation -@_conversion('segment_open', 'segment_.*') +@_conversion("segment_open", "segment_.*") def segment_to_open(annotation): - '''Convert any segmentation to open label space''' - - annotation.namespace = 'segment_open' + """Convert any segmentation to open label space""" + annotation.namespace = "segment_open" return annotation -@_conversion('tag_open', 'tag_.*') +@_conversion("tag_open", "tag_.*") def tag_to_open(annotation): - '''Convert any tag annotation to open label space''' - - annotation.namespace = 'tag_open' + """Convert any tag annotation to open label space""" + annotation.namespace = "tag_open" return annotation -@_conversion('tag_open', 'scaper') +@_conversion("tag_open", "scaper") def scaper_to_tag(annotation): - '''Convert scaper annotations to tag_open''' - - annotation.namespace = 'tag_open' + """Convert scaper annotations to tag_open""" + annotation.namespace = "tag_open" data = annotation.pop_data() for obs in data: - annotation.append(time=obs.time, duration=obs.duration, - confidence=obs.confidence, value=obs.value['label']) + annotation.append( + time=obs.time, + duration=obs.duration, + confidence=obs.confidence, + value=obs.value["label"], + ) return annotation -@_conversion('beat', 'beat_position') +@_conversion("beat", "beat_position") def beat_position(annotation): - '''Convert beat_position to beat''' - - annotation.namespace = 'beat' + """Convert beat_position to beat""" + annotation.namespace = "beat" data = annotation.pop_data() for obs in data: - annotation.append(time=obs.time, duration=obs.duration, - confidence=obs.confidence, - value=obs.value['position']) + annotation.append( + time=obs.time, + duration=obs.duration, + confidence=obs.confidence, + value=obs.value["position"], + ) return annotation -@_conversion('chord', 'chord_harte') +@_conversion("chord", "chord_harte") def chordh_to_chord(annotation): - '''Convert Harte annotation to chord''' - - annotation.namespace = 'chord' + """Convert Harte annotation to chord""" + annotation.namespace = "chord" return annotation diff --git a/jams/schema.py b/jams/schema.py index 34c5c926..b52b460a 100644 --- a/jams/schema.py +++ b/jams/schema.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -r''' +r""" Namespace management -------------------- @@ -14,7 +14,7 @@ values get_dtypes list_namespaces -''' +""" from __future__ import print_function @@ -27,13 +27,20 @@ from .exceptions import NamespaceError, JamsError -__all__ = ['add_namespace', 'namespace', 'is_dense', 'values', 'get_dtypes', 'VALIDATOR'] +__all__ = [ + "add_namespace", + "namespace", + "is_dense", + "values", + "get_dtypes", + "VALIDATOR", +] __NAMESPACE__ = dict() def add_namespace(filename): - '''Add a namespace definition to our working set. + """Add a namespace definition to our working set. Namespace files consist of partial JSON schemas defining the behavior of the `value` and `confidence` fields of an Annotation. @@ -42,13 +49,13 @@ def add_namespace(filename): ---------- filename : str Path to json file defining the namespace object - ''' - with open(filename, mode='r') as fileobj: + """ + with open(filename, mode="r") as fileobj: __NAMESPACE__.update(json.load(fileobj)) def namespace(ns_key): - '''Construct a validation schema for a given namespace. + """Construct a validation schema for a given namespace. Parameters ---------- @@ -59,16 +66,15 @@ def namespace(ns_key): ------- schema : dict JSON schema of `namespace` - ''' - + """ if ns_key not in __NAMESPACE__: - raise NamespaceError('Unknown namespace: {:s}'.format(ns_key)) + raise NamespaceError("Unknown namespace: {:s}".format(ns_key)) - sch = copy.deepcopy(JAMS_SCHEMA['definitions']['SparseObservation']) + sch = copy.deepcopy(JAMS_SCHEMA["definitions"]["SparseObservation"]) - for key in ['value', 'confidence']: + for key in ["value", "confidence"]: try: - sch['properties'][key] = __NAMESPACE__[ns_key][key] + sch["properties"][key] = __NAMESPACE__[ns_key][key] except KeyError: pass @@ -76,7 +82,7 @@ def namespace(ns_key): def namespace_array(ns_key): - '''Construct a validation schema for arrays of a given namespace. + """Construct a validation schema for arrays of a given namespace. Parameters ---------- @@ -87,18 +93,17 @@ def namespace_array(ns_key): ------- schema : dict JSON schema of `namespace` observation arrays - ''' - + """ obs_sch = namespace(ns_key) - obs_sch['title'] = 'Observation' + obs_sch["title"] = "Observation" - sch = copy.deepcopy(JAMS_SCHEMA['definitions']['SparseObservationList']) - sch['items'] = obs_sch + sch = copy.deepcopy(JAMS_SCHEMA["definitions"]["SparseObservationList"]) + sch["items"] = obs_sch return sch def is_dense(ns_key): - '''Determine whether a namespace has dense formatting. + """Determine whether a namespace has dense formatting. Parameters ---------- @@ -110,16 +115,15 @@ def is_dense(ns_key): dense : bool True if `ns_key` has a dense packing False otherwise. - ''' - + """ if ns_key not in __NAMESPACE__: - raise NamespaceError('Unknown namespace: {:s}'.format(ns_key)) + raise NamespaceError("Unknown namespace: {:s}".format(ns_key)) - return __NAMESPACE__[ns_key]['dense'] + return __NAMESPACE__[ns_key]["dense"] def values(ns_key): - '''Return the allowed values for an enumerated namespace. + """Return the allowed values for an enumerated namespace. Parameters ---------- @@ -140,19 +144,18 @@ def values(ns_key): >>> jams.schema.values('tag_gtzan') ['blues', 'classical', 'country', 'disco', 'hip-hop', 'jazz', 'metal', 'pop', 'reggae', 'rock'] - ''' - + """ if ns_key not in __NAMESPACE__: - raise NamespaceError('Unknown namespace: {:s}'.format(ns_key)) + raise NamespaceError("Unknown namespace: {:s}".format(ns_key)) - if 'enum' not in __NAMESPACE__[ns_key]['value']: - raise NamespaceError('Namespace {:s} is not enumerated'.format(ns_key)) + if "enum" not in __NAMESPACE__[ns_key]["value"]: + raise NamespaceError("Namespace {:s} is not enumerated".format(ns_key)) - return copy.copy(__NAMESPACE__[ns_key]['value']['enum']) + return copy.copy(__NAMESPACE__[ns_key]["value"]["enum"]) def get_dtypes(ns_key): - '''Get the dtypes associated with the value and confidence fields + """Get the dtypes associated with the value and confidence fields for a given namespace. Parameters @@ -164,40 +167,41 @@ def get_dtypes(ns_key): ------- value_dtype, confidence_dtype : numpy.dtype Type identifiers for value and confidence fields. - ''' - + """ # First, get the schema if ns_key not in __NAMESPACE__: - raise NamespaceError('Unknown namespace: {:s}'.format(ns_key)) + raise NamespaceError("Unknown namespace: {:s}".format(ns_key)) - value_dtype = __get_dtype(__NAMESPACE__[ns_key].get('value', {})) - confidence_dtype = __get_dtype(__NAMESPACE__[ns_key].get('confidence', {})) + value_dtype = __get_dtype(__NAMESPACE__[ns_key].get("value", {})) + confidence_dtype = __get_dtype(__NAMESPACE__[ns_key].get("confidence", {})) return value_dtype, confidence_dtype def list_namespaces(): - '''Print out a listing of available namespaces''' - print('{:30s}\t{:40s}'.format('NAME', 'DESCRIPTION')) - print('-' * 78) + """Print out a listing of available namespaces""" + print("{:30s}\t{:40s}".format("NAME", "DESCRIPTION")) + print("-" * 78) for sch in sorted(__NAMESPACE__): - desc = __NAMESPACE__[sch]['description'] - desc = (desc[:44] + '..') if len(desc) > 46 else desc - print('{:30s}\t{:40s}'.format(sch, desc)) + desc = __NAMESPACE__[sch]["description"] + desc = (desc[:44] + "..") if len(desc) > 46 else desc + print("{:30s}\t{:40s}".format(sch, desc)) # Mapping of js primitives to numpy types -__TYPE_MAP__ = dict(integer=np.int_, - boolean=np.bool_, - number=np.float64, - object=np.object_, - array=np.object_, - string=np.object_, - null=np.float64) +__TYPE_MAP__ = dict( + integer=np.int_, + boolean=np.bool_, + number=np.float64, + object=np.object_, + array=np.object_, + string=np.object_, + null=np.float64, +) def __get_dtype(typespec): - '''Get the dtype associated with a jsonschema type definition + """Get the dtype associated with a jsonschema type definition Parameters ---------- @@ -208,18 +212,17 @@ def __get_dtype(typespec): ------- dtype : numpy.dtype The associated dtype - ''' - - if 'type' in typespec: - return __TYPE_MAP__.get(typespec['type'], np.object_) + """ + if "type" in typespec: + return __TYPE_MAP__.get(typespec["type"], np.object_) - elif 'enum' in typespec: + elif "enum" in typespec: # Enums map to objects return np.object_ - elif 'oneOf' in typespec: + elif "oneOf" in typespec: # Recurse - types = [__get_dtype(v) for v in typespec['oneOf']] + types = [__get_dtype(v) for v in typespec["oneOf"]] # If they're not all equal, return object if all([t == types[0] for t in types]): @@ -229,21 +232,21 @@ def __get_dtype(typespec): def __load_jams_schema(): - '''Load the schema file from the package.''' + """Load the schema file from the package.""" abs_schema_dir = os.path.join(os.path.dirname(__file__), SCHEMA_DIR) - schema_file = os.path.join(abs_schema_dir, 'jams_schema.json') - with open(schema_file, mode='r') as fdesc: + schema_file = os.path.join(abs_schema_dir, "jams_schema.json") + with open(schema_file, mode="r") as fdesc: jams_schema = json.load(fdesc) if jams_schema is None: - raise JamsError('Unable to load JAMS schema') + raise JamsError("Unable to load JAMS schema") return jams_schema # Populate the schemata -SCHEMA_DIR = 'schemata' -NS_SCHEMA_DIR = os.path.join(SCHEMA_DIR, 'namespaces') +SCHEMA_DIR = "schemata" +NS_SCHEMA_DIR = os.path.join(SCHEMA_DIR, "namespaces") JAMS_SCHEMA = __load_jams_schema() VALIDATOR = jsonschema.Draft4Validator(JAMS_SCHEMA) diff --git a/jams/schemata/validate.py b/jams/schemata/validate.py index 82308693..9b8c25c7 100755 --- a/jams/schemata/validate.py +++ b/jams/schemata/validate.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -'''Validator script for jams files''' +"""Validator script for jams files""" import argparse import sys @@ -8,41 +8,37 @@ def process_arguments(args): - '''Argument parser''' - parser = argparse.ArgumentParser(description='JAMS schema validator') + """Argument parser""" + parser = argparse.ArgumentParser(description="JAMS schema validator") - parser.add_argument('schema_file', - action='store', - help='path to the schema file') - parser.add_argument('jams_files', - action='store', - nargs='+', - help='path to one or more JAMS files') + parser.add_argument("schema_file", action="store", help="path to the schema file") + parser.add_argument( + "jams_files", action="store", nargs="+", help="path to one or more JAMS files" + ) return vars(parser.parse_args(args)) def load_json(filename): - '''Load a json file''' - with open(filename, 'r') as fdesc: + """Load a json file""" + with open(filename, "r") as fdesc: return json.load(fdesc) def validate(schema_file=None, jams_files=None): - '''Validate a jams file against a schema''' - + """Validate a jams file against a schema""" schema = load_json(schema_file) for jams_file in jams_files: try: jams = load_json(jams_file) jsonschema.validate(jams, schema) - print '{:s} was successfully validated'.format(jams_file) + print("{:s} was successfully validated".format(jams_file)) except jsonschema.ValidationError as exc: - print '{:s} was NOT successfully validated'.format(jams_file) + print("{:s} was NOT successfully validated".format(jams_file)) - print exc + print(exc) -if __name__ == '__main__': +if __name__ == "__main__": validate(**process_arguments(sys.argv[1:])) diff --git a/jams/sonify.py b/jams/sonify.py index 50b2cee0..43160f7a 100644 --- a/jams/sonify.py +++ b/jams/sonify.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # CREATED:2015-12-12 18:20:37 by Brian McFee -r''' +r""" Sonification ------------ @@ -8,7 +8,7 @@ :toctree: generated/ sonify -''' +""" from itertools import product from collections import OrderedDict, defaultdict @@ -19,40 +19,37 @@ from .eval import coerce_annotation, hierarchy_flatten from .exceptions import NamespaceError -__all__ = ['sonify'] +__all__ = ["sonify"] def mkclick(freq, sr=22050, duration=0.1): - '''Generate a click sample. + """Generate a click sample. This replicates functionality from mir_eval.sonify.clicks, but exposes the target frequency and duration. - ''' - + """ times = np.arange(int(sr * duration)) click = np.sin(2 * np.pi * times * freq / float(sr)) - click *= np.exp(- times / (1e-2 * sr)) + click *= np.exp(-times / (1e-2 * sr)) return click def clicks(annotation, sr=22050, length=None, **kwargs): - '''Sonify events with clicks. + """Sonify events with clicks. This uses mir_eval.sonify.clicks, and is appropriate for instantaneous events such as beats or segment boundaries. - ''' - + """ interval, _ = annotation.to_interval_values() - return filter_kwargs(mir_eval.sonify.clicks, interval[:, 0], - fs=sr, length=length, **kwargs) + return filter_kwargs( + mir_eval.sonify.clicks, interval[:, 0], fs=sr, length=length, **kwargs + ) def downbeat(annotation, sr=22050, length=None, **kwargs): - '''Sonify beats and downbeats together. - ''' - + """Sonify beats and downbeats together.""" beat_click = mkclick(440 * 2, sr=sr) downbeat_click = mkclick(440 * 3, sr=sr) @@ -61,7 +58,7 @@ def downbeat(annotation, sr=22050, length=None, **kwargs): beats, downbeats = [], [] for time, value in zip(intervals[:, 0], values): - if value['position'] == 1: + if value["position"] == 1: downbeats.append(time) else: beats.append(time) @@ -69,81 +66,87 @@ def downbeat(annotation, sr=22050, length=None, **kwargs): if length is None: length = int(sr * np.max(intervals)) + len(beat_click) + 1 - y = filter_kwargs(mir_eval.sonify.clicks, - np.asarray(beats), - fs=sr, length=length, click=beat_click) - - y += filter_kwargs(mir_eval.sonify.clicks, - np.asarray(downbeats), - fs=sr, length=length, click=downbeat_click) + y = filter_kwargs( + mir_eval.sonify.clicks, + np.asarray(beats), + fs=sr, + length=length, + click=beat_click, + ) + + y += filter_kwargs( + mir_eval.sonify.clicks, + np.asarray(downbeats), + fs=sr, + length=length, + click=downbeat_click, + ) return y def multi_segment(annotation, sr=22050, length=None, **kwargs): - '''Sonify multi-level segmentations''' - + """Sonify multi-level segmentations""" # Pentatonic scale, because why not - PENT = [1, 32./27, 4./3, 3./2, 16./9] + PENT = [1, 32.0 / 27, 4.0 / 3, 3.0 / 2, 16.0 / 9] DURATION = 0.1 h_int, _ = hierarchy_flatten(annotation) if length is None: - length = int(sr * (max(np.max(_) for _ in h_int) + 1. / DURATION) + 1) + length = int(sr * (max(np.max(_) for _ in h_int) + 1.0 / DURATION) + 1) y = 0.0 - for ints, (oc, scale) in zip(h_int, product(range(3, 3 + len(h_int)), - PENT)): + for ints, (oc, scale) in zip(h_int, product(range(3, 3 + len(h_int)), PENT)): click = mkclick(440.0 * scale * oc, sr=sr, duration=DURATION) - y = y + filter_kwargs(mir_eval.sonify.clicks, - np.unique(ints), - fs=sr, length=length, - click=click) + y = y + filter_kwargs( + mir_eval.sonify.clicks, np.unique(ints), fs=sr, length=length, click=click + ) return y def chord(annotation, sr=22050, length=None, **kwargs): - '''Sonify chords + """Sonify chords This uses mir_eval.sonify.chords. - ''' - + """ intervals, chords = annotation.to_interval_values() - return filter_kwargs(mir_eval.sonify.chords, - chords, intervals, - fs=sr, length=length, - **kwargs) + return filter_kwargs( + mir_eval.sonify.chords, chords, intervals, fs=sr, length=length, **kwargs + ) def pitch_contour(annotation, sr=22050, length=None, **kwargs): - '''Sonify pitch contours. + """Sonify pitch contours. This uses mir_eval.sonify.pitch_contour, and should only be applied to pitch annotations using the pitch_contour namespace. Each contour is sonified independently, and the resulting waveforms are summed together. - ''' - + """ # Map contours to lists of observations times = defaultdict(list) freqs = defaultdict(list) for obs in annotation: - times[obs.value['index']].append(obs.time) - freqs[obs.value['index']].append(obs.value['frequency'] * - (-1)**(~obs.value['voiced'])) + times[obs.value["index"]].append(obs.time) + freqs[obs.value["index"]].append( + obs.value["frequency"] * (-1) ** (~obs.value["voiced"]) + ) y_out = 0.0 for ix in times: - y_out = y_out + filter_kwargs(mir_eval.sonify.pitch_contour, - np.asarray(times[ix]), - np.asarray(freqs[ix]), - fs=sr, length=length, - **kwargs) + y_out = y_out + filter_kwargs( + mir_eval.sonify.pitch_contour, + np.asarray(times[ix]), + np.asarray(freqs[ix]), + fs=sr, + length=length, + **kwargs + ) if length is None: length = len(y_out) @@ -151,13 +154,12 @@ def pitch_contour(annotation, sr=22050, length=None, **kwargs): def piano_roll(annotation, sr=22050, length=None, **kwargs): - '''Sonify a piano-roll + """Sonify a piano-roll This uses mir_eval.sonify.time_frequency, and is appropriate for sparse transcription data, e.g., annotations in the `note_midi` namespace. - ''' - + """ intervals, pitches = annotation.to_interval_values() # Construct the pitchogram @@ -168,37 +170,40 @@ def piano_roll(annotation, sr=22050, length=None, **kwargs): for col, f in enumerate(pitches): gram[pitch_map[f], col] = 1 - return filter_kwargs(mir_eval.sonify.time_frequency, - gram, np.asarray(pitches), np.asarray(intervals), - sr, length=length, **kwargs) + return filter_kwargs( + mir_eval.sonify.time_frequency, + gram, + np.asarray(pitches), + np.asarray(intervals), + sr, + length=length, + **kwargs + ) SONIFY_MAPPING = OrderedDict() -SONIFY_MAPPING['beat_position'] = downbeat -SONIFY_MAPPING['beat'] = clicks -SONIFY_MAPPING['multi_segment'] = multi_segment -SONIFY_MAPPING['segment_open'] = clicks -SONIFY_MAPPING['onset'] = clicks -SONIFY_MAPPING['chord'] = chord -SONIFY_MAPPING['note_hz'] = piano_roll -SONIFY_MAPPING['pitch_contour'] = pitch_contour +SONIFY_MAPPING["beat_position"] = downbeat +SONIFY_MAPPING["beat"] = clicks +SONIFY_MAPPING["multi_segment"] = multi_segment +SONIFY_MAPPING["segment_open"] = clicks +SONIFY_MAPPING["onset"] = clicks +SONIFY_MAPPING["chord"] = chord +SONIFY_MAPPING["note_hz"] = piano_roll +SONIFY_MAPPING["pitch_contour"] = pitch_contour def sonify(annotation, sr=22050, duration=None, **kwargs): - '''Sonify a jams annotation through mir_eval + """Sonify a jams annotation through mir_eval Parameters ---------- annotation : jams.Annotation The annotation to sonify - - sr = : positive number + sr : positive number The sampling rate of the output waveform - duration : float (optional) Optional length (in seconds) of the output waveform - - kwargs + **kwargs Additional keyword arguments to mir_eval.sonify functions Returns @@ -210,8 +215,7 @@ def sonify(annotation, sr=22050, duration=None, **kwargs): ------ NamespaceError If the annotation has an un-sonifiable namespace - ''' - + """ length = None if duration is None: @@ -223,10 +227,7 @@ def sonify(annotation, sr=22050, duration=None, **kwargs): # If the annotation can be directly sonified, try that first if annotation.namespace in SONIFY_MAPPING: ann = coerce_annotation(annotation, annotation.namespace) - return SONIFY_MAPPING[annotation.namespace](ann, - sr=sr, - length=length, - **kwargs) + return SONIFY_MAPPING[annotation.namespace](ann, sr=sr, length=length, **kwargs) for namespace, func in six.iteritems(SONIFY_MAPPING): try: @@ -235,5 +236,6 @@ def sonify(annotation, sr=22050, duration=None, **kwargs): except NamespaceError: pass - raise NamespaceError('Unable to sonify annotation of namespace="{:s}"' - .format(annotation.namespace)) + raise NamespaceError( + 'Unable to sonify annotation of namespace="{:s}"'.format(annotation.namespace) + ) diff --git a/jams/util.py b/jams/util.py index 483809fe..ed8127a3 100644 --- a/jams/util.py +++ b/jams/util.py @@ -22,7 +22,7 @@ def import_lab(namespace, filename, infer_duration=True, **parse_options): - r'''Load a .lab file as an Annotation object. + r"""Load a .lab file as an Annotation object. .lab files are assumed to have the following format: @@ -38,7 +38,6 @@ def import_lab(namespace, filename, infer_duration=True, **parse_options): If the .lab file contains more than three columns, each row's annotation value is assigned the contents of last non-empty column. - Parameters ---------- namespace : str @@ -58,7 +57,7 @@ def import_lab(namespace, filename, infer_duration=True, **parse_options): For instantaneous event annotations (e.g., beats or onsets), this should be set to `False`. - parse_options : additional keyword arguments + **parse_options : additional keyword arguments Passed to ``pandas.DataFrame.read_csv`` Returns @@ -69,31 +68,30 @@ def import_lab(namespace, filename, infer_duration=True, **parse_options): See Also -------- pandas.DataFrame.read_csv - ''' - + """ # Create a new annotation object annotation = core.Annotation(namespace) - parse_options.setdefault('sep', r'\s+') - parse_options.setdefault('engine', 'python') - parse_options.setdefault('header', None) - parse_options.setdefault('index_col', False) + parse_options.setdefault("sep", r"\s+") + parse_options.setdefault("engine", "python") + parse_options.setdefault("header", None) + parse_options.setdefault("index_col", False) # This is a hack to handle potentially ragged .lab data - parse_options.setdefault('names', range(20)) + parse_options.setdefault("names", range(20)) data = pd.read_csv(filename, **parse_options) # Drop all-nan columns - data = data.dropna(how='all', axis=1) + data = data.dropna(how="all", axis=1) # Do we need to add a duration column? # This only applies to event annotations if len(data.columns) == 2: # Insert a column of zeros after the timing - data.insert(1, 'duration', 0) + data.insert(1, "duration", 0) if infer_duration: - data['duration'][:-1] = data.loc[:, 0].diff()[1:].values + data["duration"][:-1] = data.loc[:, 0].diff()[1:].values else: # Convert from time to duration @@ -105,10 +103,7 @@ def import_lab(namespace, filename, infer_duration=True, **parse_options): value = [x for x in row[3:] if x is not None][-1] - annotation.append(time=time, - duration=duration, - confidence=1.0, - value=value) + annotation.append(time=time, duration=duration, confidence=1.0, value=value) return annotation @@ -120,7 +115,6 @@ def expand_filepaths(base_dir, rel_paths): ---------- base_dir : str The target base directory - rel_paths : list (or list-like) Collection of relative path strings @@ -145,11 +139,10 @@ def smkdirs(dpath, mode=0o777): ---------- dpath : str Path of directory/directories to create - mode : int [default=0777] Permissions for the new directories - See also + See Also -------- os.makedirs """ @@ -213,8 +206,8 @@ def find_with_extension(in_dir, ext, depth=3, sort=True): assert depth >= 1 ext = ext.strip(os.extsep) match = list() - for n in range(1, depth+1): - wildcard = os.path.sep.join(["*"]*n) + for n in range(1, depth + 1): + wildcard = os.path.sep.join(["*"] * n) search_path = os.path.join(in_dir, os.extsep.join([wildcard, ext])) match += glob.glob(search_path) diff --git a/jams/version.py b/jams/version.py index 093385af..c5e38650 100644 --- a/jams/version.py +++ b/jams/version.py @@ -2,5 +2,5 @@ # -*- coding: utf-8 -*- """Version info""" -short_version = '0.3' -version = '0.3.5a' +short_version = "0.3" +version = "0.3.5" diff --git a/pyproject.toml b/pyproject.toml index 0d4c8318..dfe16e6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,9 +5,10 @@ build-backend = "setuptools.build_meta" [tool.pytest.ini_options] addopts = [ "-v", - "--cov-report=term-missing", + "--cov-report=xml", "--cov=jams", ] + testpaths = [ "tests" -] \ No newline at end of file +] diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..2fe87f70 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,70 @@ +[coverage:report] +show_missing = True + +[pydocstyle] +# convention = numpy +# Below is equivalent to numpy convention + D400 and D205 +ignore = D107,D203,D205,D212,D213,D400,D402,D413,D415,D416,D417 + +[flake8] +count = True +statistics = True +show_source = True +select = + E9, + F63, + F7, + F82 + +[metadata] +name = jams +version = attr: jams.version.version +description = JAMS: A JSON Audio Metadata Standard +author = JAMS development crew +url = https://github.com/marl/jams +download_url = https://github.com/marl/jams/releases +long_description = file: README.md +long_description_content_type = text/markdown; charset=UTF-8 +license = ISC +python_requires = ">=3.9" +classifiers = + Programming Language :: Python + Development Status :: 3 - Alpha + Intended Audience :: Developers + Topic :: Multimedia :: Sound/Audio :: Analysis + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 + Programming Language :: Python :: 3.13 + +[options] +packages = find: +include_package_data = True +keywords = audio music json +package_data = +scripts = + scripts/jams_to_lab.py +install_requires = + numpy >= 1.20.0 + jsonschema >= 4.0.1 + pandas >= 1.2.0 + mir_eval >= 0.8.2 + sortedcontainers >= 2.1.0 + six + decorator + +[options.package_data] +jams = + schemata/*.json + schemata/namespaces/*.json + schemata/namespaces/*/*.json + +[options.extras_require] +display = + matplotlib >= 3.4.1 +tests = + pytest ~= 8.0 + pytest-cov + matplotlib >= 3.4.1 diff --git a/setup.py b/setup.py index a08cf6ec..7f1a1763 100644 --- a/setup.py +++ b/setup.py @@ -1,59 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import setup -import importlib.util -import importlib.machinery - - -def load_source(modname, filename): - loader = importlib.machinery.SourceFileLoader(modname, filename) - spec = importlib.util.spec_from_file_location(modname, filename, loader=loader) - module = importlib.util.module_from_spec(spec) - loader.exec_module(module) - return module - - -version = load_source('jams.version', 'jams/version.py') - -setup( - name='jams', - version=version.version, - description='A JSON Annotated Music Specification for Reproducible MIR Research', - author='JAMS development crew', - url='http://github.com/marl/jams', - download_url='http://github.com/marl/jams/releases', - packages=find_packages(), - package_data={'': ['schemata/*.json', - 'schemata/namespaces/*.json', - 'schemata/namespaces/*/*.json']}, - long_description='A JSON Annotated Music Specification for Reproducible MIR Research', - classifiers=[ - "License :: OSI Approved :: ISC License (ISCL)", - "Programming Language :: Python", - "Development Status :: 3 - Alpha", - "Intended Audience :: Developers", - "Topic :: Multimedia :: Sound/Audio :: Analysis", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - ], - python_requires=">=3.9", - keywords='audio music json', - license='ISC', - install_requires=[ - 'pandas', - 'sortedcontainers>=2.0.0', - 'jsonschema>=3.0.0', - 'numpy>=1.8.0', - 'six', - 'decorator', - 'mir_eval>=0.8.2' - ], - extras_require={ - 'display': ['matplotlib>=1.5.0'], - 'tests': ['pytest ~= 8.0', 'pytest-cov', 'matplotlib>=3'], - }, - scripts=['scripts/jams_to_lab.py'] -) +if __name__ == "__main__": + setup() diff --git a/tests/test_convert.py b/tests/test_convert.py index 2041517d..2ae7d559 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -'''namespace conversion tests''' +"""namespace conversion tests""" import numpy as np @@ -11,24 +11,24 @@ def test_bad_target(): - ann = jams.Annotation(namespace='tag_open') - ann.append(time=0, duration=1, value='foo', confidence=1) + ann = jams.Annotation(namespace="tag_open") + ann.append(time=0, duration=1, value="foo", confidence=1) with pytest.raises(NamespaceError): - jams.convert(ann, 'bad namespace') + jams.convert(ann, "bad namespace") -@pytest.mark.parametrize('target', - ['pitch_hz', 'pitch_midi', 'segment_open', - 'tag_open', 'beat', 'chord']) +@pytest.mark.parametrize( + "target", ["pitch_hz", "pitch_midi", "segment_open", "tag_open", "beat", "chord"] +) def test_bad_sources(target): - ann = jams.Annotation(namespace='vector') + ann = jams.Annotation(namespace="vector") with pytest.raises(NamespaceError): jams.convert(ann, target) -@pytest.mark.parametrize('namespace', list(jams.schema.__NAMESPACE__.keys())) +@pytest.mark.parametrize("namespace", list(jams.schema.__NAMESPACE__.keys())) def test_noop(namespace): ann = jams.Annotation(namespace=namespace) @@ -38,37 +38,37 @@ def test_noop(namespace): def test_pitch_hz_to_contour(): - ann = jams.Annotation(namespace='pitch_hz') + ann = jams.Annotation(namespace="pitch_hz") values = np.linspace(110, 220, num=100) # Unvoice the first half - values[::len(values)//2] *= -1 + values[:: len(values) // 2] *= -1 times = np.linspace(0, 1, num=len(values)) for t, v in zip(times, values): ann.append(time=t, value=v, duration=0) - ann2 = jams.convert(ann, 'pitch_contour') + ann2 = jams.convert(ann, "pitch_contour") ann.validate() ann2.validate() - assert ann2.namespace == 'pitch_contour' + assert ann2.namespace == "pitch_contour" # Check index values - assert ann2.data[0].value['index'] == 0 - assert ann2.data[-1].value['index'] == 0 + assert ann2.data[0].value["index"] == 0 + assert ann2.data[-1].value["index"] == 0 # Check frequency - assert np.abs(ann2.data[0].value['frequency'] == np.abs(values[0])) - assert np.abs(ann2.data[-1].value['frequency'] == np.abs(values[-1])) + assert np.abs(ann2.data[0].value["frequency"] == np.abs(values[0])) + assert np.abs(ann2.data[-1].value["frequency"] == np.abs(values[-1])) # Check voicings - assert not ann2.data[0].value['voiced'] - assert ann2.data[-1].value['voiced'] + assert not ann2.data[0].value["voiced"] + assert ann2.data[-1].value["voiced"] def test_pitch_midi_to_contour(): - ann = jams.Annotation(namespace='pitch_midi') + ann = jams.Annotation(namespace="pitch_midi") values = np.arange(100) times = np.linspace(0, 1, num=len(values)) @@ -76,29 +76,29 @@ def test_pitch_midi_to_contour(): for t, v in zip(times, values): ann.append(time=t, value=v, duration=0) - ann2 = jams.convert(ann, 'pitch_contour') + ann2 = jams.convert(ann, "pitch_contour") ann.validate() ann2.validate() - assert ann2.namespace == 'pitch_contour' + assert ann2.namespace == "pitch_contour" # Check index values - assert ann2.data[0].value['index'] == 0 - assert ann2.data[-1].value['index'] == 0 + assert ann2.data[0].value["index"] == 0 + assert ann2.data[-1].value["index"] == 0 # Check voicings - assert ann2.data[-1].value['voiced'] + assert ann2.data[-1].value["voiced"] def test_pitch_midi_to_hz(): - ann = jams.Annotation(namespace='pitch_midi') + ann = jams.Annotation(namespace="pitch_midi") ann.append(time=0, duration=1, value=69, confidence=0.5) - ann2 = jams.convert(ann, 'pitch_hz') + ann2 = jams.convert(ann, "pitch_hz") ann.validate() ann2.validate() # Check the namespace - assert ann2.namespace == 'pitch_hz' + assert ann2.namespace == "pitch_hz" # midi 69 = 440.0 Hz assert ann2.data[0].value == 440.0 @@ -113,14 +113,14 @@ def test_pitch_midi_to_hz(): def test_pitch_hz_to_midi(): - ann = jams.Annotation(namespace='pitch_hz') + ann = jams.Annotation(namespace="pitch_hz") ann.append(time=0, duration=1, value=440.0, confidence=0.5) - ann2 = jams.convert(ann, 'pitch_midi') + ann2 = jams.convert(ann, "pitch_midi") ann.validate() ann2.validate() # Check the namespace - assert ann2.namespace == 'pitch_midi' + assert ann2.namespace == "pitch_midi" # midi 69 = 440.0 Hz assert ann2.data[0].value == 69 @@ -135,14 +135,14 @@ def test_pitch_hz_to_midi(): def test_note_midi_to_hz(): - ann = jams.Annotation(namespace='note_midi') + ann = jams.Annotation(namespace="note_midi") ann.append(time=0, duration=1, value=69, confidence=0.5) - ann2 = jams.convert(ann, 'note_hz') + ann2 = jams.convert(ann, "note_hz") ann.validate() ann2.validate() # Check the namespace - assert ann2.namespace == 'note_hz' + assert ann2.namespace == "note_hz" # midi 69 = 440.0 Hz assert ann2.data[0].value == 440.0 @@ -157,14 +157,14 @@ def test_note_midi_to_hz(): def test_note_hz_to_midi(): - ann = jams.Annotation(namespace='note_hz') + ann = jams.Annotation(namespace="note_hz") ann.append(time=0, duration=1, value=440.0, confidence=0.5) - ann2 = jams.convert(ann, 'note_midi') + ann2 = jams.convert(ann, "note_midi") ann.validate() ann2.validate() # Check the namespace - assert ann2.namespace == 'note_midi' + assert ann2.namespace == "note_midi" # midi 69 = 440.0 Hz assert ann2.data[0].value == 69 @@ -179,15 +179,15 @@ def test_note_hz_to_midi(): def test_segment_open(): - ann = jams.Annotation(namespace='segment_salami_upper') - ann.append(time=0, duration=1, value='A', confidence=0.5) - ann2 = jams.convert(ann, 'segment_open') + ann = jams.Annotation(namespace="segment_salami_upper") + ann.append(time=0, duration=1, value="A", confidence=0.5) + ann2 = jams.convert(ann, "segment_open") ann.validate() ann2.validate() # Check the namespace - assert ann.namespace == 'segment_salami_upper' - assert ann2.namespace == 'segment_open' + assert ann.namespace == "segment_salami_upper" + assert ann2.namespace == "segment_open" # Check all else is equal assert ann.data == ann2.data @@ -195,15 +195,15 @@ def test_segment_open(): def test_tag_open(): - ann = jams.Annotation(namespace='tag_gtzan') - ann.append(time=0, duration=1, value='reggae', confidence=0.5) - ann2 = jams.convert(ann, 'tag_open') + ann = jams.Annotation(namespace="tag_gtzan") + ann.append(time=0, duration=1, value="reggae", confidence=0.5) + ann2 = jams.convert(ann, "tag_open") ann.validate() ann2.validate() # Check the namespace - assert ann.namespace == 'tag_gtzan' - assert ann2.namespace == 'tag_open' + assert ann.namespace == "tag_gtzan" + assert ann2.namespace == "tag_open" # Check all else is equal assert ann.data == ann2.data @@ -211,15 +211,15 @@ def test_tag_open(): def test_chord(): - ann = jams.Annotation(namespace='chord_harte') - ann.append(time=0, duration=1, value='C:maj6', confidence=0.5) - ann2 = jams.convert(ann, 'chord') + ann = jams.Annotation(namespace="chord_harte") + ann.append(time=0, duration=1, value="C:maj6", confidence=0.5) + ann2 = jams.convert(ann, "chord") ann.validate() ann2.validate() # Check the namespace - assert ann.namespace == 'chord_harte' - assert ann2.namespace == 'chord' + assert ann.namespace == "chord_harte" + assert ann2.namespace == "chord" # Check all else is equal assert ann.data == ann2.data @@ -227,23 +227,39 @@ def test_chord(): def test_beat_position(): - ann = jams.Annotation(namespace='beat_position') - ann.append(time=0, duration=0, confidence=0.5, - value=dict(position=1, measure=0, num_beats=4, beat_units=4)) - ann.append(time=0.5, duration=0, confidence=0.5, - value=dict(position=2, measure=0, num_beats=4, beat_units=4)) - ann.append(time=1, duration=0, confidence=0.5, - value=dict(position=3, measure=0, num_beats=4, beat_units=4)) - ann.append(time=1.5, duration=0, confidence=0.5, - value=dict(position=4, measure=0, num_beats=4, beat_units=4)) - - ann2 = jams.convert(ann, 'beat') + ann = jams.Annotation(namespace="beat_position") + ann.append( + time=0, + duration=0, + confidence=0.5, + value=dict(position=1, measure=0, num_beats=4, beat_units=4), + ) + ann.append( + time=0.5, + duration=0, + confidence=0.5, + value=dict(position=2, measure=0, num_beats=4, beat_units=4), + ) + ann.append( + time=1, + duration=0, + confidence=0.5, + value=dict(position=3, measure=0, num_beats=4, beat_units=4), + ) + ann.append( + time=1.5, + duration=0, + confidence=0.5, + value=dict(position=4, measure=0, num_beats=4, beat_units=4), + ) + + ann2 = jams.convert(ann, "beat") ann.validate() ann2.validate() # Check the namespace - assert ann2.namespace == 'beat' + assert ann2.namespace == "beat" # Check all else is equal assert len(ann) == len(ann2) @@ -254,7 +270,7 @@ def test_beat_position(): def test_scaper_tag_open(): - ann = jams.Annotation(namespace='scaper') + ann = jams.Annotation(namespace="scaper") value = { "source_time": 5, @@ -263,40 +279,40 @@ def test_scaper_tag_open(): "time_stretch": 0.8455598669219283, "pitch_shift": -1.2204911976305648, "snr": 7.790682558359417, - "label": 'gun_shot', + "label": "gun_shot", "role": "foreground", - "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav" + "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav", } ann.append(time=0, duration=1, value=value) - ann2 = jams.convert(ann, 'tag_open') + ann2 = jams.convert(ann, "tag_open") ann.validate() ann2.validate() - assert ann2.namespace == 'tag_open' + assert ann2.namespace == "tag_open" assert len(ann) == len(ann2) for obs1, obs2 in zip(ann.data, ann2.data): assert obs1.time == obs2.time assert obs1.duration == obs2.duration assert obs1.confidence == obs2.confidence - assert obs1.value['label'] == obs2.value + assert obs1.value["label"] == obs2.value def test_can_convert_equal(): - ann = jams.Annotation(namespace='chord') - assert jams.nsconvert.can_convert(ann, 'chord') + ann = jams.Annotation(namespace="chord") + assert jams.nsconvert.can_convert(ann, "chord") def test_can_convert_cast(): - ann = jams.Annotation(namespace='tag_gtzan') - assert jams.nsconvert.can_convert(ann, 'tag_open') + ann = jams.Annotation(namespace="tag_gtzan") + assert jams.nsconvert.can_convert(ann, "tag_open") def test_can_convert_fail(): - ann = jams.Annotation(namespace='tag_gtzan') - assert not jams.nsconvert.can_convert(ann, 'chord') + ann = jams.Annotation(namespace="tag_gtzan") + assert not jams.nsconvert.can_convert(ann, "chord") diff --git a/tests/test_display.py b/tests/test_display.py index 40b3cd5c..75a33780 100644 --- a/tests/test_display.py +++ b/tests/test_display.py @@ -4,7 +4,8 @@ import numpy as np import matplotlib -matplotlib.use('Agg') + +matplotlib.use("Agg") import matplotlib.pyplot as plt import pytest @@ -15,18 +16,29 @@ # A simple run-without-fail test for plotting -@pytest.mark.parametrize('namespace', - ['segment_open', 'chord', 'multi_segment', - 'pitch_contour', 'beat_position', 'beat', - 'onset', 'note_midi', 'tag_open']) -@pytest.mark.parametrize('meta', [False, True]) +@pytest.mark.parametrize( + "namespace", + [ + "segment_open", + "chord", + "multi_segment", + "pitch_contour", + "beat_position", + "beat", + "onset", + "note_midi", + "tag_open", + ], +) +@pytest.mark.parametrize("meta", [False, True]) def test_display(namespace, meta): ann = jams.Annotation(namespace=namespace) jams.display.display(ann, meta=meta) -@pytest.mark.parametrize('namespace', ['tempo']) -@pytest.mark.parametrize('meta', [False, True]) + +@pytest.mark.parametrize("namespace", ["tempo"]) +@pytest.mark.parametrize("meta", [False, True]) def test_display_exception(namespace, meta): with pytest.raises(NamespaceError): ann = jams.Annotation(namespace=namespace) @@ -36,22 +48,22 @@ def test_display_exception(namespace, meta): def test_display_multi(): jam = jams.JAMS() - jam.annotations.append(jams.Annotation(namespace='beat')) + jam.annotations.append(jams.Annotation(namespace="beat")) jams.display.display_multi(jam.annotations) def test_display_multi_multi(): jam = jams.JAMS() - jam.annotations.append(jams.Annotation(namespace='beat')) - jam.annotations.append(jams.Annotation(namespace='chord')) + jam.annotations.append(jams.Annotation(namespace="beat")) + jam.annotations.append(jams.Annotation(namespace="chord")) jams.display.display_multi(jam.annotations) def test_display_pitch_contour(): - ann = jams.Annotation(namespace='pitch_hz', duration=5) + ann = jams.Annotation(namespace="pitch_hz", duration=5) values = np.arange(100, 200) times = np.linspace(0, 2, num=len(values)) @@ -67,7 +79,7 @@ def test_display_labeled_events(): times = np.arange(40) values = times % 4 - ann = jams.Annotation(namespace='beat', duration=60) + ann = jams.Annotation(namespace="beat", duration=60) for t, v in zip(times, values): ann.append(time=t, value=v, duration=0) diff --git a/tests/test_eval.py b/tests/test_eval.py index 1542f28d..fd036688 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -'''mir_eval integration tests''' +"""mir_eval integration tests""" import numpy as np import pytest @@ -10,8 +10,7 @@ # Fixtures -def create_annotation(values, namespace='beat', offset=0.0, duration=1, - confidence=1): +def create_annotation(values, namespace="beat", offset=0.0, duration=1, confidence=1): ann = jams.Annotation(namespace=namespace) time = np.arange(offset, offset + len(values)) @@ -30,11 +29,10 @@ def create_annotation(values, namespace='beat', offset=0.0, duration=1, def create_hierarchy(values, offset=0.0, duration=20): - ann = jams.Annotation(namespace='multi_segment') + ann = jams.Annotation(namespace="multi_segment") for level, labels in enumerate(values): - times = np.linspace(offset, offset + duration, num=len(labels), - endpoint=False) + times = np.linspace(offset, offset + duration, num=len(labels), endpoint=False) durations = list(np.diff(times)) durations.append(duration + offset - times[-1]) @@ -45,175 +43,179 @@ def create_hierarchy(values, offset=0.0, duration=20): return ann -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ref_beat(): - return create_annotation(values=np.arange(10) % 4 + 0.5, - namespace='beat') + return create_annotation(values=np.arange(10) % 4 + 0.5, namespace="beat") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_beat(): - return create_annotation(values=np.arange(9) % 4 + 1, - namespace='beat', offset=0.01) + return create_annotation(values=np.arange(9) % 4 + 1, namespace="beat", offset=0.01) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ref_onset(): - return create_annotation(values=np.arange(10) % 4 + 1., - namespace='onset') + return create_annotation(values=np.arange(10) % 4 + 1.0, namespace="onset") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_onset(): - return create_annotation(values=np.arange(9) % 4 + 1., - namespace='onset', offset=0.01) + return create_annotation( + values=np.arange(9) % 4 + 1.0, namespace="onset", offset=0.01 + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ref_chord(): - return create_annotation(values=['C', 'E', 'G:min7'], - namespace='chord') + return create_annotation(values=["C", "E", "G:min7"], namespace="chord") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_chord(): - return create_annotation(values=['D', 'E', 'G:maj'], - namespace='chord_harte') + return create_annotation(values=["D", "E", "G:maj"], namespace="chord_harte") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_roman(): - return create_annotation(values=[{'tonic': 'C', 'chord': 'I'}], - namespace='chord_roman') + return create_annotation( + values=[{"tonic": "C", "chord": "I"}], namespace="chord_roman" + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_badchord(): - return create_annotation(values=['D', 'E', 'not at all a chord'], - namespace='chord_harte', offset=0.01) + return create_annotation( + values=["D", "E", "not at all a chord"], namespace="chord_harte", offset=0.01 + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ref_segment(): - return create_annotation(values=['A', 'B', 'A', 'C'], - namespace='segment_open') + return create_annotation(values=["A", "B", "A", "C"], namespace="segment_open") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_segment(): - return create_annotation(values=['E', 'B', 'E', 'B'], - namespace='segment_open') + return create_annotation(values=["E", "B", "E", "B"], namespace="segment_open") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_segtut(): - return create_annotation(values=[['F'], 'E', 'B', 'E', 'B'], - namespace='segment_tut') + return create_annotation( + values=[["F"], "E", "B", "E", "B"], namespace="segment_tut" + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ref_tempo(): - return create_annotation(values=[120.0, 60.0], confidence=[0.75, 0.25], - namespace='tempo') + return create_annotation( + values=[120.0, 60.0], confidence=[0.75, 0.25], namespace="tempo" + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_tempo(): - return create_annotation(values=[120.0, 80.0], confidence=[0.5, 0.5], - namespace='tempo') + return create_annotation( + values=[120.0, 80.0], confidence=[0.5, 0.5], namespace="tempo" + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_badtempo(): - return create_annotation(values=[120.0, 80.0], confidence=[-5, 1.5], - namespace='tempo') + return create_annotation( + values=[120.0, 80.0], confidence=[-5, 1.5], namespace="tempo" + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_tag(): - return create_annotation(values=['120.0', '80.0'], confidence=[0.5, 0.5], - namespace='tag_open') + return create_annotation( + values=["120.0", "80.0"], confidence=[0.5, 0.5], namespace="tag_open" + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ref_melody(): srand() freq = np.linspace(110.0, 440.0, 10) voice = np.sign(np.random.randn(len(freq))) - return create_annotation(values=freq * voice, confidence=1.0, - duration=0.01, - namespace='pitch_hz') + return create_annotation( + values=freq * voice, confidence=1.0, duration=0.01, namespace="pitch_hz" + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_melody(): srand() freq = np.linspace(110.0, 440.0, 10) voice = np.sign(np.random.randn(len(freq))) - return create_annotation(values=freq * voice, confidence=1.0, - duration=0.01, - namespace='pitch_hz') + return create_annotation( + values=freq * voice, confidence=1.0, duration=0.01, namespace="pitch_hz" + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_badmelody(): - return create_annotation(values=['a', 'b', 'c'], - confidence=1.0, - duration=0.01, - namespace='pitch_hz') + return create_annotation( + values=["a", "b", "c"], confidence=1.0, duration=0.01, namespace="pitch_hz" + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ref_pattern(): - ref_jam = jams.load('tests/fixtures/pattern_data.jams', validate=False) - return ref_jam.annotations['pattern_jku', 0] + ref_jam = jams.load("tests/fixtures/pattern_data.jams", validate=False) + return ref_jam.annotations["pattern_jku", 0] -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_badpattern(): - pattern = {'midi_pitch': 3, 'morph_pitch': 5, 'staff': 1, - 'pattern_id': None, 'occurrence_id': 1} + pattern = { + "midi_pitch": 3, + "morph_pitch": 5, + "staff": 1, + "pattern_id": None, + "occurrence_id": 1, + } - return create_annotation(values=[pattern], - confidence=1.0, - duration=0.01, - namespace='pattern_jku') + return create_annotation( + values=[pattern], confidence=1.0, duration=0.01, namespace="pattern_jku" + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ref_hier(): - return create_hierarchy(values=['AB', 'abac']) + return create_hierarchy(values=["AB", "abac"]) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_hier(): - return create_hierarchy(values=['ABCD', 'abacbcbd']) + return create_hierarchy(values=["ABCD", "abacbcbd"]) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_badhier(): return create_hierarchy(values=[[1, 2], [1, 2, 1, 3]]) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ref_transcript(): - ref_jam = jams.load('tests/fixtures/transcription_ref.jams', validate=False) - return ref_jam.annotations['pitch_hz', 0] + ref_jam = jams.load("tests/fixtures/transcription_ref.jams", validate=False) + return ref_jam.annotations["pitch_hz", 0] -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_transcript(): - est_jam = jams.load('tests/fixtures/transcription_est.jams', validate=False) - return est_jam.annotations['pitch_hz', 0] + est_jam = jams.load("tests/fixtures/transcription_est.jams", validate=False) + return est_jam.annotations["pitch_hz", 0] -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def est_badtranscript(): - est_jam = jams.load('tests/fixtures/transcription_est.jams', validate=False) - ann = est_jam.annotations['pitch_hz', 0] - ann.append(time=2., duration=1., value=None, confidence=1) + est_jam = jams.load("tests/fixtures/transcription_est.jams", validate=False) + ann = est_jam.annotations["pitch_hz", 0] + ann.append(time=2.0, duration=1.0, value=None, confidence=1) return ann diff --git a/tests/test_jams.py b/tests/test_jams.py index 4f417150..c808f172 100644 --- a/tests/test_jams.py +++ b/tests/test_jams.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- # CREATED:2015-03-06 14:24:58 by Brian McFee -'''Unit tests for JAMS core objects''' +"""Unit tests for JAMS core objects""" import os import tempfile @@ -17,14 +17,13 @@ import jams - # Borrowed from sklearn def clean_warning_registry(): - """Safe way to reset warnings """ + """Safe way to reset warnings""" warnings.resetwarnings() reg = "__warningregistry__" for mod_name, mod in list(sys.modules.items()): - if 'six.moves' in mod_name: + if "six.moves" in mod_name: continue if hasattr(mod, reg): getattr(mod, reg).clear() @@ -32,9 +31,10 @@ def clean_warning_registry(): # JObject + def test_jobject_dict(): - data = dict(key1='value 1', key2='value 2') + data = dict(key1="value 1", key2="value 2") J = jams.JObject(**data) @@ -45,7 +45,7 @@ def test_jobject_dict(): def test_jobject_serialize(): - data = dict(key1='value 1', key2='value 2') + data = dict(key1="value 1", key2="value 2") json_data = json.dumps(data, indent=2) @@ -62,7 +62,7 @@ def test_jobject_serialize(): def test_jobject_deserialize(): - data = dict(key1='value 1', key2='value 2') + data = dict(key1="value 1", key2="value 2") J = jams.JObject(**data) @@ -71,10 +71,14 @@ def test_jobject_deserialize(): assert J == jams.JObject.loads(json_jobject) -@pytest.mark.parametrize('d1', [dict(key1='value 1', key2='value 2')]) -@pytest.mark.parametrize('d2, match', - [(dict(key1='value 1', key2='value 2'), True), - (dict(key1='value 1', key2='value 3'), False)]) +@pytest.mark.parametrize("d1", [dict(key1="value 1", key2="value 2")]) +@pytest.mark.parametrize( + "d2, match", + [ + (dict(key1="value 1", key2="value 2"), True), + (dict(key1="value 1", key2="value 3"), False), + ], +) def test_jobject_eq(d1, d2, match): J1 = jams.JObject(**d1) J2 = jams.JObject(**d2) @@ -92,7 +96,7 @@ def test_jobject_eq(d1, d2, match): assert not J1 == J3 -@pytest.mark.parametrize('data, value', [({'key': True}, True), ({}, False)]) +@pytest.mark.parametrize("data, value", [({"key": True}, True), ({}, False)]) def test_jobject_nonzero(data, value): J = jams.JObject(**data) @@ -100,8 +104,7 @@ def test_jobject_nonzero(data, value): def test_jobject_repr(): - assert (repr(jams.JObject(foo=1, bar=2)) == - '') + assert repr(jams.JObject(foo=1, bar=2)) == "" def test_jobject_repr_html(): @@ -117,7 +120,7 @@ def test_jobject_repr_html(): # Sandbox def test_sandbox(): - data = dict(key1='value 1', key2='value 2') + data = dict(key1="value 1", key2="value 2") J = jams.Sandbox(**data) @@ -136,30 +139,32 @@ def test_sandbox_contains(): # Curator def test_curator(): - c = jams.Curator(name='myself', email='you@me.com') + c = jams.Curator(name="myself", email="you@me.com") - assert c.name == 'myself' - assert c.email == 'you@me.com' + assert c.name == "myself" + assert c.email == "you@me.com" # AnnotationMetadata @pytest.fixture def ann_meta_dummy(): - return dict(version='0', - corpus='test', - annotation_tools='nose', - annotation_rules='brains', - validation='unnecessary', - data_source='null') - - -@pytest.mark.parametrize('curator', [None, jams.Curator(name='nobody', - email='none@none.com')]) -@pytest.mark.parametrize('annotator', [None, jams.Sandbox(description='desc')]) + return dict( + version="0", + corpus="test", + annotation_tools="nose", + annotation_rules="brains", + validation="unnecessary", + data_source="null", + ) + + +@pytest.mark.parametrize( + "curator", [None, jams.Curator(name="nobody", email="none@none.com")] +) +@pytest.mark.parametrize("annotator", [None, jams.Sandbox(description="desc")]) def test_annotation_metadata(ann_meta_dummy, curator, annotator): - md = jams.AnnotationMetadata(curator=curator, annotator=annotator, - **ann_meta_dummy) + md = jams.AnnotationMetadata(curator=curator, annotator=annotator, **ann_meta_dummy) if curator is not None: assert dict(md.curator) == dict(curator) @@ -168,33 +173,35 @@ def test_annotation_metadata(ann_meta_dummy, curator, annotator): assert dict(md.annotator) == dict(annotator) real_data = dict(md) - real_data.pop('curator') - real_data.pop('annotator') + real_data.pop("curator") + real_data.pop("annotator") assert real_data == ann_meta_dummy # Annotation -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def tag_data(): - return [dict(time=0, duration=0.5, value='one', confidence=0.9), - dict(time=1.0, duration=0.5, value='two', confidence=0.9)] + return [ + dict(time=0, duration=0.5, value="one", confidence=0.9), + dict(time=1.0, duration=0.5, value="two", confidence=0.9), + ] -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ann_sandbox(): - return jams.Sandbox(description='ann_sandbox') + return jams.Sandbox(description="ann_sandbox") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ann_metadata(): - return jams.AnnotationMetadata(corpus='test collection') + return jams.AnnotationMetadata(corpus="test collection") -@pytest.mark.parametrize('namespace', ['tag_open']) +@pytest.mark.parametrize("namespace", ["tag_open"]) def test_annotation(namespace, tag_data, ann_metadata, ann_sandbox): - ann = jams.Annotation(namespace, data=tag_data, - annotation_metadata=ann_metadata, - sandbox=ann_sandbox) + ann = jams.Annotation( + namespace, data=tag_data, annotation_metadata=ann_metadata, sandbox=ann_sandbox + ) assert namespace == ann.namespace @@ -209,14 +216,16 @@ def test_annotation(namespace, tag_data, ann_metadata, ann_sandbox): def test_annotation_append(): - data = [dict(time=0, duration=0.5, value='one', confidence=0.9), - dict(time=1.0, duration=0.5, value='two', confidence=0.9)] + data = [ + dict(time=0, duration=0.5, value="one", confidence=0.9), + dict(time=1.0, duration=0.5, value="two", confidence=0.9), + ] - namespace = 'tag_open' + namespace = "tag_open" ann = jams.Annotation(namespace, data=data) - update = dict(time=2.0, duration=1.0, value='three', confidence=0.8) + update = dict(time=2.0, duration=1.0, value="three", confidence=0.8) ann.append(**update) @@ -224,7 +233,7 @@ def test_annotation_append(): def test_annotation_eq(tag_data): - namespace = 'tag_open' + namespace = "tag_open" ann1 = jams.Annotation(namespace, data=tag_data) ann2 = jams.Annotation(namespace, data=tag_data) @@ -234,7 +243,7 @@ def test_annotation_eq(tag_data): # Test the type-check in equality assert not (ann1 == tag_data) - update = dict(time=2.0, duration=1.0, value='three', confidence=0.8) + update = dict(time=2.0, duration=1.0, value="three", confidence=0.8) ann2.append(**update) @@ -243,10 +252,12 @@ def test_annotation_eq(tag_data): def test_annotation_iterator(): - data = [dict(time=0, duration=0.5, value='one', confidence=0.2), - dict(time=1, duration=1, value='two', confidence=0.5)] + data = [ + dict(time=0, duration=0.5, value="one", confidence=0.2), + dict(time=1, duration=1, value="two", confidence=0.5), + ] - namespace = 'tag_open' + namespace = "tag_open" ann = jams.Annotation(namespace, data=data) @@ -257,17 +268,17 @@ def test_annotation_iterator(): def test_annotation_interval_values(tag_data): - ann = jams.Annotation(namespace='tag_open', data=tag_data) + ann = jams.Annotation(namespace="tag_open", data=tag_data) intervals, values = ann.to_interval_values() assert np.allclose(intervals, np.array([[0.0, 0.5], [1.0, 1.5]])) - assert values == ['one', 'two'] + assert values == ["one", "two"] def test_annotation_badtype(): - an = jams.Annotation(namespace='tag_open') + an = jams.Annotation(namespace="tag_open") # This should throw a jams error because NoneType can't be indexed with pytest.raises(jams.JamsError): @@ -277,10 +288,9 @@ def test_annotation_badtype(): # FileMetadata def test_filemetadata(): - meta = dict(title='Test track', - artist='Test artist', - release='Test release', - duration=31.3) + meta = dict( + title="Test track", artist="Test artist", release="Test release", duration=31.3 + ) fm = jams.FileMetadata(**meta) dict_fm = dict(fm) @@ -291,23 +301,21 @@ def test_filemetadata(): def test_filemetadata_validation_warning(): # This should fail validation because null duration is not allowed - fm = jams.FileMetadata(title='Test track', - artist='Test artist', - release='Test release', - duration=None) + fm = jams.FileMetadata( + title="Test track", artist="Test artist", release="Test release", duration=None + ) clean_warning_registry() - with pytest.warns(UserWarning, match='.*(Failed validating).*') as out: + with pytest.warns(UserWarning, match=".*(Failed validating).*") as out: fm.validate(strict=False) def test_filemetadata_validation_strict(): # This should fail validation because null duration is not allowed - fm = jams.FileMetadata(title='Test track', - artist='Test artist', - release='Test release', - duration=None) + fm = jams.FileMetadata( + title="Test track", artist="Test artist", release="Test release", duration=None + ) clean_warning_registry() @@ -325,7 +333,7 @@ def test_annotation_array(): def test_annotation_array_data(tag_data): - ann = jams.Annotation('tag_open', data=tag_data) + ann = jams.Annotation("tag_open", data=tag_data) arr = jams.AnnotationArray(annotations=[ann, ann]) assert len(arr) == 2 @@ -339,7 +347,7 @@ def test_annotation_array_data(tag_data): def test_annotation_array_serialize(tag_data): - namespace = 'tag_open' + namespace = "tag_open" ann = jams.Annotation(namespace, data=tag_data) arr = jams.AnnotationArray(annotations=[ann, ann]) @@ -355,7 +363,7 @@ def test_annotation_array_index_simple(): jam = jams.JAMS() - anns = [jams.Annotation('beat') for _ in range(5)] + anns = [jams.Annotation("beat") for _ in range(5)] for ann in anns: jam.annotations.append(ann) @@ -370,7 +378,7 @@ def test_annotation_array_slice_simple(): jam = jams.JAMS() - anns = [jams.Annotation('beat') for _ in range(5)] + anns = [jams.Annotation("beat") for _ in range(5)] for ann in anns: jam.annotations.append(ann) @@ -383,34 +391,34 @@ def test_annotation_array_slice_simple(): def test_annotation_array_index_fancy(): jam = jams.JAMS() - ann = jams.Annotation(namespace='beat') + ann = jams.Annotation(namespace="beat") jam.annotations.append(ann) # We should have exactly one beat annotation - res = jam.annotations['beat'] + res = jam.annotations["beat"] assert len(res) == 1 assert res[0] == ann # Any other namespace should give an empty list - assert jam.annotations['segment'] == [] + assert jam.annotations["segment"] == [] def test_annotation_array_composite(): jam = jams.JAMS() for _ in range(10): - ann = jams.Annotation(namespace='beat') + ann = jams.Annotation(namespace="beat") jam.annotations.append(ann) - assert len(jam.annotations['beat', :3]) == 3 - assert len(jam.annotations['beat', 3:]) == 7 - assert len(jam.annotations['beat', 2::2]) == 4 + assert len(jam.annotations["beat", :3]) == 3 + assert len(jam.annotations["beat", 3:]) == 7 + assert len(jam.annotations["beat", 2::2]) == 4 def test_annotation_array_index_error(): jam = jams.JAMS() - ann = jams.Annotation(namespace='beat') + ann = jams.Annotation(namespace="beat") jam.annotations.append(ann) with pytest.raises(IndexError): @@ -418,38 +426,39 @@ def test_annotation_array_index_error(): # JAMS -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def file_metadata(): - return jams.FileMetadata(title='Test track', artist='Test artist', - release='Test release', duration=31.3) + return jams.FileMetadata( + title="Test track", artist="Test artist", release="Test release", duration=31.3 + ) def test_jams(tag_data, file_metadata, ann_sandbox): - ann = jams.Annotation('tag_open', data=tag_data) + ann = jams.Annotation("tag_open", data=tag_data) annotations = jams.AnnotationArray(annotations=[ann]) - jam = jams.JAMS(annotations=annotations, - file_metadata=file_metadata, - sandbox=ann_sandbox) + jam = jams.JAMS( + annotations=annotations, file_metadata=file_metadata, sandbox=ann_sandbox + ) assert dict(file_metadata) == dict(jam.file_metadata) assert dict(ann_sandbox) == dict(jam.sandbox) assert annotations == jam.annotations -@pytest.fixture(params=['jams', 'jamz']) -def output_path(request): - - _, jam_out = tempfile.mkstemp(suffix='.{:s}'.format(request.param)) - - yield jam_out +@pytest.fixture(params=["jams", "jamz"]) +def output_path(tmp_path, request): + # tmp_path is a pathlib.Path unique to this test‐invocation (including each param) + # build your filename with the desired extension + path = tmp_path / f"output.{request.param}" + # (optional) if you really need an existing file, you could touch it: + # path.write_bytes(b"") + return str(path) - os.unlink(jam_out) - -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def input_jam(): - return jams.load('tests/fixtures/valid.jams') + return jams.load("tests/fixtures/valid.jams") def test_jams_save(input_jam, output_path): @@ -461,7 +470,7 @@ def test_jams_save(input_jam, output_path): def test_jams_add(tag_data): - fn = 'tests/fixtures/valid.jams' + fn = "tests/fixtures/valid.jams" # The original jam jam_orig = jams.load(fn) @@ -469,7 +478,7 @@ def test_jams_add(tag_data): # Make a new jam with the same metadata and different data jam2 = jams.load(fn) - ann = jams.Annotation('tag_open', data=tag_data) + ann = jams.Annotation("tag_open", data=tag_data) jam2.annotations = jams.AnnotationArray(annotations=[ann]) # Add the two @@ -480,10 +489,9 @@ def test_jams_add(tag_data): assert jam.annotations[-1] == jam2.annotations[0] -@pytest.mark.parametrize('on_conflict', - ['overwrite', 'ignore']) +@pytest.mark.parametrize("on_conflict", ["overwrite", "ignore"]) def test_jams_add_conflict(on_conflict): - fn = 'tests/fixtures/valid.jams' + fn = "tests/fixtures/valid.jams" # The original jam jam = jams.load(fn) @@ -496,18 +504,18 @@ def test_jams_add_conflict(on_conflict): jam.add(jam2, on_conflict=on_conflict) - if on_conflict == 'overwrite': + if on_conflict == "overwrite": assert jam.file_metadata == jam2.file_metadata - elif on_conflict == 'ignore': + elif on_conflict == "ignore": assert jam.file_metadata == jam_orig.file_metadata -@pytest.mark.parametrize('on_conflict,exception', [ - ('fail', jams.JamsError), - ('bad_fail_mdoe', jams.ParameterError) -]) +@pytest.mark.parametrize( + "on_conflict,exception", + [("fail", jams.JamsError), ("bad_fail_mdoe", jams.ParameterError)], +) def test_jams_add_conflict_exceptions(on_conflict, exception): - fn = 'tests/fixtures/valid.jams' + fn = "tests/fixtures/valid.jams" # The original jam jam = jams.load(fn) @@ -520,24 +528,28 @@ def test_jams_add_conflict_exceptions(on_conflict, exception): jam.add(jam2, on_conflict=on_conflict) -jam = jams.load('tests/fixtures/valid.jams', validate=False) +jam = jams.load("tests/fixtures/valid.jams", validate=False) jam.annotations[0].sandbox.foo = None -@pytest.mark.parametrize('query, expected', - [(dict(corpus='SMC_MIREX'), jam.annotations), - (dict(), []), - (dict(namespace='beat'), jam.annotations[:1]), - (dict(namespace='tag_open'), jam.annotations[1:]), - (dict(namespace='segment_tut'), jams.AnnotationArray()), - (dict(foo='bar'), jams.AnnotationArray())]) +@pytest.mark.parametrize( + "query, expected", + [ + (dict(corpus="SMC_MIREX"), jam.annotations), + (dict(), []), + (dict(namespace="beat"), jam.annotations[:1]), + (dict(namespace="tag_open"), jam.annotations[1:]), + (dict(namespace="segment_tut"), jams.AnnotationArray()), + (dict(foo="bar"), jams.AnnotationArray()), + ], +) def test_jams_search(query, expected): result = jam.search(**query) def test_jams_validate_good(): - fn = 'tests/fixtures/valid.jams' + fn = "tests/fixtures/valid.jams" j1 = jams.load(fn, validate=False) j1.validate() @@ -545,9 +557,9 @@ def test_jams_validate_good(): j1.file_metadata.validate() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def jam_validate(): - j1 = jams.load('tests/fixtures/invalid.jams', validate=False) + j1 = jams.load("tests/fixtures/invalid.jams", validate=False) return j1 @@ -555,9 +567,10 @@ def test_jams_validate_warning(jam_validate): clean_warning_registry() - with pytest.warns(UserWarning, match='.*(Failed validating).*') as out: + with pytest.warns(UserWarning, match=".*(Failed validating).*") as out: jam_validate.validate(strict=False) + def test_jams_validate_exception(jam_validate): clean_warning_registry() @@ -577,11 +590,13 @@ def test_jams_bad_annotation_warnings(): jam = jams.JAMS() jam.file_metadata.duration = 10 - jam.annotations.append('not an annotation') + jam.annotations.append("not an annotation") clean_warning_registry() - with pytest.warns(UserWarning, match='.*(is not a well-formed JAMS Annotation).*') as out: + with pytest.warns( + UserWarning, match=".*(is not a well-formed JAMS Annotation).*" + ) as out: jam.validate(strict=False) @@ -589,7 +604,7 @@ def test_jams_bad_annotation_exception(): jam = jams.JAMS() jam.file_metadata.duration = 10 - jam.annotations.append('not an annotation') + jam.annotations.append("not an annotation") clean_warning_registry() @@ -602,7 +617,7 @@ def test_jams_bad_jam_warning(): clean_warning_registry() - with pytest.warns(UserWarning, match='.*(Failed validating).*') as out: + with pytest.warns(UserWarning, match=".*(Failed validating).*") as out: jam.validate(strict=False) @@ -637,34 +652,34 @@ def test_load_fail(): # Make a non-existent file tdir = tempfile.mkdtemp() with pytest.raises(IOError): - jams.load(os.path.join(tdir, 'nonexistent.jams'), fmt='jams') + jams.load(os.path.join(tdir, "nonexistent.jams"), fmt="jams") os.rmdir(tdir) # Make a non-json file tdir = tempfile.mkdtemp() - badfile = os.path.join(tdir, 'nonexistent.jams') - with open(badfile, mode='w') as fp: - fp.write('some garbage') + badfile = os.path.join(tdir, "nonexistent.jams") + with open(badfile, mode="w") as fp: + fp.write("some garbage") with pytest.raises(ValueError): - jams.load(os.path.join(tdir, 'nonexistent.jams'), fmt='jams') + jams.load(os.path.join(tdir, "nonexistent.jams"), fmt="jams") os.unlink(badfile) os.rmdir(tdir) tdir = tempfile.mkdtemp() - for ext in ['txt', '']: - badfile = os.path.join(tdir, 'nonexistent') + for ext in ["txt", ""]: + badfile = os.path.join(tdir, "nonexistent") with pytest.raises(jams.ParameterError): - jams.load('{:s}.{:s}'.format(badfile, ext), fmt='auto') + jams.load("{:s}.{:s}".format(badfile, ext), fmt="auto") with pytest.raises(jams.ParameterError): - jams.load('{:s}.{:s}'.format(badfile, ext), fmt=ext) + jams.load("{:s}.{:s}".format(badfile, ext), fmt=ext) with pytest.raises(jams.ParameterError): - jams.load('{:s}.jams'.format(badfile, ext), fmt=ext) + jams.load("{:s}.jams".format(badfile, ext), fmt=ext) # one last test, trying to load form a non-file-like object with pytest.raises(jams.ParameterError): - jams.load(None, fmt='auto') + jams.load(None, fmt="auto") os.rmdir(tdir) @@ -673,14 +688,12 @@ def test_load_valid(): # 3. test good jams file with strict validation # 4. test good jams file without strict validation - fn = 'tests/fixtures/valid' + fn = "tests/fixtures/valid" - for ext in ['jams', 'jamz']: + for ext in ["jams", "jamz"]: for validate in [False, True]: for strict in [False, True]: - jams.load('{:s}.{:s}'.format(fn, ext), - validate=validate, - strict=strict) + jams.load("{:s}.{:s}".format(fn, ext), validate=validate, strict=strict) def test_load_invalid(): @@ -688,12 +701,12 @@ def test_load_invalid(): def __test_warn(filename, valid, strict): clean_warning_registry() - with pytest.warns(UserWarning, match='.*(Failed validating).*'): + with pytest.warns(UserWarning, match=".*(Failed validating).*"): jams.load(filename, validate=valid, strict=strict) # 5. test bad jams file with strict validation # 6. test bad jams file without strict validation - fn = 'tests/fixtures/invalid.jams' + fn = "tests/fixtures/invalid.jams" # Test once with no validation jams.load(fn, validate=False, strict=False) @@ -708,14 +721,14 @@ def __test_warn(filename, valid, strict): def test_annotation_trim_bad_params(): # end_time must be greater than start_time - ann = jams.Annotation('tag_open') + ann = jams.Annotation("tag_open") with pytest.raises(jams.ParameterError): ann.trim(5, 3, strict=False) def test_annotation_trim_no_duration(): # When ann.duration is not set prior to trim should raise warning - ann = jams.Annotation('tag_open') + ann = jams.Annotation("tag_open") ann.duration = None clean_warning_registry() @@ -724,15 +737,15 @@ def test_annotation_trim_no_duration(): assert len(out) > 0 assert out[0].category is UserWarning - assert 'annotation.duration is not defined' in str(out[0].message).lower() + assert "annotation.duration is not defined" in str(out[0].message).lower() # When duration is not defined trim should keep all observations in the # user-specified trim range. - namespace = 'tag_open' + namespace = "tag_open" ann = jams.Annotation(namespace) ann.time = 100 ann.duration = None - ann.append(time=5, duration=2, value='one') + ann.append(time=5, duration=2, value="one") clean_warning_registry() with warnings.catch_warnings(record=True) as out: @@ -740,14 +753,12 @@ def test_annotation_trim_no_duration(): assert len(out) > 0 assert out[0].category is UserWarning - assert 'annotation.duration is not defined' in str(out[0].message).lower() + assert "annotation.duration is not defined" in str(out[0].message).lower() - expected_data = dict(time=[5.0], - duration=[2.0], - value=['one'], - confidence=[None]) - expected_ann = jams.Annotation(namespace, data=expected_data, time=5.0, - duration=3.0) + expected_data = dict(time=[5.0], duration=[2.0], value=["one"], confidence=[None]) + expected_ann = jams.Annotation( + namespace, data=expected_data, time=5.0, duration=3.0 + ) assert ann_trim.data == expected_ann.data @@ -755,7 +766,7 @@ def test_annotation_trim_no_duration(): def test_annotation_trim_no_overlap(): # when there's no overlap, a warning is raised and the # returned annotation should be empty - ann = jams.Annotation('tag_open') + ann = jams.Annotation("tag_open") ann.time = 5 ann.duration = 10 @@ -769,7 +780,7 @@ def test_annotation_trim_no_overlap(): assert len(out) > 0 assert out[0].category is UserWarning - assert 'does not intersect' in str(out[0].message).lower() + assert "does not intersect" in str(out[0].message).lower() assert len(ann_trim.data) == 0 assert ann_trim.time == ann.time @@ -778,11 +789,13 @@ def test_annotation_trim_no_overlap(): def test_annotation_trim_complete_overlap(): # For a valid scenario, ensure everything behaves as expected - namespace = 'tag_open' - data = dict(time=[5.0, 5.0, 10.0], - duration=[2.0, 4.0, 4.0], - value=['one', 'two', 'three'], - confidence=[0.9, 0.9, 0.9]) + namespace = "tag_open" + data = dict( + time=[5.0, 5.0, 10.0], + duration=[2.0, 4.0, 4.0], + value=["one", "two", "three"], + confidence=[0.9, 0.9, 0.9], + ) ann = jams.Annotation(namespace, data=data, time=5.0, duration=10.0) # When the trim region is completely inside the annotation time range @@ -792,17 +805,21 @@ def test_annotation_trim_complete_overlap(): assert ann_trim.time == 8 assert ann_trim.duration == 4 - assert ann_trim.sandbox.trim == [{'start_time': 8, 'end_time': 12, - 'trim_start': 8, 'trim_end': 12}] + assert ann_trim.sandbox.trim == [ + {"start_time": 8, "end_time": 12, "trim_start": 8, "trim_end": 12} + ] assert ann_trim.namespace == ann.namespace assert ann_trim.annotation_metadata == ann.annotation_metadata - expected_data = dict(time=[8.0, 10.0], - duration=[1.0, 2.0], - value=['two', 'three'], - confidence=[0.9, 0.9]) - expected_ann = jams.Annotation(namespace, data=expected_data, time=8.0, - duration=4.0) + expected_data = dict( + time=[8.0, 10.0], + duration=[1.0, 2.0], + value=["two", "three"], + confidence=[0.9, 0.9], + ) + expected_ann = jams.Annotation( + namespace, data=expected_data, time=8.0, duration=4.0 + ) assert ann_trim.data == expected_ann.data @@ -811,14 +828,16 @@ def test_annotation_trim_complete_overlap(): assert ann_trim.time == 8 assert ann_trim.duration == 4 - assert ann_trim.sandbox.trim == [{'start_time': 8, 'end_time': 12, - 'trim_start': 8, 'trim_end': 12}] + assert ann_trim.sandbox.trim == [ + {"start_time": 8, "end_time": 12, "trim_start": 8, "trim_end": 12} + ] assert ann_trim.namespace == ann.namespace assert ann_trim.annotation_metadata == ann.annotation_metadata expected_data = None - expected_ann = jams.Annotation(namespace, data=expected_data, time=8.0, - duration=4.0) + expected_ann = jams.Annotation( + namespace, data=expected_data, time=8.0, duration=4.0 + ) assert ann_trim.data == expected_ann.data @@ -827,28 +846,34 @@ def test_annotation_trim_partial_overlap_beginning(): # When the trim region only partially overlaps with the annotation time # range: at the beginning # strict=False - namespace = 'tag_open' - data = dict(time=[4.0, 5.0, 5.0, 5.0, 10.0], - duration=[1.0, 0.0, 2.0, 4.0, 4.0], - value=['none', 'zero', 'one', 'two', 'three'], - confidence=[1, 0.1, 0.9, 0.9, 0.9]) + namespace = "tag_open" + data = dict( + time=[4.0, 5.0, 5.0, 5.0, 10.0], + duration=[1.0, 0.0, 2.0, 4.0, 4.0], + value=["none", "zero", "one", "two", "three"], + confidence=[1, 0.1, 0.9, 0.9, 0.9], + ) ann = jams.Annotation(namespace, data=data, time=5.0, duration=10.0) ann_trim = ann.trim(1, 8, strict=False) assert ann_trim.time == 5 assert ann_trim.duration == 3 - assert ann_trim.sandbox.trim == [{'start_time': 1, 'end_time': 8, - 'trim_start': 5, 'trim_end': 8}] + assert ann_trim.sandbox.trim == [ + {"start_time": 1, "end_time": 8, "trim_start": 5, "trim_end": 8} + ] assert ann_trim.namespace == ann.namespace assert ann_trim.annotation_metadata == ann.annotation_metadata - expected_data = dict(time=[5.0, 5.0, 5.0], - duration=[0.0, 2.0, 3.0], - value=['zero', 'one', 'two'], - confidence=[0.1, 0.9, 0.9]) - expected_ann = jams.Annotation(namespace, data=expected_data, time=5.0, - duration=3.0) + expected_data = dict( + time=[5.0, 5.0, 5.0], + duration=[0.0, 2.0, 3.0], + value=["zero", "one", "two"], + confidence=[0.1, 0.9, 0.9], + ) + expected_ann = jams.Annotation( + namespace, data=expected_data, time=5.0, duration=3.0 + ) assert ann_trim.data == expected_ann.data @@ -857,17 +882,21 @@ def test_annotation_trim_partial_overlap_beginning(): assert ann_trim.time == 5 assert ann_trim.duration == 3 - assert ann_trim.sandbox.trim == [{'start_time': 1, 'end_time': 8, - 'trim_start': 5, 'trim_end': 8}] + assert ann_trim.sandbox.trim == [ + {"start_time": 1, "end_time": 8, "trim_start": 5, "trim_end": 8} + ] assert ann_trim.namespace == ann.namespace assert ann_trim.annotation_metadata == ann.annotation_metadata - expected_data = dict(time=[5.0, 5.0], - duration=[0.0, 2.0], - value=['zero', 'one'], - confidence=[0.1, 0.9]) - expected_ann = jams.Annotation(namespace, data=expected_data, time=5.0, - duration=3.0) + expected_data = dict( + time=[5.0, 5.0], + duration=[0.0, 2.0], + value=["zero", "one"], + confidence=[0.1, 0.9], + ) + expected_ann = jams.Annotation( + namespace, data=expected_data, time=5.0, duration=3.0 + ) assert ann_trim.data == expected_ann.data @@ -876,28 +905,34 @@ def test_annotation_trim_partial_overlap_end(): # When the trim region only partially overlaps with the annotation time # range: at the end # strict=False - namespace = 'tag_open' - data = dict(time=[5.0, 5.0, 10.0], - duration=[2.0, 4.0, 4.0], - value=['one', 'two', 'three'], - confidence=[0.9, 0.9, 0.9]) + namespace = "tag_open" + data = dict( + time=[5.0, 5.0, 10.0], + duration=[2.0, 4.0, 4.0], + value=["one", "two", "three"], + confidence=[0.9, 0.9, 0.9], + ) ann = jams.Annotation(namespace, data=data, time=5.0, duration=10.0) ann_trim = ann.trim(8, 20, strict=False) assert ann_trim.time == 8 assert ann_trim.duration == 7 - assert ann_trim.sandbox.trim == [{'start_time': 8, 'end_time': 20, - 'trim_start': 8, 'trim_end': 15}] + assert ann_trim.sandbox.trim == [ + {"start_time": 8, "end_time": 20, "trim_start": 8, "trim_end": 15} + ] assert ann_trim.namespace == ann.namespace assert ann_trim.annotation_metadata == ann.annotation_metadata - expected_data = dict(time=[8.0, 10.0], - duration=[1.0, 4.0], - value=['two', 'three'], - confidence=[0.9, 0.9]) - expected_ann = jams.Annotation(namespace, data=expected_data, time=8.0, - duration=7.0) + expected_data = dict( + time=[8.0, 10.0], + duration=[1.0, 4.0], + value=["two", "three"], + confidence=[0.9, 0.9], + ) + expected_ann = jams.Annotation( + namespace, data=expected_data, time=8.0, duration=7.0 + ) assert ann_trim.data == expected_ann.data @@ -906,17 +941,16 @@ def test_annotation_trim_partial_overlap_end(): assert ann_trim.time == 8 assert ann_trim.duration == 7 - assert ann_trim.sandbox.trim == [{'start_time': 8, 'end_time': 20, - 'trim_start': 8, 'trim_end': 15}] + assert ann_trim.sandbox.trim == [ + {"start_time": 8, "end_time": 20, "trim_start": 8, "trim_end": 15} + ] assert ann_trim.namespace == ann.namespace assert ann_trim.annotation_metadata == ann.annotation_metadata - expected_data = dict(time=[10.0], - duration=[4.0], - value=['three'], - confidence=[0.9]) - expected_ann = jams.Annotation(namespace, data=expected_data, time=8.0, - duration=7.0) + expected_data = dict(time=[10.0], duration=[4.0], value=["three"], confidence=[0.9]) + expected_ann = jams.Annotation( + namespace, data=expected_data, time=8.0, duration=7.0 + ) assert ann_trim.data == expected_ann.data @@ -924,29 +958,32 @@ def test_annotation_trim_partial_overlap_end(): def test_annotation_trim_multiple(): # Multiple trims # strict=False - namespace = 'tag_open' - data = dict(time=[5.0, 5.0, 10.0], - duration=[2.0, 4.0, 4.0], - value=['one', 'two', 'three'], - confidence=[0.9, 0.9, 0.9]) + namespace = "tag_open" + data = dict( + time=[5.0, 5.0, 10.0], + duration=[2.0, 4.0, 4.0], + value=["one", "two", "three"], + confidence=[0.9, 0.9, 0.9], + ) ann = jams.Annotation(namespace, data=data, time=5.0, duration=10.0) ann_trim = ann.trim(0, 10, strict=False).trim(8, 20, strict=False) assert ann_trim.time == 8 assert ann_trim.duration == 2 assert ann_trim.sandbox.trim == ( - [{'start_time': 0, 'end_time': 10, 'trim_start': 5, 'trim_end': 10}, - {'start_time': 8, 'end_time': 20, 'trim_start': 8, 'trim_end': 10}]) + [ + {"start_time": 0, "end_time": 10, "trim_start": 5, "trim_end": 10}, + {"start_time": 8, "end_time": 20, "trim_start": 8, "trim_end": 10}, + ] + ) assert ann_trim.namespace == ann.namespace assert ann_trim.annotation_metadata == ann.annotation_metadata - expected_data = dict(time=[8.0], - duration=[1.0], - value=['two'], - confidence=[0.9]) + expected_data = dict(time=[8.0], duration=[1.0], value=["two"], confidence=[0.9]) - expected_ann = jams.Annotation(namespace, data=expected_data, time=8.0, - duration=2.0) + expected_ann = jams.Annotation( + namespace, data=expected_data, time=8.0, duration=2.0 + ) assert ann_trim.data == expected_ann.data @@ -956,14 +993,18 @@ def test_annotation_trim_multiple(): assert ann_trim.duration == 2 # assert ann_trim.sandbox.trim == [(0, 10, 5, 10), (8, 20, 8, 10)] assert ann_trim.sandbox.trim == ( - [{'start_time': 0, 'end_time': 10, 'trim_start': 5, 'trim_end': 10}, - {'start_time': 8, 'end_time': 20, 'trim_start': 8, 'trim_end': 10}]) + [ + {"start_time": 0, "end_time": 10, "trim_start": 5, "trim_end": 10}, + {"start_time": 8, "end_time": 20, "trim_start": 8, "trim_end": 10}, + ] + ) assert ann_trim.namespace == ann.namespace assert ann_trim.annotation_metadata == ann.annotation_metadata expected_data = None - expected_ann = jams.Annotation(namespace, data=expected_data, time=8.0, - duration=2.0) + expected_ann = jams.Annotation( + namespace, data=expected_data, time=8.0, duration=2.0 + ) assert ann_trim.data == expected_ann.data @@ -995,11 +1036,13 @@ def test_jams_trim_valid(): jam = jams.JAMS() jam.file_metadata.duration = 15 - namespace = 'tag_open' - data = dict(time=[5.0, 5.0, 10.0], - duration=[2.0, 4.0, 4.0], - value=['one', 'two', 'three'], - confidence=[0.9, 0.9, 0.9]) + namespace = "tag_open" + data = dict( + time=[5.0, 5.0, 10.0], + duration=[2.0, 4.0, 4.0], + value=["one", "two", "three"], + confidence=[0.9, 0.9, 0.9], + ) ann = jams.Annotation(namespace, data=data, time=5.0, duration=10.0) for _ in range(5): jam.annotations.append(ann) @@ -1012,7 +1055,7 @@ def test_jams_trim_valid(): assert ann.data == ann_trim.data assert jam_trim.file_metadata.duration == jam.file_metadata.duration - assert jam_trim.sandbox.trim == [{'start_time': 0, 'end_time': 10}] + assert jam_trim.sandbox.trim == [{"start_time": 0, "end_time": 10}] # Multiple trims jam_trim = jam.trim(0, 10).trim(8, 10) @@ -1022,7 +1065,8 @@ def test_jams_trim_valid(): assert ann.data == ann_trim.data assert jam_trim.sandbox.trim == ( - [{'start_time': 0, 'end_time': 10}, {'start_time': 8, 'end_time': 10}]) + [{"start_time": 0, "end_time": 10}, {"start_time": 8, "end_time": 10}] + ) # Make sure file metadata copied over correctly assert jam_trim.file_metadata == jam.file_metadata @@ -1030,82 +1074,82 @@ def test_jams_trim_valid(): def test_annotation_slice(): - namespace = 'tag_open' - data = dict(time=[5.0, 6.0, 10.0], - duration=[2.0, 4.0, 4.0], - value=['one', 'two', 'three'], - confidence=[0.9, 0.9, 0.9]) + namespace = "tag_open" + data = dict( + time=[5.0, 6.0, 10.0], + duration=[2.0, 4.0, 4.0], + value=["one", "two", "three"], + confidence=[0.9, 0.9, 0.9], + ) ann = jams.Annotation(namespace, data=data, time=5.0, duration=10.0) # Slice out range that's completely inside the time range spanned by the # annotation ann_slice = ann.slice(8, 10, strict=False) - expected_data = dict(time=[0.0], - duration=[2.0], - value=['two'], - confidence=[0.9]) + expected_data = dict(time=[0.0], duration=[2.0], value=["two"], confidence=[0.9]) - expected_ann = jams.Annotation(namespace, data=expected_data, time=0, - duration=2.0) + expected_ann = jams.Annotation(namespace, data=expected_data, time=0, duration=2.0) assert ann_slice.data == expected_ann.data - assert ann_slice.sandbox.slice == [{'start_time': 8, - 'end_time': 10, - 'slice_start': 8, - 'slice_end': 10}] + assert ann_slice.sandbox.slice == [ + {"start_time": 8, "end_time": 10, "slice_start": 8, "slice_end": 10} + ] assert ann_slice.time == expected_ann.time assert ann_slice.duration == expected_ann.duration # Slice out range that's partially inside the time range spanned by the # annotation (starts BEFORE annotation starts) ann_slice = ann.slice(3, 10, strict=False) - expected_data = dict(time=[2.0, 3.0], - duration=[2.0, 4.0], - value=['one', 'two'], - confidence=[0.9, 0.9]) - - expected_ann = jams.Annotation(namespace, data=expected_data, time=2.0, - duration=5.0) + expected_data = dict( + time=[2.0, 3.0], + duration=[2.0, 4.0], + value=["one", "two"], + confidence=[0.9, 0.9], + ) + + expected_ann = jams.Annotation( + namespace, data=expected_data, time=2.0, duration=5.0 + ) assert ann_slice.time == expected_ann.time assert ann_slice.duration == expected_ann.duration assert ann_slice.data == expected_ann.data - assert ann_slice.sandbox.slice == [{'start_time': 3, - 'end_time': 10, - 'slice_start': 5, - 'slice_end': 10}] + assert ann_slice.sandbox.slice == [ + {"start_time": 3, "end_time": 10, "slice_start": 5, "slice_end": 10} + ] # Slice out range that's partially inside the time range spanned by the # annotation (starts AFTER annotation starts) ann_slice = ann.slice(8, 20, strict=False) - expected_data = dict(time=[0.0, 2.0], - duration=[2.0, 4.0], - value=['two', 'three'], - confidence=[0.9, 0.9]) + expected_data = dict( + time=[0.0, 2.0], + duration=[2.0, 4.0], + value=["two", "three"], + confidence=[0.9, 0.9], + ) - expected_ann = jams.Annotation(namespace, data=expected_data, time=0, - duration=7.0) + expected_ann = jams.Annotation(namespace, data=expected_data, time=0, duration=7.0) assert ann_slice.data == expected_ann.data assert ann_slice.sandbox.slice == ( - [{'start_time': 8, 'end_time': 20, 'slice_start': 8, 'slice_end': 15}]) + [{"start_time": 8, "end_time": 20, "slice_start": 8, "slice_end": 15}] + ) assert ann_slice.time == expected_ann.time assert ann_slice.duration == expected_ann.duration # Multiple slices ann_slice = ann.slice(0, 10).slice(8, 10) - expected_data = dict(time=[0.0], - duration=[2.0], - value=['two'], - confidence=[0.9]) + expected_data = dict(time=[0.0], duration=[2.0], value=["two"], confidence=[0.9]) - expected_ann = jams.Annotation(namespace, data=expected_data, time=0, - duration=2.0) + expected_ann = jams.Annotation(namespace, data=expected_data, time=0, duration=2.0) assert ann_slice.data == expected_ann.data assert ann_slice.sandbox.slice == ( - [{'start_time': 0, 'end_time': 10, 'slice_start': 5, 'slice_end': 10}, - {'start_time': 8, 'end_time': 10, 'slice_start': 8, 'slice_end': 10}]) + [ + {"start_time": 0, "end_time": 10, "slice_start": 5, "slice_end": 10}, + {"start_time": 8, "end_time": 10, "slice_start": 8, "slice_end": 10}, + ] + ) assert ann_slice.time == expected_ann.time assert ann_slice.duration == expected_ann.duration @@ -1127,11 +1171,13 @@ def test_jams_slice(): jam.slice(tt[0], tt[1], strict=False) # For a valid scenario, ensure everything behaves as expected - namespace = 'tag_open' - data = dict(time=[5.0, 5.0, 10.0], - duration=[2.0, 4.0, 4.0], - value=['one', 'two', 'three'], - confidence=[0.9, 0.9, 0.9]) + namespace = "tag_open" + data = dict( + time=[5.0, 5.0, 10.0], + duration=[2.0, 4.0, 4.0], + value=["one", "two", "three"], + confidence=[0.9, 0.9, 0.9], + ) ann = jams.Annotation(namespace, data=data, time=5.0, duration=10.0) for _ in range(5): jam.annotations.append(ann) @@ -1144,8 +1190,7 @@ def test_jams_slice(): assert ann.data == ann_slice.data assert jam_slice.file_metadata.duration == 10 - assert jam_slice.sandbox.slice == [{'start_time': 0, 'end_time': 10}] - + assert jam_slice.sandbox.slice == [{"start_time": 0, "end_time": 10}] # Multiple trims jam_slice = jam.slice(0, 10).slice(8, 10) @@ -1155,44 +1200,47 @@ def test_jams_slice(): assert ann.data == ann_slice.data assert jam_slice.sandbox.slice == ( - [{'start_time': 0, 'end_time': 10}, {'start_time': 8, 'end_time': 10}]) + [{"start_time": 0, "end_time": 10}, {"start_time": 8, "end_time": 10}] + ) # Make sure file metadata copied over correctly (except for duration) orig_metadata = dict(jam.file_metadata) slice_metadata = dict(jam_slice.file_metadata) - del orig_metadata['duration'] - del slice_metadata['duration'] + del orig_metadata["duration"] + del slice_metadata["duration"] assert slice_metadata == orig_metadata assert jam_slice.file_metadata.duration == 2 def test_annotation_data_frame(): - namespace = 'tag_open' - data = dict(time=[5.0, 5.0, 10.0], - duration=[2.0, 4.0, 4.0], - value=['one', 'two', 'three'], - confidence=[0.9, 0.9, 0.9]) + namespace = "tag_open" + data = dict( + time=[5.0, 5.0, 10.0], + duration=[2.0, 4.0, 4.0], + value=["one", "two", "three"], + confidence=[0.9, 0.9, 0.9], + ) ann = jams.Annotation(namespace, data=data, time=5.0, duration=10.0) df = ann.to_dataframe() - assert list(df.columns) == ['time', 'duration', 'value', 'confidence'] + assert list(df.columns) == ["time", "duration", "value", "confidence"] for i, row in df.iterrows(): - assert row.time == data['time'][i] - assert row.duration == data['duration'][i] - assert row.value == data['value'][i] - assert row.confidence == data['confidence'][i] + assert row.time == data["time"][i] + assert row.duration == data["duration"][i] + assert row.value == data["value"][i] + assert row.confidence == data["confidence"][i] def test_deprecated(): - @jams.core.deprecated('old version', 'new version') + @jams.core.deprecated("old version", "new version") def _foo(): pass warnings.resetwarnings() - warnings.simplefilter('always') + warnings.simplefilter("always") with warnings.catch_warnings(record=True) as out: _foo() @@ -1203,7 +1251,7 @@ def _foo(): assert out[0].category is DeprecationWarning # And that it says the right thing (roughly) - assert 'deprecated' in str(out[0].message).lower() + assert "deprecated" in str(out[0].message).lower() def test_numpy_serialize(): @@ -1214,21 +1262,25 @@ def test_numpy_serialize(): def test_annotation_serialize(): # Secondary test to trigger #159 on observation data - ann = jams.Annotation(namespace='tag_open', duration=1.0) - ann.append(time=np.float32(0), duration=np.float32(1), - value=np.float32(5), confidence=np.float32(0.5)) + ann = jams.Annotation(namespace="tag_open", duration=1.0) + ann.append( + time=np.float32(0), + duration=np.float32(1), + value=np.float32(5), + confidence=np.float32(0.5), + ) ann.dumps() -@pytest.mark.parametrize('confidence', [False, True]) +@pytest.mark.parametrize("confidence", [False, True]) def test_annotation_to_samples(confidence): - ann = jams.Annotation('tag_open') + ann = jams.Annotation("tag_open") - ann.append(time=0, duration=0.5, value='one', confidence=0.1) - ann.append(time=0.25, duration=0.5, value='two', confidence=0.2) - ann.append(time=0.75, duration=0.5, value='three', confidence=0.3) - ann.append(time=1.5, duration=0.5, value='four', confidence=0.4) + ann.append(time=0, duration=0.5, value="one", confidence=0.1) + ann.append(time=0.25, duration=0.5, value="two", confidence=0.2) + ann.append(time=0.75, duration=0.5, value="three", confidence=0.3) + ann.append(time=1.5, duration=0.5, value="four", confidence=0.4) values = ann.to_samples([0.2, 0.4, 0.75, 1.25, 1.75, 1.4], confidence=confidence) @@ -1236,30 +1288,37 @@ def test_annotation_to_samples(confidence): values, confs = values assert confs == [[0.1], [0.1, 0.2], [0.2, 0.3], [0.3], [0.4], []] - assert values == [['one'], ['one', 'two'], ['two', 'three'], ['three'], ['four'], []] + assert values == [ + ["one"], + ["one", "two"], + ["two", "three"], + ["three"], + ["four"], + [], + ] + def test_annotation_to_samples_fail_neg(): - ann = jams.Annotation('tag_open') + ann = jams.Annotation("tag_open") - ann.append(time=0, duration=0.5, value='one', confidence=0.1) - ann.append(time=0.25, duration=0.5, value='two', confidence=0.2) - ann.append(time=0.75, duration=0.5, value='three', confidence=0.3) - ann.append(time=1.5, duration=0.5, value='four', confidence=0.4) + ann.append(time=0, duration=0.5, value="one", confidence=0.1) + ann.append(time=0.25, duration=0.5, value="two", confidence=0.2) + ann.append(time=0.75, duration=0.5, value="three", confidence=0.3) + ann.append(time=1.5, duration=0.5, value="four", confidence=0.4) with pytest.raises(jams.ParameterError): values = ann.to_samples([-0.2, 0.4, 0.75, 1.25, 1.75, 1.4]) - def test_annotation_to_samples_fail_shape(): - ann = jams.Annotation('tag_open') + ann = jams.Annotation("tag_open") - ann.append(time=0, duration=0.5, value='one', confidence=0.1) - ann.append(time=0.25, duration=0.5, value='two', confidence=0.2) - ann.append(time=0.75, duration=0.5, value='three', confidence=0.3) - ann.append(time=1.5, duration=0.5, value='four', confidence=0.4) + ann.append(time=0, duration=0.5, value="one", confidence=0.1) + ann.append(time=0.25, duration=0.5, value="two", confidence=0.2) + ann.append(time=0.75, duration=0.5, value="three", confidence=0.3) + ann.append(time=1.5, duration=0.5, value="four", confidence=0.4) with pytest.raises(jams.ParameterError): values = ann.to_samples([[0.2, 0.4, 0.75, 1.25, 1.75, 1.4]]) diff --git a/tests/test_ns.py b/tests/test_ns.py index 96aaa748..ae847972 100644 --- a/tests/test_ns.py +++ b/tests/test_ns.py @@ -20,7 +20,7 @@ def test_ns_time_valid(): - ann = Annotation(namespace='onset') + ann = Annotation(namespace="onset") for time in np.arange(5.0, 10.0): ann.append(time=time, duration=0.0, value=None, confidence=None) @@ -28,14 +28,13 @@ def test_ns_time_valid(): ann.validate() -@parametrize('time, duration', [(-1, 0), (1, -1)]) +@parametrize("time, duration", [(-1, 0), (1, -1)]) def test_ns_time_invalid(time, duration): - ann = Annotation(namespace='onset') + ann = Annotation(namespace="onset") # Bypass the safety checks in append - ann.data.add(Observation(time=time, duration=duration, - value=None, confidence=None)) + ann.data.add(Observation(time=time, duration=duration, value=None, confidence=None)) with pytest.raises(jams.SchemaError): ann.validate() @@ -44,7 +43,7 @@ def test_ns_time_invalid(time, duration): def test_ns_beat_valid(): # A valid example - ann = Annotation(namespace='beat') + ann = Annotation(namespace="beat") for time in np.arange(5.0): ann.append(time=time, duration=0.0, value=1, confidence=None) @@ -57,10 +56,10 @@ def test_ns_beat_valid(): def test_ns_beat_invalid(): - ann = Annotation(namespace='beat') + ann = Annotation(namespace="beat") for time in np.arange(5.0): - ann.append(time=time, duration=0.0, value='foo', confidence=None) + ann.append(time=time, duration=0.0, value="foo", confidence=None) with pytest.raises(jams.SchemaError): ann.validate() @@ -68,45 +67,57 @@ def test_ns_beat_invalid(): def test_ns_beat_position_valid(): - ann = Annotation(namespace='beat_position') + ann = Annotation(namespace="beat_position") - ann.append(time=0, duration=1.0, value=dict(position=1, - measure=1, - num_beats=3, - beat_units=4)) + ann.append( + time=0, + duration=1.0, + value=dict(position=1, measure=1, num_beats=3, beat_units=4), + ) ann.validate() -@parametrize('key, value', - [('position', -1), ('position', 0), - ('position', 'a'), ('position', None), - ('measure', -1), ('measure', 1.0), - ('measure', 'a'), ('measure', None), - ('num_beats', -1), ('num_beats', 1.5), - ('num_beats', 'a'), ('num_beats', None), - ('beat_units', -1), ('beat_units', 1.5), - ('beat_units', 3), ('beat_units', 'a'), - ('beat_units', None)]) +@parametrize( + "key, value", + [ + ("position", -1), + ("position", 0), + ("position", "a"), + ("position", None), + ("measure", -1), + ("measure", 1.0), + ("measure", "a"), + ("measure", None), + ("num_beats", -1), + ("num_beats", 1.5), + ("num_beats", "a"), + ("num_beats", None), + ("beat_units", -1), + ("beat_units", 1.5), + ("beat_units", 3), + ("beat_units", "a"), + ("beat_units", None), + ], +) def test_ns_beat_position_invalid(key, value): data = dict(position=1, measure=1, num_beats=3, beat_units=4) data[key] = value - ann = Annotation(namespace='beat_position') + ann = Annotation(namespace="beat_position") ann.append(time=0, duration=1.0, value=data) with pytest.raises(jams.SchemaError): ann.validate() -@parametrize('key', - ['position', 'measure', 'num_beats', 'beat_units']) +@parametrize("key", ["position", "measure", "num_beats", "beat_units"]) def test_ns_beat_position_missing(key): data = dict(position=1, measure=1, num_beats=3, beat_units=4) del data[key] - ann = Annotation(namespace='beat_position') + ann = Annotation(namespace="beat_position") ann.append(time=0, duration=1.0, value=data) with pytest.raises(jams.SchemaError): @@ -115,17 +126,17 @@ def test_ns_beat_position_missing(key): def test_ns_mood_thayer_valid(): - ann = Annotation(namespace='mood_thayer') + ann = Annotation(namespace="mood_thayer") ann.append(time=0, duration=1.0, value=[0.3, 2.0]) ann.validate() -@parametrize('value', [[0], [0, 1, 2], ['a', 'b'], None, 0]) +@parametrize("value", [[0], [0, 1, 2], ["a", "b"], None, 0]) def test_ns_mood_thayer_invalid(value): - ann = Annotation(namespace='mood_thayer') + ann = Annotation(namespace="mood_thayer") ann.append(time=0, duration=1.0, value=value) with pytest.raises(jams.SchemaError): ann.validate() @@ -134,7 +145,7 @@ def test_ns_mood_thayer_invalid(value): def test_ns_onset(): # A valid example - ann = Annotation(namespace='onset') + ann = Annotation(namespace="onset") for time in np.arange(5.0): ann.append(time=time, duration=0.0, value=1, confidence=None) @@ -145,18 +156,17 @@ def test_ns_onset(): ann.validate() -@parametrize('lyric', - ['Check yourself', six.u('before you wreck yourself')]) +@parametrize("lyric", ["Check yourself", six.u("before you wreck yourself")]) def test_ns_lyrics(lyric): - ann = Annotation(namespace='lyrics') + ann = Annotation(namespace="lyrics") ann.append(time=0, duration=1, value=lyric) ann.validate() -@parametrize('lyric', [23, None]) +@parametrize("lyric", [23, None]) def test_ns_lyrics_invalid(lyric): - ann = Annotation(namespace='lyrics') + ann = Annotation(namespace="lyrics") ann.append(time=0, duration=1, value=lyric) with pytest.raises(SchemaError): ann.validate() @@ -164,20 +174,28 @@ def test_ns_lyrics_invalid(lyric): def test_ns_tempo_valid(): - ann = Annotation(namespace='tempo') + ann = Annotation(namespace="tempo") ann.append(time=0, duration=0, value=1, confidence=0.85) ann.validate() -@parametrize('value, confidence', - [(-1, 0.5), (-0.5, 0.5), ('a', 0.5), - (120.0, -1), (120.0, -0.5), - (120.0, 2.0), (120.0, 'a')]) +@parametrize( + "value, confidence", + [ + (-1, 0.5), + (-0.5, 0.5), + ("a", 0.5), + (120.0, -1), + (120.0, -0.5), + (120.0, 2.0), + (120.0, "a"), + ], +) def test_ns_tempo_invalid(value, confidence): - ann = Annotation(namespace='tempo') + ann = Annotation(namespace="tempo") ann.append(time=0, duration=0, value=value, confidence=confidence) with pytest.raises(jams.SchemaError): @@ -186,25 +204,25 @@ def test_ns_tempo_invalid(value, confidence): def test_ns_note_hz_valid(): - ann = Annotation(namespace='note_hz') + ann = Annotation(namespace="note_hz") seq_len = 21 times = np.arange(seq_len) durations = np.zeros(seq_len) values = np.linspace(0, 22050, seq_len) # includes 0 (odd symmetric) - confidences = np.linspace(0, 1., seq_len) - confidences[seq_len//2] = None # throw in a None confidence value + confidences = np.linspace(0, 1.0, seq_len) + confidences[seq_len // 2] = None # throw in a None confidence value - for (t, d, v, c) in zip(times, durations, values, confidences): + for t, d, v, c in zip(times, durations, values, confidences): ann.append(time=t, duration=d, value=v, confidence=c) ann.validate() -@parametrize('value', ['a', -23]) +@parametrize("value", ["a", -23]) def test_ns_note_hz_invalid(value): - ann = Annotation(namespace='note_hz') + ann = Annotation(namespace="note_hz") ann.append(time=0, duration=0, value=value, confidence=0.5) with pytest.raises(jams.SchemaError): @@ -213,25 +231,25 @@ def test_ns_note_hz_invalid(value): def test_ns_pitch_hz_valid(): - ann = Annotation(namespace='pitch_hz') + ann = Annotation(namespace="pitch_hz") seq_len = 21 times = np.arange(seq_len) durations = np.zeros(seq_len) - values = np.linspace(-22050., 22050, seq_len) # includes 0 (odd symmetric) - confidences = np.linspace(0, 1., seq_len) - confidences[seq_len//2] = None # throw in a None confidence value + values = np.linspace(-22050.0, 22050, seq_len) # includes 0 (odd symmetric) + confidences = np.linspace(0, 1.0, seq_len) + confidences[seq_len // 2] = None # throw in a None confidence value - for (t, d, v, c) in zip(times, durations, values, confidences): + for t, d, v, c in zip(times, durations, values, confidences): ann.append(time=t, duration=d, value=v, confidence=c) ann.validate() -@parametrize('value', ['a']) +@parametrize("value", ["a"]) def test_ns_pitch_hz_invalid(value): - ann = Annotation(namespace='pitch_hz') + ann = Annotation(namespace="pitch_hz") ann.append(time=0, duration=0, value=value, confidence=0.5) with pytest.raises(jams.SchemaError): @@ -240,25 +258,25 @@ def test_ns_pitch_hz_invalid(value): def test_ns_note_midi_valid(): - ann = Annotation(namespace='note_midi') + ann = Annotation(namespace="note_midi") seq_len = 21 times = np.arange(seq_len) durations = np.zeros(seq_len) - values = np.linspace(-108., 108, seq_len) # includes 0 (odd symmetric) - confidences = np.linspace(0, 1., seq_len) - confidences[seq_len//2] = None # throw in a None confidence value + values = np.linspace(-108.0, 108, seq_len) # includes 0 (odd symmetric) + confidences = np.linspace(0, 1.0, seq_len) + confidences[seq_len // 2] = None # throw in a None confidence value - for (t, d, v, c) in zip(times, durations, values, confidences): + for t, d, v, c in zip(times, durations, values, confidences): ann.append(time=t, duration=d, value=v, confidence=c) ann.validate() -@parametrize('value', ['a']) +@parametrize("value", ["a"]) def test_ns_note_midi_invalid(value): - ann = Annotation(namespace='note_midi') + ann = Annotation(namespace="note_midi") ann.append(time=0, duration=0, value=value, confidence=0.5) with pytest.raises(jams.SchemaError): @@ -267,25 +285,25 @@ def test_ns_note_midi_invalid(value): def test_ns_pitch_midi_valid(): - ann = Annotation(namespace='pitch_midi') + ann = Annotation(namespace="pitch_midi") seq_len = 21 times = np.arange(seq_len) durations = np.zeros(seq_len) - values = np.linspace(-108., 108, seq_len) # includes 0 (odd symmetric) - confidences = np.linspace(0, 1., seq_len) - confidences[seq_len//2] = None # throw in a None confidence value + values = np.linspace(-108.0, 108, seq_len) # includes 0 (odd symmetric) + confidences = np.linspace(0, 1.0, seq_len) + confidences[seq_len // 2] = None # throw in a None confidence value - for (t, d, v, c) in zip(times, durations, values, confidences): + for t, d, v, c in zip(times, durations, values, confidences): ann.append(time=t, duration=d, value=v, confidence=c) ann.validate() -@parametrize('value', ['a']) +@parametrize("value", ["a"]) def test_ns_pitch_midi_invalid(value): - ann = Annotation(namespace='pitch_midi') + ann = Annotation(namespace="pitch_midi") ann.append(time=0, duration=0, value=value, confidence=0.5) with pytest.raises(jams.SchemaError): @@ -296,7 +314,7 @@ def test_ns_contour_valid(): srand() - ann = Annotation(namespace='pitch_contour') + ann = Annotation(namespace="pitch_contour") seq_len = 21 times = np.arange(seq_len) @@ -305,13 +323,13 @@ def test_ns_contour_valid(): ids = np.arange(len(values)) // 4 voicing = np.random.randn(len(ids)) > 0 - confidences = np.linspace(0, 1., seq_len) - confidences[seq_len//2] = None # throw in a None confidence value + confidences = np.linspace(0, 1.0, seq_len) + confidences[seq_len // 2] = None # throw in a None confidence value - for (t, d, v, c, i, b) in zip(times, durations, values, - confidences, ids, voicing): - ann.append(time=t, duration=d, - value={'pitch': v, 'id': i, 'voiced': b}, confidence=c) + for t, d, v, c, i, b in zip(times, durations, values, confidences, ids, voicing): + ann.append( + time=t, duration=d, value={"pitch": v, "id": i, "voiced": b}, confidence=c + ) ann.validate() @@ -320,7 +338,7 @@ def test_ns_contour_invalid(): srand() - ann = Annotation(namespace='pitch_contour') + ann = Annotation(namespace="pitch_contour") seq_len = 21 times = np.arange(seq_len) @@ -329,261 +347,289 @@ def test_ns_contour_invalid(): ids = np.arange(len(values)) // 4 voicing = np.random.randn(len(ids)) * 2 - confidences = np.linspace(0, 1., seq_len) - confidences[seq_len//2] = None # throw in a None confidence value + confidences = np.linspace(0, 1.0, seq_len) + confidences[seq_len // 2] = None # throw in a None confidence value - for (t, d, v, c, i, b) in zip(times, durations, values, - confidences, ids, voicing): - ann.append(time=t, duration=d, - value={'pitch': v, 'id': i, 'voiced': b}, confidence=c) + for t, d, v, c, i, b in zip(times, durations, values, confidences, ids, voicing): + ann.append( + time=t, duration=d, value={"pitch": v, "id": i, "voiced": b}, confidence=c + ) ann.validate() -@parametrize('value', - ['B#:locrian', six.u('A:minor'), 'N', 'E']) +@parametrize("value", ["B#:locrian", six.u("A:minor"), "N", "E"]) def test_ns_key_mode(value): - ann = Annotation(namespace='key_mode') + ann = Annotation(namespace="key_mode") ann.append(time=0, duration=0, value=value, confidence=None) ann.validate() -@parametrize('value', - ['asdf', 'A&:phrygian', 11, '', ':dorian', None]) + +@parametrize("value", ["asdf", "A&:phrygian", 11, "", ":dorian", None]) def test_ns_key_mode_schema_error(value): - ann = Annotation(namespace='key_mode') + ann = Annotation(namespace="key_mode") ann.append(time=0, duration=0, value=value, confidence=None) with pytest.raises(jams.SchemaError): ann.validate() -@parametrize('value', - ['A:9', 'Gb:sus2(1,3,5)', 'X', 'C:13(*9)/b7']) +@parametrize("value", ["A:9", "Gb:sus2(1,3,5)", "X", "C:13(*9)/b7"]) def test_ns_chord_valid(value): - ann = Annotation(namespace='chord') + ann = Annotation(namespace="chord") ann.append(time=0, duration=1.0, value=value) ann.validate() -@parametrize('value', - ['{}:maj'.format(_) - for _ in [42, 'H', 'a', 'F1', True, None]] + - ['C:{}'.format(_) - for _ in [64, 'z', 'mj', 'Ab', 'iiii', False, None]] + - ['C/{}'.format(_) - for _ in ['A', 7.5, '8b']] + [None]) +@parametrize( + "value", + ["{}:maj".format(_) for _ in [42, "H", "a", "F1", True, None]] + + ["C:{}".format(_) for _ in [64, "z", "mj", "Ab", "iiii", False, None]] + + ["C/{}".format(_) for _ in ["A", 7.5, "8b"]] + + [None], +) def test_ns_chord_invalid(value): - ann = Annotation(namespace='chord') + ann = Annotation(namespace="chord") ann.append(time=0, duration=1.0, value=value) with pytest.raises(SchemaError): ann.validate() -@parametrize('value', ['B:7', 'Gb:(1,3,5)', 'A#:(*3)', 'C:sus4(*5)/b7']) +@parametrize("value", ["B:7", "Gb:(1,3,5)", "A#:(*3)", "C:sus4(*5)/b7"]) def test_ns_chord_harte_valid(value): - ann = Annotation(namespace='chord_harte') + ann = Annotation(namespace="chord_harte") ann.append(time=0, duration=1.0, value=value) ann.validate() -@parametrize('value', - ['{}:maj'.format(_) - for _ in [42, 'X', 'a', 'F1', True, None]] + - ['C:{}'.format(_) - for _ in [64, 'z', 'mj', 'Ab', 'iiii', False, None]] + - ['C/{}'.format(_) - for _ in ['A', 7.5, '8b']] + [None]) +@parametrize( + "value", + ["{}:maj".format(_) for _ in [42, "X", "a", "F1", True, None]] + + ["C:{}".format(_) for _ in [64, "z", "mj", "Ab", "iiii", False, None]] + + ["C/{}".format(_) for _ in ["A", 7.5, "8b"]] + + [None], +) def test_ns_chord_harte_invalid(value): - ann = Annotation(namespace='chord_harte') + ann = Annotation(namespace="chord_harte") ann.append(time=0, duration=1.0, value=value) with pytest.raises(SchemaError): ann.validate() -@parametrize('value', - [dict(tonic='B', chord='bII7'), - dict(tonic=six.u('Gb'), chord=six.u('ii7/#V'))]) +@parametrize( + "value", + [dict(tonic="B", chord="bII7"), dict(tonic=six.u("Gb"), chord=six.u("ii7/#V"))], +) def test_ns_chord_roman_valid(value): - ann = Annotation(namespace='chord_roman') + ann = Annotation(namespace="chord_roman") ann.append(time=0, duration=1.0, value=value) ann.validate() -@parametrize('key, value', - [('tonic', 42), ('tonic', 'H'), - ('tonic', 'a'), ('tonic', 'F#b'), - ('tonic', True), ('tonic', None), - ('chord', 64), ('chord', 'z'), - ('chord', 'i/V64'), ('chord', 'Ab'), - ('chord', 'iiii'), ('chord', False), - ('chord', None)]) +@parametrize( + "key, value", + [ + ("tonic", 42), + ("tonic", "H"), + ("tonic", "a"), + ("tonic", "F#b"), + ("tonic", True), + ("tonic", None), + ("chord", 64), + ("chord", "z"), + ("chord", "i/V64"), + ("chord", "Ab"), + ("chord", "iiii"), + ("chord", False), + ("chord", None), + ], +) def test_ns_chord_roman_invalid(key, value): - data = dict(tonic='E', chord='iv64') + data = dict(tonic="E", chord="iv64") data[key] = value - ann = Annotation(namespace='chord_roman') + ann = Annotation(namespace="chord_roman") ann.append(time=0, duration=1.0, value=data) with pytest.raises(SchemaError): ann.validate() -@parametrize('key', ['tonic', 'chord']) +@parametrize("key", ["tonic", "chord"]) def test_ns_chord_roman_missing(key): - data = dict(tonic='E', chord='iv64') + data = dict(tonic="E", chord="iv64") del data[key] - ann = Annotation(namespace='chord_roman') + ann = Annotation(namespace="chord_roman") ann.append(time=0, duration=1.0, value=data) with pytest.raises(SchemaError): ann.validate() -@parametrize('value', - [dict(tonic='B', pitch=0), - dict(tonic=six.u('Gb'), pitch=11)]) +@parametrize("value", [dict(tonic="B", pitch=0), dict(tonic=six.u("Gb"), pitch=11)]) def test_ns_pitch_class_valid(value): - ann = Annotation(namespace='pitch_class') + ann = Annotation(namespace="pitch_class") ann.append(time=0, duration=1.0, value=value) ann.validate() -@parametrize('key, value', - [('tonic', 42), ('tonic', 'H'), - ('tonic', 'a'), ('tonic', 'F#b'), - ('tonic', True), ('tonic', None), - ('pitch', 1.5), ('pitch', 'xyz'), - ('pitch', '3'), ('pitch', False), - ('pitch', None)]) +@parametrize( + "key, value", + [ + ("tonic", 42), + ("tonic", "H"), + ("tonic", "a"), + ("tonic", "F#b"), + ("tonic", True), + ("tonic", None), + ("pitch", 1.5), + ("pitch", "xyz"), + ("pitch", "3"), + ("pitch", False), + ("pitch", None), + ], +) def test_ns_pitch_class_invalid(key, value): - data = dict(tonic='E', pitch=7) + data = dict(tonic="E", pitch=7) data[key] = value - ann = Annotation(namespace='pitch_class') + ann = Annotation(namespace="pitch_class") ann.append(time=0, duration=1.0, value=data) with pytest.raises(SchemaError): ann.validate() -@parametrize('key', ['tonic', 'pitch']) +@parametrize("key", ["tonic", "pitch"]) def test_ns_pitch_class_missing(key): - data = dict(tonic='E', pitch=7) + data = dict(tonic="E", pitch=7) del data[key] - ann = Annotation(namespace='pitch_class') + ann = Annotation(namespace="pitch_class") ann.append(time=0, duration=1.0, value=data) with pytest.raises(SchemaError): ann.validate() -@parametrize('namespace,tag', [ - ('tag_cal500', 'Emotion-Angry_/_Aggressive' ), - ('tag_cal500', 'Genre--_Metal/Hard_Rock' ), - ('tag_cal500', six.u('Genre-Best-Jazz') ), - ('tag_cal10k', 'a dub production'), - ('tag_cal10k', "boomin' kick drum"), - ('tag_cal10k', six.u('rock & roll ? roots')), - ('tag_gtzan', 'blues'), - ('tag_gtzan', 'classical'), - ('tag_gtzan', 'country'), - ('tag_gtzan', 'disco'), - ('tag_gtzan', 'hip-hop'), - ('tag_gtzan', 'jazz'), - ('tag_gtzan', 'metal'), - ('tag_gtzan', 'pop'), - ('tag_gtzan', 'reggae'), - ('tag_gtzan', six.u('rock')), - ('tag_msd_tagtraum_cd1', 'reggae'), - ('tag_msd_tagtraum_cd1', 'pop/rock'), - ('tag_msd_tagtraum_cd1', 'rnb'), - ('tag_msd_tagtraum_cd1', 'jazz'), - ('tag_msd_tagtraum_cd1', 'vocal'), - ('tag_msd_tagtraum_cd1', 'new age'), - ('tag_msd_tagtraum_cd1', 'latin'), - ('tag_msd_tagtraum_cd1', 'rap'), - ('tag_msd_tagtraum_cd1', 'country'), - ('tag_msd_tagtraum_cd1', 'international'), - ('tag_msd_tagtraum_cd1', 'blues'), - ('tag_msd_tagtraum_cd1', 'electronic'), - ('tag_msd_tagtraum_cd1', six.u('folk')), - ('tag_medleydb_instruments','accordion'), - ('tag_medleydb_instruments','alto saxophone'), - ('tag_medleydb_instruments',six.u('fx/processed sound')), - ('tag_open', 'a tag'), - ('segment_open', 'a segment'), - ('segment_salami_lower','a'), - ('segment_salami_lower',"a'"), - ('segment_salami_lower',"a'''"), - ('segment_salami_lower',"silence"), - ('segment_salami_lower',"Silence"), - ('segment_salami_lower',six.u('a')), - ('segment_salami_lower','aa'), - ('segment_salami_lower',"aa'"), - ('segment_salami_lower','ab'), - ('segment_salami_upper', 'A'), - ('segment_salami_upper', "A'"), - ('segment_salami_upper', "A'''"), - ('segment_salami_upper', "silence"), - ('segment_salami_upper', "Silence"), - ('segment_salami_upper', six.u('A')), - ('segment_salami_function', 'verse'), - ('segment_salami_function', "chorus"), - ('segment_salami_function', "theme"), - ('segment_salami_function', "voice"), - ('segment_salami_function', "silence"), - ('segment_salami_function', six.u('verse')), - ('segment_tut', 'verse'), - ('segment_tut', "refrain"), - ('segment_tut', "Si"), - ('segment_tut', "bridge"), - ('segment_tut', "Bridge"), - ('segment_tut', six.u('verse')), - ('vector', [1]), - ('vector', [1, 2]), - ('vector', np.asarray([1])), - ('vector', np.asarray([1, 2])), - ('blob', 'a tag'), - ('blob', six.u('a unicode tag')), - ('blob', 23), - ('blob', None), - ('blob', dict()), - ('blob', list()), - ('lyrics_bow', [['foo', 23]],), - ('lyrics_bow', [['foo', 23], ['bar', 35]],), - ('lyrics_bow', [['foo', 23], [['foo', 'bar'], 13]],), - ('lyrics_bow', []), - ('tag_audioset', 'Accordion'), - ('tag_audioset', 'Afrobeat'), - ('tag_audioset', six.u('Cacophony')), - ('tag_audioset_genre', 'Afrobeat'), - ('tag_audioset_genre', 'Disco'), - ('tag_audioset_genre', six.u('Opera')), - ('tag_audioset_instruments' ,'Organ'), - ('tag_audioset_instruments' ,'Harmonica'), - ('tag_audioset_instruments' ,six.u('Zither')), - ('tag_fma_genre', 'Blues'), - ('tag_fma_genre', 'Classical'), - ('tag_fma_genre', six.u('Soul-RnB')), - ('tag_fma_subgenre', 'Blues'), - ('tag_fma_subgenre', 'British Folk'), - ('tag_fma_subgenre', six.u('Klezmer')), - ('tag_urbansound', 'air_conditioner'), - ('tag_urbansound', 'car_horn'), - ('tag_urbansound', 'children_playing'), - ('tag_urbansound', 'dog_bark'), - ('tag_urbansound', 'drilling'), - ('tag_urbansound', 'engine_idling'), - ('tag_urbansound', 'gun_shot'), - ('tag_urbansound', 'jackhammer'), - ('tag_urbansound', 'siren'), - ('tag_urbansound', six.u('street_music')), -]) +@parametrize( + "namespace,tag", + [ + ("tag_cal500", "Emotion-Angry_/_Aggressive"), + ("tag_cal500", "Genre--_Metal/Hard_Rock"), + ("tag_cal500", six.u("Genre-Best-Jazz")), + ("tag_cal10k", "a dub production"), + ("tag_cal10k", "boomin' kick drum"), + ("tag_cal10k", six.u("rock & roll ? roots")), + ("tag_gtzan", "blues"), + ("tag_gtzan", "classical"), + ("tag_gtzan", "country"), + ("tag_gtzan", "disco"), + ("tag_gtzan", "hip-hop"), + ("tag_gtzan", "jazz"), + ("tag_gtzan", "metal"), + ("tag_gtzan", "pop"), + ("tag_gtzan", "reggae"), + ("tag_gtzan", six.u("rock")), + ("tag_msd_tagtraum_cd1", "reggae"), + ("tag_msd_tagtraum_cd1", "pop/rock"), + ("tag_msd_tagtraum_cd1", "rnb"), + ("tag_msd_tagtraum_cd1", "jazz"), + ("tag_msd_tagtraum_cd1", "vocal"), + ("tag_msd_tagtraum_cd1", "new age"), + ("tag_msd_tagtraum_cd1", "latin"), + ("tag_msd_tagtraum_cd1", "rap"), + ("tag_msd_tagtraum_cd1", "country"), + ("tag_msd_tagtraum_cd1", "international"), + ("tag_msd_tagtraum_cd1", "blues"), + ("tag_msd_tagtraum_cd1", "electronic"), + ("tag_msd_tagtraum_cd1", six.u("folk")), + ("tag_medleydb_instruments", "accordion"), + ("tag_medleydb_instruments", "alto saxophone"), + ("tag_medleydb_instruments", six.u("fx/processed sound")), + ("tag_open", "a tag"), + ("segment_open", "a segment"), + ("segment_salami_lower", "a"), + ("segment_salami_lower", "a'"), + ("segment_salami_lower", "a'''"), + ("segment_salami_lower", "silence"), + ("segment_salami_lower", "Silence"), + ("segment_salami_lower", six.u("a")), + ("segment_salami_lower", "aa"), + ("segment_salami_lower", "aa'"), + ("segment_salami_lower", "ab"), + ("segment_salami_upper", "A"), + ("segment_salami_upper", "A'"), + ("segment_salami_upper", "A'''"), + ("segment_salami_upper", "silence"), + ("segment_salami_upper", "Silence"), + ("segment_salami_upper", six.u("A")), + ("segment_salami_function", "verse"), + ("segment_salami_function", "chorus"), + ("segment_salami_function", "theme"), + ("segment_salami_function", "voice"), + ("segment_salami_function", "silence"), + ("segment_salami_function", six.u("verse")), + ("segment_tut", "verse"), + ("segment_tut", "refrain"), + ("segment_tut", "Si"), + ("segment_tut", "bridge"), + ("segment_tut", "Bridge"), + ("segment_tut", six.u("verse")), + ("vector", [1]), + ("vector", [1, 2]), + ("vector", np.asarray([1])), + ("vector", np.asarray([1, 2])), + ("blob", "a tag"), + ("blob", six.u("a unicode tag")), + ("blob", 23), + ("blob", None), + ("blob", dict()), + ("blob", list()), + ( + "lyrics_bow", + [["foo", 23]], + ), + ( + "lyrics_bow", + [["foo", 23], ["bar", 35]], + ), + ( + "lyrics_bow", + [["foo", 23], [["foo", "bar"], 13]], + ), + ("lyrics_bow", []), + ("tag_audioset", "Accordion"), + ("tag_audioset", "Afrobeat"), + ("tag_audioset", six.u("Cacophony")), + ("tag_audioset_genre", "Afrobeat"), + ("tag_audioset_genre", "Disco"), + ("tag_audioset_genre", six.u("Opera")), + ("tag_audioset_instruments", "Organ"), + ("tag_audioset_instruments", "Harmonica"), + ("tag_audioset_instruments", six.u("Zither")), + ("tag_fma_genre", "Blues"), + ("tag_fma_genre", "Classical"), + ("tag_fma_genre", six.u("Soul-RnB")), + ("tag_fma_subgenre", "Blues"), + ("tag_fma_subgenre", "British Folk"), + ("tag_fma_subgenre", six.u("Klezmer")), + ("tag_urbansound", "air_conditioner"), + ("tag_urbansound", "car_horn"), + ("tag_urbansound", "children_playing"), + ("tag_urbansound", "dog_bark"), + ("tag_urbansound", "drilling"), + ("tag_urbansound", "engine_idling"), + ("tag_urbansound", "gun_shot"), + ("tag_urbansound", "jackhammer"), + ("tag_urbansound", "siren"), + ("tag_urbansound", six.u("street_music")), + ], +) def test_ns_tag(namespace, tag): ann = Annotation(namespace=namespace) @@ -591,27 +637,30 @@ def test_ns_tag(namespace, tag): ann.validate() -@parametrize('namespace', [ - 'tag_cal500', - 'tag_cal10k', - 'tag_gtzan', - 'tag_msd_tagtraum_cd1', - 'tag_medleydb_instruments', - 'tag_open', - 'segment_open', - 'segment_salami_lower', - 'segment_salami_upper', - 'segment_salami_function', - 'segment_tut', - 'tag_audioset', - 'tag_audioset_genre', - 'tag_audioset_instruments', - 'tag_fma_genre', - 'tag_fma_subgenre', - 'tag_urbansound', - 'multi_segment', -]) -@parametrize('value', [23, None]) +@parametrize( + "namespace", + [ + "tag_cal500", + "tag_cal10k", + "tag_gtzan", + "tag_msd_tagtraum_cd1", + "tag_medleydb_instruments", + "tag_open", + "segment_open", + "segment_salami_lower", + "segment_salami_upper", + "segment_salami_function", + "segment_tut", + "tag_audioset", + "tag_audioset_genre", + "tag_audioset_instruments", + "tag_fma_genre", + "tag_fma_subgenre", + "tag_urbansound", + "multi_segment", + ], +) +@parametrize("value", [23, None]) def test_ns_tag_invalid_type(namespace, value): ann = Annotation(namespace=namespace) @@ -619,52 +668,56 @@ def test_ns_tag_invalid_type(namespace, value): with pytest.raises(SchemaError): ann.validate() -@parametrize('namespace,value', [ - ('tag_cal500', 'GENRE-BEST-JAZZ'), - ('tag_cal10k', 'A DUB PRODUCTION'), - ('tag_gtzan', 'ROCK'), - ('tag_msd_tagtraum_cd1', 'FOLK'), - ('tag_medleydb_instruments', 'ACCORDION'), - ('segment_salami_lower', 'A'), - ('segment_salami_lower', 'S'), - ('segment_salami_lower', 'a23'), - ('segment_salami_lower', ' Silence 23'), - ('segment_salami_lower', 'aba'), - ('segment_salami_lower', 'aab'), - ('segment_salami_upper', 'a'), - ('segment_salami_upper', 'A23'), - ('segment_salami_upper', ' Silence 23'), - ('segment_salami_upper', 'ABA'), - ('segment_salami_upper', 'AAB'), - ('segment_salami_upper', 'AA'), - ('segment_salami_function', 'a'), - ('segment_salami_function', 'a'), - ('segment_salami_function', 'A23'), - ('segment_salami_function', ' Silence 23'), - ('segment_salami_function', 'Some Garbage'), - ('segment_tut', 'chorus'), - ('segment_tut', 'a'), - ('segment_tut', 'a'), - ('segment_tut', 'A23'), - ('segment_tut', ' Silence 23'), - ('segment_tut', 'Some Garbage'), - ('vector', 'a tag'), - ('vector', six.u('a unicode tag')), - ('vector', 23), - ('vector', None), - ('vector', dict()), - ('vector', list()), - ('lyrics_bow', ('foo', 23)), - ('lyrics_bow', [('foo', -23)]), - ('lyrics_bow', [(23, 'foo')]), - ('tag_audioset', 'ACCORDION'), - ('tag_audioset_genre', 'Accordion'), - ('tag_audioset_instruments', 'Afrobeat'), - ('tag_fma_genre', 'Afrobeat'), - ('tag_fma_subgenre', 'title'), - ('tag_urbansound', 'air conditioner'), - ('tag_urbansound', 'AIR_CONDITIONER'), -]) + +@parametrize( + "namespace,value", + [ + ("tag_cal500", "GENRE-BEST-JAZZ"), + ("tag_cal10k", "A DUB PRODUCTION"), + ("tag_gtzan", "ROCK"), + ("tag_msd_tagtraum_cd1", "FOLK"), + ("tag_medleydb_instruments", "ACCORDION"), + ("segment_salami_lower", "A"), + ("segment_salami_lower", "S"), + ("segment_salami_lower", "a23"), + ("segment_salami_lower", " Silence 23"), + ("segment_salami_lower", "aba"), + ("segment_salami_lower", "aab"), + ("segment_salami_upper", "a"), + ("segment_salami_upper", "A23"), + ("segment_salami_upper", " Silence 23"), + ("segment_salami_upper", "ABA"), + ("segment_salami_upper", "AAB"), + ("segment_salami_upper", "AA"), + ("segment_salami_function", "a"), + ("segment_salami_function", "a"), + ("segment_salami_function", "A23"), + ("segment_salami_function", " Silence 23"), + ("segment_salami_function", "Some Garbage"), + ("segment_tut", "chorus"), + ("segment_tut", "a"), + ("segment_tut", "a"), + ("segment_tut", "A23"), + ("segment_tut", " Silence 23"), + ("segment_tut", "Some Garbage"), + ("vector", "a tag"), + ("vector", six.u("a unicode tag")), + ("vector", 23), + ("vector", None), + ("vector", dict()), + ("vector", list()), + ("lyrics_bow", ("foo", 23)), + ("lyrics_bow", [("foo", -23)]), + ("lyrics_bow", [(23, "foo")]), + ("tag_audioset", "ACCORDION"), + ("tag_audioset_genre", "Accordion"), + ("tag_audioset_instruments", "Afrobeat"), + ("tag_fma_genre", "Afrobeat"), + ("tag_fma_subgenre", "title"), + ("tag_urbansound", "air conditioner"), + ("tag_urbansound", "AIR_CONDITIONER"), + ], +) def test_ns_invalid_value(namespace, value): ann = Annotation(namespace=namespace) ann.append(time=0, duration=1, value=value) @@ -672,91 +725,97 @@ def test_ns_invalid_value(namespace, value): ann.validate() -@parametrize('confidence', [0.0, 1.0, None]) +@parametrize("confidence", [0.0, 1.0, None]) def test_ns_tag_msd_tagtraum_cd1_confidence(confidence): - ann = Annotation(namespace='tag_msd_tagtraum_cd1') - ann.append(time=0, duration=1, value='rnb', confidence=confidence) + ann = Annotation(namespace="tag_msd_tagtraum_cd1") + ann.append(time=0, duration=1, value="rnb", confidence=confidence) ann.validate() -@parametrize('confidence', [1.2, -0.1]) +@parametrize("confidence", [1.2, -0.1]) def test_ns_tag_msd_tagtraum_cd1_bad_confidence(confidence): - ann = Annotation(namespace='tag_msd_tagtraum_cd1') - ann.append(time=0, duration=1, value='rnb', confidence=confidence) + ann = Annotation(namespace="tag_msd_tagtraum_cd1") + ann.append(time=0, duration=1, value="rnb", confidence=confidence) with pytest.raises(SchemaError): ann.validate() -@parametrize('pattern', [ - dict(midi_pitch=3, morph_pitch=5, staff=1, pattern_id=1, occurrence_id=1), - dict(midi_pitch=-3, morph_pitch=-1.5, staff=1.0, pattern_id=1, occurrence_id=1) -]) +@parametrize( + "pattern", + [ + dict(midi_pitch=3, morph_pitch=5, staff=1, pattern_id=1, occurrence_id=1), + dict(midi_pitch=-3, morph_pitch=-1.5, staff=1.0, pattern_id=1, occurrence_id=1), + ], +) def test_ns_pattern_valid(pattern): - ann = Annotation(namespace='pattern_jku') + ann = Annotation(namespace="pattern_jku") ann.append(time=0, duration=1.0, value=pattern) ann.validate() -@parametrize('key', ['midi_pitch', 'morph_pitch', 'staff', 'pattern_id', 'occurrence_id']) -@parametrize('value', ['foo', None, dict(), list()]) +@parametrize( + "key", ["midi_pitch", "morph_pitch", "staff", "pattern_id", "occurrence_id"] +) +@parametrize("value", ["foo", None, dict(), list()]) def test_ns_pattern_invalid(key, value): - data = dict(midi_pitch=3, morph_pitch=5, - staff=1, pattern_id=1, occurrence_id=1) + data = dict(midi_pitch=3, morph_pitch=5, staff=1, pattern_id=1, occurrence_id=1) data[key] = value - ann = Annotation(namespace='pattern_jku') + ann = Annotation(namespace="pattern_jku") ann.append(time=0, duration=1.0, value=data) with pytest.raises(SchemaError): ann.validate() -@parametrize('key', ['pattern_id', 'occurrence_id']) -@parametrize('value', [-1, 0, 0.5]) +@parametrize("key", ["pattern_id", "occurrence_id"]) +@parametrize("value", [-1, 0, 0.5]) def test_ns_pattern_invalid_bounded(key, value): - data = dict(midi_pitch=3, morph_pitch=5, - staff=1, pattern_id=1, occurrence_id=1) + data = dict(midi_pitch=3, morph_pitch=5, staff=1, pattern_id=1, occurrence_id=1) data[key] = value - ann = Annotation(namespace='pattern_jku') + ann = Annotation(namespace="pattern_jku") ann.append(time=0, duration=1.0, value=data) with pytest.raises(SchemaError): ann.validate() -@parametrize('label', ['a segment', six.u('a unicode segment')]) -@parametrize('level', [0, 2]) +@parametrize("label", ["a segment", six.u("a unicode segment")]) +@parametrize("level", [0, 2]) def test_ns_multi_segment_label(label, level): - ann = Annotation(namespace='multi_segment') + ann = Annotation(namespace="multi_segment") ann.append(time=0, duration=1, value=dict(label=label, level=level)) ann.validate() -@parametrize('label', [23, None]) + +@parametrize("label", [23, None]) def test_ns_multi_segment_invalid_label(label): - ann = Annotation(namespace='multi_segment') + ann = Annotation(namespace="multi_segment") ann.append(time=0, duration=1, value=dict(label=label, level=0)) with pytest.raises(SchemaError): ann.validate() -@parametrize('level', [-1, 'foo', None]) + +@parametrize("level", [-1, "foo", None]) def test_ns_multi_segment_invalid_level(level): - ann = Annotation(namespace='multi_segment') - ann.append(time=0, duration=1, value=dict(label='a segment', level=level)) + ann = Annotation(namespace="multi_segment") + ann.append(time=0, duration=1, value=dict(label="a segment", level=level)) with pytest.raises(SchemaError): ann.validate() + def test_ns_multi_segment_invalid_both(): - ann = Annotation(namespace='multi_segment') + ann = Annotation(namespace="multi_segment") ann.append(time=0, duration=1, value=dict(label=None, level=None)) with pytest.raises(SchemaError): ann.validate() def test_ns_multi_segment_bad_type(): - ann = Annotation(namespace='multi_segment') - ann.append(time=0, duration=1, value='a string') + ann = Annotation(namespace="multi_segment") + ann.append(time=0, duration=1, value="a string") with pytest.raises(SchemaError): ann.validate() @@ -770,17 +829,16 @@ def scraper_value(): "time_stretch": 0.8455598669219283, "pitch_shift": -1.2204911976305648, "snr": 7.790682558359417, - "label": 'gun_shot', + "label": "gun_shot", "role": "foreground", - "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav" + "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav", } - -@parametrize('source_time', [0, 5, 1.0]) +@parametrize("source_time", [0, 5, 1.0]) def test_ns_scaper_source_time(source_time): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = { "source_time": source_time, @@ -789,19 +847,19 @@ def test_ns_scaper_source_time(source_time): "time_stretch": 0.8455598669219283, "pitch_shift": -1.2204911976305648, "snr": 7.790682558359417, - "label": 'gun_shot', + "label": "gun_shot", "role": "foreground", - "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav" + "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav", } ann.append(time=0, duration=1, value=value) ann.validate() -@parametrize('source_time', [-1, -1.0, 'zero', None]) +@parametrize("source_time", [-1, -1.0, "zero", None]) def test_ns_scraper_source_time_invalid(scraper_value, source_time): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = dict(scraper_value, source_time=source_time) ann.append(time=0, duration=1, value=value) @@ -809,11 +867,10 @@ def test_ns_scraper_source_time_invalid(scraper_value, source_time): ann.validate() -@parametrize('event_duration', - [0.5, 5, 1.0]) +@parametrize("event_duration", [0.5, 5, 1.0]) def test_ns_scaper_event_duration(event_duration): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = { "source_time": 0.0, @@ -822,19 +879,19 @@ def test_ns_scaper_event_duration(event_duration): "time_stretch": 0.8455598669219283, "pitch_shift": -1.2204911976305648, "snr": 7.790682558359417, - "label": 'gun_shot', + "label": "gun_shot", "role": "foreground", - "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav" + "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav", } ann.append(time=0, duration=1, value=value) ann.validate() -@parametrize('event_duration', [0, -1, -1.0, 'zero', None]) +@parametrize("event_duration", [0, -1, -1.0, "zero", None]) def test_ns_scraper_event_duration_invalid(scraper_value, event_duration): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = dict(scraper_value, event_duration=event_duration) ann.append(time=0, duration=1, value=value) @@ -842,11 +899,10 @@ def test_ns_scraper_event_duration_invalid(scraper_value, event_duration): ann.validate() -@parametrize('event_time', - [0, 5, 1.0]) +@parametrize("event_time", [0, 5, 1.0]) def test_ns_scaper_event_time(event_time): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = { "source_time": 0.0, @@ -855,19 +911,19 @@ def test_ns_scaper_event_time(event_time): "time_stretch": 0.8455598669219283, "pitch_shift": -1.2204911976305648, "snr": 7.790682558359417, - "label": 'gun_shot', + "label": "gun_shot", "role": "foreground", - "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav" + "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav", } ann.append(time=0, duration=1, value=value) ann.validate() -@parametrize('event_time', [-1, -1.0, 'zero', None]) +@parametrize("event_time", [-1, -1.0, "zero", None]) def test_ns_scraper_event_time_invalid(scraper_value, event_time): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = dict(scraper_value, event_time=event_time) ann.append(time=0, duration=1, value=value) @@ -875,11 +931,10 @@ def test_ns_scraper_event_time_invalid(scraper_value, event_time): ann.validate() -@parametrize('time_stretch', - [0.5, 5, 1.0, None]) +@parametrize("time_stretch", [0.5, 5, 1.0, None]) def test_ns_scaper_time_stretch(time_stretch): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = { "source_time": 0.0, @@ -888,19 +943,19 @@ def test_ns_scaper_time_stretch(time_stretch): "time_stretch": time_stretch, "pitch_shift": -1.2204911976305648, "snr": 7.790682558359417, - "label": 'gun_shot', + "label": "gun_shot", "role": "foreground", - "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav" + "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav", } ann.append(time=0, duration=1, value=value) ann.validate() -@parametrize('time_stretch', [0, -1, -1.0, 'zero']) +@parametrize("time_stretch", [0, -1, -1.0, "zero"]) def test_ns_scraper_time_stretch_invalid(scraper_value, time_stretch): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = dict(scraper_value, time_stretch=time_stretch) ann.append(time=0, duration=1, value=value) @@ -908,11 +963,10 @@ def test_ns_scraper_time_stretch_invalid(scraper_value, time_stretch): ann.validate() -@parametrize('pitch_shift', - [0.5, 5, 1.0, -1, -3.5, 0, None]) +@parametrize("pitch_shift", [0.5, 5, 1.0, -1, -3.5, 0, None]) def test_ns_scaper_pitch_shift(pitch_shift): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = { "source_time": 0.0, @@ -921,9 +975,9 @@ def test_ns_scaper_pitch_shift(pitch_shift): "time_stretch": 0.8455598669219283, "pitch_shift": pitch_shift, "snr": 7.790682558359417, - "label": 'gun_shot', + "label": "gun_shot", "role": "foreground", - "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav" + "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav", } ann.append(time=0, duration=1, value=value) @@ -932,19 +986,18 @@ def test_ns_scaper_pitch_shift(pitch_shift): def test_ns_scraper_pitch_shift_invalid(scraper_value): - ann = Annotation(namespace='scaper') - value = dict(scraper_value, pitch_shift='zero') + ann = Annotation(namespace="scaper") + value = dict(scraper_value, pitch_shift="zero") ann.append(time=0, duration=1, value=value) with pytest.raises(SchemaError): ann.validate() -@parametrize('snr', - [0.5, 5, 1.0, -1, -3.5, 0]) +@parametrize("snr", [0.5, 5, 1.0, -1, -3.5, 0]) def test_ns_scaper_snr(snr): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = { "source_time": 0.0, @@ -953,19 +1006,19 @@ def test_ns_scaper_snr(snr): "time_stretch": 0.8455598669219283, "pitch_shift": -1.2204911976305648, "snr": snr, - "label": 'gun_shot', + "label": "gun_shot", "role": "foreground", - "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav" + "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav", } ann.append(time=0, duration=1, value=value) ann.validate() -@parametrize('snr', ['zero', None]) +@parametrize("snr", ["zero", None]) def test_ns_scraper_snr_invalid(scraper_value, snr): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = dict(scraper_value, snr=snr) ann.append(time=0, duration=1, value=value) @@ -973,11 +1026,12 @@ def test_ns_scraper_snr_invalid(scraper_value, snr): ann.validate() -@parametrize('label', - ['air_conditioner', 'car_horn', six.u('street_music'), 'any string']) +@parametrize( + "label", ["air_conditioner", "car_horn", six.u("street_music"), "any string"] +) def test_ns_scaper_label(label): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = { "source_time": 0.0, @@ -988,17 +1042,17 @@ def test_ns_scaper_label(label): "snr": 7.790682558359417, "label": label, "role": "foreground", - "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav" + "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav", } ann.append(time=0, duration=1, value=value) ann.validate() -@parametrize('label', [23, None]) +@parametrize("label", [23, None]) def test_ns_scraper_label_invalid(scraper_value, label): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = dict(scraper_value, label=label) ann.append(time=0, duration=1, value=value) @@ -1006,12 +1060,17 @@ def test_ns_scraper_label_invalid(scraper_value, label): ann.validate() -@parametrize('role', - ['foreground', 'background', six.u('background'), - ]) +@parametrize( + "role", + [ + "foreground", + "background", + six.u("background"), + ], +) def test_ns_scaper_role(role): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = { "source_time": 0.0, @@ -1022,17 +1081,17 @@ def test_ns_scaper_role(role): "snr": 7.790682558359417, "label": "gun_shot", "role": role, - "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav" + "source_file": "/audio/foreground/gun_shot/135544-6-17-0.wav", } ann.append(time=0, duration=1, value=value) ann.validate() -@parametrize('role', ['FOREGROUND', 'BACKGROUND', 'something', 23, None]) +@parametrize("role", ["FOREGROUND", "BACKGROUND", "something", 23, None]) def test_ns_scraper_role_invalid(scraper_value, role): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = dict(scraper_value, role=role) ann.append(time=0, duration=1, value=value) @@ -1040,11 +1099,10 @@ def test_ns_scraper_role_invalid(scraper_value, role): ann.validate() -@parametrize('source_file', - ['filename', '/a/b/c.wav', six.u('filename.wav')]) +@parametrize("source_file", ["filename", "/a/b/c.wav", six.u("filename.wav")]) def test_ns_scaper_source_file(source_file): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = { "source_time": 0.0, @@ -1055,17 +1113,17 @@ def test_ns_scaper_source_file(source_file): "snr": 7.790682558359417, "label": "gun_shot", "role": "foreground", - "source_file": source_file + "source_file": source_file, } ann.append(time=0, duration=1, value=value) ann.validate() -@parametrize('source_file', [23, None]) +@parametrize("source_file", [23, None]) def test_ns_scraper_source_file_invalid(scraper_value, source_file): - ann = Annotation(namespace='scaper') + ann = Annotation(namespace="scaper") value = dict(scraper_value, source_file=source_file) ann.append(time=0, duration=1, value=value) diff --git a/tests/test_schema.py b/tests/test_schema.py index 165e3774..0e893936 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # CREATED:2015-07-15 10:21:30 by Brian McFee -'''Namespace management tests''' +"""Namespace management tests""" from six.moves import reload_module @@ -11,29 +11,30 @@ import jams -@pytest.mark.parametrize('ns_key', ['pitch_hz', 'beat']) +@pytest.mark.parametrize("ns_key", ["pitch_hz", "beat"]) def test_schema_namespace(ns_key): # Get the schema schema = jams.schema.namespace(ns_key) # Make sure it has the correct properties - valid_keys = set(['time', 'duration', 'value', 'confidence']) - for key in schema['properties']: + valid_keys = set(["time", "duration", "value", "confidence"]) + for key in schema["properties"]: assert key in valid_keys - for key in ['time', 'duration']: - assert key in schema['properties'] + for key in ["time", "duration"]: + assert key in schema["properties"] -@pytest.mark.parametrize('ns_key', ['DNE']) + +@pytest.mark.parametrize("ns_key", ["DNE"]) def test_schema_namespace_exception(ns_key): with pytest.raises(NamespaceError): jams.schema.namespace(ns_key) - -@pytest.mark.parametrize('ns, dense', [('pitch_hz', True), ('beat', False)]) +@pytest.mark.parametrize("ns, dense", [("pitch_hz", True), ("beat", False)]) def test_schema_is_dense(ns, dense): assert dense == jams.schema.is_dense(ns) -@pytest.mark.parametrize('ns', ['DNE']) + +@pytest.mark.parametrize("ns", ["DNE"]) def test_schema_is_dense_exception(ns): with pytest.raises(NamespaceError): jams.schema.is_dense(ns) @@ -42,14 +43,14 @@ def test_schema_is_dense_exception(ns): @pytest.fixture def local_namespace(): - os.environ['JAMS_SCHEMA_DIR'] = os.path.join('tests', 'fixtures', 'schema') + os.environ["JAMS_SCHEMA_DIR"] = os.path.join("tests", "fixtures", "schema") reload_module(jams) # This one should pass - yield 'testing_tag_upper', True + yield "testing_tag_upper", True # Cleanup - del os.environ['JAMS_SCHEMA_DIR'] + del os.environ["JAMS_SCHEMA_DIR"] reload_module(jams) @@ -62,12 +63,12 @@ def test_schema_local(local_namespace): schema = jams.schema.namespace(ns_key) # Make sure it has the correct properties - valid_keys = set(['time', 'duration', 'value', 'confidence']) - for key in schema['properties']: + valid_keys = set(["time", "duration", "value", "confidence"]) + for key in schema["properties"]: assert key in valid_keys - for key in ['time', 'duration']: - assert key in schema['properties'] + for key in ["time", "duration"]: + assert key in schema["properties"] else: with pytest.raises(NamespaceError): schema = jams.schema.namespace(ns_key) @@ -75,21 +76,30 @@ def test_schema_local(local_namespace): def test_schema_values_pass(): - values = jams.schema.values('tag_gtzan') + values = jams.schema.values("tag_gtzan") - assert values == ['blues', 'classical', 'country', - 'disco', 'hip-hop', 'jazz', 'metal', - 'pop', 'reggae', 'rock'] + assert values == [ + "blues", + "classical", + "country", + "disco", + "hip-hop", + "jazz", + "metal", + "pop", + "reggae", + "rock", + ] def test_schema_values_missing(): with pytest.raises(NamespaceError): - jams.schema.values('imaginary namespace') + jams.schema.values("imaginary namespace") def test_schema_values_notenum(): with pytest.raises(NamespaceError): - jams.schema.values('chord_harte') + jams.schema.values("chord_harte") def test_schema_dtypes(): @@ -100,7 +110,7 @@ def test_schema_dtypes(): def test_schema_dtypes_badns(): with pytest.raises(NamespaceError): - jams.schema.get_dtypes('unknown namespace') + jams.schema.get_dtypes("unknown namespace") def test_list_namespaces(): diff --git a/tests/test_sonify.py b/tests/test_sonify.py index 31b9648f..07cf62a2 100644 --- a/tests/test_sonify.py +++ b/tests/test_sonify.py @@ -11,26 +11,26 @@ def test_no_sonify(): - ann = jams.Annotation(namespace='vector') + ann = jams.Annotation(namespace="vector") with pytest.raises(jams.NamespaceError): jams.sonify.sonify(ann) def test_bad_sonify(): - ann = jams.Annotation(namespace='chord') - ann.append(time=0, duration=1, value='not a chord') + ann = jams.Annotation(namespace="chord") + ann.append(time=0, duration=1, value="not a chord") with pytest.raises(jams.SchemaError): jams.sonify.sonify(ann) -@pytest.mark.parametrize('ns', ['segment_open', 'chord']) -@pytest.mark.parametrize('sr', [8000, 11025]) -@pytest.mark.parametrize('duration', [None, 5.0, 1.0]) +@pytest.mark.parametrize("ns", ["segment_open", "chord"]) +@pytest.mark.parametrize("sr", [8000, 11025]) +@pytest.mark.parametrize("duration", [None, 5.0, 1.0]) def test_duration(ns, sr, duration): ann = jams.Annotation(namespace=ns) - ann.append(time=3, duration=1, value='C') + ann.append(time=3, duration=1, value="C") y = jams.sonify.sonify(ann, sr=sr, duration=duration) @@ -39,7 +39,7 @@ def test_duration(ns, sr, duration): def test_note_hz(): - ann = jams.Annotation(namespace='note_hz') + ann = jams.Annotation(namespace="note_hz") ann.append(time=0, duration=1, value=261.0) y = jams.sonify.sonify(ann, sr=8000, duration=2.0) @@ -47,7 +47,7 @@ def test_note_hz(): def test_note_hz_nolength(): - ann = jams.Annotation(namespace='note_hz') + ann = jams.Annotation(namespace="note_hz") ann.append(time=0, duration=1, value=261.0) y = jams.sonify.sonify(ann, sr=8000) @@ -56,16 +56,16 @@ def test_note_hz_nolength(): def test_note_midi(): - ann = jams.Annotation(namespace='note_midi') + ann = jams.Annotation(namespace="note_midi") ann.append(time=0, duration=1, value=60) y = jams.sonify.sonify(ann, sr=8000, duration=2.0) assert len(y) == 8000 * 2 -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ann_contour(): - ann = jams.Annotation(namespace='pitch_contour') + ann = jams.Annotation(namespace="pitch_contour") duration = 5.0 fs = 0.01 @@ -75,25 +75,27 @@ def ann_contour(): vibrato = 220 + 20 * np.sin(2 * np.pi * times * rate) for t, v in zip(times, vibrato): - ann.append(time=t, duration=fs, value={'frequency': v, - 'index': 0, - 'voiced': (t < 3 or t > 4)}) + ann.append( + time=t, + duration=fs, + value={"frequency": v, "index": 0, "voiced": (t < 3 or t > 4)}, + ) return ann -@pytest.mark.parametrize('duration', [None, 5.0, 10.0]) -@pytest.mark.parametrize('sr', [8000]) +@pytest.mark.parametrize("duration", [None, 5.0, 10.0]) +@pytest.mark.parametrize("sr", [8000]) def test_contour(ann_contour, duration, sr): y = jams.sonify.sonify(ann_contour, sr=sr, duration=duration) if duration is not None: assert len(y) == sr * duration -@pytest.mark.parametrize('namespace', ['chord', 'chord_harte']) -@pytest.mark.parametrize('sr', [8000]) -@pytest.mark.parametrize('duration', [2.0]) -@pytest.mark.parametrize('value', ['C:maj/5']) +@pytest.mark.parametrize("namespace", ["chord", "chord_harte"]) +@pytest.mark.parametrize("sr", [8000]) +@pytest.mark.parametrize("duration", [2.0]) +@pytest.mark.parametrize("value", ["C:maj/5"]) def test_chord(namespace, sr, duration, value): ann = jams.Annotation(namespace=namespace) @@ -103,12 +105,11 @@ def test_chord(namespace, sr, duration, value): assert len(y) == sr * duration -@pytest.mark.parametrize('namespace, value', - [('beat', 1), - ('segment_open', 'C'), - ('onset', 1)]) -@pytest.mark.parametrize('sr', [8000]) -@pytest.mark.parametrize('duration', [2.0]) +@pytest.mark.parametrize( + "namespace, value", [("beat", 1), ("segment_open", "C"), ("onset", 1)] +) +@pytest.mark.parametrize("sr", [8000]) +@pytest.mark.parametrize("duration", [2.0]) def test_event(namespace, sr, duration, value): ann = jams.Annotation(namespace=namespace) @@ -117,21 +118,23 @@ def test_event(namespace, sr, duration, value): assert len(y) == sr * duration -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def beat_pos_ann(): - ann = jams.Annotation(namespace='beat_position') + ann = jams.Annotation(namespace="beat_position") for i, t in enumerate(np.arange(0, 10, 0.25)): - ann.append(time=t, duration=0, - value=dict(position=1 + i % 4, - measure=1 + i // 4, - num_beats=4, - beat_units=4)) + ann.append( + time=t, + duration=0, + value=dict( + position=1 + i % 4, measure=1 + i // 4, num_beats=4, beat_units=4 + ), + ) return ann -@pytest.mark.parametrize('sr', [8000]) -@pytest.mark.parametrize('duration', [None, 5, 15]) +@pytest.mark.parametrize("sr", [8000]) +@pytest.mark.parametrize("duration", [None, 5, 15]) def test_beat_position(beat_pos_ann, sr, duration): yout = jams.sonify.sonify(beat_pos_ann, sr=sr, duration=duration) @@ -139,13 +142,13 @@ def test_beat_position(beat_pos_ann, sr, duration): assert len(yout) == duration * sr -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ann_hier(): - return create_hierarchy(values=['AB', 'abac', 'xxyyxxzz'], duration=30) + return create_hierarchy(values=["AB", "abac", "xxyyxxzz"], duration=30) -@pytest.mark.parametrize('sr', [8000]) -@pytest.mark.parametrize('duration', [None, 15, 30]) +@pytest.mark.parametrize("sr", [8000]) +@pytest.mark.parametrize("duration", [None, 15, 30]) def test_multi_segment(ann_hier, sr, duration): y = jams.sonify.sonify(ann_hier, sr=sr, duration=duration) if duration: diff --git a/tests/test_util.py b/tests/test_util.py index a1d30b90..5ce7037e 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -18,30 +18,29 @@ def srand(seed=628318530): pass -@pytest.mark.parametrize('ns, lab, ints, y, infer_duration', - [('beat', - "1.0 1\n3.0 2", - np.array([[1.0, 3.0], [3.0, 3.0]]), - [1, 2], - True), - ('beat', - "1.0 1\n3.0 2", - np.array([[1.0, 1.0], [3.0, 3.0]]), - [1, 2], - False), - ('chord_harte', - "1.0 2.0 a\n2.0 4.0 b", - np.array([[1.0, 2.0], [2.0, 4.0]]), - ['a', 'b'], - True), - ('chord', - "1.0 1.0 c\n2.0 2.0 d", - np.array([[1.0, 2.0], [2.0, 4.0]]), - ['c', 'd'], - False)]) +@pytest.mark.parametrize( + "ns, lab, ints, y, infer_duration", + [ + ("beat", "1.0 1\n3.0 2", np.array([[1.0, 3.0], [3.0, 3.0]]), [1, 2], True), + ("beat", "1.0 1\n3.0 2", np.array([[1.0, 1.0], [3.0, 3.0]]), [1, 2], False), + ( + "chord_harte", + "1.0 2.0 a\n2.0 4.0 b", + np.array([[1.0, 2.0], [2.0, 4.0]]), + ["a", "b"], + True, + ), + ( + "chord", + "1.0 1.0 c\n2.0 2.0 d", + np.array([[1.0, 2.0], [2.0, 4.0]]), + ["c", "d"], + False, + ), + ], +) def test_import_lab(ns, lab, ints, y, infer_duration): - ann = util.import_lab(ns, six.StringIO(lab), - infer_duration=infer_duration) + ann = util.import_lab(ns, six.StringIO(lab), infer_duration=infer_duration) assert len(ints) == len(ann.data) assert len(y) == len(ann.data) @@ -52,26 +51,34 @@ def test_import_lab(ns, lab, ints, y, infer_duration): assert obs.value == yi -@pytest.mark.parametrize('query, prefix, sep, target', - [('al.beta.gamma', 'al', '.', 'beta.gamma'), - ('al/beta/gamma', 'al', '/', 'beta/gamma'), - ('al.beta.gamma', 'beta', '.', 'al.beta.gamma'), - ('al.beta.gamma', 'beta', '/', 'al.beta.gamma'), - ('al.pha.beta.gamma', 'al', '.', 'pha.beta.gamma')]) +@pytest.mark.parametrize( + "query, prefix, sep, target", + [ + ("al.beta.gamma", "al", ".", "beta.gamma"), + ("al/beta/gamma", "al", "/", "beta/gamma"), + ("al.beta.gamma", "beta", ".", "al.beta.gamma"), + ("al.beta.gamma", "beta", "/", "al.beta.gamma"), + ("al.pha.beta.gamma", "al", ".", "pha.beta.gamma"), + ], +) def test_query_pop(query, prefix, sep, target): assert target == core.query_pop(query, prefix, sep=sep) -@pytest.mark.parametrize('needle, haystack, result', - [('abcdeABCDE123', 'abcdeABCDE123', True), - ('.*cde.*', 'abcdeABCDE123', True), - ('cde$', 'abcdeABCDE123', False), - (r'.*\d+$', 'abcdeABCDE123', True), - (r'^\d+$', 'abcdeABCDE123', False), - (lambda x: True, 'abcdeABCDE123', True), - (lambda x: False, 'abcdeABCDE123', False), - (5, 5, True), - (5, 4, False)]) +@pytest.mark.parametrize( + "needle, haystack, result", + [ + ("abcdeABCDE123", "abcdeABCDE123", True), + (".*cde.*", "abcdeABCDE123", True), + ("cde$", "abcdeABCDE123", False), + (r".*\d+$", "abcdeABCDE123", True), + (r"^\d+$", "abcdeABCDE123", False), + (lambda x: True, "abcdeABCDE123", True), + (lambda x: False, "abcdeABCDE123", False), + (5, 5, True), + (5, 4, False), + ], +) def test_match_query(needle, haystack, result): assert result == core.match_query(haystack, needle) @@ -79,7 +86,7 @@ def test_match_query(needle, haystack, result): def test_smkdirs(): root = tempfile.mkdtemp() - my_dirs = [root, 'level1', 'level2', 'level3'] + my_dirs = [root, "level1", "level2", "level3"] try: target = os.sep.join(my_dirs) @@ -95,11 +102,15 @@ def test_smkdirs(): os.rmdir(tmpdir) -@pytest.mark.parametrize('query, target', - [('foo', 'foo'), - ('foo.txt', 'foo'), - ('/path/to/foo.txt', 'foo'), - ('/path/to/foo', 'foo')]) +@pytest.mark.parametrize( + "query, target", + [ + ("foo", "foo"), + ("foo.txt", "foo"), + ("/path/to/foo.txt", "foo"), + ("/path/to/foo", "foo"), + ], +) def test_filebase(query, target): assert target == util.filebase(query) @@ -109,20 +120,22 @@ def root_and_files(): root = tempfile.mkdtemp() - files = [[root, 'file1.txt'], - [root, 'sub1', 'file2.txt'], - [root, 'sub1', 'sub2', 'file3.txt'], - [root, 'sub1', 'sub2', 'sub3', 'file4.txt']] + files = [ + [root, "file1.txt"], + [root, "sub1", "file2.txt"], + [root, "sub1", "sub2", "file3.txt"], + [root, "sub1", "sub2", "sub3", "file4.txt"], + ] files = [os.sep.join(_) for _ in files] - badfiles = [_.replace('.txt', '.csv') for _ in files] + badfiles = [_.replace(".txt", ".csv") for _ in files] # Create all the necessary directories util.smkdirs(os.path.dirname(files[-1])) # Create the dummy files for fname in files + badfiles: - with open(fname, 'w'): + with open(fname, "w"): pass yield root, files @@ -133,21 +146,24 @@ def root_and_files(): os.rmdir(os.path.dirname(fname)) -@pytest.mark.parametrize('level', [1, 2, 3, 4]) -@pytest.mark.parametrize('sort', [False, True]) +@pytest.mark.parametrize("level", [1, 2, 3, 4]) +@pytest.mark.parametrize("sort", [False, True]) def test_find_with_extension(root_and_files, level, sort): root, files = root_and_files - results = util.find_with_extension(root, 'txt', depth=level, sort=sort) + results = util.find_with_extension(root, "txt", depth=level, sort=sort) assert sorted(results) == sorted(files[:level]) -@pytest.mark.skipif(sys.platform == "win32", reason="os.path.normpath does something different on windows") +@pytest.mark.skipif( + sys.platform == "win32", + reason="os.path.normpath does something different on windows", +) def test_expand_filepaths(): - targets = ['foo.bar', 'dir/file.txt', 'dir2///file2.txt', '/q.bin'] + targets = ["foo.bar", "dir/file.txt", "dir2///file2.txt", "/q.bin"] - target_dir = '/tmp' + target_dir = "/fake_directory" paths = util.expand_filepaths(target_dir, targets)