Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions pyrato/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import numpy as np
import pyfar as pf
import numbers


def reverberation_time_linear_regression(
Expand Down Expand Up @@ -202,3 +203,177 @@ def clarity(energy_decay_curve, early_time_limit=80):
clarity_db = 10 * np.log10(clarity)

return clarity_db



def clarity_from_energy_balance(energy_decay_curve, early_time_limit=80):
r"""
Calculate the clarity from the energy decay curve (EDC).

The clarity parameter (C50 or C80) is defined as the ratio of early-to-late
arriving energy in an impulse response and is a measure for how clearly
speech or music can be perceived in a room. The early-to-late boundary is
typically set at 50 ms (C50) or 80 ms (C80) [#iso]_.

Clarity is calculated as:

.. math::

C_{t_e} = 10 \log_{10} \frac{
\displaystyle \int_0^{t_e} p^2(t) \, dt
}{
\displaystyle \int_{t_e}^{\infty} p^2(t) \, dt
}

where :math:`t_e` is the early time limit and :math:`p(t)` is the pressure
of a room impulse response. Here, the clarity is efficiently computed
from the EDC :math:`e(t)` directly by:

.. math::

C_{t_e} = 10 \log_{10} \left( \frac{e(0)}{e(t_e)} - 1 \right).

Parameters
----------
energy_decay_curve : pyfar.TimeData
Energy decay curve (EDC) of the room impulse response
(time-domain signal). The EDC must start at time zero.
early_time_limit : float, optional
Early time limit (:math:`t_e`) in milliseconds. Defaults to 80 (C80).
Typical values are 50 ms (C50) or 80 ms (C80) [#iso]_.

Returns
-------
clarity : numpy.ndarray[float]
Clarity index (early-to-late energy ratio) in decibels,
shaped according to the channel shape of the input EDC.

References
----------
.. [#iso] ISO 3382, Acoustics — Measurement of the reverberation time of
rooms with reference to other acoustical parameters.

Examples
--------
Estimate the clarity from a real room impulse response filtered in
octave bands:

>>> import numpy as np
>>> import pyfar as pf
>>> import pyrato as ra
>>> rir = pf.signals.files.room_impulse_response(sampling_rate=44100)
>>> rir = pf.dsp.filter.fractional_octave_bands(rir)
>>> edc = ra.edc.energy_decay_curve_lundeby(rir)
>>> C80 = clarity(edc, early_time_limit=80)
"""

if not isinstance(energy_decay_curve, pf.TimeData):
raise TypeError("Input must be a pyfar.TimeData object.")

if not isinstance(early_time_limit, (int, float)):
raise TypeError('early_time_limit must be a number.')

# Validate time range
if (early_time_limit > energy_decay_curve.signal_length * 1000) or (
early_time_limit <= 0):
raise ValueError(
"early_time_limit must be in the range of 0"
f"and {energy_decay_curve.signal_length * 1000}.",
)

# Convert milliseconds to seconds
early_time_limit_sec = early_time_limit / 1000

# calculate lim1 - lim4 for each channel
lim1, lim2 = early_time_limit_sec, energy_decay_curve.times[-1]
lim3, lim4 = 0.0, early_time_limit_sec

# return in dB
return 10* np.log10(__energy_balance(lim1, lim2, lim3, lim4, energy_decay_curve, energy_decay_curve))


def __energy_balance(lim1, lim2, lim3, lim4,
energy_decay_curve1,
energy_decay_curve2):
r"""
Calculate the energy balance for the time limits from the two energy
decay curves (EDC). If second one is not provided, the first will be used for both.

A collection of roomacoustic parameters are defined by their
time-respective energy balance, where the differentiation is made by
the four given time limits [#iso]_.

Energy-Balance is calculated as:

.. math::

EB(p) = 10 \log_{10} \frac{
\displaystyle \int_{lim3}^{lim4} p_2^2(t) \, dt
}{
\displaystyle \int_{lim1}^{lim2} p_1^2(t) \, dt
}

where :math:`lim1 - lim4` are the time limits and :math:`p(t)` is the
pressure of a room impulse response. Here, the energy balance is
efficiently computed from the EDC :math:`e(t)` directly by:

.. math::

EB(e) = 10 \log_{10} \left( \frac{e_2(lim3) -
e_2(lim4)}{e_1(lim1) - e_1(lim2)} \right).

Parameters
----------
lim1, lim2, lim3, lim4 : float
Time limits (:math:`t_e`) in seconds.
energy_decay_curve1 : pyfar.TimeData
Energy decay curve 1 (EDC1) of the room impulse response
(time-domain signal). The EDC must start at time zero.
energy_decay_curve2 : pyfar.TimeData
Energy decay curve 2 (EDC2) of the room impulse response
(time-domain signal). The EDC must start at time zero.

Returns
-------
energy balance : numpy.ndarray[float]
energy-balance index (early-to-late energy ratio),
shaped according to the channel shape of the input EDC.

References
----------
.. [#iso] ISO 3382, Acoustics — Measurement of the reverberation time of
rooms with reference to other acoustical parameters.
"""
# Check input type
if not isinstance(energy_decay_curve1, pf.TimeData):
raise TypeError("energy_decay_curve1 must be a pyfar.TimeData object.")
if not isinstance(energy_decay_curve2, pf.TimeData):
raise TypeError("energy_decay_curve2 must be a pyfar.TimeData object.")

for name, val in zip(("lim1","lim2","lim3","lim4"), (lim1,
lim2,
lim3,
lim4)):
if not isinstance(val, numbers.Real):
raise TypeError(f"{name} must be numeric.")

if not (lim2 > lim1 if np.isscalar(lim1) and np.isscalar(lim2) else True):
raise ValueError("If scalars, require lim1 < lim2.")
if not (lim4 > lim3 if np.isscalar(lim3) and np.isscalar(lim4) else True):
raise ValueError("If scalars, require lim3 < lim4.")

lim1_idx = energy_decay_curve1.find_nearest_time(lim1)
lim2_idx = energy_decay_curve1.find_nearest_time(lim2)
lim3_idx = energy_decay_curve2.find_nearest_time(lim3)
lim4_idx = energy_decay_curve2.find_nearest_time(lim4)

lim1_vals = energy_decay_curve1.time[..., lim1_idx]
lim2_vals = energy_decay_curve1.time[..., lim2_idx]
lim3_vals = energy_decay_curve2.time[..., lim3_idx]
lim4_vals = energy_decay_curve2.time[..., lim4_idx]

numerator = lim3_vals - lim4_vals # edc 2
denominator = lim1_vals - lim2_vals # edc 1

energy_balance = numerator / denominator
return energy_balance
116 changes: 116 additions & 0 deletions tests/test_parameters_energy_balance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import numpy as np
import pytest
import pyfar as pf
import numpy.testing as npt
import re

from pyrato.parameters import __energy_balance


def make_edc_from_energy(energy, sampling_rate=1000):
"""Helper: build normalized EDC TimeData from an energy curve."""
energy = np.asarray(energy, dtype=float)
energy = energy / np.max(energy) if np.max(energy) != 0 else energy
times = np.arange(energy.shape[-1]) / sampling_rate
if energy.ndim == 1:
energy = energy[np.newaxis, :]
return pf.TimeData(energy, times)


# --- Basic type and shape tests ---
def test_energy_balance_accepts_timedata_and_returns_correct_shape():
energy = np.linspace(1, 0, 10)
edc = make_edc_from_energy(energy)
result = __energy_balance(0.0, 0.005, 0.0, 0.001, edc, edc)
assert isinstance(result, np.ndarray)
assert result.shape == edc.cshape


def test_energy_balance_rejects_non_timedata_input():
invalid_input = np.arange(10)
expected_message = "pyfar.TimeData"
with pytest.raises(TypeError, match=expected_message):
__energy_balance(0.0, 0.005, 0.0, 0.001, invalid_input, invalid_input)


def test_energy_balance_rejects_if_second_edc_is_not_timedata():
edc = make_edc_from_energy(np.linspace(1, 0, 10))
with pytest.raises(TypeError, match="pyfar.TimeData"):
__energy_balance(0.0, 0.005, 0.0, 0.001, edc, "invalid_type")


def test_energy_balance_rejects_non_numeric_limits():
edc = make_edc_from_energy(np.linspace(1, 0, 10))
with pytest.raises(TypeError, match="lim1 must be numeric."):
__energy_balance("not_a_number", 1, 0, 1, edc, edc)


def test_energy_balance_rejects_invalid_limit_order():
edc = make_edc_from_energy(np.linspace(1, 0, 10))
with pytest.raises(ValueError, match="If scalars, require lim1 < lim2."):
__energy_balance(1.0, 0.5, 0.0, 1.0, edc, edc)
with pytest.raises(ValueError, match="If scalars, require lim3 < lim4."):
__energy_balance(0.0, 1.0, 1.0, 0.5, edc, edc)


# --- Functional correctness ---
def test_energy_balance_computes_known_ratio_correctly():
"""If EDC is linear, energy balance ratio should be 1."""
edc_vals = np.array([1.0, 0.75, 0.5, 0.25])
edc = make_edc_from_energy(edc_vals, sampling_rate=1000)

# For linear EDC: numerator = e(lim3)-e(lim4) = (1.0 - 0.75) = 0.25
# denominator = e(lim1)-e(lim2) = (0.75 - 0.5) = 0.25
# ratio = 1
result = __energy_balance(0.001, 0.002, 0.0, 0.001, edc, edc)
npt.assert_allclose(result, 1.0, atol=1e-12)


def test_energy_balance_handles_multichannel_data_correctly():
energy = np.linspace(1, 0, 10)
multi = np.stack([energy, energy * 0.5])
edc = make_edc_from_energy(multi)
result = __energy_balance(0.0, 0.005, 0.0, 0.001, edc, edc)
assert result.shape == edc.cshape


def test_energy_balance_returns_nan_for_zero_denominator():
"""If denominator e(lim1)-e(lim2)=0, expect NaN (invalid ratio)."""
energy = np.ones(10)
edc = make_edc_from_energy(energy)
result = __energy_balance(0.0, 0.001, 0.002, 0.003, edc, edc)
assert np.isnan(result)


def test_energy_balance_matches_reference_case():
"""
Analytical reference:
EDC = exp(-a*t). For exponential decay, ratio known analytically.
"""
sampling_rate = 1000
a = 13.8155 # decay constant
times = np.arange(1000) / sampling_rate
edc_vals = np.exp(-a * times)
edc = pf.TimeData(edc_vals[np.newaxis, :], times)

lim1, lim2, lim3, lim4 = 0.0, 0.05, 0.0, 0.02

analytical_ratio = (
(np.exp(-a*lim3) - np.exp(-a*lim4)) /
(np.exp(-a*lim1) - np.exp(-a*lim2))
)

result = __energy_balance(lim1, lim2, lim3, lim4, edc, edc)
npt.assert_allclose(result, analytical_ratio, atol=1e-8)


def test_energy_balance_works_with_two_different_edcs():
"""Energy balance between two different EDCs should compute distinct ratio."""
times = np.linspace(0, 0.009, 10)
edc1 = pf.TimeData(np.linspace(1, 0, 10)[np.newaxis, :], times)
edc2 = pf.TimeData((np.linspace(1, 0, 10) ** 2)[np.newaxis, :], times)

# Expect a ratio != 1 because edc2 decays faster
ratio = __energy_balance(0.0, 0.004, 0.0, 0.002, edc1, edc2)
assert np.all(np.isfinite(ratio))
assert not np.allclose(ratio, 1.0)