Source code for anyblok.column

# This file is a part of the AnyBlok project
#
#    Copyright (C) 2016 Jean-Sebastien SUZANNE <jssuzanne@anybox.fr>
#    Copyright (C) 2017 Jean-Sebastien SUZANNE <jssuzanne@anybox.fr>
#    Copyright (C) 2018 Jean-Sebastien SUZANNE <jssuzanne@anybox.fr>
#    Copyright (C) 2019 Jean-Sebastien SUZANNE <js.suzanne@gmail.com>
#    Copyright (C) 2020 Jean-Sebastien SUZANNE <js.suzanne@gmail.com>
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file,You can
# obtain one at http://mozilla.org/MPL/2.0/.
import decimal
import time
from base64 import b64decode, b64encode
from datetime import date, datetime, timedelta
from hashlib import md5
from inspect import ismethod
from json import dumps, loads
from logging import getLogger

import pytz
from dateutil.parser import parse
from sqlalchemy import JSON as SA_JSON
from sqlalchemy import CheckConstraint, and_, or_, types
from sqlalchemy.schema import Column as SA_Column
from sqlalchemy.schema import Sequence as SA_Sequence
from sqlalchemy_utils.types.color import ColorType
from sqlalchemy_utils.types.email import EmailType
from sqlalchemy_utils.types.encrypted.encrypted_type import StringEncryptedType
from sqlalchemy_utils.types.password import Password as SAU_PWD
from sqlalchemy_utils.types.password import PasswordType
from sqlalchemy_utils.types.phone_number import PhoneNumberType
from sqlalchemy_utils.types.scalar_coercible import ScalarCoercible
from sqlalchemy_utils.types.url import URLType
from sqlalchemy_utils.types.uuid import UUIDType

from anyblok.config import Configuration

from .common import sgdb_in
from .field import Field, FieldException
from .mapper import ModelAttribute, ModelAttributeAdapter

pycountry = None
python_pycountry_type = None
try:
    import pycountry

    if not pycountry.countries._is_loaded:
        pycountry.countries._load()

    python_pycountry_type = pycountry.countries.data_class
except ImportError:
    pass


logger = getLogger(__name__)


def wrap_default(registry, namespace, default_val):
    """Return default wrapper

    :param registry: the current registry
    :param namespace: the namespace of the model
    :param default_val:
    :return: default wrapper
    """

    def wrapper():
        """Return wrapper

        :return: default val
        """
        Model = registry.get(namespace)
        if hasattr(Model, default_val):
            func = getattr(Model, default_val)
            if ismethod(func):
                if default_val not in Model.loaded_columns:
                    if default_val not in Model.loaded_fields:
                        return func()
                    else:
                        logger.warning(
                            "The attribute %r is already declared as a default "
                            "value on the Model %r, a field with the same name "
                            "already exists" % (default_val, namespace)
                        )
                else:
                    logger.warning(
                        "The attribute %r is already declared as a default "
                        "value on the Model %r, a column with the same name "
                        "already exists" % (default_val, namespace)
                    )
            else:
                logger.warning(
                    "The attribute %r is already declared as a default "
                    "value on the Model %r, an instance method with the same "
                    "name already exists" % (default_val, namespace)
                )

        return default_val

    return wrapper


class ColumnDefaultValue:
    def __init__(self, callable):
        self.callable = callable

    def get_default_callable(self, registry, namespace, fieldname, properties):
        """Get default callable

        :param registry: the current registry
        :param namespace: the namespace of the model
        :param fieldname: the fieldname of the model
        :param properties: the properties of the model
        :return: default callable
        """
        return self.callable(registry, namespace, fieldname, properties)


class CompareType:
    comparators = []

    @classmethod
    def default_comparator(cls, col1, type1, col2, type2):
        if type1.__class__ is not type2.__class__:
            raise FieldException(
                "You can't add a foreign key using columns with different "
                "types {model1!s}.{col1!s}` pointing to `{model2!s}.{col2!s}` "
                "have different types  {type1!r} -> {type2!r}".format(
                    model1=col1.model_name,
                    col1=col1.attribute_name,
                    model2=col2.model_name,
                    col2=col2.attribute_name,
                    type1=type1.__class__,
                    type2=type2.__class__,
                )
            )

    @classmethod
    def add_comparator(cls, type1, type2):
        def wrapper(funct):
            cls.comparators.append((type1, type2, funct))
            return funct

        return wrapper

    @classmethod
    def validate(cls, col1, type1, col2, type2):
        for cls1, cls2, funct in cls.comparators:
            if type1.__class__ is cls1 and type2.__class__ is cls2:
                funct(col1, type1, col2, type2)
                return

        cls.default_comparator(col1, type1, col2, type2)


class NoDefaultValue:
    pass


[docs]class Column(Field): """Column class This class can't be instantiated """ use_hybrid_property = True foreign_key = None sqlalchemy_type = None type = None def __init__(self, *args, **kwargs): """Initialize the column :param label: label of this field :type label: str """ self.forbid_instance(Column) assert self.sqlalchemy_type self.sequence = None if "type_" in kwargs: del kwargs["type_"] if "foreign_key" in kwargs: self.foreign_key = ModelAttributeAdapter(kwargs.pop("foreign_key")) if "sequence" in kwargs: self.sequence = SA_Sequence(kwargs.pop("sequence")) self.db_column_name = None if "db_column_name" in kwargs: self.db_column_name = kwargs.pop("db_column_name") self.default_val = NoDefaultValue if "default" in kwargs: self.default_val = kwargs.pop("default") self.encrypt_key = kwargs.pop("encrypt_key", None) super(Column, self).__init__(*args, **kwargs)
[docs] def autodoc_get_properties(self): """Return properties list for autodoc :return: autodoc properties """ res = super(Column, self).autodoc_get_properties() res["foreign_key"] = self.foreign_key res["DB column name"] = self.db_column_name res["default"] = self.default_val res["is crypted"] = True if self.encrypt_key else False return res
autodoc_omit_property_values = Field.autodoc_omit_property_values.union( ( ("foreign_key", None), ("DB column name", None), ("default", None), ("is crypted", False), ) )
[docs] def native_type(self, registry): """Return the native SqlAlchemy type :param registry: :rtype: sqlalchemy native type """ return self.sqlalchemy_type
[docs] def format_foreign_key(self, registry, namespace, fieldname, args, kwargs): """Format a foreign key :param registry: the current registry :param args: :param kwargs: :return: """ if self.foreign_key: CompareType.validate( ModelAttribute(namespace, fieldname), self, self.foreign_key, self.foreign_key.get_type(registry), ) args = args + (self.foreign_key.get_fk(registry),) kwargs["info"].update( { "foreign_key": self.foreign_key.get_fk_name(registry), "remote_model": self.foreign_key.model_name, } ) return args
[docs] def get_sqlalchemy_mapping( self, registry, namespace, fieldname, properties ): """Return the instance of the real field :param registry: current registry :param namespace: name of the model :param fieldname: name of the field :param properties: known properties of the model :rtype: sqlalchemy column instance """ self.format_label(fieldname) args = self.args kwargs = self.kwargs.copy() if "info" not in kwargs: kwargs["info"] = {} args = self.format_foreign_key( registry, namespace, fieldname, args, kwargs ) kwargs["info"]["label"] = self.label if self.sequence: args = (self.sequence,) + args if self.db_column_name: db_column_name = self.db_column_name else: db_column_name = fieldname if self.default_val is not NoDefaultValue: if isinstance(self.default_val, str): kwargs["default"] = wrap_default( registry, namespace, self.default_val ) elif isinstance(self.default_val, ColumnDefaultValue): kwargs["default"] = self.default_val.get_default_callable( registry, namespace, fieldname, properties ) else: kwargs["default"] = self.default_val sqlalchemy_type = self.native_type(registry) if self.encrypt_key: encrypt_key = self.format_encrypt_key(registry, namespace) sqlalchemy_type = self.get_encrypt_key_type( registry, sqlalchemy_type, encrypt_key ) return SA_Column(db_column_name, sqlalchemy_type, *args, **kwargs)
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key): sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key) if sgdb_in(registry.engine, ["MySQL", "MariaDB"]): sqlalchemy_type.impl = types.String(64) return sqlalchemy_type
[docs] def format_encrypt_key(self, registry, namespace): """Format and return the encyption key :param registry: the current registry :param namespace: the namespace of the model :return: encrypt key """ encrypt_key = self.encrypt_key if encrypt_key is True: encrypt_key = Configuration.get("default_encrypt_key") if not encrypt_key: raise FieldException( # pragma: no cover "No encrypt_key defined in the configuration" ) def wrapper(): """Return encrypt_key wrapper :return: """ Model = registry.get(namespace) if hasattr(Model, encrypt_key): func = getattr(Model, encrypt_key) if ismethod(func): if encrypt_key not in Model.loaded_columns: if encrypt_key not in Model.loaded_fields: return func() return encrypt_key return wrapper
[docs] def must_be_declared_as_attr(self): """Return True if the column have a foreign key to a remote column""" if self.foreign_key is not None: return True return False
class ForbiddenPrimaryKey: """Mixin to forbid primary key on column type""" def get_sqlalchemy_mapping( self, registry, namespace, fieldname, properties ): if self.kwargs.get("primary_key") is True: raise FieldException( f"{self.__class__} column `{namespace}.{fieldname}` " "are not allowed as primary key" ) return super().get_sqlalchemy_mapping( registry, namespace, fieldname, properties )
[docs]class Integer(Column): """Integer column :: from anyblok.declarations import Declarations from anyblok.column import Integer @Declarations.register(Declarations.Model) class Test: x = Integer(default=1) """ def __init__(self, *args, **kwargs): super(Integer, self).__init__(*args, **kwargs) if self.kwargs.get("primary_key") is True: if "autoincrement" not in self.kwargs: self.kwargs["autoincrement"] = True sqlalchemy_type = types.Integer
[docs]class BigInteger(Column): """Big integer column :: from anyblok.declarations import Declarations from anyblok.column import BigInteger @Declarations.register(Declarations.Model) class Test: x = BigInteger(default=1) """ sqlalchemy_type = types.BigInteger
[docs]class Boolean(Column): """Boolean column :: from anyblok.declarations import Declarations from anyblok.column import Boolean @Declarations.register(Declarations.Model) class Test: x = Boolean(default=True) """ sqlalchemy_type = types.Boolean
[docs]class Float(ForbiddenPrimaryKey, Column): """Float column :: from anyblok.declarations import Declarations from anyblok.column import Float @Declarations.register(Declarations.Model) class Test: x = Float(default=1.0) """ sqlalchemy_type = types.Float
""" Added *process_result_value* at the class *DECIMAL*, because this method is necessary for encrypt the column """ types.DECIMAL.process_result_value = lambda self, value, dialect: value
[docs]class Decimal(ForbiddenPrimaryKey, Column): """Decimal column :: from decimal import Decimal as D from anyblok.declarations import Declarations from anyblok.column import Decimal @Declarations.register(Declarations.Model) class Test: x = Decimal(default=D('1.1')) """ sqlalchemy_type = types.DECIMAL
[docs] def setter_format_value(self, value): """Format the given value to decimal if needed :param value: :return: decimal value """ if value is not None: if self.encrypt_key: value = str(value) elif not isinstance(value, decimal.Decimal): value = decimal.Decimal(value) return value
def getter_format_value(self, value): if value is None: return None # pragma: no cover if self.encrypt_key: value = decimal.Decimal(value) return value
[docs]class Date(Column): """Date column :: from datetime import date from anyblok.declarations import Declarations from anyblok.column import Date @Declarations.register(Declarations.Model) class Test: x = Date(default=date.today()) """ sqlalchemy_type = types.Date
def convert_string_to_datetime(value): """Convert a given value to datetime :param value: :return: datetime value """ if value is None: return None elif isinstance(value, datetime): return value elif isinstance(value, date): return datetime.combine(value, datetime.min.time()) elif isinstance(value, str): return parse(value) raise FieldException("We can't convert this value %s to datetime") def add_timezone_on_datetime(dt, default_timezone): """Convert a datetime considering the default timezone :param dt: :param default_timezone: :return: """ if dt is not None: if dt.tzinfo is None: dt = default_timezone.localize(dt) return dt class DateTimeType(types.TypeDecorator): impl = types.DateTime(timezone=True) cache_ok = True def __init__(self, field): self.default_timezone = field.default_timezone self.field = field def process_bind_param(self, value, engine): value = convert_string_to_datetime(value) value = add_timezone_on_datetime(value, self.default_timezone) if self.field.encrypt_key: return value.isoformat() return value def process_result_value(self, value, dialect): if self.field.encrypt_key: return convert_string_to_datetime(value) return value @property def python_type(self): return datetime # pragma: no cover
[docs]class DateTime(Column): """DateTime column :: from datetime import datetime from anyblok.declarations import Declarations from anyblok.column import DateTime @Declarations.register(Declarations.Model) class Test: x = DateTime(default=datetime.now) """ def __init__(self, *args, **kwargs): self.auto_update = kwargs.pop("auto_update", False) default_timezone = kwargs.pop( "default_timezone", Configuration.get("default_timezone") ) if not default_timezone: default_timezone = time.tzname[0] if isinstance(default_timezone, str): default_timezone = pytz.timezone(default_timezone) self.default_timezone = default_timezone self.sqlalchemy_type = DateTimeType(self) super(DateTime, self).__init__(*args, **kwargs)
[docs] def setter_format_value(self, value): """Return converted and formatted value :param value: :return: """ value = convert_string_to_datetime(value) return add_timezone_on_datetime(value, self.default_timezone)
def getter_format_value(self, value): value = convert_string_to_datetime(value) return add_timezone_on_datetime(value, self.default_timezone)
[docs] def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(Column, self).autodoc_get_properties() res["is auto updated"] = self.auto_update if self.default_timezone: res["default timezone"] = self.default_timezone return res
class TimeStampType(DateTimeType): impl = types.TIMESTAMP(timezone=True) cache_ok = True @property def python_type(self): return time # pragma: no cover class TimeStamp(DateTime): """TimeStamp column :: from datetime import datetime from anyblok.declarations import Declarations from anyblok.column import DateTime @Declarations.register(Declarations.Model) class Test: x = TimeStamp(default=datetime.now) """ def __init__(self, *args, **kwargs): super(TimeStamp, self).__init__(*args, **kwargs) self.sqlalchemy_type = TimeStampType(self) def getter_format_value(self, value): value = convert_string_to_datetime(value) return add_timezone_on_datetime(value, self.default_timezone)
[docs]class Time(Column): """Time column :: from datetime import time from anyblok.declarations import Declarations from anyblok.column import Time @Declarations.register(Declarations.Model) class Test: x = Time(default=time()) """ sqlalchemy_type = types.Time
[docs]class Interval(Column): """Datetime interval column :: from datetime import timedelta from anyblok.declarations import Declarations from anyblok.column import Interval @Declarations.register(Declarations.Model) class Test: x = Interval(default=timedelta(days=5)) """ sqlalchemy_type = types.Interval
[docs] def native_type(self, registry): if self.encrypt_key: return types.VARCHAR(1024) return self.sqlalchemy_type
def setter_format_value(self, value): if self.encrypt_key: value = dumps( { x: getattr(value, x) for x in ["days", "seconds", "microseconds"] } ) return value def getter_format_value(self, value): if self.encrypt_key: value = timedelta(**loads(value)) return value
class StringType(types.TypeDecorator): impl = types.String cache_ok = True def process_bind_param(self, value, engine): if value is False: value = "" return value def process_result_value(self, value, dialect): return value
[docs]class String(Column): """String column :: from anyblok.declarations import Declarations from anyblok.column import String @Declarations.register(Declarations.Model) class Test: x = String(default='test') """ def __init__(self, *args, **kwargs): self.size = kwargs.pop("size", 64) kwargs.pop("type_", None) self.sqlalchemy_type = StringType(self.size) super(String, self).__init__(*args, **kwargs)
[docs] def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(String, self).autodoc_get_properties() res["size"] = self.size return res
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key): sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key) if sgdb_in(registry.engine, ["MySQL", "MariaDB"]): sqlalchemy_type.impl = types.String(max(self.size, 64)) return sqlalchemy_type
class Enum(Column): """Enum column :: from anyblok.declarations import Declarations from anyblok.column import Enum import enum class MyEnumClass(enum.Enum): one = 1 two = 2 three = 3 @Declarations.register(Declarations.Model) class Test: x = Enum(enum_cls=MyEnumClass, default='test') enum_cls should be an enum class """ def __init__(self, *args, **kwargs): self.enum_cls = kwargs.pop("enum_cls") self.sqlalchemy_type = types.Enum(self.enum_cls) super(Enum, self).__init__(*args, **kwargs) def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(Enum, self).autodoc_get_properties() res["enum_cls"] = repr(self.enum_cls) return res class MsSQLPasswordType(PasswordType): impl = types.VARCHAR(1024) def load_dialect_impl(self, dialect): return dialect.type_descriptor(types.VARCHAR(self.length))
[docs]class Password(Column): """String column :: from anyblok.declarations import Declarations from anyblok.column import Password @Declarations.register(Declarations.Model) class Test: x = Password(crypt_context={'schemes': ['md5_crypt']}) ========================================= test = Test.insert() test.x = 'mypassword' test.x ==> Password object with encrypt value, the value can not be read test.x == 'mypassword' ==> True ..warning:: the column type Password can not be querying:: Test.query().filter(Test.x == 'mypassword').count() ==> 0 """ def __init__(self, *args, **kwargs): self.size = kwargs.pop("size", 64) crypt_context = kwargs.pop("crypt_context", {}) self.crypt_context = crypt_context kwargs.pop("type_", None) if "foreign_key" in kwargs: raise FieldException("Column Password can not have a foreign key") self.sqlalchemy_type = PasswordType( max_length=self.size, **crypt_context ) super(Password, self).__init__(*args, **kwargs)
[docs] def setter_format_value(self, value): """Return formatted value :param value: :return: """ value = self.sqlalchemy_type.context.hash(value).encode("utf8") value = SAU_PWD(value, context=self.sqlalchemy_type.context) return value
[docs] def native_type(self, registry): """Return the native SqlAlchemy type""" if sgdb_in(registry.engine, ["MsSQL"]): return MsSQLPasswordType(max_length=self.size, **self.crypt_context) return self.sqlalchemy_type
[docs] def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(Password, self).autodoc_get_properties() res["size"] = self.size res["Crypt context"] = self.crypt_context return res
class TextType(types.TypeDecorator): impl = types.Text cache_ok = True def process_bind_param(self, value, engine): if value is False: value = "" return value def process_result_value(self, value, dialect): return value
[docs]class Text(Column): """Text column :: from anyblok.declarations import Declarations from anyblok.column import Text @Declarations.register(Declarations.Model) class Test: x = Text(default='test') """ sqlalchemy_type = TextType def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key): sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key) if sgdb_in(registry.engine, ["MySQL", "MariaDB"]): sqlalchemy_type.impl = types.Text() return sqlalchemy_type
[docs]class StrSelection(str): """Class representing the data of one column Selection""" selections = dumps({}) registry = None namespace = None
[docs] def get_selections(self): """Return a dict of selections :return: selections dict """ selections = loads(self.selections) if isinstance(selections, dict): return selections if isinstance(selections, str): m = self.registry.get(self.namespace) return dict(getattr(m, selections)())
[docs] def validate(self): """Validate if the key is in the selections :return: True or False """ a = super(StrSelection, self).__str__() return a in self.get_selections().keys()
@property def label(self): """Return the label corresponding to the selection key :return: """ a = super(StrSelection, self).__str__() return self.get_selections()[a]
[docs]class SelectionType(types.TypeDecorator): """Generic type for Column Selection""" impl = types.String cache_ok = True def __init__(self, selections, size, registry=None, namespace=None): super(SelectionType, self).__init__(length=size) self.size = size if isinstance(selections, (dict, str)): self.selections = selections elif isinstance(selections, (list, tuple)): self.selections = dict(selections) else: raise FieldException( # pragma: no cover "selection wainting 'dict', get %r" % type(selections) ) if isinstance(self.selections, dict): for k in self.selections.keys(): if not isinstance(k, str): raise FieldException("The key must be a str") if len(k) > 64: raise Exception( # pragma: no cover "%r is too long %r, waiting max %s or use size arg" % (k, len(k), size) ) self.selections = dumps(self.selections) self._StrSelection = type( "StrSelection", (StrSelection,), { "selections": self.selections, "registry": registry, "namespace": namespace, }, ) @property def python_type(self): return self._StrSelection
[docs] def process_bind_param(self, value, engine): if value is not None: value = self.python_type(value) return value
[docs] def process_result_value(self, value, dialect): return value
[docs]class Selection(Column): """Selection column :: from anyblok.declarations import Declarations from anyblok.column import Selection @Declarations.register(Declarations.Model) class Test: STATUS = ( (u'draft', u'Draft'), (u'done', u'Done'), ) x = Selection(selections=STATUS, size=64, default=u'draft') """ def __init__(self, *args, **kwargs): self.selections = tuple() if "selections" in kwargs: self.selections = kwargs.pop("selections") self.size = kwargs.pop("size", 64) self.sqlalchemy_type = "tmp value for assert" super(Selection, self).__init__(*args, **kwargs)
[docs] def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(Selection, self).autodoc_get_properties() res["selections"] = self.selections res["size"] = self.size return res
[docs] def getter_format_value(self, value): """Return formatted value :param value: :return: """ if value is None: return None return self.sqlalchemy_type.python_type(value)
[docs] def setter_format_value(self, value): """Return value or raise exception if the given value is invalid :param value: :exception FieldException :return: """ if value is not None: val = self.sqlalchemy_type.python_type(value) if not val.validate(): raise FieldException( "%r is not in the selections (%s)" % (value, ", ".join(val.get_selections())) ) return value
[docs] def get_sqlalchemy_mapping( self, registry, namespace, fieldname, properties ): """Return sqlalchmy mapping :param registry: the current registry :param namespace: the namespace of the model :param fieldname: the fieldname of the model :param properties: the properties of the model :return: instance of the real field """ self.sqlalchemy_type = SelectionType( self.selections, self.size, registry=registry, namespace=namespace ) return super(Selection, self).get_sqlalchemy_mapping( registry, namespace, fieldname, properties )
[docs] def update_description(self, registry, model, res): """Update model description :param registry: the current registry :param model: :param res: """ super(Selection, self).update_description(registry, model, res) sqlalchemy_type = SelectionType( self.selections, self.size, registry=registry, namespace=model ) values = sqlalchemy_type._StrSelection().get_selections() res["selections"] = [(k, v) for k, v in values.items()]
[docs] def must_be_copied_before_declaration(self): """Return True if selections is an instance of str. In the case of the field selection is a mixin, it must be copied or the selection method can fail """ if isinstance(self.selections, str): return True else: return False
[docs] def update_properties(self, registry, namespace, fieldname, properties): """Update column properties :param registry: the current registry :param namespace: the namespace of the model :param fieldname: the fieldname of the model :param properties: the properties of the model """ super(Selection, self).update_properties( registry, namespace, fieldname, properties ) self.fieldname = fieldname properties["add_in_table_args"].append(self)
[docs] def update_table_args(self, registry, Model): """Return check constraints to limit the value :param registry: :param Model: :return: list of checkConstraint """ if self.encrypt_key: # dont add constraint because the state is crypted and nobody # can add new entry return [] if sgdb_in(registry.engine, ["MariaDB", "MsSQL", "MySQL"]): # No check constraint in MariaDB return [] selections = loads(self.sqlalchemy_type.selections) if isinstance(selections, dict): enum = selections.keys() else: enum = getattr(Model, selections)() if isinstance(enum, (list, tuple)): enum = dict(enum) enum = enum.keys() if len(enum) > 1: constraint = """"%s" in ('%s')""" % ( self.fieldname, "', '".join(enum), ) elif enum: constraint = """"%s" = '%s'""" % (self.fieldname, list(enum)[0]) else: constraint = None if constraint: enum = list(enum) enum.sort() key = md5() key.update(str(enum).encode("utf-8")) name = self.fieldname + "_" + key.hexdigest() + "_types" return [CheckConstraint(constraint, name=name)] return []
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key): sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key) if sgdb_in(registry.engine, ["MySQL", "MariaDB"]): sqlalchemy_type.impl = types.String(max(self.size, 64)) return sqlalchemy_type
""" Added *process_result_value* at the class *JSON*, because this method is necessary for encrypt the column """ types.JSON.process_result_value = lambda self, value, dialect: value
[docs]class Json(Column): """JSON column :: from anyblok.declarations import Declarations from anyblok.column import Json @Declarations.register(Declarations.Model) class Test: x = Json() """ sqlalchemy_type = types.JSON(none_as_null=True) def setter_format_value(self, value): if self.encrypt_key: value = dumps(value) return value def getter_format_value(self, value): if value is None: return None if self.encrypt_key: value = loads(value) return value def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key): sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key) if sgdb_in(registry.engine, ["MySQL", "MariaDB"]): sqlalchemy_type.impl = types.Text() return sqlalchemy_type
[docs]class LargeBinary(Column): """Large binary column :: from os import urandom from anyblok.declarations import Declarations from anyblok.column import LargeBinary blob = urandom(100000) @Declarations.register(Declarations.Model) class Test: x = LargeBinary(default=blob) """ sqlalchemy_type = types.LargeBinary
[docs] def native_type(self, registry): if self.encrypt_key: return types.Text return self.sqlalchemy_type
def setter_format_value(self, value): if self.encrypt_key: value = b64encode(value).decode("utf-8") return value def getter_format_value(self, value): if self.encrypt_key: value = b64decode(value.encode("utf-8")) return value def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key): sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key) if sgdb_in(registry.engine, ["MySQL", "MariaDB"]): sqlalchemy_type.impl = types.Text() return sqlalchemy_type
class Sequence(String): """Sequence column :: from anyblok.column import Sequence @Declarations.register(Declarations.Model) class Test: x = Sequence() If you wish ensure no gap in the sequence:: from anyblok.column import Sequence @Declarations.register(Declarations.Model) class Test: x = Sequence(no_gap=True, code="SO", formater="{code}-{seq:06d}") .. warning:: Keep in mind `no_gap=True` will raise an `sqlalchemy.exc.OperationalError: (psycopg2.errors.LockNotAvailable)` exception in case a concurrent transaction do not release the lock while getting the next value. usage with `no_gap=True`:: >>> Test.insert().x "SO-000001" >>> Test.insert().x "SO-000002" >>> registry.rollback() >>> Test.insert().x "SO-000001" """ def __init__(self, *args, **kwargs): if "foreign_key" in kwargs: raise FieldException( "Sequence column can not define a foreign key" " %r" % kwargs["foreign_key"] ) if "default" in kwargs: raise FieldException( "Sequence column can not define a default " "value" ) kwargs["default"] = ColumnDefaultValue(self.wrap_default) self.code = kwargs.pop("code") if "code" in kwargs else None self.start = kwargs.pop("start", 1) self.formater = kwargs.pop("formater") if "formater" in kwargs else None self.no_gap = kwargs.pop("no_gap") if "no_gap" in kwargs else None super(Sequence, self).__init__(*args, **kwargs) def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(Sequence, self).autodoc_get_properties() res["formater"] = self.formater res["no_gap"] = self.no_gap return res def wrap_default(self, registry, namespace, fieldname, properties): """Return default wrapper :param registry: the current registry :param namespace: the namespace of the model :param fieldname: the fieldname of the model :return: """ if not hasattr( registry, "_need_sequence_to_create_if_not_exist" ): # pragma: no cover registry._need_sequence_to_create_if_not_exist = [] elif registry._need_sequence_to_create_if_not_exist is None: registry._need_sequence_to_create_if_not_exist = [] code = self.code if self.code else "%s=>%s" % (namespace, fieldname) registry._need_sequence_to_create_if_not_exist.append( {"code": code, "formater": self.formater, "no_gap": self.no_gap} ) # {'code': code, 'formater': self.formater, 'start': self.start}) def default_value(self, *args, **kwargs): """Return next sequence value :return: """ return registry.System.Sequence.nextvalBy(code=code) return default_value
[docs]class Color(Column): """Color column. `See colour package on pypi <https://pypi.python.org/pypi/colour/>`_ :: from anyblok.declarations import Declarations from anyblok.column import Color @Declarations.register(Declarations.Model) class Test: x = Color(default='green') """ def __init__(self, *args, **kwargs): self.max_length = max_length = kwargs.pop("size", 20) kwargs.pop("type_", None) self.sqlalchemy_type = ColorType(max_length) super(Color, self).__init__(*args, **kwargs)
[docs] def setter_format_value(self, value): """Format the given value :param value: :return: """ if isinstance(value, str): value = self.sqlalchemy_type.python_type(value) return value
[docs] def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(Color, self).autodoc_get_properties() res["size"] = self.max_length return res
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key): sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key) if sgdb_in(registry.engine, ["MySQL", "MariaDB"]): sqlalchemy_type.impl = types.String(max(self.max_length, 64)) return sqlalchemy_type
class UUID(Column): """UUID column :: from anyblok.column import UUID @Declarations.register(Declarations.Model) class Test: x = UUID() """ def __init__(self, *args, **kwargs): uuid_kwargs = {} self.binary = None self.native = None for kwarg in ("binary", "native"): if kwarg in kwargs: uuid_kwargs[kwarg] = kwargs.pop(kwarg) setattr(self, kwarg, uuid_kwargs[kwarg]) self.sqlalchemy_type = UUIDType(**uuid_kwargs) super(UUID, self).__init__(*args, **kwargs) def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(UUID, self).autodoc_get_properties() res["binary"] = self.binary res["native"] = self.native return res URLType.cache_ok = True # waiting fix from sqlalchemy_utils class URL(Column): """URL column :: from anyblok.declarations import Declarations from anyblok.column import URL @Declarations.register(Declarations.Model) class Test: x = URL(default='doc.anyblok.org') """ sqlalchemy_type = URLType def setter_format_value(self, value): """Return formatted url value :param value: :return: """ from furl import furl if value is not None: if isinstance(value, str): value = furl(value) return value
[docs]class PhoneNumber(Column): """PhoneNumber column :: from anyblok.declarations import Declarations from anyblok.column import PhoneNumber @Declarations.register(Declarations.Model) class Test: x = PhoneNumber(default='+120012301') .. note:: ``phonenumbers`` >= **8.9.5** distribution is required """ def __init__(self, region="FR", max_length=20, *args, **kwargs): self.region = region self.max_length = max_length kwargs.pop("type_", None) self.sqlalchemy_type = PhoneNumberType( region=region, max_length=max_length ) super(PhoneNumber, self).__init__(*args, **kwargs)
[docs] def setter_format_value(self, value): """Return formatted phone number value :param value: :return: """ if value and isinstance(value, str): value = self.sqlalchemy_type.python_type(value, self.region) return value
[docs] def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(PhoneNumber, self).autodoc_get_properties() res["region"] = self.region res["max_length"] = self.max_length return res
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key): sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key) if sgdb_in(registry.engine, ["MySQL", "MariaDB"]): sqlalchemy_type.impl = types.String(max(self.max_length, 64)) return sqlalchemy_type
""" Added *process_result_value* at the class *EmailType*, because this method is necessary for encrypt the column """ EmailType.process_result_value = lambda self, value, dialect: value EmailType.cache_ok = True # waiting fix from sqlalchemy_utils
[docs]class Email(Column): """Email column :: from anyblok.column import Email @Declarations.register(Declarations.Model) class Test: x = Email() """ sqlalchemy_type = EmailType
[docs] def setter_format_value(self, value): """Return formatted email value :param value: :return: """ if value is not None: return value.lower() return value # pragma: no cover
class CountryType(types.TypeDecorator, ScalarCoercible): """Generic type for Column Country""" impl = types.Unicode(3) cache_ok = True python_type = python_pycountry_type def process_bind_param(self, value, dialect): if value and isinstance(value, self.python_type): return value.alpha_3 return value def process_result_value(self, value, dialect): return self._coerce(value) def _coerce(self, value): if value is not None and not isinstance(value, self.python_type): return pycountry.countries.get(alpha_3=value) return value # pragma: no cover
[docs]class Country(Column): """Country column. :: from anyblok.declarations import Declarations from anyblok.column import Country from pycountry import countries @Declarations.register(Declarations.Model) class Test: x = Country(default=countries.get(alpha_2='FR')) """ sqlalchemy_type = CountryType def __init__(self, mode="alpha_2", *args, **kwargs): self.mode = mode if pycountry is None: raise FieldException( # pragma: no cover "'pycountry' package is required for use 'CountryType'" ) self.choices = { getattr(country, mode): country.name for country in pycountry.countries } super(Country, self).__init__(*args, **kwargs)
[docs] def setter_format_value(self, value): """Return formatted country value :param value: :return: """ if value and not isinstance(value, self.sqlalchemy_type.python_type): value = pycountry.countries.get( **{ self.mode: value, "default": pycountry.countries.lookup(value), } ) return value
[docs] def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(Country, self).autodoc_get_properties() res["mode"] = self.mode res["choices"] = self.choices return res
[docs] def update_properties(self, registry, namespace, fieldname, properties): """Update column properties :param registry: the current registry :param namespace: the namespace of the model :param fieldname: the fieldname of the model :param properties: the properties of the model """ super(Country, self).update_properties( registry, namespace, fieldname, properties ) self.fieldname = fieldname properties["add_in_table_args"].append(self)
[docs] def update_table_args(self, registry, Model): """Return check constraints to limit the value :param registry: :param Model: :return: list of checkConstraint """ if self.encrypt_key: # dont add constraint because the state is crypted and nobody # can add new entry return [] if sgdb_in(registry.engine, ["MariaDB", "MsSQL"]): # No Check constraint in MariaDB return [] enum = [country.alpha_3 for country in pycountry.countries] constraint = """"%s" in ('%s')""" % (self.fieldname, "', '".join(enum)) enum.sort() key = md5() key.update(str(enum).encode("utf-8")) name = self.fieldname + "_" + key.hexdigest() + "_types" return [CheckConstraint(constraint, name=name)]
def model_validator_all(Model): return True def model_validator_is_sql(Model): return Model.is_sql is True def model_validator_is_not_sql(Model): return not model_validator_is_sql(Model) def model_validator_is_view(Model): return hasattr(Model, "__view__") def model_validator_is_not_view(Model): return not model_validator_is_view(Model) def model_validator_in_namespace(namespace): if hasattr(namespace, "__registry_name__"): namespace = f"{namespace.__registry_name__}." def validator(Model): return Model.__registry_name__.startswith(namespace) return validator def merge_validators(*validators): def validator(obj): return all(v(obj) for v in validators) return validator class StrModelSelection(str): """Class representing the data of one column ModelSelection""" validator = None registry = None selections = None def get_selections(self): """Return a dict of selections :return: selections dict """ if not self.selections: self.__class__.selections = { k: v.__doc__ and v.__doc__.split("\n")[0] or k for k, v in self.registry.loaded_namespaces.items() if self.validator(v) } return self.selections def validate(self): """Validate if the key is in the selections :return: True or False """ a = super(StrModelSelection, self).__str__() return a in self.get_selections().keys() @property def Model(self): """Return the class corresponding to the selection key :return: """ a = super(StrModelSelection, self).__str__() if a: return self.registry.get(a) return None # pragma: no cover class ModelSelectionType(ScalarCoercible, types.TypeDecorator): """Generic type for Column ModelSelection""" impl = types.String cache_ok = True def __init__(self, validator, registry=None, namespace=None): super(ModelSelectionType, self).__init__(length=256) if validator is None: validator = model_validator_all def _validator(obj, Model): if isinstance(validator, str): return getattr(registry.get(namespace), validator)(Model) return validator(Model) self._StrModelSelection = type( "StrModelSelection", (StrModelSelection,), { "validator": _validator, "registry": registry, "selections": None, }, ) @property def python_type(self): return self._StrModelSelection def process_bind_param(self, value, engine): if value is not None: if hasattr(value, "__registry_name__"): value = value.__registry_name__ value = self.python_type(value) return value def process_result_value(self, value, dialect): return self._coerce(value) def _coerce(self, value): if value is not None and not isinstance(value, self._StrModelSelection): value = self.python_type(value) return value class ModelSelection(Column): """ModelSelection column Allow to Reference an AnyBlok Model :: from anyblok.declarations import Declarations from anyblok.column import ModelSelection @Declarations.register(Declarations.Model) class Test: x = ModelSelection( default='Model.System.Blok', validator="_x_validator" ) @classmethod def _x_validator(cls, Model): return True or False """ def __init__(self, *args, **kwargs): self.size = 256 # used by comparator self.validator = None if "validator" in kwargs: self.validator = kwargs.pop("validator") self.sqlalchemy_type = "tmp value for assert" super(ModelSelection, self).__init__(*args, **kwargs) if self.default_val and hasattr(self.default_val, "__registry_name__"): self.default_val = self.default_val.__registry_name__ def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(ModelSelection, self).autodoc_get_properties() res["validator"] = str(self.validator) return res def getter_format_value(self, value): """Return formatted value :param value: :return: """ if value is None: return None return self.sqlalchemy_type.python_type(value) def setter_format_value(self, value): """Return value or raise exception if the given value is invalid :param value: :exception FieldException :return: """ if value is not None: if hasattr(value, "__registry_name__"): value = value.__registry_name__ val = self.sqlalchemy_type.python_type(value) if not val.validate(): raise FieldException( "%r is not in the selections (%s)" % (value, ", ".join(val.get_selections())) ) return value def get_sqlalchemy_mapping( self, registry, namespace, fieldname, properties ): """Return sqlalchmy mapping :param registry: the current registry :param namespace: the namespace of the model :param fieldname: the fieldname of the model :param properties: the properties of the model :return: instance of the real field """ self.sqlalchemy_type = ModelSelectionType( self.validator, registry=registry, namespace=namespace ) return super(ModelSelection, self).get_sqlalchemy_mapping( registry, namespace, fieldname, properties ) def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key): sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key) if sgdb_in(registry.engine, ["MySQL", "MariaDB"]): sqlalchemy_type.impl = types.String(256) return sqlalchemy_type def update_description(self, registry, model, res): """Update model description :param registry: the current registry :param model: :param res: """ super(ModelSelection, self).update_description(registry, model, res) sqlalchemy_type = ModelSelectionType( self.validator, registry=registry, namespace=model ) values = sqlalchemy_type._StrModelSelection().get_selections() res["selections"] = [(k, v) for k, v in values.items()] def fieldToModelAttribute(field): if hasattr(field, "anyblok_field_name"): field = ModelAttribute( field.anyblok_registry_name, field.anyblok_field_name, ) return field def field_validator_all(field): return True def field_validator_is_field(field): return field_validator_is_not_column( field ) and field_validator_is_not_relationship(field) def field_validator_is_not_field(field): return not field_validator_is_field(field) def field_validator_is_column(field): return isinstance( fieldToModelAttribute(field).get_type(field.from_model().anyblok), Column, ) def field_validator_is_not_column(field): return not field_validator_is_column(field) def field_validator_is_relationship(field): from .relationship import RelationShip return isinstance( fieldToModelAttribute(field).get_type(field.from_model().anyblok), RelationShip, ) def field_validator_is_not_relationship(field): return not field_validator_is_relationship(field) def field_validator_is_named(*names): def validator(field): return field.anyblok_field_name in names return validator def field_validator_is_from_types(*types): def validator(field): return isinstance( fieldToModelAttribute(field).get_type(field.from_model().anyblok), types, ) return validator class StrModelFieldSelection(str): """Class representing the data of one column ModelFieldSelection""" model_validator = None field_validator = None registry = None selections = None def get_selections(self): """Return a dict of selections :return: selections dict """ if not self.selections: self.__class__.selections = { str(ModelAttribute(namespace, field)): ( Model.__doc__ and Model.__doc__.split("\n")[0] or namespace ) + " : " + field for namespace, Model in self.registry.loaded_namespaces.items() if self.model_validator(Model) for field in Model.fields_name() if self.field_validator(getattr(Model, field)) } return self.selections def validate(self): """Validate if the key is in the selections :return: True or False """ a = super(StrModelFieldSelection, self).__str__() return a in self.get_selections().keys() @property def field(self): """Return the class corresponding to the selection key :return: """ a = super(StrModelFieldSelection, self).__str__() if a: return ModelAttributeAdapter(a).get_attribute(self.registry) return None # pragma: no cover class ModelFieldSelectionType(ScalarCoercible, types.TypeDecorator): """Generic type for Column ModelFieldSelection""" impl = types.String cache_ok = True def __init__( self, model_validator, field_validator, registry=None, namespace=None ): super(ModelFieldSelectionType, self).__init__(length=256) if model_validator is None: model_validator = model_validator_is_sql if field_validator is None: field_validator = field_validator_all def _model_validator(obj, Model): if isinstance(model_validator, str): return getattr(registry.get(namespace), model_validator)(Model) return model_validator(Model) def _field_validator(obj, field): if isinstance(field_validator, str): return getattr(registry.get(namespace), field_validator)(field) return field_validator(field) self._StrModelFieldSelection = type( "StrModelFieldSelection", (StrModelFieldSelection,), { "model_validator": _model_validator, "field_validator": _field_validator, "registry": registry, "selections": None, }, ) @property def python_type(self): return self._StrModelFieldSelection def process_bind_param(self, value, engine): if value is not None: value = self.python_type(fieldToModelAttribute(value)) return value def process_result_value(self, value, dialect): return self._coerce(value) def _coerce(self, value): if value is not None and not isinstance( value, self._StrModelFieldSelection ): value = self.python_type(value) return value class ModelFieldSelection(Column): """ModelFieldSelection column Allow to Reference an AnyBlok Model :: from anyblok.declarations import Declarations from anyblok.column import ModelFieldSelection @Declarations.register(Declarations.Model) class Test: x = ModelFieldSelection( default='Model.System.Blok => name', model_validator="_x_model_validator", field_validator="_x_field_validator", ) @classmethod def _x_model_validator(cls, Model): return True or False @classmethod def _x_field_validator(cls, field): return True or False """ def __init__(self, *args, **kwargs): self.size = 256 # used by comparator self.model_validator = None if "model_validator" in kwargs: self.model_validator = kwargs.pop("model_validator") self.field_validator = None if "field_validator" in kwargs: self.field_validator = kwargs.pop("field_validator") self.sqlalchemy_type = "tmp value for assert" super(ModelFieldSelection, self).__init__(*args, **kwargs) def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(ModelFieldSelection, self).autodoc_get_properties() res["model_validator"] = str(self.model_validator) res["field_validator"] = str(self.field_validator) return res def getter_format_value(self, value): """Return formatted value :param value: :return: """ if value is None: return None return self.sqlalchemy_type.python_type(value) def setter_format_value(self, value): """Return value or raise exception if the given value is invalid :param value: :exception FieldException :return: """ if value is not None: if hasattr(value, "anyblok_field_name"): value = str( ModelAttribute( value.anyblok_registry_name, value.anyblok_field_name, ) ) val = self.sqlalchemy_type.python_type(value) if not val.validate(): raise FieldException( "%r is not in the selections (%s)" % (value, ", ".join(val.get_selections())) ) return value def get_sqlalchemy_mapping( self, registry, namespace, fieldname, properties ): """Return sqlalchmy mapping :param registry: the current registry :param namespace: the namespace of the model :param fieldname: the fieldname of the model :param properties: the properties of the model :return: instance of the real field """ self.sqlalchemy_type = ModelFieldSelectionType( self.model_validator, self.field_validator, registry=registry, namespace=namespace, ) return super(ModelFieldSelection, self).get_sqlalchemy_mapping( registry, namespace, fieldname, properties ) def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key): sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key) if sgdb_in(registry.engine, ["MySQL", "MariaDB"]): sqlalchemy_type.impl = types.String(256) return sqlalchemy_type def update_description(self, registry, model, res): """Update model description :param registry: the current registry :param model: :param res: """ super(ModelFieldSelection, self).update_description( registry, model, res ) sqlalchemy_type = ModelFieldSelectionType( self.model_validator, self.field_validator, registry=registry, namespace=model, ) values = sqlalchemy_type._StrModelFieldSelection().get_selections() res["selections"] = [(k, v) for k, v in values.items()] def instanceToDict(instance): if not isinstance(instance, dict): return dict( model=instance.__registry_name__, primary_keys=instance.to_primary_keys(), ) elif "model" not in instance: raise FieldException("The model entry is required : %r" % instance) elif "primary_keys" not in instance: raise FieldException( "The primary_keys entry is required : %r" % instance ) return instance def instance_validator_all(instance): return True class ModelReferenceType(types.JSON): """Generic type for Column ModelFieldSelection""" cache_ok = True class comparator_factory(SA_JSON.Comparator): def is_(self, instance): value = instanceToDict(instance) filters = [ self.expr["model"].as_string() == value["model"], ] for pk, value in value.get("primary_keys", {}).items(): filters.append( self.expr["primary_keys"][pk].as_string() == str(value) ) return and_(*filters) def with_models(self, *namespaces): filters = [] for namespace in namespaces: if hasattr(namespace, "__registry_name__"): namespace = namespace.__registry_name__ filters.append(self.expr["model"].as_string() == namespace) if len(filters) == 1: return filters[0] return or_(*filters) def __init__( self, model_validator, instance_validator, registry=None, namespace=None, ): if model_validator is None: model_validator = model_validator_is_sql if instance_validator is None: instance_validator = instance_validator_all def validate_model(value): Model = registry.get(value) if isinstance(model_validator, str): return getattr(registry.get(namespace), model_validator)(Model) return model_validator(Model) def validate_instance(value): instance = registry.get(value["model"]).from_primary_keys( **value["primary_keys"] ) if isinstance(instance_validator, str): return getattr(instance, instance_validator)() return instance_validator(instance) self.validate_model = validate_model self.validate_instance = validate_instance self.model_selections = None self.model_validator = model_validator self.instance_validator = instance_validator self.registry = registry self.namespace = namespace super(ModelReferenceType, self).__init__(none_as_null=True) def get_instance(self, value): if value: value = instanceToDict(value) value = self.registry.get(value["model"]).from_primary_keys( **value["primary_keys"] ) return value def get_model_selections(self): """Return a dict of selections :return: selections dict """ if not self.model_selections: self.model_selections = { namespace: ( Model.__doc__ and Model.__doc__.split("\n")[0] or namespace ) for namespace, Model in self.registry.loaded_namespaces.items() if self.validate_model(namespace) } return self.model_selections class ModelReference(Json): """ModelReference column Allow to Reference an AnyBlok instance of Model :: from anyblok.declarations import Declarations from anyblok.column import ModelReference @Declarations.register(Declarations.Model) class Test: x = ModelReference( model_validator="_x_model_validator", instance_validator="_x_instance_validator", ) @classmethod def _x_model_validator(cls, Model): return True or False @classmethod def _x_instance_validator(cls, field): return True or False ========== anyblok.Test.query().filter(Test.x.is_(instance)) anyblok.Test.query().filter(Test.x.with_models(anyblok.System.Blok)) """ def __init__(self, *args, **kwargs): self.model_validator = None if "model_validator" in kwargs: self.model_validator = kwargs.pop("model_validator") self.instance_validator = None if "instance_validator" in kwargs: self.instance_validator = kwargs.pop("instance_validator") if "default" in kwargs: self.real_default_value = kwargs.pop("default") kwargs["default"] = ColumnDefaultValue(self.wrap_default) super(ModelReference, self).__init__(*args, **kwargs) def autodoc_get_properties(self): """Return properties for autodoc :return: autodoc properties """ res = super(ModelReference, self).autodoc_get_properties() res["model_validator"] = str(self.model_validator) res["instance_validator"] = str(self.instance_validator) return res def getter_format_value(self, value): """Return formatted value :param value: :return: """ value = super(ModelReference, self).getter_format_value(value) if value is None: return None if isinstance(value, dict): value = self._sqlalchemy_type.get_instance(value) return value def validate_dict(self, value): if value is not None: if not self._sqlalchemy_type.validate_model(value["model"]): raise FieldException( "The model of %r is not in the selections" % value ) if not self._sqlalchemy_type.get_instance(value): raise FieldException("%r is not existing" % value) if not self._sqlalchemy_type.validate_instance(value): raise FieldException("%r is not a valid choice" % value) return value def setter_format_value(self, value): """Return value or raise exception if the given value is invalid :param value: :exception FieldException :return: """ if value is not None: value = self.validate_dict(instanceToDict(value)) return value def get_sqlalchemy_mapping( self, registry, namespace, fieldname, properties ): """Return sqlalchmy mapping :param registry: the current registry :param namespace: the namespace of the model :param fieldname: the fieldname of the model :param properties: the properties of the model :return: instance of the real field """ self._sqlalchemy_type = self.sqlalchemy_type = ModelReferenceType( self.model_validator, self.instance_validator, registry=registry, namespace=namespace, ) return super(ModelReference, self).get_sqlalchemy_mapping( registry, namespace, fieldname, properties ) def update_description(self, registry, model, res): """Update model description :param registry: the current registry :param model: :param res: """ super(ModelReference, self).update_description(registry, model, res) sqlalchemy_type = ModelReferenceType( self.model_validator, self.instance_validator, registry=registry, namespace=model, ) values = sqlalchemy_type.get_model_selections() res["model_selections"] = [(k, v) for k, v in values.items()] def wrap_default(self, registry, namespace, fieldname, properties): """Return default wrapper :param registry: the current registry :param namespace: the namespace of the model :param fieldname: the fieldname of the model :return: """ def default_value(): if isinstance(self.real_default_value, str): return self.validate_dict( instanceToDict( wrap_default( registry, namespace, self.real_default_value )() ) ) return self.validate_dict(instanceToDict(self.real_default_value)) return default_value @CompareType.add_comparator(String, String) @CompareType.add_comparator(String, Selection) @CompareType.add_comparator(String, Sequence) @CompareType.add_comparator(String, ModelSelection) @CompareType.add_comparator(String, ModelFieldSelection) def compare_strings(col1, type1, col2, type2): if type1.size != type2.size: raise FieldException( "You can't add a foreign key using based String columns with " "different size `{model1!s}.{col1!s}` pointing to " "`{model2!s}.{col2!s}` have different sizes {type1!r}({size1:d}) " "-> {type2!r}({size2:d})".format( model1=col1.model_name, col1=col1.attribute_name, model2=col2.model_name, col2=col2.attribute_name, type1=type1.__class__, type2=type2.__class__, size1=type1.size, size2=type2.size, ) ) @CompareType.add_comparator(String, Color) def compare_string_to_color(col1, type1, col2, type2): if type1.size != type2.max_length: raise FieldException( "You can't add a foreign key using based String columns with " "different size `{model1!s}.{col1!s}` pointing to " "`{model2!s}.{col2!s}` have different sizes {type1!r}({size1:d}) " "-> {type2!r}({size2:d})".format( model1=col1.model_name, col1=col1.attribute_name, model2=col2.model_name, col2=col2.attribute_name, type1=type1.__class__, type2=type2.__class__, size1=type1.size, size2=type2.max_length, ) )