Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions binder/plugins/progress_reporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from abc import ABCMeta
from typing import List

from django.conf import settings

from binder.websocket import trigger


class ProgressReporterInterface(metaclass=ABCMeta):
"""
Generic progress reporter interface
"""

def report(self, percentage_done: float) -> None:
"""
Report that we have done a certain percentage (0 <= percentage <= 1)
"""
pass

def report_finished(self):
"""
Called when everything is done.
"""
return self.report(1)


class ProgressReporter(ProgressReporterInterface):
"""
A very basic progress reporter. Propagates the percentage door to a set of websocket rooms. This allows the frontend
to listen to the progress.

Usage example

```
progress_reporter = ProgressReporter(targets=[
{
'target': 'download',
'uuid': download.uuid,
'triggered_by': '*'
},
{
'target': 'download',
'uuid': download.uuid,
'triggered_by': download.triggered_by.pk
}
])

# 20% done
progress_reporter.report(0.2)

# 50% done
progress_reporter.report(0.5)

# 100% done
progress_reporter.report_finished()
````
"""
def __init__(self, targets: List[dict]):
self.targets = targets

def report(self, percentage_done: float):
if not (0 <= percentage_done <= 1):
raise Exception("percentage_done must be between 0 and 1")

# For testing purposes
if settings.DEBUG:
from time import sleep
sleep(0.5)

trigger({
'percentage_done': percentage_done
}, self.targets)

def report_finished(self):
return self.report(1)
157 changes: 114 additions & 43 deletions binder/plugins/views/csvexport.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import abc
import csv
from collections import namedtuple
from tempfile import NamedTemporaryFile
from typing import List
from typing import List, Optional

from django.http import HttpResponse, HttpRequest

from binder.json import jsonloads
from binder.plugins.progress_reporter import ProgressReporterInterface
from binder.router import list_route


Expand Down Expand Up @@ -163,7 +165,48 @@ def get_response(self) -> HttpResponse:
return self.base_adapter.get_response()


class ProgressReporterAdapter:
Chunk = namedtuple('Chunk', ['page', 'limit'])
"""
Adapter which keeps tracks of some statistics to do progress reporting
"""
def __init__(self, progress_reporter: ProgressReporterInterface, page_size):
self.progress_reporter = progress_reporter
self._total_records = 0
self._limit = page_size

self._chunks = set()
self._chunks_done = set()

def set_total_records(self, total_records):
self._total_records = total_records

def report_chunk_done(self, chunk):
self._chunks_done.add(chunk)

if self._chunks == self._chunks_done:
self.progress_reporter.report_finished()
else:
percentage = round(len(self._chunks_done) * 1.0 / len(self._chunks), 2)
self.progress_reporter.report(percentage)


def chunks(self) -> List[Chunk]:
"""
returns all the pages, with corresponding limits for this progress reporter
"""
self._chunks = set()
self._chunks_done = set()

records_to_go = self._total_records
page_counter = 1

while records_to_go > 0:
self._chunks.add(ProgressReporterAdapter.Chunk(page_counter, min(self._limit, records_to_go)))
page_counter += 1
records_to_go -= self._limit

return list(self._chunks)

class CsvExportView:
"""
Expand All @@ -182,7 +225,7 @@ class CsvExportSettings:
"""

def __init__(self, withs, column_map, file_name=None, default_file_name='download', multi_value_delimiter=' ',
extra_permission=None, extra_params={}, csv_adapter=RequestAwareAdapter, limit=10000):
extra_permission=None, extra_params={}, csv_adapter=RequestAwareAdapter, limit=10000, page_size=10000):
"""
@param withs: String[] An array of all the withs that are necessary for this csv export
@param column_map: Tuple[] An array, with all columns of the csv file in order. Each column is represented by a tuple
Expand All @@ -197,6 +240,7 @@ def __init__(self, withs, column_map, file_name=None, default_file_name='downloa
@param response_type_mapping: Mapping between the parameter used in the custom response type
@param limit: Limit for amount of items in the csv. This is a fail save that you do not bring down the server with
a big query
@param page_size: The amount of records we get per chunk from the database. Progress is reported between each chunk.
"""
self.withs = withs
self.column_map = column_map
Expand All @@ -207,51 +251,49 @@ def __init__(self, withs, column_map, file_name=None, default_file_name='downloa
self.extra_params = extra_params
self.csv_adapter = csv_adapter
self.limit = limit
self.page_size = page_size


def _generate_csv_file(self, request: HttpRequest, file_adapter: CsvFileAdapter):
def _generate_csv_file(self, request: HttpRequest, file_adapter: CsvFileAdapter,
progress_reporter: Optional[ProgressReporterInterface] = None):
"""
Generate the actual data for the CSV file, by doing a HTTP request, and then parsing the data to CSV Format

:param progress_reporter An optional progress reporter. Adding this will slow down the query
"""

# Sometimes we want to add an extra permission check before a csv file can be downloaded. This checks if the
# permission is set, and if the permission is set, checks if the current user has the specified permission
if self.csv_settings.extra_permission is not None:
self._require_model_perm(self.csv_settings.extra_permission, request)

# First, if we have a progress_reporter,
progress_reporter_adapter = ProgressReporterAdapter(progress_reporter=progress_reporter, page_size=self.csv_settings.page_size)

# Check the total amount of records. Do this by stripping all the withs and annotations. This should be very quick
total_records = self._get_filtered_queryset_base(request)[0].count()

# Bound the total amount of records, to make sure that we do not ddos the server. By default is capped at 10000m
# but can be overwritten in the settings


if self.csv_settings.limit is not None and total_records > self.csv_settings.limit:
total_records = self.csv_settings.limit
elif total_records > 10000:
total_records = 10000

progress_reporter_adapter.set_total_records(total_records)

# CSV header
file_adapter.set_columns(list(map(lambda x: x[1], self.csv_settings.column_map)))

# # A bit of a hack. We overwrite some get parameters, to make sure that we can create the CSV file
mutable = request.POST._mutable
request.GET._mutable = True
request.GET['page'] = 1
request.GET['limit'] = self.csv_settings.limit if self.csv_settings.limit is not None else 'none'
request.GET['with'] = ",".join(self.csv_settings.withs)
for key, value in self.csv_settings.extra_params.items():
request.GET[key] = value
request.GET._mutable = mutable

parent_result = self.get(request)
parent_data = jsonloads(parent_result.content)

file_name = self.csv_settings.file_name
if callable(file_name):
file_name = file_name(parent_data)
if file_name is None:
file_name = self.csv_settings.default_file_name
file_adapter.set_file_name(file_name)

# CSV header
file_adapter.set_columns(list(map(lambda x: x[1], self.csv_settings.column_map)))

# Make a mapping from the withs. This creates a map. This is needed for easy looking up relations
# {
# "with_name": {
# model_id: model,
# ...
# },
# ...
# }
key_mapping = {}
for key in parent_data['with']:
key_mapping[key] = {}
for row in parent_data['with'][key]:
key_mapping[key][row['id']] = row

def get_datum(data, key, prefix=''):
"""
Expand Down Expand Up @@ -302,17 +344,46 @@ def get_datum(data, key, prefix=''):
else:
raise Exception("{} not found in {}".format(head_key, data))

for row in parent_data['data']:
data = []
for col_definition in self.csv_settings.column_map:
datum = get_datum(row, col_definition[0])
if len(col_definition) >= 3:
transform_function = col_definition[2]
datum = transform_function(datum, row, key_mapping)
if isinstance(datum, list):
datum = self.csv_settings.multi_value_delimiter.join(datum)
data.append(datum)
file_adapter.add_row(data)
for chunk in progress_reporter_adapter.chunks():
request.GET['limit'] = chunk.limit
request.GET['page'] = chunk.page

parent_result = self.get(request)
parent_data = jsonloads(parent_result.content)

file_name = self.csv_settings.file_name
if callable(file_name):
file_name = file_name(parent_data)
if file_name is None:
file_name = self.csv_settings.default_file_name
file_adapter.set_file_name(file_name)

# Make a mapping from the withs. This creates a map. This is needed for easy looking up relations
# {
# "with_name": {
# model_id: model,
# ...
# },
# ...
# }
key_mapping = {}
for key in parent_data['with']:
key_mapping[key] = {}
for row in parent_data['with'][key]:
key_mapping[key][row['id']] = row

for row in parent_data['data']:
data = []
for col_definition in self.csv_settings.column_map:
datum = get_datum(row, col_definition[0])
if len(col_definition) >= 3:
transform_function = col_definition[2]
datum = transform_function(datum, row, key_mapping)
if isinstance(datum, list):
datum = self.csv_settings.multi_value_delimiter.join(datum)
data.append(datum)
file_adapter.add_row(data)
progress_reporter_adapter.report_chunk_done(chunk)

@list_route(name='download', methods=['GET'])
def download(self, request):
Expand Down
59 changes: 59 additions & 0 deletions tests/plugins/test_csvexport.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from unittest.mock import patch, call

from PIL import Image
from os import urandom
from tempfile import NamedTemporaryFile
import io

from django.http import QueryDict
from django.test import TestCase, Client
from django.core.files import File
from django.contrib.auth.models import User

from binder.plugins.views.csvexport import CsvFileAdapter
from binder.router import Router
from binder.views import ModelView
from ..testapp.models import Picture, Animal, Caretaker
from ..testapp.views import PictureView
import csv
Expand Down Expand Up @@ -151,6 +157,7 @@ def test_context_aware_download_xlsx(self):
def test_csv_export_custom_limit(self):
old_limit = PictureView.csv_settings.limit;
PictureView.csv_settings.limit = 1

response = self.client.get('/picture/download/')
self.assertEqual(200, response.status_code)
response_data = csv.reader(io.StringIO(response.content.decode("utf-8")))
Expand Down Expand Up @@ -203,3 +210,55 @@ def test_csv_settings_limit_none_working(self):
self.assertIsNone(next(response_data))

PictureView.csv_settings.limit = old_limit;



class ProgressReporterTest(TestCase):
def setUp(self):
router = Router()
router.register(ModelView)
self.view = PictureView()
self.view.router = router

animal = Animal(name='test')
animal.save()

for i in range(3):
picture = Picture(animal=animal)
file = CsvExportTest.temp_imagefile(50, 50, 'jpeg')
picture.file.save('picture.jpg', File(file), save=False)
picture.original_file.save('picture_copy.jpg', File(file), save=False)
picture.save()



user = User.objects.create(username='test')

class RequestAdapter:
def __init__(self):
self.GET = QueryDict(mutable=True)
self.POST = QueryDict()
self.user = user
self.request_id = 'blaat'


self.request = RequestAdapter()
self.file_adapter = CsvFileAdapter(self.request)


@patch('binder.plugins.progress_reporter.ProgressReporterInterface')
def test_progress_reporter(self, progress_reporter_mock):
self.view.csv_settings.page_size = 1
self.view._generate_csv_file(request=self.request, file_adapter=self.file_adapter,
progress_reporter=progress_reporter_mock)
progress_reporter_mock.report.assert_has_calls([call(0.33), call(0.67)])
progress_reporter_mock.report_finished.assert_called_with()

@patch('binder.plugins.progress_reporter.ProgressReporterInterface')
def test_progress_page_size(self, progress_reporter_mock):
# Page size = 2. So we get 2 chunks, so Only one call at 50%
self.view.csv_settings.page_size = 2
self.view._generate_csv_file(request=self.request, file_adapter=self.file_adapter,
progress_reporter=progress_reporter_mock)
progress_reporter_mock.report.assert_called_once_with(0.5)
progress_reporter_mock.report_finished.assert_called_with()
Loading