diff --git a/wtforms_sqlalchemy/fields.py b/wtforms_sqlalchemy/fields.py
index d377fcf..e3a3bfa 100644
--- a/wtforms_sqlalchemy/fields.py
+++ b/wtforms_sqlalchemy/fields.py
@@ -2,10 +2,12 @@
Useful form fields for use with SQLAlchemy ORM.
"""
import operator
+import itertools
from wtforms import widgets
-from wtforms.fields import SelectFieldBase
+from wtforms.fields import FieldList, SelectFieldBase
from wtforms.validators import ValidationError
+from wtforms.widgets.core import html_params, Markup
try:
from sqlalchemy.orm.util import identity_key
@@ -20,8 +22,319 @@
"QuerySelectMultipleField",
"QueryRadioField",
"QueryCheckboxField",
+ "ModelFieldList",
)
+DELETE_BUTTON = '_MFLTW_DEL'
+ADD_BUTTON = '_MFLTW_ADD'
+
+class _ModelFieldListTableWiget():
+ prefix = '_MFLTW-'
+
+ def __init__(self, horizontal_layout, prefix_label=True, with_table_tag=True):
+ self.with_table_tag = with_table_tag
+ self.prefix_label = prefix_label
+ self.horizontal_layout = horizontal_layout
+
+
+ def build_horizontal_layout(self, field, **kwargs):
+ html = []
+
+
+ if self.with_table_tag:
+ kwargs.setdefault("id", field.id)
+ html.append("
")
+ if hidden:
+ html.append(hidden)
+
+ add = self.parent._separator.join([field.id, ADD_BUTTON])
+ html.append(f" ")
+
+ return Markup("".join(html))
+
+
+ def build_vertical_layout(self, field, **kwargs):
+ html = []
+
+ html.append(f"")
+
+ # Add a visible invisible submit button to allow submit by pressing enter
+ html.append("""
""")
+
+ if self.with_table_tag:
+ kwargs.setdefault("id", field.id)
+ html.append(f"
")
+
+ hidden = ""
+ if len(field) == 0:
+ html.append(f"There are no {field.label.text.lower()} ")
+ else:
+ # Build up table head
+ html.append("")
+ for subfield in field:
+ for subsubfield in subfield:
+ html.append(f"{subsubfield.label} ")
+ break
+ html.append(f"- ")
+ html.append(" ")
+
+ # Build up table body
+ html.append("")
+
+ for subfield in field:
+ if subfield.type in ("HiddenField", "CSRFTokenField"):
+ hidden += str(subfield)
+ else:
+ html.append("")
+ for subsubfield in subfield:
+ html.append(f"{subsubfield()} ")
+ delete = field._separator.join([subfield.id, DELETE_BUTTON])
+ html.append(f"Delete ")
+ html.append(" ")
+
+ html.append(" ")
+
+ if self.with_table_tag:
+ html.append("
")
+ if hidden:
+ html.append(hidden)
+
+ add = field._separator.join([field.id, ADD_BUTTON])
+ html.append(f"
Add {field.label.text.lower()} ")
+
+ html.append('
')
+
+ return Markup("".join(html))
+
+
+ def __call__(self, field, **kwargs):
+ if self.horizontal_layout:
+ return self.build_horizontal_layout(field, **kwargs)
+ else:
+ return self.build_vertical_layout(field, **kwargs)
+
+
+class ModelFieldList(FieldList):
+ FROM_DB_ID = "_MFL_PK"
+ NEW_ID = "_MFL_NEW"
+
+ def __init__(self, unbound_field, horizontal_layout=False, model=None, *args, **kwargs):
+ self.widget = _ModelFieldListTableWiget(horizontal_layout)
+ self.model = model
+
+ super().__init__(unbound_field, *args, **kwargs)
+
+ if not self.model:
+ raise ValueError("ModelFieldList requires model to be set")
+
+ def _rebuild_form(self, formdata):
+ db_elements_ids = set()
+ db_elements_deleted = set()
+
+ new_elements_indices = set()
+ new_elements_deleted = set()
+
+ add_button_pressed = False
+
+ prefix = self.id + self._separator
+
+ # Examine all elements in formdata
+ for form_element in formdata:
+ if not form_element.startswith(prefix):
+ continue
+
+ if form_element[len(prefix):] == ADD_BUTTON:
+ # _MFLTW_ADD
+ add_button_pressed = True
+ self.valid = False
+ continue
+
+ parts = form_element[len(prefix):].split(self._separator)
+
+ if parts[0] == self.FROM_DB_ID:
+ # _MFL_PK-10232
+ _id = int(parts[1])
+ if len(parts) == 3 and parts[2] == DELETE_BUTTON:
+ # _MFL_PK-10232-_MFLTW_DEL
+ db_elements_deleted.add(_id)
+ self.valid = False
+ else:
+ db_elements_ids.add(_id)
+
+ # See if the element was added to the form earlier (without processing in db)
+ elif parts[0] == self.NEW_ID:
+ # _MFL_NEW-1598
+ _id = int(parts[1])
+ if len(parts) == 3 and parts[2] == DELETE_BUTTON:
+ # _MFL_NEW-1598-_MFLTW_DEL
+ new_elements_deleted.add(_id)
+ self.valid = False
+ else:
+ new_elements_indices.add(_id)
+
+
+ # Now, add rows according to results from above loop
+
+ # First, add database form entries if they were not deleted
+ for _id, obj in self.object_data.items():
+ if _id in db_elements_ids and _id not in db_elements_deleted:
+ self._add_entry(formdata=formdata, sql_obj=obj)
+
+ # Then, add new entries if they were not deleted
+ for new_ix in sorted(new_elements_indices - new_elements_deleted):
+ self._add_entry(formdata=formdata, index=new_ix)
+
+ # Finally, add an empty entry if the add button was pressed
+ if add_button_pressed:
+ self._add_entry()
+
+
+ def _add_entry(self, formdata=None, sql_obj=None, index=None):
+ assert (
+ not self.max_entries or len(self.entries) < self.max_entries
+ ), "You cannot have more than max_entries entries in this FieldList"
+
+ if sql_obj:
+ entry_type = self.FROM_DB_ID
+ entry_id = str(sql_obj.id)
+ elif index:
+ self.last_index = index
+ entry_type = self.NEW_ID
+ entry_id = str(index)
+ else:
+ self.last_index += 1
+ entry_type = self.NEW_ID
+ entry_id = str(self.last_index)
+
+ field_name = self._separator.join([self.short_name, entry_type, entry_id])
+ field_id = self._separator.join([self.id, entry_type, entry_id])
+
+ field = self.unbound_field.bind(
+ form=None,
+ name=field_name,
+ prefix=self._prefix,
+ id=field_id,
+ _meta=self.meta,
+ translations=self._translations,
+ )
+
+ field.process(formdata, sql_obj)
+ self.entries.append(field)
+
+ return field
+
+ def process(self, formdata, data=None, extra_filters=None):
+ if extra_filters:
+ raise TypeError(
+ "FieldList does not accept any filters. Instead, define"
+ " them on the enclosed field."
+ )
+
+ self.valid = True
+ self.entries = []
+
+ self.object_data = {obj.id: obj for obj in data} if data else {}
+
+ if formdata:
+ # Add entries based on formdata
+ self._rebuild_form(formdata)
+
+ else:
+ # Add entries based on self.object_data
+ for obj in self.object_data.values():
+ self._add_entry(formdata=None, sql_obj=obj)
+
+ # Add entries until min_entries is reached
+ while len(self.entries) < self.min_entries:
+ self._add_entry(formdata)
+
+ def validate(self, form, extra_validators=()):
+ self.errors = []
+
+ valid = self.valid
+ for subfield in self.entries:
+ valid = subfield.validate(form) and valid
+ self.errors.append(subfield.errors)
+
+ if not any(x for x in self.errors):
+ self.errors = []
+
+ chain = itertools.chain(self.validators, extra_validators)
+ self._run_validation_chain(form, chain)
+
+ return valid and len(self.errors) == 0
+
+ def populate_obj(self, obj, name):
+ relation = getattr(obj, name)
+
+ prefix = self.id + self._separator
+ updated = set()
+
+ for entry in self.entries:
+ if not entry.id.startswith(prefix):
+ continue
+
+ parts = entry.id[len(prefix):].split(self._separator, 2)
+ _fake_util = type("_fake", (object,), {})
+
+ if parts[0] == self.FROM_DB_ID:
+ # _MFL_PK-10232
+ _id = int(parts[1])
+
+
+ # If the object is found in self.object_data, update it.
+ if obj := self.object_data.get(_id):
+ fake_obj = _fake_util()
+ fake_obj.data = obj
+ entry.populate_obj(fake_obj, "data")
+
+ obj = fake_obj.data
+
+ updated.add(obj.id)
+
+ # If the object was newly added, add it to relation
+ elif parts[0] == self.NEW_ID:
+ # _MFL_NEW-1598
+ fake_obj = _fake_util()
+ fake_obj.data = self.model()
+
+ entry.populate_obj(fake_obj, "data")
+
+ new_model = fake_obj.data
+
+ relation.append(new_model)
+
+ # Finally also if relation has any objects that are missing
+ # in self.entries, if so delete them from relation
+ for deleted_id in (set(self.object_data.keys()) - updated):
+ obj = self.object_data.get(deleted_id)
+ try:
+ ix = relation.index(obj)
+ except ValueError:
+ continue
+
+ db_obj = relation.pop(ix)
+
class QuerySelectField(SelectFieldBase):
"""
diff --git a/wtforms_sqlalchemy/orm.py b/wtforms_sqlalchemy/orm.py
index 6d41068..fc7bcea 100644
--- a/wtforms_sqlalchemy/orm.py
+++ b/wtforms_sqlalchemy/orm.py
@@ -6,9 +6,9 @@
from wtforms import fields as wtforms_fields
from wtforms import validators
from wtforms.form import Form
+from sqlalchemy.orm import RelationshipProperty
-from .fields import QuerySelectField
-from .fields import QuerySelectMultipleField
+from .fields import ModelFieldList, QuerySelectField, QuerySelectMultipleField
__all__ = (
"model_fields",
@@ -75,7 +75,7 @@ def get_converter(self, column):
% (column.name, types[0])
)
- def convert(self, model, mapper, prop, field_args, db_session=None):
+ def convert(self, model, mapper, prop, field_args, db_session=None, embed=False):
if not hasattr(prop, "columns") and not hasattr(prop, "direction"):
return
elif not hasattr(prop, "direction") and len(prop.columns) != 1:
@@ -153,7 +153,7 @@ def convert(self, model, mapper, prop, field_args, db_session=None):
converter = self.converters[prop.direction.name]
return converter(
- model=model, mapper=mapper, prop=prop, column=column, field_args=kwargs
+ model=model, mapper=mapper, prop=prop, column=column, field_args=kwargs, embed=embed, db_session=db_session
)
@@ -233,8 +233,21 @@ def conv_PGUuid(self, field_args, **extra):
def conv_ManyToOne(self, field_args, **extra):
return QuerySelectField(**field_args)
- @converts("MANYTOMANY", "ONETOMANY")
- def conv_ManyToMany(self, field_args, **extra):
+ @converts("ONETOMANY")
+ def conv_OneToMany(self, field_args, prop, embed, db_session, **extra):
+ if embed:
+ RelatedModel = prop.entity.class_
+ sub_embed = embed.get(prop.key, False) if isinstance(embed,dict) else True
+ if not sub_embed:
+ return
+
+ RelatedModelForm = model_form(RelatedModel, db_session=db_session, embed=sub_embed, exclude=prop.back_populates)
+ return ModelFieldList(wtforms_fields.FormField(RelatedModelForm), model=RelatedModel)
+
+ return QuerySelectMultipleField(**field_args)
+
+ @converts("MANYTOMANY")
+ def conv_ManyToMany(self, field_args, prop, db_session, **extra):
return QuerySelectMultipleField(**field_args)
@@ -247,6 +260,7 @@ def model_fields(
converter=None,
exclude_pk=False,
exclude_fk=False,
+ embed=False,
):
"""
Generate a dictionary of fields for a given SQLAlchemy model.
@@ -257,6 +271,7 @@ def model_fields(
converter = converter or ModelConverter()
field_args = field_args or {}
properties = []
+ relationship_properties = []
for prop in mapper.iterate_properties:
if getattr(prop, "columns", None):
@@ -265,7 +280,12 @@ def model_fields(
elif exclude_pk and prop.columns[0].primary_key:
continue
- properties.append((prop.key, prop))
+ if isinstance(prop, RelationshipProperty):
+ relationship_properties.append((prop.key, prop))
+ else:
+ properties.append((prop.key, prop))
+
+ properties += relationship_properties
# ((p.key, p) for p in mapper.iterate_properties)
if only:
@@ -275,7 +295,7 @@ def model_fields(
field_dict = {}
for name, prop in properties:
- field = converter.convert(model, mapper, prop, field_args.get(name), db_session)
+ field = converter.convert(model, mapper, prop, field_args.get(name), db_session, embed)
if field is not None:
field_dict[name] = field
@@ -293,6 +313,7 @@ def model_form(
exclude_pk=True,
exclude_fk=True,
type_name=None,
+ embed=False,
):
"""
Create a wtforms Form for a given SQLAlchemy model class::
@@ -325,6 +346,12 @@ def model_form(
An optional boolean to force foreign keys exclusion.
:param type_name:
An optional string to set returned type name.
+ :param embed:
+ An optional boolean or dictionary specifying whether and/or how to embed
+ relations in model. If set to True, fields for all related models are
+ generated. If set to a dictionary, all specified fields will be embedded
+ (Example: embed={'student': {courses: False}}, where 'student' will get
+ fields but 'courses' will not)
"""
if not hasattr(model, "_sa_class_manager"):
raise TypeError("model must be a sqlalchemy mapped model")
@@ -339,5 +366,6 @@ def model_form(
converter,
exclude_pk=exclude_pk,
exclude_fk=exclude_fk,
+ embed=embed,
)
return type(type_name, (base_class,), field_dict)