diff --git a/.gitignore b/.gitignore index b8f00fed..46ffd526 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,8 @@ node_modules # Jet Brains .idea + +# VS code & pipenv ignore +Pipfile +Pipfile.lock +.vscode/settings.json diff --git a/flask_restx/tools.py b/flask_restx/tools.py new file mode 100644 index 00000000..1cb6a44c --- /dev/null +++ b/flask_restx/tools.py @@ -0,0 +1,149 @@ +# coding:utf-8 +__all__ = ["createApiModel", "_get_res"] +from sqlalchemy.inspection import inspect +from . import fields +from sqlalchemy import types +from .model import Model + +_not_allowed = ["TypeEngine", "TypeDecorator", "UserDefinedType", "PickleType"] +conversion = { + "INT": "Integer", + "CHAR": "String", + "VARCHAR": "String", + "NCHAR": "String", + "NVARCHAR": "String", + "TEXT": "String", + "Text": "String", + "FLOAT": "Float", + "NUMERIC": "Float", + "REAL": "Float", + "DECIMAL": "Float", + "TIMESTAMP": "DateTime", + "DATETIME": "DateTime", + "CLOB": "Raw", + "BLOB": "Raw", + "BINARY": "Raw", + "VARBINARY": "Raw", + "BOOLEAN": "Boolean", + "BIGINT": "Integer", + "SMALLINT": "Integer", + "INTEGER": "Integer", + "DATE": "Date", + "TIME": "String", + "String": "String", + "Integer": "Integer", + "SmallInteger": "Integer", + "BigInteger": "Integer", + "Numeric": "Float", + "Float": "Float", + "DateTime": "DateTime", + "Date": "Date", + "Time": "String", + "LargeBinary": "Raw", + "Boolean": "Boolean", + "Unicode": "String", + "Concatenable": "String", + "UnicodeText": "String", + "Interval": "List", + "Enum": "List", + "Indexable": "List", + "ARRAY": "List", + "JSON": "List", +} + +fieldtypes = [r for r in types.__all__ if r not in _not_allowed] + + +def _get_res( + table, modelname: str = None, readonlyfields: list = [], show: list = [] +) -> Model: + """Private function to obtain model_columns as a list + + Args: + table: SQLalchemy Table + modelname (Optional[str], optional): Custom model name. if it's is None then the modelname will be the capitalized tablename. + readonlyfields (Optional[list], optional): Set readonly fields. Defaults to []. + show (Optional[list], optional): Set shown fields. Defaults to []. + + Return: + Model + """ + + res = {} + foreignsmapped = [] + # reading from sqlalchemy column into flask-restx column + for fieldname, col in table.__table__.columns.items(): + tipo = col.type + isprimarykey = col.primary_key and fieldname not in show + params = {} + fieldnameinreadonly = fieldname in readonlyfields + if isprimarykey or fieldnameinreadonly: + params = {"readonly": True} + if not col.nullable and (not fieldnameinreadonly) and (not isprimarykey): + params["required"] = True + if col.default is not None: + if isinstance(col.default.arg, (str, float, int, bytearray, bytes)): + params["default"] = col.default.arg + _tipo = str(tipo).split("(")[0] + if _tipo in fieldtypes: + if hasattr(tipo, "length"): + params["max_length"] = tipo.length + if len(col.foreign_keys) > 0: + foreignsmapped.extend(list(col.foreign_keys)) + res[fieldname] = getattr(fields, conversion[_tipo])(**params) + # cheking for relationships + relationitems = [] + try: + relationitems = inspect(table).relationships.items() + except: + # It could faild in composed primary keys + pass + # implementing relationship columns + for field, relationship in relationitems: + if relationship.backref != table.__tablename__: + continue + try: + col = list(relationship.local_columns)[0] + tipo = col.type + _tipo = str(tipo).split("(")[0] + if _tipo in fieldtypes: + outparams = {} + if hasattr(tipo, "length"): + params["max_length"] = tipo.length + if field in readonlyfields: + outparams["readonly"] = True + if col.foreign_keys is not None: + foreignsmapped.extend(list(col.foreign_keys)) + if relationship.uselist: + res[field] = fields.List( + getattr(fields, conversion[_tipo])(**params), **outparams + ) + else: + for key, value in outparams.items(): + params[key] = value + res[field] = getattr(fields, conversion[_tipo])(**params) + except: + continue + if modelname in ("", None): + modelname = table.__tablename__.lower().capitalize() + return res + + +def createApiModel( + api, table, modelname: str = None, readonlyfields: list = [], show: list = [] +) -> Model: + """Create a basic Flask-restx ApiModel by given an sqlachemy Table and a flask-restx api. + Requires sqlalchemy + + Args: + api: Flask-restx api + table: SqlalchemyTable + modelname (Optional[str], optional): Custom model name. if it's is None then the modelname will be the capitalized tablename. + readonlyfields (Optional[list], optional): Set readonly fields. Defaults to []. + show (Optional[list], optional): Set shown fields. Defaults to []. + + Return: + Model + """ + res = _get_res(table, modelname, readonlyfields, show) + return api.model(modelname, res) diff --git a/requirements/test.pip b/requirements/test.pip index e4d58140..70d7623b 100644 --- a/requirements/test.pip +++ b/requirements/test.pip @@ -10,4 +10,5 @@ pytest-profiling==1.7.0 tzlocal invoke==2.2.0 twine==3.8.0 -setuptools +sqlalchemy==1.4.44 +setuptools \ No newline at end of file diff --git a/tests/test_tools_createApiModel.py b/tests/test_tools_createApiModel.py new file mode 100644 index 00000000..accc69c9 --- /dev/null +++ b/tests/test_tools_createApiModel.py @@ -0,0 +1,101 @@ +# codign:utf-8 +from pathlib import Path +from sys import path + +p = Path(__file__).parent.parent.resolve() +path.insert(0, str(p)) + +from flask_restx import fields +from flask_restx.tools import _get_res + +SQLALCHEMY_AVAILABLE = False +try: + + from sqlalchemy import ( + Boolean, + Column, + Date, + DateTime, + Float, + ForeignKey, + Integer, + String, + ) + from sqlalchemy.orm import declarative_base, relationship +except ImportError: + print("ERROR") + SQLALCHEMY_AVAILABLE = False + + +if SQLALCHEMY_AVAILABLE: + + Base = declarative_base() + + class Unrelated(Base): + __tablename__ = "unrelated_table" + id = Column(Integer, primary_key=True) + string = Column(String(30)) + float = Column(Float()) + datetime = Column(DateTime()) + boolean = Column(Boolean()) + date = Column(Date()) + + class User(Base): + __tablename__ = "user_account" + id = Column(Integer, primary_key=True) + name = Column(String(30)) + fullname = Column(String) + addresses = relationship( + "Address", back_populates="user", cascade="all, delete-orphan" + ) + + def __repr__(self): + return f"""User(id={self.id!r}, name={self.name!r}, + fullname={self.fullname!r})""" + + class Address(Base): + __tablename__ = "address" + id = Column(Integer, primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey("user_account.id"), nullable=False) + user = relationship("User", back_populates="addresses") + + def __repr__(self): + return f"""Address(id={self.id!r}, + email_address={self.email_address!r})""" + + def _checkestruct(own, expected): + for key, value in expected.items(): + try: + dc1 = own[key].__dict__ + dc2 = value.__dict__ + if not all((dc1.get(k) == v for k, v in dc2.items())): + return False + except KeyError: + return False + return True + + class CreateApiModel_test(object): + def test_table_without_relationships(self, *args, **kwargs): + mymodel = { + "id": fields.Integer(readonly=True), + "string": fields.String(max_length=30), + "float": fields.Float(), + "boolean": fields.Boolean(), + "date": fields.Date(), + } + assert _checkestruct(_get_res(Unrelated), mymodel) + + def test_table_with_single_relationship(self, *args, **kwargs): + # TODO + # mymodel = {'':,'':} + # assert _checkestruct(_get_res(Address), mymodel) + assert True + + def test_editable_primary_key(self, *args, **kwargs): + # TODO + assert True + + def test_making_not_editable_fields(self, *args, **kwargs): + # TODO + assert True