Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 60 additions & 33 deletions unisens/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,30 @@
from __future__ import annotations

import importlib
import os, sys
import logging
import os
import sys
import warnings
from abc import ABC
from copy import deepcopy
from typing import List, Tuple

import numpy as np
import logging
from .utils import validkey, strip, lowercase, make_key, valid_filename, infer_dtype
from .utils import read_csv, write_csv
from xml.etree import ElementTree as ET
from xml.etree.ElementTree import Element
from copy import deepcopy

import numpy as np

from .utils import (
infer_dtype,
lowercase,
make_key,
read_csv,
strip,
valid_filename,
validkey,
write_csv,
)

logger = logging.getLogger("unisens")


def get_module(name):
Expand Down Expand Up @@ -74,8 +86,10 @@ def __init__(self, attrib=None, parent='.', **kwargs):
self._autosave()

def __contains__(self, item):
if item in self.__dict__: return True
if make_key(item) in self.__dict__: return True
if item in self.__dict__:
return True
if make_key(item) in self.__dict__:
return True
try:
self.__getitem__(item)
return True
Expand All @@ -87,7 +101,8 @@ def __setattr__(self, name: str, value: str):
Allows settings of attributes via .name = value.
"""
super.__setattr__(self, name, value)
if name.startswith('_'): return
if name.startswith('_'):
return
methods = dir(type(self))
# do not overwrite if it's a builtin method
if name not in methods and \
Expand Down Expand Up @@ -194,8 +209,10 @@ def _get_index(self, id_or_name: str, raises: bool = True) -> Tuple[int, str]:
elif os.path.basename(id_upper) == id_or_name_upper:
found += [(i, make_key(entry.id))]

if len(found) == 1: return found[0]
if len(found) > 1: raise IndexError(f'More than one match for {id_or_name}: {found}')
if len(found) == 1:
return found[0]
if len(found) > 1:
raise IndexError(f'More than one match for {id_or_name}: {found}')
raise KeyError(f'{id_or_name} not found')

def _set_channels(self, ch_names: List[str], n_data: int):
Expand All @@ -211,11 +228,12 @@ def _set_channels(self, ch_names: List[str], n_data: int):
one less if an index is expected.
"""
if ch_names is not None:
if isinstance(ch_names, str): ch_names = [ch_names]
if isinstance(ch_names, str):
ch_names = [ch_names]
# this means new channel names are indicated and will overwrite.
assert len(ch_names) == n_data, f'len {ch_names}!={n_data}'
if hasattr(self, 'channel'):
logging.warning('Channels present will be overwritten')
logger.warning('Channels present will be overwritten')
self.remove_entry('channel')
for name in ch_names:
channel = MiscEntry('channel', key='name', value=name)
Expand All @@ -226,7 +244,7 @@ def _set_channels(self, ch_names: List[str], n_data: int):
'Please provide a list of channel names with set_data().',
category=DeprecationWarning, stacklevel=2)
# we create new generic names for the channels
logging.info('No channel names indicated, will use generic names')
logger.info('No channel names indicated, will use generic names')
for i in range(n_data):
channel = MiscEntry('channel', key='name', value=f'ch_{i}')
self.add_entry(channel)
Expand Down Expand Up @@ -373,7 +391,7 @@ def remove_attr(self, name: str):
del self.attrib[name]
del self.__dict__[name]
else:
logging.error('{} not in attrib'.format(name))
logger.error('{} not in attrib'.format(name))
self._autosave()
return self

Expand Down Expand Up @@ -438,12 +456,12 @@ def __init__(self, id, attrib=None, parent='.', **kwargs):
valid_filename(self.id)
self._filename = os.path.join(self._folder, self.id)
if not os.access(self._filename, os.F_OK):
logging.error('File {} does not exist'.format(self.id))
logger.error('File {} does not exist'.format(self.id))
elif id:
# writing entry
valid_filename(id)
if os.path.splitext(str(id))[-1] == '':
logging.warning('id should be a filename with extension ie. .bin')
logger.warning('id should be a filename with extension ie. .bin')
self._filename = os.path.join(self._folder, id)
self.set_attrib('id', id)
# ensure subdirectories exist to write data
Expand All @@ -452,7 +470,8 @@ def __init__(self, id, attrib=None, parent='.', **kwargs):
os.makedirs(sub_folder, exist_ok=True)
else:
raise ValueError('The id must be supplied if it is not yet set.')
if isinstance(parent, Entry): parent.add_entry(self)
if isinstance(parent, Entry):
parent.add_entry(self)


class SignalEntry(FileEntry):
Expand Down Expand Up @@ -480,7 +499,7 @@ def get_data(self, scaled: bool = True, return_type: str = None) -> np.array:
"""

if return_type is not None:
warnings.warn(f'The argument `return_type` has no effect and will be removed with the next release.',
warnings.warn('The argument `return_type` has no effect and will be removed with the next release.',
category=DeprecationWarning, stacklevel=2)

if self.id.endswith('csv'):
Expand Down Expand Up @@ -594,15 +613,23 @@ def set_data(self, data: np.ndarray, sampleRate: float = None, dataType: str = N

self._set_channels(ch_names, n_data=len(data))

if sampleRate is not None: self.set_attrib('sampleRate', sampleRate)
if sampleRate is not None:
self.set_attrib('sampleRate', sampleRate)
assert 'sampleRate' in self.attrib, "Please specify sampleRate for correct visualization."
if unit is not None: self.set_attrib('unit', unit)
if comment is not None: self.set_attrib('comment', comment)
if contentClass is not None: self.set_attrib('contentClass', contentClass)
if adcZero is not None: self.set_attrib('adcZero', adcZero)
if adcResolution is not None: self.set_attrib('adcResolution', adcResolution)
if source is not None: self.set_attrib('source', source)
if sourceId is not None: self.set_attrib('sourceId', sourceId)
if unit is not None:
self.set_attrib('unit', unit)
if comment is not None:
self.set_attrib('comment', comment)
if contentClass is not None:
self.set_attrib('contentClass', contentClass)
if adcZero is not None:
self.set_attrib('adcZero', adcZero)
if adcResolution is not None:
self.set_attrib('adcResolution', adcResolution)
if source is not None:
self.set_attrib('source', source)
if sourceId is not None:
self.set_attrib('sourceId', sourceId)

# set all other keyword arguments/comments as well.
for key in kwargs:
Expand All @@ -623,7 +650,7 @@ def __init__(self, id=None, attrib=None, parent='.',
assert decimalSeparator and separator, 'Must supply separators'

if not self.id.endswith('csv'):
logging.warning(f'id "{id}" does not end in .csv')
logger.warning(f'id "{id}" does not end in .csv')

csvFileFormat = MiscEntry('csvFileFormat', parent=self)
csvFileFormat.set_attrib('decimalSeparator', decimalSeparator)
Expand Down Expand Up @@ -652,8 +679,8 @@ def set_data(self, data: list, **kwargs):
sep = self.csvFileFormat.separator
dec = self.csvFileFormat.decimalSeparator

if len(data) == 0 or len(data[0]) < 2: logging.warning('Should supply at least two columns: ' \
'time and data')
if len(data) == 0 or len(data[0]) < 2:
logger.warning('Should supply at least two columns: time and data')

write_csv(self._filename, data, sep=sep, decimal_sep=dec)

Expand Down Expand Up @@ -802,7 +829,7 @@ def get_data(self, dtype='auto'):
elif dtype == 'json':
try:
import json_tricks as json
except:
except ImportError:
json = get_module('json')
with open(self._filename, 'r') as f:
data = json.load(f)
Expand Down Expand Up @@ -861,7 +888,7 @@ def set_data(self, data, dtype='auto', **kwargs):
try:
import json_tricks as json
tricks_installed = True
except:
except ImportError:
json = get_module('json')
tricks_installed = False
with open(self._filename, 'w') as f:
Expand Down
18 changes: 10 additions & 8 deletions unisens/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
from xml.etree.ElementTree import Element
from .entry import Entry, FileEntry, ValuesEntry, SignalEntry, MiscEntry
from .entry import EventEntry, CustomEntry, CustomAttributes
from .utils import AttrDict, strip, validkey, lowercase, make_key, indent
from .utils import AttrDict, strip, make_key, indent
from .utils import str2num

logger = logging.getLogger("unisens")


class Unisens(Entry):
"""
Expand Down Expand Up @@ -78,12 +80,12 @@ def __init__(self, folder: str, makenew=False, autosave=False, readonly=False,
self._convert_nums = convert_nums

if os.path.isfile(self._file) and not makenew:
logging.debug('loading unisens.xml from {}'.format(self._file))
logger.debug('loading unisens.xml from {}'.format(self._file))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.read_unisens()
else:
logging.debug('New unisens.xml will be created at {}'.format(self._file))
logger.debug('New unisens.xml will be created at {}'.format(self._file))
if not timestampStart:
now = datetime.datetime.now()
timestampStart = now.strftime('%Y-%m-%dT%H:%M:%S')
Expand Down Expand Up @@ -187,8 +189,8 @@ def unpack_element(self, element: (Element, ET)) -> Entry:
name = element.tag
entry = MiscEntry(name=name, attrib=attrib, parent=self._folder)
else:
if not 'Entry' in element.tag:
logging.warning('Unknown entry type: {}'.format(entryType))
if 'Entry' not in element.tag:
logger.warning('Unknown entry type: {}'.format(entryType))
name = element.tag
entry = MiscEntry(name=name, attrib=attrib, parent=self._folder)

Expand Down Expand Up @@ -228,9 +230,9 @@ def read_unisens(self, folder: str = None, filename='unisens.xml') -> Entry:
That means, self.attrib and self.children are added
as well as tag, tail and text
"""
warnings.warn(f'`read_unisens` is deprecated and will be removed with the '
f'next release. Please read your unisens file by calling'
f' Unisens(folder=folder, filename=filename).',
warnings.warn('`read_unisens` is deprecated and will be removed with the '
'next release. Please read your unisens file by calling'
' Unisens(folder=folder, filename=filename).',
category=DeprecationWarning, stacklevel=2)
# Saving data from one unisens file to another is still possible with Unisens.save() .
if folder is None:
Expand Down
Loading