diff --git a/pyrato/parameters.py b/pyrato/parameters.py index 1430a80..e8b2c00 100644 --- a/pyrato/parameters.py +++ b/pyrato/parameters.py @@ -5,6 +5,7 @@ import re import numpy as np import pyfar as pf +import numbers def reverberation_time_linear_regression( @@ -202,3 +203,177 @@ def clarity(energy_decay_curve, early_time_limit=80): clarity_db = 10 * np.log10(clarity) return clarity_db + + + +def clarity_from_energy_balance(energy_decay_curve, early_time_limit=80): + r""" + Calculate the clarity from the energy decay curve (EDC). + + The clarity parameter (C50 or C80) is defined as the ratio of early-to-late + arriving energy in an impulse response and is a measure for how clearly + speech or music can be perceived in a room. The early-to-late boundary is + typically set at 50 ms (C50) or 80 ms (C80) [#iso]_. + + Clarity is calculated as: + + .. math:: + + C_{t_e} = 10 \log_{10} \frac{ + \displaystyle \int_0^{t_e} p^2(t) \, dt + }{ + \displaystyle \int_{t_e}^{\infty} p^2(t) \, dt + } + + where :math:`t_e` is the early time limit and :math:`p(t)` is the pressure + of a room impulse response. Here, the clarity is efficiently computed + from the EDC :math:`e(t)` directly by: + + .. math:: + + C_{t_e} = 10 \log_{10} \left( \frac{e(0)}{e(t_e)} - 1 \right). + + Parameters + ---------- + energy_decay_curve : pyfar.TimeData + Energy decay curve (EDC) of the room impulse response + (time-domain signal). The EDC must start at time zero. + early_time_limit : float, optional + Early time limit (:math:`t_e`) in milliseconds. Defaults to 80 (C80). + Typical values are 50 ms (C50) or 80 ms (C80) [#iso]_. + + Returns + ------- + clarity : numpy.ndarray[float] + Clarity index (early-to-late energy ratio) in decibels, + shaped according to the channel shape of the input EDC. + + References + ---------- + .. [#iso] ISO 3382, Acoustics — Measurement of the reverberation time of + rooms with reference to other acoustical parameters. + + Examples + -------- + Estimate the clarity from a real room impulse response filtered in + octave bands: + + >>> import numpy as np + >>> import pyfar as pf + >>> import pyrato as ra + >>> rir = pf.signals.files.room_impulse_response(sampling_rate=44100) + >>> rir = pf.dsp.filter.fractional_octave_bands(rir) + >>> edc = ra.edc.energy_decay_curve_lundeby(rir) + >>> C80 = clarity(edc, early_time_limit=80) + """ + + if not isinstance(energy_decay_curve, pf.TimeData): + raise TypeError("Input must be a pyfar.TimeData object.") + + if not isinstance(early_time_limit, (int, float)): + raise TypeError('early_time_limit must be a number.') + + # Validate time range + if (early_time_limit > energy_decay_curve.signal_length * 1000) or ( + early_time_limit <= 0): + raise ValueError( + "early_time_limit must be in the range of 0" + f"and {energy_decay_curve.signal_length * 1000}.", + ) + + # Convert milliseconds to seconds + early_time_limit_sec = early_time_limit / 1000 + + # calculate lim1 - lim4 for each channel + lim1, lim2 = early_time_limit_sec, energy_decay_curve.times[-1] + lim3, lim4 = 0.0, early_time_limit_sec + + # return in dB + return 10* np.log10(__energy_balance(lim1, lim2, lim3, lim4, energy_decay_curve, energy_decay_curve)) + + +def __energy_balance(lim1, lim2, lim3, lim4, + energy_decay_curve1, + energy_decay_curve2): + r""" + Calculate the energy balance for the time limits from the two energy + decay curves (EDC). If second one is not provided, the first will be used for both. + + A collection of roomacoustic parameters are defined by their + time-respective energy balance, where the differentiation is made by + the four given time limits [#iso]_. + + Energy-Balance is calculated as: + + .. math:: + + EB(p) = 10 \log_{10} \frac{ + \displaystyle \int_{lim3}^{lim4} p_2^2(t) \, dt + }{ + \displaystyle \int_{lim1}^{lim2} p_1^2(t) \, dt + } + + where :math:`lim1 - lim4` are the time limits and :math:`p(t)` is the + pressure of a room impulse response. Here, the energy balance is + efficiently computed from the EDC :math:`e(t)` directly by: + + .. math:: + + EB(e) = 10 \log_{10} \left( \frac{e_2(lim3) - + e_2(lim4)}{e_1(lim1) - e_1(lim2)} \right). + + Parameters + ---------- + lim1, lim2, lim3, lim4 : float + Time limits (:math:`t_e`) in seconds. + energy_decay_curve1 : pyfar.TimeData + Energy decay curve 1 (EDC1) of the room impulse response + (time-domain signal). The EDC must start at time zero. + energy_decay_curve2 : pyfar.TimeData + Energy decay curve 2 (EDC2) of the room impulse response + (time-domain signal). The EDC must start at time zero. + + Returns + ------- + energy balance : numpy.ndarray[float] + energy-balance index (early-to-late energy ratio), + shaped according to the channel shape of the input EDC. + + References + ---------- + .. [#iso] ISO 3382, Acoustics — Measurement of the reverberation time of + rooms with reference to other acoustical parameters. + """ + # Check input type + if not isinstance(energy_decay_curve1, pf.TimeData): + raise TypeError("energy_decay_curve1 must be a pyfar.TimeData object.") + if not isinstance(energy_decay_curve2, pf.TimeData): + raise TypeError("energy_decay_curve2 must be a pyfar.TimeData object.") + + for name, val in zip(("lim1","lim2","lim3","lim4"), (lim1, + lim2, + lim3, + lim4)): + if not isinstance(val, numbers.Real): + raise TypeError(f"{name} must be numeric.") + + if not (lim2 > lim1 if np.isscalar(lim1) and np.isscalar(lim2) else True): + raise ValueError("If scalars, require lim1 < lim2.") + if not (lim4 > lim3 if np.isscalar(lim3) and np.isscalar(lim4) else True): + raise ValueError("If scalars, require lim3 < lim4.") + + lim1_idx = energy_decay_curve1.find_nearest_time(lim1) + lim2_idx = energy_decay_curve1.find_nearest_time(lim2) + lim3_idx = energy_decay_curve2.find_nearest_time(lim3) + lim4_idx = energy_decay_curve2.find_nearest_time(lim4) + + lim1_vals = energy_decay_curve1.time[..., lim1_idx] + lim2_vals = energy_decay_curve1.time[..., lim2_idx] + lim3_vals = energy_decay_curve2.time[..., lim3_idx] + lim4_vals = energy_decay_curve2.time[..., lim4_idx] + + numerator = lim3_vals - lim4_vals # edc 2 + denominator = lim1_vals - lim2_vals # edc 1 + + energy_balance = numerator / denominator + return energy_balance diff --git a/tests/test_parameters_energy_balance.py b/tests/test_parameters_energy_balance.py new file mode 100644 index 0000000..f489501 --- /dev/null +++ b/tests/test_parameters_energy_balance.py @@ -0,0 +1,116 @@ +import numpy as np +import pytest +import pyfar as pf +import numpy.testing as npt +import re + +from pyrato.parameters import __energy_balance + + +def make_edc_from_energy(energy, sampling_rate=1000): + """Helper: build normalized EDC TimeData from an energy curve.""" + energy = np.asarray(energy, dtype=float) + energy = energy / np.max(energy) if np.max(energy) != 0 else energy + times = np.arange(energy.shape[-1]) / sampling_rate + if energy.ndim == 1: + energy = energy[np.newaxis, :] + return pf.TimeData(energy, times) + + +# --- Basic type and shape tests --- +def test_energy_balance_accepts_timedata_and_returns_correct_shape(): + energy = np.linspace(1, 0, 10) + edc = make_edc_from_energy(energy) + result = __energy_balance(0.0, 0.005, 0.0, 0.001, edc, edc) + assert isinstance(result, np.ndarray) + assert result.shape == edc.cshape + + +def test_energy_balance_rejects_non_timedata_input(): + invalid_input = np.arange(10) + expected_message = "pyfar.TimeData" + with pytest.raises(TypeError, match=expected_message): + __energy_balance(0.0, 0.005, 0.0, 0.001, invalid_input, invalid_input) + + +def test_energy_balance_rejects_if_second_edc_is_not_timedata(): + edc = make_edc_from_energy(np.linspace(1, 0, 10)) + with pytest.raises(TypeError, match="pyfar.TimeData"): + __energy_balance(0.0, 0.005, 0.0, 0.001, edc, "invalid_type") + + +def test_energy_balance_rejects_non_numeric_limits(): + edc = make_edc_from_energy(np.linspace(1, 0, 10)) + with pytest.raises(TypeError, match="lim1 must be numeric."): + __energy_balance("not_a_number", 1, 0, 1, edc, edc) + + +def test_energy_balance_rejects_invalid_limit_order(): + edc = make_edc_from_energy(np.linspace(1, 0, 10)) + with pytest.raises(ValueError, match="If scalars, require lim1 < lim2."): + __energy_balance(1.0, 0.5, 0.0, 1.0, edc, edc) + with pytest.raises(ValueError, match="If scalars, require lim3 < lim4."): + __energy_balance(0.0, 1.0, 1.0, 0.5, edc, edc) + + +# --- Functional correctness --- +def test_energy_balance_computes_known_ratio_correctly(): + """If EDC is linear, energy balance ratio should be 1.""" + edc_vals = np.array([1.0, 0.75, 0.5, 0.25]) + edc = make_edc_from_energy(edc_vals, sampling_rate=1000) + + # For linear EDC: numerator = e(lim3)-e(lim4) = (1.0 - 0.75) = 0.25 + # denominator = e(lim1)-e(lim2) = (0.75 - 0.5) = 0.25 + # ratio = 1 + result = __energy_balance(0.001, 0.002, 0.0, 0.001, edc, edc) + npt.assert_allclose(result, 1.0, atol=1e-12) + + +def test_energy_balance_handles_multichannel_data_correctly(): + energy = np.linspace(1, 0, 10) + multi = np.stack([energy, energy * 0.5]) + edc = make_edc_from_energy(multi) + result = __energy_balance(0.0, 0.005, 0.0, 0.001, edc, edc) + assert result.shape == edc.cshape + + +def test_energy_balance_returns_nan_for_zero_denominator(): + """If denominator e(lim1)-e(lim2)=0, expect NaN (invalid ratio).""" + energy = np.ones(10) + edc = make_edc_from_energy(energy) + result = __energy_balance(0.0, 0.001, 0.002, 0.003, edc, edc) + assert np.isnan(result) + + +def test_energy_balance_matches_reference_case(): + """ + Analytical reference: + EDC = exp(-a*t). For exponential decay, ratio known analytically. + """ + sampling_rate = 1000 + a = 13.8155 # decay constant + times = np.arange(1000) / sampling_rate + edc_vals = np.exp(-a * times) + edc = pf.TimeData(edc_vals[np.newaxis, :], times) + + lim1, lim2, lim3, lim4 = 0.0, 0.05, 0.0, 0.02 + + analytical_ratio = ( + (np.exp(-a*lim3) - np.exp(-a*lim4)) / + (np.exp(-a*lim1) - np.exp(-a*lim2)) + ) + + result = __energy_balance(lim1, lim2, lim3, lim4, edc, edc) + npt.assert_allclose(result, analytical_ratio, atol=1e-8) + + +def test_energy_balance_works_with_two_different_edcs(): + """Energy balance between two different EDCs should compute distinct ratio.""" + times = np.linspace(0, 0.009, 10) + edc1 = pf.TimeData(np.linspace(1, 0, 10)[np.newaxis, :], times) + edc2 = pf.TimeData((np.linspace(1, 0, 10) ** 2)[np.newaxis, :], times) + + # Expect a ratio != 1 because edc2 decays faster + ratio = __energy_balance(0.0, 0.004, 0.0, 0.002, edc1, edc2) + assert np.all(np.isfinite(ratio)) + assert not np.allclose(ratio, 1.0)