diff --git a/docs/conf.py b/docs/conf.py index 20878e9f9..2babd0c03 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -30,7 +30,7 @@ # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest'] +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest', 'stevedore.sphinxext'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] diff --git a/docs/index.rst b/docs/index.rst index b2ef4f7bf..fdf82f1c5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -26,6 +26,7 @@ Guide overview integrations code + plugins Project info ------------ diff --git a/docs/plugins.rst b/docs/plugins.rst new file mode 100644 index 000000000..05e75afaa --- /dev/null +++ b/docs/plugins.rst @@ -0,0 +1,20 @@ + +Plugins +------- + +Auth Providers +============== + +.. list-plugins:: mfr.providers + :detailed: + +Exporters +========= + +.. list-plugins:: mfr.exporters + +Renderers +========= + +.. list-plugins:: mfr.renderers + diff --git a/mfr/core/extension.py b/mfr/core/extension.py index cec133e7c..a2f202d5a 100644 --- a/mfr/core/extension.py +++ b/mfr/core/extension.py @@ -1,6 +1,15 @@ import abc +import asyncio +import time +from dataclasses import dataclass, field +from waterbutler.core.streams import StringStream +from waterbutler.core.utils import make_provider + +from mfr.server import settings from mfr.core.metrics import MetricsRecord +from mfr.core.provider import ProviderMetadata +from mfr.tasks.serializer import serializable class BaseExporter(metaclass=abc.ABCMeta): @@ -42,19 +51,28 @@ def _get_module_name(self): .replace('mfr.extensions.', '', 1) \ .replace('.export', '', 1) - +@serializable +@dataclass class BaseRenderer(metaclass=abc.ABCMeta): - - def __init__(self, metadata, file_path, url, assets_url, export_url): - self.metadata = metadata - self.file_path = file_path - self.url = url - self.assets_url = f'{assets_url}/{self._get_module_name()}' - self.export_url = export_url - self.renderer_metrics = MetricsRecord('renderer') + metadata: ProviderMetadata + file_path: str + url: str + assets_url: str + export_url: str + renderer_metrics: MetricsRecord = field(default=None) + metrics: MetricsRecord = field(default=None) + + def __post_init__(self): + self.assets_url = f'{self.assets_url}/{self._get_module_name()}' + self.renderer_metrics = MetricsRecord('renderer',) if self._get_module_name(): self.metrics = self.renderer_metrics.new_subrecord(self._get_module_name()) + if name := self.metadata.name: + self.cache_file_path_str = f'/export/{self.metadata.unique_key}.{name}' + else: + self.cache_file_path_str = f'/export/{self.metadata.unique_key}' + self.renderer_metrics.merge({ 'class': self._get_module_name(), 'ext': self.metadata.ext, @@ -73,10 +91,55 @@ def __init__(self, metadata, file_path, url, assets_url, export_url): except AttributeError: pass + @property + def cache_provider(self): + return make_provider( + settings.CACHE_PROVIDER_NAME, + {}, # User information which can be left blank + settings.CACHE_PROVIDER_CREDENTIALS, + settings.CACHE_PROVIDER_SETTINGS + ) + + async def get_cache_file_path(self): + return await self.cache_provider.validate_path(self.cache_file_path_str) + @abc.abstractmethod - def render(self): + def _render(self) -> str: pass + async def render(self): + if self.use_celery or self.cache_result: + self.cache_file_path = await self.cache_provider.validate_path(self.cache_file_path_str) + if not self.use_celery: + rendition = await self.do_render() + return StringStream(rendition) + else: + from mfr.tasks.render import render + result = render.delay(self) + for i in range(100 * 60 * 10): + if not result.ready(): + time.sleep(0.01) + else: + return await self.cache_provider.download(self.cache_file_path) + + return None + + async def do_render(self): + if self.use_celery or self.cache_result: + file_path_task = asyncio.ensure_future(self.get_cache_file_path()) + rendition = await asyncio.get_running_loop().run_in_executor(None, self._render) + if self.use_celery or self.cache_result: + upload_task = asyncio.ensure_future( + self.cache_provider.upload( + StringStream(rendition), + await file_path_task + ) + ) + if self.use_celery: + await upload_task + return rendition + + @property @abc.abstractmethod def file_required(self): """Does the rendering html need the raw file content to display correctly? @@ -85,10 +148,15 @@ def file_required(self): """ pass + @property @abc.abstractmethod - def cache_result(self): + def cache_result(self) -> bool: pass + @property + def use_celery(self) -> bool: + return False + def _get_module_name(self): return self.__module__ \ .replace('mfr.extensions.', '', 1) \ diff --git a/mfr/core/metrics.py b/mfr/core/metrics.py index 9dc29aefe..bc80819e5 100644 --- a/mfr/core/metrics.py +++ b/mfr/core/metrics.py @@ -1,4 +1,7 @@ import copy +from dataclasses import dataclass, field + +from mfr.tasks.serializer import serializable def _merge_dicts(a, b, path=None): @@ -93,16 +96,17 @@ def _set_dotted_key(store, key, value): current = current[part] current[parts[-1]] = value - +@serializable +@dataclass class MetricsRecord(MetricsBase): """An extension to MetricsBase that carries a category and list of submetrics. When serialized, will include the serialized child metrics """ + category: str + subrecords: list = field(default_factory=list) - def __init__(self, category): + def __post_init__(self): super().__init__() - self.category = category - self.subrecords = [] @property def key(self): @@ -121,20 +125,23 @@ def serialize(self): def new_subrecord(self, name): """Create a new MetricsSubRecord object with our category and save it to the subrecords list.""" - subrecord = MetricsSubRecord(self.category, name) + subrecord = MetricsSubRecord(category=self.category, name=name) self.subrecords.append(subrecord) return subrecord - +@serializable +@dataclass class MetricsSubRecord(MetricsRecord): """An extension to MetricsRecord that carries a name in addition to a category. Will identify itself as {category}_{name}. Can create its own subrecord whose category will be this subrecord's ``name``. """ + name: str = field(default=None) - def __init__(self, category, name): - super().__init__(category) - self.name = name + def __post_init__(self): + super().__post_init__() + if self.name is None: + raise TypeError('name must be provided') @property def key(self): @@ -153,6 +160,6 @@ def new_subrecord(self, name): print(child.key) # foo_bar print(grandchild.key) # bar_baz """ - subrecord = MetricsSubRecord(self.name, name) + subrecord = MetricsSubRecord(category=self.name, name=name) self.subrecords.append(subrecord) return subrecord diff --git a/mfr/core/provider.py b/mfr/core/provider.py index fb7a2b722..33fc705b9 100644 --- a/mfr/core/provider.py +++ b/mfr/core/provider.py @@ -1,4 +1,6 @@ import abc +from dataclasses import dataclass + import markupsafe import furl @@ -6,6 +8,7 @@ from mfr.core import exceptions from mfr.server import settings from mfr.core.metrics import MetricsRecord +from mfr.tasks.serializer import serializable class BaseProvider(metaclass=abc.ABCMeta): @@ -47,16 +50,15 @@ def metadata(self): def download(self): pass - +@serializable +@dataclass class ProviderMetadata: - - def __init__(self, name, ext, content_type, unique_key, download_url, stable_id=None): - self.name = name - self.ext = ext - self.content_type = content_type - self.unique_key = unique_key - self.download_url = download_url - self.stable_id = stable_id + name: str + ext: str + content_type: str + unique_key: str + download_url: str + stable_id: str = None def serialize(self): return { diff --git a/mfr/core/utils.py b/mfr/core/utils.py index 57c5ad68b..163093721 100644 --- a/mfr/core/utils.py +++ b/mfr/core/utils.py @@ -100,7 +100,13 @@ def make_renderer(name, metadata, file_path, url, assets_url, export_url): namespace='mfr.renderers', name=normalized_name, invoke_on_load=True, - invoke_args=(metadata, file_path, url, assets_url, export_url), + invoke_kwds={ + 'metadata': metadata, + 'file_path': file_path, + 'url': url, + 'assets_url': assets_url, + 'export_url': export_url + }, ).driver except RuntimeError: raise exceptions.MakeRendererError( diff --git a/mfr/extensions/audio/render.py b/mfr/extensions/audio/render.py index 00b8be3fd..8bdd490bc 100644 --- a/mfr/extensions/audio/render.py +++ b/mfr/extensions/audio/render.py @@ -13,7 +13,7 @@ class AudioRenderer(extension.BaseRenderer): os.path.join(os.path.dirname(__file__), 'templates') ]).get_template('viewer.mako') - def render(self): + def _render(self): safe_url = escape_url_for_template(self.url) return self.TEMPLATE.render(base=self.assets_url, url=safe_url) diff --git a/mfr/extensions/codepygments/render.py b/mfr/extensions/codepygments/render.py index d28739610..89dc364dc 100644 --- a/mfr/extensions/codepygments/render.py +++ b/mfr/extensions/codepygments/render.py @@ -30,7 +30,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.metrics.add('pygments_version', pygments.__version__) - def render(self): + def _render(self): file_size = os.path.getsize(self.file_path) if file_size > settings.MAX_SIZE: raise exceptions.FileTooLargeError( diff --git a/mfr/extensions/image/render.py b/mfr/extensions/image/render.py index dd8469433..827a1aef0 100644 --- a/mfr/extensions/image/render.py +++ b/mfr/extensions/image/render.py @@ -15,7 +15,7 @@ class ImageRenderer(extension.BaseRenderer): os.path.join(os.path.dirname(__file__), 'templates') ]).get_template('viewer.mako') - def render(self): + def _render(self): self.metrics.add('needs_export', False) if self.metadata.ext in settings.EXPORT_EXCLUSIONS: download_url = munge_url_for_localdev(self.url) diff --git a/mfr/extensions/ipynb/render.py b/mfr/extensions/ipynb/render.py index 3e046c269..c4fcc8d82 100644 --- a/mfr/extensions/ipynb/render.py +++ b/mfr/extensions/ipynb/render.py @@ -22,7 +22,7 @@ def __init__(self, *args, **kwargs): self.metrics.add('nbformat_version', nbformat.__version__) self.metrics.add('nbconvert_version', nbconvert.__version__) - def render(self): + def _render(self): try: with open(self.file_path) as file_pointer: notebook = nbformat.reads(file_pointer.read(), as_version=4) diff --git a/mfr/extensions/jamovi/render.py b/mfr/extensions/jamovi/render.py index d634e89b9..59c241e48 100644 --- a/mfr/extensions/jamovi/render.py +++ b/mfr/extensions/jamovi/render.py @@ -22,7 +22,7 @@ class JamoviRenderer(extension.BaseRenderer): MESSAGE_FILE_CORRUPT = 'This jamovi file is corrupt and cannot be viewed.' MESSAGE_NO_PREVIEW = 'This jamovi file does not support previews.' - def render(self): + def _render(self): try: with ZipFile(self.file_path) as zip_file: self._check_file(zip_file) diff --git a/mfr/extensions/jasp/render.py b/mfr/extensions/jasp/render.py index ae9bd0e65..ae2b2bf8f 100644 --- a/mfr/extensions/jasp/render.py +++ b/mfr/extensions/jasp/render.py @@ -21,7 +21,7 @@ class JASPRenderer(extension.BaseRenderer): MESSAGE_FILE_CORRUPT = 'This JASP file is corrupt and cannot be viewed.' - def render(self): + def _render(self): try: with ZipFile(self.file_path) as zip_file: self._check_file(zip_file) diff --git a/mfr/extensions/jsc3d/render.py b/mfr/extensions/jsc3d/render.py index 3bff7b024..2a9dd858b 100644 --- a/mfr/extensions/jsc3d/render.py +++ b/mfr/extensions/jsc3d/render.py @@ -17,7 +17,7 @@ class JSC3DRenderer(extension.BaseRenderer): os.path.join(os.path.dirname(__file__), 'templates') ]).get_template('viewer.mako') - def render(self): + def _render(self): self.metrics.add('needs_export', False) if self.metadata.ext in settings.EXPORT_EXCLUSIONS: download_url = munge_url_for_localdev(self.metadata.download_url) diff --git a/mfr/extensions/md/render.py b/mfr/extensions/md/render.py index d409ec889..11a1b1e49 100644 --- a/mfr/extensions/md/render.py +++ b/mfr/extensions/md/render.py @@ -25,7 +25,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.metrics.add('markdown_version', markdown.__version__) - def render(self): + def _render(self): """Render a markdown file to html.""" with open(self.file_path) as fp: body = markdown.markdown(fp.read(), extensions=[EscapeHtml()]) diff --git a/mfr/extensions/pdb/render.py b/mfr/extensions/pdb/render.py index 41d88da58..b779cb2a5 100644 --- a/mfr/extensions/pdb/render.py +++ b/mfr/extensions/pdb/render.py @@ -16,7 +16,7 @@ class PdbRenderer(extension.BaseRenderer): os.path.join(os.path.dirname(__file__), 'templates') ]).get_template('viewer.mako') - def render(self): + def _render(self): download_url = munge_url_for_localdev(self.metadata.download_url) safe_url = escape_url_for_template(download_url.geturl()) return self.TEMPLATE.render( diff --git a/mfr/extensions/pdf/render.py b/mfr/extensions/pdf/render.py index c0a10e94e..5d305b23d 100644 --- a/mfr/extensions/pdf/render.py +++ b/mfr/extensions/pdf/render.py @@ -19,7 +19,7 @@ class PdfRenderer(extension.BaseRenderer): os.path.join(os.path.dirname(__file__), 'templates') ]).get_template('viewer.mako') - def render(self): + def _render(self): download_url = munge_url_for_localdev(self.metadata.download_url) escaped_name = escape_url_for_template( diff --git a/mfr/extensions/rst/render.py b/mfr/extensions/rst/render.py index d0d37d8b9..115218b61 100644 --- a/mfr/extensions/rst/render.py +++ b/mfr/extensions/rst/render.py @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.metrics.add('docutils_version', docutils.__version__) - def render(self): + def _render(self): with open(self.file_path) as fp: body = publish_parts(fp.read(), writer_name='html')['html_body'] return self.TEMPLATE.render(base=self.assets_url, body=body) diff --git a/mfr/extensions/svg/render.py b/mfr/extensions/svg/render.py index ca0293107..9565240ac 100644 --- a/mfr/extensions/svg/render.py +++ b/mfr/extensions/svg/render.py @@ -14,7 +14,7 @@ class SvgRenderer(extension.BaseRenderer): os.path.join(os.path.dirname(__file__), 'templates') ]).get_template('viewer.mako') - def render(self): + def _render(self): safe_url = escape_url_for_template(self.url) return self.TEMPLATE.render(base=self.assets_url, url=safe_url) diff --git a/mfr/extensions/tabular/render.py b/mfr/extensions/tabular/render.py index a77f82c75..0f1e7168f 100644 --- a/mfr/extensions/tabular/render.py +++ b/mfr/extensions/tabular/render.py @@ -20,7 +20,7 @@ class TabularRenderer(extension.BaseRenderer): os.path.join(os.path.dirname(__file__), 'templates') ]).get_template('viewer.mako') - def render(self): + def _render(self): file_size = os.path.getsize(self.file_path) if file_size > settings.MAX_FILE_SIZE: raise exceptions.FileTooLargeError( diff --git a/mfr/extensions/unoconv/render.py b/mfr/extensions/unoconv/render.py index 36fe07495..62b13ddb7 100644 --- a/mfr/extensions/unoconv/render.py +++ b/mfr/extensions/unoconv/render.py @@ -1,19 +1,22 @@ +from dataclasses import dataclass +import logging import os + import furl -import logging from mfr.core import utils from mfr.core import extension - from mfr.extensions.unoconv import settings +from mfr.tasks.serializer import serializable logger = logging.getLogger(__name__) +@serializable +@dataclass class UnoconvRenderer(extension.BaseRenderer): - def __init__(self, metadata, file_path, url, assets_url, export_url): - super().__init__(metadata, file_path, url, assets_url, export_url) - + def __post_init__(self): + super().__post_init__() try: self.map = settings.RENDER_MAP[self.metadata.ext] except KeyError: @@ -21,7 +24,7 @@ def __init__(self, metadata, file_path, url, assets_url, export_url): self.export_file_path = self.file_path + self.map['renderer'] - exported_url = furl.furl(export_url) + exported_url = furl.furl(self.export_url) exported_url.args['format'] = self.map['format'] exported_metadata = self.metadata exported_metadata.download_url = exported_url.url @@ -31,8 +34,8 @@ def __init__(self, metadata, file_path, url, assets_url, export_url): exported_metadata, self.export_file_path, exported_url.url, - assets_url, - export_url + self.assets_url, + self.export_url ) self.renderer_metrics.add('file_required', self.file_required) @@ -45,7 +48,7 @@ def __init__(self, metadata, file_path, url, assets_url, export_url): }, }) - def render(self): + def _render(self): if self.renderer.file_required: exporter = utils.make_exporter( self.metadata.ext, @@ -56,7 +59,7 @@ def render(self): ) exporter.export() - rendition = self.renderer.render() + rendition = self.renderer._render() self.metrics.add('subrenderer', self.renderer.renderer_metrics.serialize()) if self.renderer.file_required: @@ -74,3 +77,6 @@ def file_required(self): @property def cache_result(self): return self.renderer.cache_result + + def use_celery(self) -> bool: + return True diff --git a/mfr/extensions/video/render.py b/mfr/extensions/video/render.py index 920d27959..eb2028e08 100644 --- a/mfr/extensions/video/render.py +++ b/mfr/extensions/video/render.py @@ -13,7 +13,7 @@ class VideoRenderer(extension.BaseRenderer): os.path.join(os.path.dirname(__file__), 'templates') ]).get_template('viewer.mako') - def render(self): + def _render(self): download_url = munge_url_for_localdev(self.metadata.download_url) safe_url = escape_url_for_template(download_url.geturl()) return self.TEMPLATE.render(url=safe_url) diff --git a/mfr/extensions/zip/render.py b/mfr/extensions/zip/render.py index a6370530c..fe89735e3 100644 --- a/mfr/extensions/zip/render.py +++ b/mfr/extensions/zip/render.py @@ -15,7 +15,7 @@ class ZipRenderer(extension.BaseRenderer): os.path.join(os.path.dirname(__file__), 'templates') ]).get_template('viewer.mako') - def render(self): + def _render(self): zip_file = zipfile.ZipFile(self.file_path, 'r') filelist = [{'name': markupsafe.escape(file.filename), diff --git a/mfr/providers/osf/provider.py b/mfr/providers/osf/provider.py index 01da33795..cb7db7839 100644 --- a/mfr/providers/osf/provider.py +++ b/mfr/providers/osf/provider.py @@ -68,6 +68,7 @@ async def metadata(self): metadata_url = download_url.replace('/file?', '/data?', 1) metadata_response = await self._make_request('GET', metadata_url) metadata = await metadata_response.json() + await metadata_response.release() else: # URL is for WaterButler v1 API self.metrics.add('metadata.wb_api', 'v1') @@ -140,6 +141,7 @@ async def download(self): if response.status >= 400: resp_text = await response.text() + await response.release() logger.error(f'Unable to download file: ({response.status}) {resp_text}') raise exceptions.DownloadError( 'Unable to download the requested file, please try again later.', diff --git a/mfr/server/handlers/core.py b/mfr/server/handlers/core.py index 841ca38d7..3375f3e19 100644 --- a/mfr/server/handlers/core.py +++ b/mfr/server/handlers/core.py @@ -1,8 +1,9 @@ import abc +import pathlib import uuid import asyncio import logging -import pathlib +import os from importlib.metadata import entry_points @@ -300,11 +301,11 @@ def initialize(self): namespace = "mfr.renderers" module_path_prefix = "mfr.extensions" self.modules = {} + root = pathlib.Path(os.path.abspath(__file__).split('mfr')[0]) for ep in entry_points().select(group=namespace): fq_mod = ep.value.split(":")[0] module = fq_mod.replace(f"{module_path_prefix}.", "").split(".")[0] - root = pathlib.Path(ep.dist.locate_file("")) static_dir = root / "mfr" / "extensions" / module / "static" if static_dir.is_dir(): self.modules[module] = static_dir.as_posix() @@ -321,11 +322,17 @@ def initialize(self): self.modules.setdefault(module_path.name, static_dir.as_posix()) logger.debug(f"{module_path.name}: {static_dir}") - async def get(self, module: str, path: str): + async def get(self, module: str, path: str, **kwargs): root = self.modules.get(module) if not root: - self.set_status(404) - return + while path: + module, path = path.split('/', maxsplit=1) + root = self.modules.get(module) + if root: + break + else: + self.set_status(404) + return try: super().initialize(root) diff --git a/mfr/server/handlers/render.py b/mfr/server/handlers/render.py index 842828f7a..025c48a42 100644 --- a/mfr/server/handlers/render.py +++ b/mfr/server/handlers/render.py @@ -1,5 +1,5 @@ +import http import os -import asyncio import logging import waterbutler.core.streams @@ -71,20 +71,14 @@ async def get(self): else: self.metrics.add('source_file.upload.required', False) - loop = asyncio.get_event_loop() - rendition = await loop.run_in_executor(None, renderer.render) + rendition = await renderer.render() self.renderer_metrics = renderer.renderer_metrics - - # Spin off upload into non-blocking operation - if renderer.cache_result and settings.CACHE_ENABLED: - asyncio.ensure_future( - self.cache_provider.upload( - waterbutler.core.streams.StringStream(rendition), - self.cache_file_path - ) - ) - - await self.write_stream(waterbutler.core.streams.StringStream(rendition)) + if rendition: + await self.write_stream(rendition) + else: + self.set_status(http.HTTPStatus.ACCEPTED) + await self.write("Accepted") + await self.flush() async def _cache_and_clean(self): if hasattr(self, 'source_file_path'): diff --git a/mfr/tasks/__init__.py b/mfr/tasks/__init__.py index b28aa8c65..7dee91fdb 100644 --- a/mfr/tasks/__init__.py +++ b/mfr/tasks/__init__.py @@ -1,5 +1,4 @@ from mfr.tasks.app import app -from mfr.tasks.render import render from mfr.tasks.core import celery_task from mfr.tasks.core import backgrounded from mfr.tasks.core import wait_on_celery @@ -7,7 +6,6 @@ __all__ = [ 'app', - 'render', 'celery_task', 'backgrounded', 'wait_on_celery', diff --git a/mfr/tasks/export.py b/mfr/tasks/export.py new file mode 100644 index 000000000..a7dee9313 --- /dev/null +++ b/mfr/tasks/export.py @@ -0,0 +1,10 @@ +import logging + +from mfr.tasks import core + +logger = logging.getLogger(__name__) + + +@core.celery_task +async def export(exporter): + exporter.do_export() diff --git a/mfr/tasks/render.py b/mfr/tasks/render.py index 0bcbd125b..deee8a3bd 100644 --- a/mfr/tasks/render.py +++ b/mfr/tasks/render.py @@ -1,9 +1,10 @@ import logging +from mfr.extensions.unoconv import UnoconvRenderer from mfr.tasks import core logger = logging.getLogger(__name__) @core.celery_task -async def render(*args, **kwargs): - logger.critical(f'Received task with {args=} and {kwargs=}') +async def render(renderer: UnoconvRenderer): + await renderer.do_render() diff --git a/mfr/tasks/serializer.py b/mfr/tasks/serializer.py new file mode 100644 index 000000000..32778e4c5 --- /dev/null +++ b/mfr/tasks/serializer.py @@ -0,0 +1,116 @@ +import copy +import types +from dataclasses import is_dataclass, fields + +import kombu.utils.json + +__registry = {} + +_ATOMIC_TYPES = frozenset({ + # Common JSON Serializable types + types.NoneType, + bool, + int, + float, + str, + # Other common types + complex, + bytes, + # Other types that are also unaffected by deepcopy + types.EllipsisType, + types.NotImplementedType, + types.CodeType, + types.BuiltinFunctionType, + types.FunctionType, + type, + range, + property, +}) + + +def serialize(obj, *, dict_factory=dict): + if not is_dataclass(obj): + raise TypeError("asdict() should be called on dataclass instances") + return _asdict_inner(obj, dict_factory) + + +def _asdict_inner(obj, dict_factory): + obj_type = type(obj) + if obj_type in _ATOMIC_TYPES: + return obj + elif is_dataclass(obj_type): + # dataclass instance: fast path for the common case + if obj_type.__name__ not in __registry: + raise TypeError('cannot serialize non-serializable class') + if dict_factory is dict: + return { + f.name: _asdict_inner(getattr(obj, f.name), dict) + for f in fields(obj) + } | {'__type': obj_type.__name__} + else: + return dict_factory([ + (f.name, _asdict_inner(getattr(obj, f.name), dict_factory)) + for f in fields(obj) + ]) | {'__type': obj_type.__name__} + elif obj_type is list: + return [_asdict_inner(v, dict_factory) for v in obj] + elif obj_type is dict: + return { + _asdict_inner(k, dict_factory): _asdict_inner(v, dict_factory) + for k, v in obj.items() + } + elif obj_type is tuple: + return tuple([_asdict_inner(v, dict_factory) for v in obj]) + elif issubclass(obj_type, tuple): + if hasattr(obj, '_fields'): + return obj_type(*[_asdict_inner(v, dict_factory) for v in obj]) + else: + return obj_type(_asdict_inner(v, dict_factory) for v in obj) + elif issubclass(obj_type, dict): + if hasattr(obj_type, 'default_factory'): + result = obj_type(obj.default_factory) + for k, v in obj.items(): + result[_asdict_inner(k, dict_factory)] = _asdict_inner(v, dict_factory) + return result + return obj_type((_asdict_inner(k, dict_factory), + _asdict_inner(v, dict_factory)) + for k, v in obj.items()) + elif issubclass(obj_type, list): + return obj_type(_asdict_inner(v, dict_factory) for v in obj) + else: + return copy.deepcopy(obj) + + +def serializable(cls): + assert is_dataclass(cls), f'class {cls.__name__} is not a dataclass, therefore cannot be serializable' + assert cls.__name__ not in __registry, 'This class has already been registered' + __registry[cls.__name__] = cls + + kombu.utils.json.register_type( + cls, + cls.__name__, + serialize, + deserialize, + ) + return cls + + +def deserialize(data): + if isinstance(data, list): + return [deserialize(item) for item in data] + if isinstance(data, dict): + data_type = data.pop('__type', None) + if not data_type: + raise TypeError('invalid deserialize payload') + if data_type not in __registry: + raise TypeError(f'type provided but type {data_type.__name__} is not declared serializable') + data_type = __registry.get(data_type) + for field in fields(data_type): + if is_dataclass(field.type) or isinstance(field.type, list): + data[field.name] = deserialize(data[field.name]) + return data_type(**data) + + raise TypeError(f'Cannot deserialize type {type(data)}') + + +__all__ = ['serializable'] diff --git a/mfr/tasks/settings.py b/mfr/tasks/settings.py index dfdc0ad88..47c173d74 100644 --- a/mfr/tasks/settings.py +++ b/mfr/tasks/settings.py @@ -35,5 +35,6 @@ task_eager_propagates = True imports = [ - 'mfr.tasks.render', + "mfr.tasks.export", + "mfr.tasks.render", ] diff --git a/tests/extensions/audio/test_renderer.py b/tests/extensions/audio/test_renderer.py index 50712fe7b..7f3317bd7 100644 --- a/tests/extensions/audio/test_renderer.py +++ b/tests/extensions/audio/test_renderer.py @@ -38,7 +38,7 @@ def renderer(metadata, file_path, url, assets_url, export_url): class TestAudioRenderer: def test_render_audio(self, renderer, url): - body = renderer.render() + body = renderer._render() assert '