diff --git a/wtforms_sqlalchemy/fields.py b/wtforms_sqlalchemy/fields.py index d377fcf..e3a3bfa 100644 --- a/wtforms_sqlalchemy/fields.py +++ b/wtforms_sqlalchemy/fields.py @@ -2,10 +2,12 @@ Useful form fields for use with SQLAlchemy ORM. """ import operator +import itertools from wtforms import widgets -from wtforms.fields import SelectFieldBase +from wtforms.fields import FieldList, SelectFieldBase from wtforms.validators import ValidationError +from wtforms.widgets.core import html_params, Markup try: from sqlalchemy.orm.util import identity_key @@ -20,8 +22,319 @@ "QuerySelectMultipleField", "QueryRadioField", "QueryCheckboxField", + "ModelFieldList", ) +DELETE_BUTTON = '_MFLTW_DEL' +ADD_BUTTON = '_MFLTW_ADD' + +class _ModelFieldListTableWiget(): + prefix = '_MFLTW-' + + def __init__(self, horizontal_layout, prefix_label=True, with_table_tag=True): + self.with_table_tag = with_table_tag + self.prefix_label = prefix_label + self.horizontal_layout = horizontal_layout + + + def build_horizontal_layout(self, field, **kwargs): + html = [] + + + if self.with_table_tag: + kwargs.setdefault("id", field.id) + html.append("" % html_params(**kwargs)) + + # Add a visible invisible submit button to allow submit by pressing enter + html.append("""""") + + hidden = "" + for subfield in field: + if subfield.type in ("HiddenField", "CSRFTokenField"): + hidden += str(subfield) + else: + html.append("") + if self.prefix_label: + html.append(f"") + else: + html.append(f"") + + delete = self.parent._separator.join([subfield.id, DELETE_BUTTON]) + html.append(f"") + + if self.with_table_tag: + html.append("
{subfield.label} {subfield()}{subfield()} {subfield.label}
") + if hidden: + html.append(hidden) + + add = self.parent._separator.join([field.id, ADD_BUTTON]) + html.append(f"") + + return Markup("".join(html)) + + + def build_vertical_layout(self, field, **kwargs): + html = [] + + html.append(f"
") + + # Add a visible invisible submit button to allow submit by pressing enter + html.append("""""") + + if self.with_table_tag: + kwargs.setdefault("id", field.id) + html.append(f"") + + hidden = "" + if len(field) == 0: + html.append(f"") + else: + # Build up table head + html.append("") + for subfield in field: + for subsubfield in subfield: + html.append(f"") + break + html.append(f"") + html.append("") + + # Build up table body + html.append("") + + for subfield in field: + if subfield.type in ("HiddenField", "CSRFTokenField"): + hidden += str(subfield) + else: + html.append("") + for subsubfield in subfield: + html.append(f"") + delete = field._separator.join([subfield.id, DELETE_BUTTON]) + html.append(f"") + html.append("") + + html.append("") + + if self.with_table_tag: + html.append("
There are no {field.label.text.lower()}
{subsubfield.label}-
{subsubfield()}
") + if hidden: + html.append(hidden) + + add = field._separator.join([field.id, ADD_BUTTON]) + html.append(f"") + + html.append('
') + + return Markup("".join(html)) + + + def __call__(self, field, **kwargs): + if self.horizontal_layout: + return self.build_horizontal_layout(field, **kwargs) + else: + return self.build_vertical_layout(field, **kwargs) + + +class ModelFieldList(FieldList): + FROM_DB_ID = "_MFL_PK" + NEW_ID = "_MFL_NEW" + + def __init__(self, unbound_field, horizontal_layout=False, model=None, *args, **kwargs): + self.widget = _ModelFieldListTableWiget(horizontal_layout) + self.model = model + + super().__init__(unbound_field, *args, **kwargs) + + if not self.model: + raise ValueError("ModelFieldList requires model to be set") + + def _rebuild_form(self, formdata): + db_elements_ids = set() + db_elements_deleted = set() + + new_elements_indices = set() + new_elements_deleted = set() + + add_button_pressed = False + + prefix = self.id + self._separator + + # Examine all elements in formdata + for form_element in formdata: + if not form_element.startswith(prefix): + continue + + if form_element[len(prefix):] == ADD_BUTTON: + # _MFLTW_ADD + add_button_pressed = True + self.valid = False + continue + + parts = form_element[len(prefix):].split(self._separator) + + if parts[0] == self.FROM_DB_ID: + # _MFL_PK-10232 + _id = int(parts[1]) + if len(parts) == 3 and parts[2] == DELETE_BUTTON: + # _MFL_PK-10232-_MFLTW_DEL + db_elements_deleted.add(_id) + self.valid = False + else: + db_elements_ids.add(_id) + + # See if the element was added to the form earlier (without processing in db) + elif parts[0] == self.NEW_ID: + # _MFL_NEW-1598 + _id = int(parts[1]) + if len(parts) == 3 and parts[2] == DELETE_BUTTON: + # _MFL_NEW-1598-_MFLTW_DEL + new_elements_deleted.add(_id) + self.valid = False + else: + new_elements_indices.add(_id) + + + # Now, add rows according to results from above loop + + # First, add database form entries if they were not deleted + for _id, obj in self.object_data.items(): + if _id in db_elements_ids and _id not in db_elements_deleted: + self._add_entry(formdata=formdata, sql_obj=obj) + + # Then, add new entries if they were not deleted + for new_ix in sorted(new_elements_indices - new_elements_deleted): + self._add_entry(formdata=formdata, index=new_ix) + + # Finally, add an empty entry if the add button was pressed + if add_button_pressed: + self._add_entry() + + + def _add_entry(self, formdata=None, sql_obj=None, index=None): + assert ( + not self.max_entries or len(self.entries) < self.max_entries + ), "You cannot have more than max_entries entries in this FieldList" + + if sql_obj: + entry_type = self.FROM_DB_ID + entry_id = str(sql_obj.id) + elif index: + self.last_index = index + entry_type = self.NEW_ID + entry_id = str(index) + else: + self.last_index += 1 + entry_type = self.NEW_ID + entry_id = str(self.last_index) + + field_name = self._separator.join([self.short_name, entry_type, entry_id]) + field_id = self._separator.join([self.id, entry_type, entry_id]) + + field = self.unbound_field.bind( + form=None, + name=field_name, + prefix=self._prefix, + id=field_id, + _meta=self.meta, + translations=self._translations, + ) + + field.process(formdata, sql_obj) + self.entries.append(field) + + return field + + def process(self, formdata, data=None, extra_filters=None): + if extra_filters: + raise TypeError( + "FieldList does not accept any filters. Instead, define" + " them on the enclosed field." + ) + + self.valid = True + self.entries = [] + + self.object_data = {obj.id: obj for obj in data} if data else {} + + if formdata: + # Add entries based on formdata + self._rebuild_form(formdata) + + else: + # Add entries based on self.object_data + for obj in self.object_data.values(): + self._add_entry(formdata=None, sql_obj=obj) + + # Add entries until min_entries is reached + while len(self.entries) < self.min_entries: + self._add_entry(formdata) + + def validate(self, form, extra_validators=()): + self.errors = [] + + valid = self.valid + for subfield in self.entries: + valid = subfield.validate(form) and valid + self.errors.append(subfield.errors) + + if not any(x for x in self.errors): + self.errors = [] + + chain = itertools.chain(self.validators, extra_validators) + self._run_validation_chain(form, chain) + + return valid and len(self.errors) == 0 + + def populate_obj(self, obj, name): + relation = getattr(obj, name) + + prefix = self.id + self._separator + updated = set() + + for entry in self.entries: + if not entry.id.startswith(prefix): + continue + + parts = entry.id[len(prefix):].split(self._separator, 2) + _fake_util = type("_fake", (object,), {}) + + if parts[0] == self.FROM_DB_ID: + # _MFL_PK-10232 + _id = int(parts[1]) + + + # If the object is found in self.object_data, update it. + if obj := self.object_data.get(_id): + fake_obj = _fake_util() + fake_obj.data = obj + entry.populate_obj(fake_obj, "data") + + obj = fake_obj.data + + updated.add(obj.id) + + # If the object was newly added, add it to relation + elif parts[0] == self.NEW_ID: + # _MFL_NEW-1598 + fake_obj = _fake_util() + fake_obj.data = self.model() + + entry.populate_obj(fake_obj, "data") + + new_model = fake_obj.data + + relation.append(new_model) + + # Finally also if relation has any objects that are missing + # in self.entries, if so delete them from relation + for deleted_id in (set(self.object_data.keys()) - updated): + obj = self.object_data.get(deleted_id) + try: + ix = relation.index(obj) + except ValueError: + continue + + db_obj = relation.pop(ix) + class QuerySelectField(SelectFieldBase): """ diff --git a/wtforms_sqlalchemy/orm.py b/wtforms_sqlalchemy/orm.py index 6d41068..fc7bcea 100644 --- a/wtforms_sqlalchemy/orm.py +++ b/wtforms_sqlalchemy/orm.py @@ -6,9 +6,9 @@ from wtforms import fields as wtforms_fields from wtforms import validators from wtforms.form import Form +from sqlalchemy.orm import RelationshipProperty -from .fields import QuerySelectField -from .fields import QuerySelectMultipleField +from .fields import ModelFieldList, QuerySelectField, QuerySelectMultipleField __all__ = ( "model_fields", @@ -75,7 +75,7 @@ def get_converter(self, column): % (column.name, types[0]) ) - def convert(self, model, mapper, prop, field_args, db_session=None): + def convert(self, model, mapper, prop, field_args, db_session=None, embed=False): if not hasattr(prop, "columns") and not hasattr(prop, "direction"): return elif not hasattr(prop, "direction") and len(prop.columns) != 1: @@ -153,7 +153,7 @@ def convert(self, model, mapper, prop, field_args, db_session=None): converter = self.converters[prop.direction.name] return converter( - model=model, mapper=mapper, prop=prop, column=column, field_args=kwargs + model=model, mapper=mapper, prop=prop, column=column, field_args=kwargs, embed=embed, db_session=db_session ) @@ -233,8 +233,21 @@ def conv_PGUuid(self, field_args, **extra): def conv_ManyToOne(self, field_args, **extra): return QuerySelectField(**field_args) - @converts("MANYTOMANY", "ONETOMANY") - def conv_ManyToMany(self, field_args, **extra): + @converts("ONETOMANY") + def conv_OneToMany(self, field_args, prop, embed, db_session, **extra): + if embed: + RelatedModel = prop.entity.class_ + sub_embed = embed.get(prop.key, False) if isinstance(embed,dict) else True + if not sub_embed: + return + + RelatedModelForm = model_form(RelatedModel, db_session=db_session, embed=sub_embed, exclude=prop.back_populates) + return ModelFieldList(wtforms_fields.FormField(RelatedModelForm), model=RelatedModel) + + return QuerySelectMultipleField(**field_args) + + @converts("MANYTOMANY") + def conv_ManyToMany(self, field_args, prop, db_session, **extra): return QuerySelectMultipleField(**field_args) @@ -247,6 +260,7 @@ def model_fields( converter=None, exclude_pk=False, exclude_fk=False, + embed=False, ): """ Generate a dictionary of fields for a given SQLAlchemy model. @@ -257,6 +271,7 @@ def model_fields( converter = converter or ModelConverter() field_args = field_args or {} properties = [] + relationship_properties = [] for prop in mapper.iterate_properties: if getattr(prop, "columns", None): @@ -265,7 +280,12 @@ def model_fields( elif exclude_pk and prop.columns[0].primary_key: continue - properties.append((prop.key, prop)) + if isinstance(prop, RelationshipProperty): + relationship_properties.append((prop.key, prop)) + else: + properties.append((prop.key, prop)) + + properties += relationship_properties # ((p.key, p) for p in mapper.iterate_properties) if only: @@ -275,7 +295,7 @@ def model_fields( field_dict = {} for name, prop in properties: - field = converter.convert(model, mapper, prop, field_args.get(name), db_session) + field = converter.convert(model, mapper, prop, field_args.get(name), db_session, embed) if field is not None: field_dict[name] = field @@ -293,6 +313,7 @@ def model_form( exclude_pk=True, exclude_fk=True, type_name=None, + embed=False, ): """ Create a wtforms Form for a given SQLAlchemy model class:: @@ -325,6 +346,12 @@ def model_form( An optional boolean to force foreign keys exclusion. :param type_name: An optional string to set returned type name. + :param embed: + An optional boolean or dictionary specifying whether and/or how to embed + relations in model. If set to True, fields for all related models are + generated. If set to a dictionary, all specified fields will be embedded + (Example: embed={'student': {courses: False}}, where 'student' will get + fields but 'courses' will not) """ if not hasattr(model, "_sa_class_manager"): raise TypeError("model must be a sqlalchemy mapped model") @@ -339,5 +366,6 @@ def model_form( converter, exclude_pk=exclude_pk, exclude_fk=exclude_fk, + embed=embed, ) return type(type_name, (base_class,), field_dict)