Skip to content

Commit

Permalink
Reformatted, now makes use of dict.get-default
Browse files Browse the repository at this point in the history
  • Loading branch information
ri0t committed Sep 10, 2017
1 parent c266deb commit cf178c6
Showing 1 changed file with 61 additions and 55 deletions.
116 changes: 61 additions & 55 deletions warmongo/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from bson.errors import InvalidId
from .exceptions import InvalidSchemaException


class OutdatedCodeException(Exception):
pass

Expand All @@ -41,23 +42,24 @@ class OutdatedCodeException(Exception):
"date": datetime
}


def extend_with_default(validator_class):
validate_properties = validator_class.VALIDATORS["properties"]

def set_defaults(validator, properties, instance, schema):
for property, subschema in properties.items():
#print(property, subschema)
# print(property, subschema)
if "default" in subschema:
#print("Setting default: ", subschema['default'])
# print("Setting default: ", subschema['default'])
instance.setdefault(property, subschema["default"])

for error in validate_properties(
validator, properties, instance, schema,
validator, properties, instance, schema,
):
yield error

return validators.extend(
validator_class, {"properties" : set_defaults},
validator_class, {"properties": set_defaults},
)


Expand Down Expand Up @@ -85,7 +87,7 @@ def __init__(self, origfields={}, from_find=False, *args, **kwargs):
# populate any default fields for objects that haven't come from the DB
if not from_find:
DefaultValidatingDraft4Validator(self._schema).validate(fields)
#for field, details in self._schema["properties"].items():
# for field, details in self._schema["properties"].items():
# if "default" in details and not field in fields:
# fields[field] = details["default"]

Expand All @@ -100,14 +102,15 @@ def get(self, field, default=None):

@classmethod
def collection_name(cls):
""" Get the collection associated with this class. The convention is
to take the lowercase of the class name and pluralize it. """
if cls._schema.get("collectionName"):
return cls._schema.get("collectionName")
elif cls._schema.get("name"):
name = cls._schema.get("name")
else:
name = cls.__name__
""" Get the collection associated with this class. """
name = cls._schema.get(
"collectionName",
cls._schema.get(
"collectionName",
cls._schema.get("name",
cls.__name__)
)
)

# convert to snake case
name = (name[0] + re.sub('([A-Z])', r'_\1', name[1:])).lower()
Expand All @@ -118,17 +121,15 @@ def collection_name(cls):
def database_name(cls):
""" Get the database associated with this class. Meant to be overridden
in subclasses. """
if cls._schema.get("databaseName"):
return cls._schema.get("databaseName")
return None
return cls._schema.get("databaseName", None)

def to_dict(self):
""" Convert the object to a dict. """
return self._fields

def validate(self):
""" Validate `schema` against a dict `obj`. """
#self.validate_field("", self._schema, self._fields)
# self.validate_field("", self._schema, self._fields)
try:
pass
# TODO: Deepcopying for validation is probably not so good ;)
Expand All @@ -139,13 +140,14 @@ def validate(self):
except InvalidId:
raise ValidationError('Invalid object ID: ', fields['_id'])

# Now remove for schema validation (jsonschema knows nothing off object ids)
# Now remove for schema validation (jsonschema knows nothing
# off object ids)
del (fields['_id'])

validate(fields, self._schema)
except ValidationError as e:
raise ValidationError("Error:\n" + str(e) + "\nFields:\n" + str(self._fields))

raise ValidationError(
"Error:\n" + str(e) + "\nFields:\n" + str(self._fields))

def cast(self, fields, schema=None):
""" Cast the fields from Mongo into our format - necessary to convert
Expand All @@ -159,12 +161,14 @@ def cast(self, fields, schema=None):
schema.get("properties"):
result = dict()
for key, value in fields.items():
result[key] = self.cast(value, schema["properties"].get(key, {}))
result[key] = self.cast(value,
schema["properties"].get(key, {}))
return result
elif value_type == "array" and isinstance(fields, list) and schema.get("items"):
elif value_type == "array" and isinstance(fields, list) and schema.get(
"items"):
return [
self.cast(value, schema["items"]) for value in fields
]
]
elif value_type == "integer" and isinstance(fields, float):
# The only thing that needs to be casted: floats -> ints
return int(fields)
Expand All @@ -185,37 +189,37 @@ def __getattr__(self, attr):
raise AttributeError("Item has no attribute '%s'" % attr)


# if attr.startswith('_'):
# return super(ModelBase, self).__getattr__(attr)
#
# if attr in self._schema["properties"] and attr in self._fields:
# #print("Direct hit")
# return self._fields.get(attr)
# curschema = self._schema["properties"]
# curfields = self._fields
# path = attr
# newattr = path
#
# #print("Query path:", path)
# #print("Initial Fields:", curfields)
#
# while '.' in path:
#
# newattr, path = path.split('.', maxsplit=1)
# #print("Looking for intermediate path in ", newattr, path)
#
# if newattr in curschema and newattr in curfields:
# curschema = curschema[newattr]['properties']
# curfields = curfields[newattr]
# else:
# if attr.startswith('_'):
# return super(ModelBase, self).__getattr__(attr)
#
# if attr in self._schema["properties"] and attr in self._fields:
# #print("Direct hit")
# return self._fields.get(attr)
# curschema = self._schema["properties"]
# curfields = self._fields
# path = attr
# newattr = path
#
# #print("Query path:", path)
# #print("Initial Fields:", curfields)
#
# while '.' in path:
#
# newattr, path = path.split('.', maxsplit=1)
# #print("Looking for intermediate path in ", newattr, path)
#
# if newattr in curschema and newattr in curfields:
# curschema = curschema[newattr]['properties']
# curfields = curfields[newattr]
# else:
# raise AttributeError("Item has no intermediate attribute '%s'"
# % ( newattr))
#
#
# if newattr in curschema and newattr in curfields:
# return curfields.get(newattr)
# else:
# raise AttributeError("Item has no attribute '%s'" % ( attr))
# % ( newattr))
#
#
# if newattr in curschema and newattr in curfields:
# return curfields.get(newattr)
# else:
# raise AttributeError("Item has no attribute '%s'" % ( attr))

def __setattr__(self, attr, value):
""" Set one of the fields, with validation. Exception is on "private"
Expand All @@ -235,7 +239,8 @@ def __setattr__(self, attr, value):
validator.validate()
elif not self._schema.get("additionalProperties", True):
# not allowed to add additional properties
raise ValidationError("Additional property '%s' not allowed!" % attr)
raise ValidationError(
"Additional property '%s' not allowed!" % attr)

self._fields[attr] = value
return value
Expand All @@ -246,4 +251,5 @@ def update(self, newfields, updateId=False):
if not key == "_id" or updateId:
self.__setattr__(key, value)
except Exception as e:
raise ValidationError("Unknown Validation error: '%s' (%s)" % (e, type(e)))
raise ValidationError(
"Unknown Validation error: '%s' (%s)" % (e, type(e)))

0 comments on commit cf178c6

Please sign in to comment.