diff --git a/pydictable/core.py b/pydictable/core.py index 3ba6084..c0bd343 100644 --- a/pydictable/core.py +++ b/pydictable/core.py @@ -1,7 +1,8 @@ import inspect from datetime import datetime from enum import Enum -from typing import Dict, get_type_hints, Optional, Union, List +from itertools import chain +from typing import Dict, get_type_hints, Union from pydictable.field import StrField, IntField, FloatField, BoolField, ListField, MultiTypeField, UnionField, \ NoneField, ObjectField, DataValidationError, EnumField, DatetimeField @@ -70,8 +71,14 @@ def __clear_default_field_values(self): self.__setattr__(attr, None) def __apply_dict(self, d: dict): - for attr, field in self.__class__.__get_fields().items(): - self.__setattr__(attr, field.from_dict(d.get(self.__get_field_key(attr), field.default))) + _updated_attributes = [] + for attr, field in chain(self.__class__.__get_fields().items(), d.items()): + if isinstance(field, Field): + self.__setattr__(attr, field.from_dict(d.get(self.__get_field_key(attr), field.default))) + _updated_attributes.append(self.__get_field_key(attr)) + continue + if attr not in _updated_attributes: + self.__setattr__(attr, field) def __validate_dict(self, raw_values: dict): for attr, field in self.__get_fields().items(): @@ -102,8 +109,14 @@ def __validate(self): def to_dict(self) -> dict: d = {} - for attr, field in self.__class__.__get_fields().items(): - d[self.__get_field_key(attr)] = field.to_dict(self.__getattribute__(attr)) + _updated_attributes = [] + for attr, field in chain(self.__class__.__get_fields().items(), self.__dict__.items()): + if isinstance(field, Field): + d[self.__get_field_key(attr)] = field.to_dict(self.__getattribute__(attr)) + _updated_attributes.append(attr) + continue + if attr not in _updated_attributes: + d[attr] = field return d @classmethod diff --git a/pydictable/test_core.py b/pydictable/test_core.py index 1c429c6..bb75072 100644 --- a/pydictable/test_core.py +++ b/pydictable/test_core.py @@ -658,3 +658,29 @@ class UserInfo(DictAble): UserInfo(dict={'dob': 1673442076263}) except DataValidationError as e: self.assertEqual(e.err, 'Pre check failed: Invalid value 123 for field name') + + def test_return_all_input_dict_fields(self): + class Address(DictAble): + pin_code: int = IntField(default=560090) + street: str = StrField() + + input_dict = {'street': 'RT Nagar', 'city': 'Bengaluru'} + address = Address(dict=input_dict) + self.assertEqual(address.pin_code, 560090) + self.assertEqual(address.street, 'RT Nagar') + self.assertEqual(address.city, 'Bengaluru') + self.assertEqual(len(address.to_dict()), 3) + + class Email(DictAble): + to: str = StrField(required=True, key='to_email') + subject: str = StrField(required=True, default="General inquiry") + body: str = StrField(required=True) + + input_dict = {'to_email': 'testing@gmail.com', 'body': 'Hello', 'cc': 'testing1@gmail.com', 'bcc': 't@test.com'} + email = Email(dict=input_dict) + self.assertEqual(email.to, 'testing@gmail.com') + self.assertEqual(email.subject, 'General inquiry') + self.assertEqual(email.body, 'Hello') + self.assertEqual(email.cc, 'testing1@gmail.com') + self.assertEqual(email.bcc, 't@test.com') + self.assertEqual(len(email.to_dict()), 5)