Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions spharpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from . import interpolate
from . import spatial
from . import special
from . import sht


__all__ = [
Expand All @@ -31,4 +32,5 @@
'spatial',
'special',
'SamplingSphere',
'sht'
]
119 changes: 119 additions & 0 deletions spharpy/sht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import numpy as np
from pyfar import Signal
from . import SphericalHarmonicSignal
from . import SphericalHarmonics


def sht(signal, spherical_harmonics, axis='auto'):
"""Compute the spherical harmonics transform

Parameters
----------
signal: Signal, TimeData, or FrequencyData
the signal for which the spherical harmonics transform is computed
spherical_harmonics: :class:`spharpy.SphericalHarmonics`
Spherical harmonics object
axis: integer
Axis along which the SH transform is computed

Returns
----------
SphericalHarmonicSignal, SphericalHarmonicsTimeData,
or SphericalHarmonicsFrequencyData

References
----------

[#] Rafaely, B. (2015). Fundamentals of Spherical Array Processing,
(J. Benesty and W. Kellermann, Eds.) Springer Berlin Heidelberg,
2nd ed., 196 pages. doi:10.1007/978-3-319-99561-8
[#] Ramani Duraiswami, Dimitry N. Zotkin, and Nail A. Gumerov: "Inter-
polation and range extrapolation of HRTFs." IEEE Int. Conf.
Acoustics, Speech, and Signal Processing (ICASSP), Montreal,
Canada, May 2004, p. 45-48, doi: 10.1109/ICASSP.2004.1326759.
"""
Y_inv = spherical_harmonics.basis_inv

if axis == 'auto':
axis = np.where(np.array(signal.cshape) == Y_inv.shape[1])[0]
if len(axis) == 0:
raise ValueError("No axes matches the number of spherical "
"harmonics basis functions")
if len(axis) > 1:
raise ValueError("To many axis match the number of spherical "
"harmonics basis functions")
axis = axis[0]

if signal.cshape[axis] != Y_inv.shape[1]:
raise ValueError("Spherical samples of provided axis does not match "
"the number of spherical harmonics basis functions.")

# get data from Signal, TimeData or FrequencyData
data_nm = np.tensordot(Y_inv, signal.time, [1, axis])

if len(data_nm.shape) < 3:
data_nm = data_nm[np.newaxis, ...]

# ensure that number of SH channels is at -2
target_m = (spherical_harmonics.n_max+1)**2
target_n = signal.n_samples

# find corresponding axes
axis_m = next(i for i, dim in enumerate(data_nm.shape) if dim == target_m)
axis_n = next(i for i, dim in enumerate(data_nm.shape)
if dim == target_n and i != axis_m)

# create new shape
new_axes = [
i for i in range(len(data_nm.shape)) if i not in (axis_m, axis_n)
] + [axis_m, axis_n]

data_nm = data_nm.transpose(*new_axes)

return SphericalHarmonicSignal(
data=data_nm,
basis_type=spherical_harmonics.basis_type,
normalization=spherical_harmonics.normalization,
channel_convention=spherical_harmonics.channel_convention,
condon_shortley=spherical_harmonics.condon_shortley,
sampling_rate=signal.sampling_rate,
fft_norm=signal.fft_norm,
is_complex=signal.complex,
comment=signal.comment)


def isht(sh_signal, coordinates):
"""Compute the inverse spherical harmonics transform

Parameters
----------
sh_signal: Signal
The spherical harmonics signal for which the inverse spherical
harmonics transform is computed
coordinates: :class:`spharpy.samplings.Coordinates`, :doc:`pf.Coordinates
<pyfar:classes/pyfar.coordinates>`
Coordinates for which the inverse SH transform is computed

Returns
----------
Signal
"""

# get spherical harmonics basis functions according to sh_signals
# properties
spherical_harmonics = SphericalHarmonics(
sh_signal.n_max,
coordinates=coordinates,
basis_type=sh_signal.basis_type,
channel_convention=sh_signal.channel_convention,
normalization=sh_signal.normalization,
inverse_method="pseudo_inverse",
condon_shortley=sh_signal.condon_shortley)

# perform inverse transform
data = np.tensordot(spherical_harmonics.basis, sh_signal.time, [1, -2])

return Signal(data, sh_signal.sampling_rate,
fft_norm=sh_signal.fft_norm,
comment=sh_signal.comment,
is_complex=sh_signal.complex)
64 changes: 64 additions & 0 deletions tests/test_sht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np
import numpy.testing as npt
import pyfar as pf
from pytest import raises, warns, mark
from spharpy.sht import sht, isht
from spharpy import SphericalHarmonicSignal
from spharpy import samplings


def test_sht_assert_num_channels():
"test assert match of number of channels and number of sampling positions"
n_max = 3
signal = pf.Signal(data=np.zeros((7, 512)), sampling_rate=48000)
coords = pf.Coordinates.from_spherical_elevation(np.zeros((8)),
np.zeros((8)),
np.ones((8)))

with raises(ValueError, match="Signal shape does not match number of "
"coordinates."):
_ = sht(signal, coords, n_max)


def test_sht_wrong_axis():
"test warning wrong axis"
n_max = 3
signal = pf.Signal(data=np.zeros((8, 1, 512)), sampling_rate=48000)
coords = pf.Coordinates.from_spherical_elevation(np.zeros((8)),
np.zeros((8)),
np.ones((8)))

with warns(UserWarning, match="Compute spherical harmonics transform "
"along axis = 0."):
_ = sht(signal, coords, n_max, axis=1)


@mark.parametrize("n_max", [3, 12, 20])
@mark.parametrize("basis_type", ["real", "complex"])
@mark.parametrize("normalization", ["n3d", "sn3d"])
@mark.parametrize("condon_shortley", [True, False])
def test_back_and_forth(n_max, basis_type, normalization, condon_shortley):

sampling = samplings.equiangular(n_max=n_max)

data = np.ones((1, (n_max+1) ** 2, 16), dtype=complex)
is_complex = True

if basis_type == 'real':
data = np.real(data)
is_complex = False

# generate unit amplitude sh signal
a_nm = SphericalHarmonicSignal(data, basis_type=basis_type,
channel_convention='acn',
condon_shortley=condon_shortley,
normalization=normalization,
sampling_rate=48000,
is_complex=is_complex)

a = isht(a_nm, sampling)
a_eval_nm = sht(a, sampling, n_max=n_max, basis_type=basis_type,
normalization=normalization,
condon_shortley=condon_shortley)

npt.assert_allclose(a_nm.time, a_eval_nm.time, rtol=1e-8)