diff --git a/binder/plugins/progress_reporter.py b/binder/plugins/progress_reporter.py new file mode 100644 index 00000000..7b5dcd27 --- /dev/null +++ b/binder/plugins/progress_reporter.py @@ -0,0 +1,75 @@ +from abc import ABCMeta +from typing import List + +from django.conf import settings + +from binder.websocket import trigger + + +class ProgressReporterInterface(metaclass=ABCMeta): + """ + Generic progress reporter interface + """ + + def report(self, percentage_done: float) -> None: + """ + Report that we have done a certain percentage (0 <= percentage <= 1) + """ + pass + + def report_finished(self): + """ + Called when everything is done. + """ + return self.report(1) + + +class ProgressReporter(ProgressReporterInterface): + """ + A very basic progress reporter. Propagates the percentage door to a set of websocket rooms. This allows the frontend + to listen to the progress. + + Usage example + + ``` + progress_reporter = ProgressReporter(targets=[ + { + 'target': 'download', + 'uuid': download.uuid, + 'triggered_by': '*' + }, + { + 'target': 'download', + 'uuid': download.uuid, + 'triggered_by': download.triggered_by.pk + } + ]) + + # 20% done + progress_reporter.report(0.2) + + # 50% done + progress_reporter.report(0.5) + + # 100% done + progress_reporter.report_finished() + ```` + """ + def __init__(self, targets: List[dict]): + self.targets = targets + + def report(self, percentage_done: float): + if not (0 <= percentage_done <= 1): + raise Exception("percentage_done must be between 0 and 1") + + # For testing purposes + if settings.DEBUG: + from time import sleep + sleep(0.5) + + trigger({ + 'percentage_done': percentage_done + }, self.targets) + + def report_finished(self): + return self.report(1) diff --git a/binder/plugins/views/csvexport.py b/binder/plugins/views/csvexport.py index aa6c18c1..87e3c5fb 100644 --- a/binder/plugins/views/csvexport.py +++ b/binder/plugins/views/csvexport.py @@ -1,11 +1,13 @@ import abc import csv +from collections import namedtuple from tempfile import NamedTemporaryFile -from typing import List +from typing import List, Optional from django.http import HttpResponse, HttpRequest from binder.json import jsonloads +from binder.plugins.progress_reporter import ProgressReporterInterface from binder.router import list_route @@ -163,7 +165,48 @@ def get_response(self) -> HttpResponse: return self.base_adapter.get_response() +class ProgressReporterAdapter: + Chunk = namedtuple('Chunk', ['page', 'limit']) + """ + Adapter which keeps tracks of some statistics to do progress reporting + """ + def __init__(self, progress_reporter: ProgressReporterInterface, page_size): + self.progress_reporter = progress_reporter + self._total_records = 0 + self._limit = page_size + + self._chunks = set() + self._chunks_done = set() + + def set_total_records(self, total_records): + self._total_records = total_records + def report_chunk_done(self, chunk): + self._chunks_done.add(chunk) + + if self._chunks == self._chunks_done: + self.progress_reporter.report_finished() + else: + percentage = round(len(self._chunks_done) * 1.0 / len(self._chunks), 2) + self.progress_reporter.report(percentage) + + + def chunks(self) -> List[Chunk]: + """ + returns all the pages, with corresponding limits for this progress reporter + """ + self._chunks = set() + self._chunks_done = set() + + records_to_go = self._total_records + page_counter = 1 + + while records_to_go > 0: + self._chunks.add(ProgressReporterAdapter.Chunk(page_counter, min(self._limit, records_to_go))) + page_counter += 1 + records_to_go -= self._limit + + return list(self._chunks) class CsvExportView: """ @@ -182,7 +225,7 @@ class CsvExportSettings: """ def __init__(self, withs, column_map, file_name=None, default_file_name='download', multi_value_delimiter=' ', - extra_permission=None, extra_params={}, csv_adapter=RequestAwareAdapter, limit=10000): + extra_permission=None, extra_params={}, csv_adapter=RequestAwareAdapter, limit=10000, page_size=10000): """ @param withs: String[] An array of all the withs that are necessary for this csv export @param column_map: Tuple[] An array, with all columns of the csv file in order. Each column is represented by a tuple @@ -197,6 +240,7 @@ def __init__(self, withs, column_map, file_name=None, default_file_name='downloa @param response_type_mapping: Mapping between the parameter used in the custom response type @param limit: Limit for amount of items in the csv. This is a fail save that you do not bring down the server with a big query + @param page_size: The amount of records we get per chunk from the database. Progress is reported between each chunk. """ self.withs = withs self.column_map = column_map @@ -207,51 +251,49 @@ def __init__(self, withs, column_map, file_name=None, default_file_name='downloa self.extra_params = extra_params self.csv_adapter = csv_adapter self.limit = limit + self.page_size = page_size - def _generate_csv_file(self, request: HttpRequest, file_adapter: CsvFileAdapter): + def _generate_csv_file(self, request: HttpRequest, file_adapter: CsvFileAdapter, + progress_reporter: Optional[ProgressReporterInterface] = None): + """ + Generate the actual data for the CSV file, by doing a HTTP request, and then parsing the data to CSV Format + + :param progress_reporter An optional progress reporter. Adding this will slow down the query + """ # Sometimes we want to add an extra permission check before a csv file can be downloaded. This checks if the # permission is set, and if the permission is set, checks if the current user has the specified permission if self.csv_settings.extra_permission is not None: self._require_model_perm(self.csv_settings.extra_permission, request) + # First, if we have a progress_reporter, + progress_reporter_adapter = ProgressReporterAdapter(progress_reporter=progress_reporter, page_size=self.csv_settings.page_size) + + # Check the total amount of records. Do this by stripping all the withs and annotations. This should be very quick + total_records = self._get_filtered_queryset_base(request)[0].count() + + # Bound the total amount of records, to make sure that we do not ddos the server. By default is capped at 10000m + # but can be overwritten in the settings + + + if self.csv_settings.limit is not None and total_records > self.csv_settings.limit: + total_records = self.csv_settings.limit + elif total_records > 10000: + total_records = 10000 + + progress_reporter_adapter.set_total_records(total_records) + + # CSV header + file_adapter.set_columns(list(map(lambda x: x[1], self.csv_settings.column_map))) + # # A bit of a hack. We overwrite some get parameters, to make sure that we can create the CSV file - mutable = request.POST._mutable request.GET._mutable = True request.GET['page'] = 1 request.GET['limit'] = self.csv_settings.limit if self.csv_settings.limit is not None else 'none' request.GET['with'] = ",".join(self.csv_settings.withs) for key, value in self.csv_settings.extra_params.items(): request.GET[key] = value - request.GET._mutable = mutable - - parent_result = self.get(request) - parent_data = jsonloads(parent_result.content) - - file_name = self.csv_settings.file_name - if callable(file_name): - file_name = file_name(parent_data) - if file_name is None: - file_name = self.csv_settings.default_file_name - file_adapter.set_file_name(file_name) - - # CSV header - file_adapter.set_columns(list(map(lambda x: x[1], self.csv_settings.column_map))) - - # Make a mapping from the withs. This creates a map. This is needed for easy looking up relations - # { - # "with_name": { - # model_id: model, - # ... - # }, - # ... - # } - key_mapping = {} - for key in parent_data['with']: - key_mapping[key] = {} - for row in parent_data['with'][key]: - key_mapping[key][row['id']] = row def get_datum(data, key, prefix=''): """ @@ -302,17 +344,46 @@ def get_datum(data, key, prefix=''): else: raise Exception("{} not found in {}".format(head_key, data)) - for row in parent_data['data']: - data = [] - for col_definition in self.csv_settings.column_map: - datum = get_datum(row, col_definition[0]) - if len(col_definition) >= 3: - transform_function = col_definition[2] - datum = transform_function(datum, row, key_mapping) - if isinstance(datum, list): - datum = self.csv_settings.multi_value_delimiter.join(datum) - data.append(datum) - file_adapter.add_row(data) + for chunk in progress_reporter_adapter.chunks(): + request.GET['limit'] = chunk.limit + request.GET['page'] = chunk.page + + parent_result = self.get(request) + parent_data = jsonloads(parent_result.content) + + file_name = self.csv_settings.file_name + if callable(file_name): + file_name = file_name(parent_data) + if file_name is None: + file_name = self.csv_settings.default_file_name + file_adapter.set_file_name(file_name) + + # Make a mapping from the withs. This creates a map. This is needed for easy looking up relations + # { + # "with_name": { + # model_id: model, + # ... + # }, + # ... + # } + key_mapping = {} + for key in parent_data['with']: + key_mapping[key] = {} + for row in parent_data['with'][key]: + key_mapping[key][row['id']] = row + + for row in parent_data['data']: + data = [] + for col_definition in self.csv_settings.column_map: + datum = get_datum(row, col_definition[0]) + if len(col_definition) >= 3: + transform_function = col_definition[2] + datum = transform_function(datum, row, key_mapping) + if isinstance(datum, list): + datum = self.csv_settings.multi_value_delimiter.join(datum) + data.append(datum) + file_adapter.add_row(data) + progress_reporter_adapter.report_chunk_done(chunk) @list_route(name='download', methods=['GET']) def download(self, request): diff --git a/tests/plugins/test_csvexport.py b/tests/plugins/test_csvexport.py index faf5a363..fd31072e 100644 --- a/tests/plugins/test_csvexport.py +++ b/tests/plugins/test_csvexport.py @@ -1,12 +1,18 @@ +from unittest.mock import patch, call + from PIL import Image from os import urandom from tempfile import NamedTemporaryFile import io +from django.http import QueryDict from django.test import TestCase, Client from django.core.files import File from django.contrib.auth.models import User +from binder.plugins.views.csvexport import CsvFileAdapter +from binder.router import Router +from binder.views import ModelView from ..testapp.models import Picture, Animal, Caretaker from ..testapp.views import PictureView import csv @@ -151,6 +157,7 @@ def test_context_aware_download_xlsx(self): def test_csv_export_custom_limit(self): old_limit = PictureView.csv_settings.limit; PictureView.csv_settings.limit = 1 + response = self.client.get('/picture/download/') self.assertEqual(200, response.status_code) response_data = csv.reader(io.StringIO(response.content.decode("utf-8"))) @@ -203,3 +210,55 @@ def test_csv_settings_limit_none_working(self): self.assertIsNone(next(response_data)) PictureView.csv_settings.limit = old_limit; + + + +class ProgressReporterTest(TestCase): + def setUp(self): + router = Router() + router.register(ModelView) + self.view = PictureView() + self.view.router = router + + animal = Animal(name='test') + animal.save() + + for i in range(3): + picture = Picture(animal=animal) + file = CsvExportTest.temp_imagefile(50, 50, 'jpeg') + picture.file.save('picture.jpg', File(file), save=False) + picture.original_file.save('picture_copy.jpg', File(file), save=False) + picture.save() + + + + user = User.objects.create(username='test') + + class RequestAdapter: + def __init__(self): + self.GET = QueryDict(mutable=True) + self.POST = QueryDict() + self.user = user + self.request_id = 'blaat' + + + self.request = RequestAdapter() + self.file_adapter = CsvFileAdapter(self.request) + + + @patch('binder.plugins.progress_reporter.ProgressReporterInterface') + def test_progress_reporter(self, progress_reporter_mock): + self.view.csv_settings.page_size = 1 + self.view._generate_csv_file(request=self.request, file_adapter=self.file_adapter, + progress_reporter=progress_reporter_mock) + progress_reporter_mock.report.assert_has_calls([call(0.33), call(0.67)]) + progress_reporter_mock.report_finished.assert_called_with() + + @patch('binder.plugins.progress_reporter.ProgressReporterInterface') + def test_progress_page_size(self, progress_reporter_mock): + # Page size = 2. So we get 2 chunks, so Only one call at 50% + self.view.csv_settings.page_size = 2 + self.view._generate_csv_file(request=self.request, file_adapter=self.file_adapter, + progress_reporter=progress_reporter_mock) + progress_reporter_mock.report.assert_called_once_with(0.5) + progress_reporter_mock.report_finished.assert_called_with()