Skip to content

Adds support for marshmallow @post_load #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions flask_apispec/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# -*- coding: utf-8 -*-
try:
from collections.abc import Mapping
except ImportError: # Python 2
from collections import Mapping

from six.moves import http_client as http

import flask

import marshmallow as ma
import werkzeug
from six.moves import http_client as http
from webargs import flaskparser

from flask_apispec import utils

import marshmallow as ma

MARSHMALLOW_VERSION_INFO = tuple(
[int(part) for part in ma.__version__.split('.') if part.isdigit()]
Expand Down Expand Up @@ -43,8 +46,11 @@ def call_view(self, *args, **kwargs):
parsed = parser.parse(schema, locations=option['kwargs']['locations'])
if getattr(schema, 'many', False):
args += tuple(parsed)
else:
elif isinstance(parsed, Mapping):
kwargs.update(parsed)
else:
args += (parsed, )

return self.func(*args, **kwargs)

def marshal_result(self, unpacked, status_code):
Expand Down
27 changes: 26 additions & 1 deletion tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json

from flask import make_response
from marshmallow import fields, Schema
from marshmallow import fields, Schema, post_load

from flask_apispec.utils import Ref
from flask_apispec.views import MethodResource
Expand All @@ -30,6 +30,31 @@ def view(**kwargs):
res = client.get('/', {'name': 'freddie'})
assert res.json == {'name': 'freddie'}

def test_use_kwargs_schema_with_post_load(self, app, client):
class User:
def __init__(self, name):
self.name = name

def update(self, name):
self.name = name

class ArgSchema(Schema):
name = fields.Str()

@post_load
def make_object(self, data):
return User(**data)

@app.route('/', methods=('POST', ))
@use_kwargs(ArgSchema())
def view(user):
assert isinstance(user, User)
return {'name': user.name}

data = {'name': 'freddie'}
res = client.post('/', data)
assert res.json == data

def test_use_kwargs_schema_many(self, app, client):
class ArgSchema(Schema):
name = fields.Str()
Expand Down