# This file is a part of the AnyBlok project
#
# Copyright (C) 2016 Jean-Sebastien SUZANNE <jssuzanne@anybox.fr>
# Copyright (C) 2017 Jean-Sebastien SUZANNE <jssuzanne@anybox.fr>
# Copyright (C) 2018 Jean-Sebastien SUZANNE <jssuzanne@anybox.fr>
# Copyright (C) 2019 Jean-Sebastien SUZANNE <js.suzanne@gmail.com>
# Copyright (C) 2020 Jean-Sebastien SUZANNE <js.suzanne@gmail.com>
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file,You can
# obtain one at http://mozilla.org/MPL/2.0/.
import decimal
import time
from base64 import b64decode, b64encode
from datetime import date, datetime, timedelta
from hashlib import md5
from inspect import ismethod
from json import dumps, loads
from logging import getLogger
import pytz
from dateutil.parser import parse
from sqlalchemy import JSON as SA_JSON
from sqlalchemy import CheckConstraint, and_, or_, types
from sqlalchemy.schema import Column as SA_Column
from sqlalchemy.schema import Sequence as SA_Sequence
from sqlalchemy_utils.types.color import ColorType
from sqlalchemy_utils.types.email import EmailType
from sqlalchemy_utils.types.encrypted.encrypted_type import StringEncryptedType
from sqlalchemy_utils.types.password import Password as SAU_PWD
from sqlalchemy_utils.types.password import PasswordType
from sqlalchemy_utils.types.phone_number import PhoneNumberType
from sqlalchemy_utils.types.scalar_coercible import ScalarCoercible
from sqlalchemy_utils.types.url import URLType
from sqlalchemy_utils.types.uuid import UUIDType
from anyblok.config import Configuration
from .common import sgdb_in
from .field import Field, FieldException
from .mapper import ModelAttribute, ModelAttributeAdapter
pycountry = None
python_pycountry_type = None
try:
import pycountry
if not pycountry.countries._is_loaded:
pycountry.countries._load()
python_pycountry_type = pycountry.countries.data_class
except ImportError:
pass
logger = getLogger(__name__)
def wrap_default(registry, namespace, default_val):
"""Return default wrapper
:param registry: the current registry
:param namespace: the namespace of the model
:param default_val:
:return: default wrapper
"""
def wrapper():
"""Return wrapper
:return: default val
"""
Model = registry.get(namespace)
if hasattr(Model, default_val):
func = getattr(Model, default_val)
if ismethod(func):
if default_val not in Model.loaded_columns:
if default_val not in Model.loaded_fields:
return func()
else:
logger.warning(
"The attribute %r is already declared as a default "
"value on the Model %r, a field with the same name "
"already exists" % (default_val, namespace)
)
else:
logger.warning(
"The attribute %r is already declared as a default "
"value on the Model %r, a column with the same name "
"already exists" % (default_val, namespace)
)
else:
logger.warning(
"The attribute %r is already declared as a default "
"value on the Model %r, an instance method with the same "
"name already exists" % (default_val, namespace)
)
return default_val
return wrapper
class ColumnDefaultValue:
def __init__(self, callable):
self.callable = callable
def get_default_callable(self, registry, namespace, fieldname, properties):
"""Get default callable
:param registry: the current registry
:param namespace: the namespace of the model
:param fieldname: the fieldname of the model
:param properties: the properties of the model
:return: default callable
"""
return self.callable(registry, namespace, fieldname, properties)
class CompareType:
comparators = []
@classmethod
def default_comparator(cls, col1, type1, col2, type2):
if type1.__class__ is not type2.__class__:
raise FieldException(
"You can't add a foreign key using columns with different "
"types {model1!s}.{col1!s}` pointing to `{model2!s}.{col2!s}` "
"have different types {type1!r} -> {type2!r}".format(
model1=col1.model_name,
col1=col1.attribute_name,
model2=col2.model_name,
col2=col2.attribute_name,
type1=type1.__class__,
type2=type2.__class__,
)
)
@classmethod
def add_comparator(cls, type1, type2):
def wrapper(funct):
cls.comparators.append((type1, type2, funct))
return funct
return wrapper
@classmethod
def validate(cls, col1, type1, col2, type2):
for cls1, cls2, funct in cls.comparators:
if type1.__class__ is cls1 and type2.__class__ is cls2:
funct(col1, type1, col2, type2)
return
cls.default_comparator(col1, type1, col2, type2)
class NoDefaultValue:
pass
[docs]class Column(Field):
"""Column class
This class can't be instantiated
"""
use_hybrid_property = True
foreign_key = None
sqlalchemy_type = None
type = None
def __init__(self, *args, **kwargs):
"""Initialize the column
:param label: label of this field
:type label: str
"""
self.forbid_instance(Column)
assert self.sqlalchemy_type
self.sequence = None
if "type_" in kwargs:
del kwargs["type_"]
if "foreign_key" in kwargs:
self.foreign_key = ModelAttributeAdapter(kwargs.pop("foreign_key"))
if "sequence" in kwargs:
self.sequence = SA_Sequence(kwargs.pop("sequence"))
self.db_column_name = None
if "db_column_name" in kwargs:
self.db_column_name = kwargs.pop("db_column_name")
self.default_val = NoDefaultValue
if "default" in kwargs:
self.default_val = kwargs.pop("default")
self.encrypt_key = kwargs.pop("encrypt_key", None)
super(Column, self).__init__(*args, **kwargs)
[docs] def autodoc_get_properties(self):
"""Return properties list for autodoc
:return: autodoc properties
"""
res = super(Column, self).autodoc_get_properties()
res["foreign_key"] = self.foreign_key
res["DB column name"] = self.db_column_name
res["default"] = self.default_val
res["is crypted"] = True if self.encrypt_key else False
return res
autodoc_omit_property_values = Field.autodoc_omit_property_values.union(
(
("foreign_key", None),
("DB column name", None),
("default", None),
("is crypted", False),
)
)
[docs] def native_type(self, registry):
"""Return the native SqlAlchemy type
:param registry:
:rtype: sqlalchemy native type
"""
return self.sqlalchemy_type
[docs] def get_sqlalchemy_mapping(
self, registry, namespace, fieldname, properties
):
"""Return the instance of the real field
:param registry: current registry
:param namespace: name of the model
:param fieldname: name of the field
:param properties: known properties of the model
:rtype: sqlalchemy column instance
"""
self.format_label(fieldname)
args = self.args
kwargs = self.kwargs.copy()
if "info" not in kwargs:
kwargs["info"] = {}
args = self.format_foreign_key(
registry, namespace, fieldname, args, kwargs
)
kwargs["info"]["label"] = self.label
if self.sequence:
args = (self.sequence,) + args
if self.db_column_name:
db_column_name = self.db_column_name
else:
db_column_name = fieldname
if self.default_val is not NoDefaultValue:
if isinstance(self.default_val, str):
kwargs["default"] = wrap_default(
registry, namespace, self.default_val
)
elif isinstance(self.default_val, ColumnDefaultValue):
kwargs["default"] = self.default_val.get_default_callable(
registry, namespace, fieldname, properties
)
else:
kwargs["default"] = self.default_val
sqlalchemy_type = self.native_type(registry)
if self.encrypt_key:
encrypt_key = self.format_encrypt_key(registry, namespace)
sqlalchemy_type = self.get_encrypt_key_type(
registry, sqlalchemy_type, encrypt_key
)
return SA_Column(db_column_name, sqlalchemy_type, *args, **kwargs)
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ["MySQL", "MariaDB"]):
sqlalchemy_type.impl = types.String(64)
return sqlalchemy_type
[docs] def must_be_declared_as_attr(self):
"""Return True if the column have a foreign key to a remote column"""
if self.foreign_key is not None:
return True
return False
class ForbiddenPrimaryKey:
"""Mixin to forbid primary key on column type"""
def get_sqlalchemy_mapping(
self, registry, namespace, fieldname, properties
):
if self.kwargs.get("primary_key") is True:
raise FieldException(
f"{self.__class__} column `{namespace}.{fieldname}` "
"are not allowed as primary key"
)
return super().get_sqlalchemy_mapping(
registry, namespace, fieldname, properties
)
[docs]class Integer(Column):
"""Integer column
::
from anyblok.declarations import Declarations
from anyblok.column import Integer
@Declarations.register(Declarations.Model)
class Test:
x = Integer(default=1)
"""
def __init__(self, *args, **kwargs):
super(Integer, self).__init__(*args, **kwargs)
if self.kwargs.get("primary_key") is True:
if "autoincrement" not in self.kwargs:
self.kwargs["autoincrement"] = True
sqlalchemy_type = types.Integer
[docs]class BigInteger(Column):
"""Big integer column
::
from anyblok.declarations import Declarations
from anyblok.column import BigInteger
@Declarations.register(Declarations.Model)
class Test:
x = BigInteger(default=1)
"""
sqlalchemy_type = types.BigInteger
[docs]class Boolean(Column):
"""Boolean column
::
from anyblok.declarations import Declarations
from anyblok.column import Boolean
@Declarations.register(Declarations.Model)
class Test:
x = Boolean(default=True)
"""
sqlalchemy_type = types.Boolean
[docs]class Float(ForbiddenPrimaryKey, Column):
"""Float column
::
from anyblok.declarations import Declarations
from anyblok.column import Float
@Declarations.register(Declarations.Model)
class Test:
x = Float(default=1.0)
"""
sqlalchemy_type = types.Float
"""
Added *process_result_value* at the class *DECIMAL*, because
this method is necessary for encrypt the column
"""
types.DECIMAL.process_result_value = lambda self, value, dialect: value
[docs]class Decimal(ForbiddenPrimaryKey, Column):
"""Decimal column
::
from decimal import Decimal as D
from anyblok.declarations import Declarations
from anyblok.column import Decimal
@Declarations.register(Declarations.Model)
class Test:
x = Decimal(default=D('1.1'))
"""
sqlalchemy_type = types.DECIMAL
def getter_format_value(self, value):
if value is None:
return None # pragma: no cover
if self.encrypt_key:
value = decimal.Decimal(value)
return value
[docs]class Date(Column):
"""Date column
::
from datetime import date
from anyblok.declarations import Declarations
from anyblok.column import Date
@Declarations.register(Declarations.Model)
class Test:
x = Date(default=date.today())
"""
sqlalchemy_type = types.Date
def convert_string_to_datetime(value):
"""Convert a given value to datetime
:param value:
:return: datetime value
"""
if value is None:
return None
elif isinstance(value, datetime):
return value
elif isinstance(value, date):
return datetime.combine(value, datetime.min.time())
elif isinstance(value, str):
return parse(value)
raise FieldException("We can't convert this value %s to datetime")
def add_timezone_on_datetime(dt, default_timezone):
"""Convert a datetime considering the default timezone
:param dt:
:param default_timezone:
:return:
"""
if dt is not None:
if dt.tzinfo is None:
dt = default_timezone.localize(dt)
return dt
class DateTimeType(types.TypeDecorator):
impl = types.DateTime(timezone=True)
cache_ok = True
def __init__(self, field):
self.default_timezone = field.default_timezone
self.field = field
def process_bind_param(self, value, engine):
value = convert_string_to_datetime(value)
value = add_timezone_on_datetime(value, self.default_timezone)
if self.field.encrypt_key:
return value.isoformat()
return value
def process_result_value(self, value, dialect):
if self.field.encrypt_key:
return convert_string_to_datetime(value)
return value
@property
def python_type(self):
return datetime # pragma: no cover
[docs]class DateTime(Column):
"""DateTime column
::
from datetime import datetime
from anyblok.declarations import Declarations
from anyblok.column import DateTime
@Declarations.register(Declarations.Model)
class Test:
x = DateTime(default=datetime.now)
"""
def __init__(self, *args, **kwargs):
self.auto_update = kwargs.pop("auto_update", False)
default_timezone = kwargs.pop(
"default_timezone", Configuration.get("default_timezone")
)
if not default_timezone:
default_timezone = time.tzname[0]
if isinstance(default_timezone, str):
default_timezone = pytz.timezone(default_timezone)
self.default_timezone = default_timezone
self.sqlalchemy_type = DateTimeType(self)
super(DateTime, self).__init__(*args, **kwargs)
def getter_format_value(self, value):
value = convert_string_to_datetime(value)
return add_timezone_on_datetime(value, self.default_timezone)
[docs] def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(Column, self).autodoc_get_properties()
res["is auto updated"] = self.auto_update
if self.default_timezone:
res["default timezone"] = self.default_timezone
return res
class TimeStampType(DateTimeType):
impl = types.TIMESTAMP(timezone=True)
cache_ok = True
@property
def python_type(self):
return time # pragma: no cover
class TimeStamp(DateTime):
"""TimeStamp column
::
from datetime import datetime
from anyblok.declarations import Declarations
from anyblok.column import DateTime
@Declarations.register(Declarations.Model)
class Test:
x = TimeStamp(default=datetime.now)
"""
def __init__(self, *args, **kwargs):
super(TimeStamp, self).__init__(*args, **kwargs)
self.sqlalchemy_type = TimeStampType(self)
def getter_format_value(self, value):
value = convert_string_to_datetime(value)
return add_timezone_on_datetime(value, self.default_timezone)
[docs]class Time(Column):
"""Time column
::
from datetime import time
from anyblok.declarations import Declarations
from anyblok.column import Time
@Declarations.register(Declarations.Model)
class Test:
x = Time(default=time())
"""
sqlalchemy_type = types.Time
[docs]class Interval(Column):
"""Datetime interval column
::
from datetime import timedelta
from anyblok.declarations import Declarations
from anyblok.column import Interval
@Declarations.register(Declarations.Model)
class Test:
x = Interval(default=timedelta(days=5))
"""
sqlalchemy_type = types.Interval
[docs] def native_type(self, registry):
if self.encrypt_key:
return types.VARCHAR(1024)
return self.sqlalchemy_type
def setter_format_value(self, value):
if self.encrypt_key:
value = dumps(
{
x: getattr(value, x)
for x in ["days", "seconds", "microseconds"]
}
)
return value
def getter_format_value(self, value):
if self.encrypt_key:
value = timedelta(**loads(value))
return value
class StringType(types.TypeDecorator):
impl = types.String
cache_ok = True
def process_bind_param(self, value, engine):
if value is False:
value = ""
return value
def process_result_value(self, value, dialect):
return value
[docs]class String(Column):
"""String column
::
from anyblok.declarations import Declarations
from anyblok.column import String
@Declarations.register(Declarations.Model)
class Test:
x = String(default='test')
"""
def __init__(self, *args, **kwargs):
self.size = kwargs.pop("size", 64)
kwargs.pop("type_", None)
self.sqlalchemy_type = StringType(self.size)
super(String, self).__init__(*args, **kwargs)
[docs] def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(String, self).autodoc_get_properties()
res["size"] = self.size
return res
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ["MySQL", "MariaDB"]):
sqlalchemy_type.impl = types.String(max(self.size, 64))
return sqlalchemy_type
class Enum(Column):
"""Enum column
::
from anyblok.declarations import Declarations
from anyblok.column import Enum
import enum
class MyEnumClass(enum.Enum):
one = 1
two = 2
three = 3
@Declarations.register(Declarations.Model)
class Test:
x = Enum(enum_cls=MyEnumClass, default='test')
enum_cls should be an enum class
"""
def __init__(self, *args, **kwargs):
self.enum_cls = kwargs.pop("enum_cls")
self.sqlalchemy_type = types.Enum(self.enum_cls)
super(Enum, self).__init__(*args, **kwargs)
def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(Enum, self).autodoc_get_properties()
res["enum_cls"] = repr(self.enum_cls)
return res
class MsSQLPasswordType(PasswordType):
impl = types.VARCHAR(1024)
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(types.VARCHAR(self.length))
[docs]class Password(Column):
"""String column
::
from anyblok.declarations import Declarations
from anyblok.column import Password
@Declarations.register(Declarations.Model)
class Test:
x = Password(crypt_context={'schemes': ['md5_crypt']})
=========================================
test = Test.insert()
test.x = 'mypassword'
test.x
==> Password object with encrypt value, the value can not be read
test.x == 'mypassword'
==> True
..warning::
the column type Password can not be querying::
Test.query().filter(Test.x == 'mypassword').count()
==> 0
"""
def __init__(self, *args, **kwargs):
self.size = kwargs.pop("size", 64)
crypt_context = kwargs.pop("crypt_context", {})
self.crypt_context = crypt_context
kwargs.pop("type_", None)
if "foreign_key" in kwargs:
raise FieldException("Column Password can not have a foreign key")
self.sqlalchemy_type = PasswordType(
max_length=self.size, **crypt_context
)
super(Password, self).__init__(*args, **kwargs)
[docs] def native_type(self, registry):
"""Return the native SqlAlchemy type"""
if sgdb_in(registry.engine, ["MsSQL"]):
return MsSQLPasswordType(max_length=self.size, **self.crypt_context)
return self.sqlalchemy_type
[docs] def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(Password, self).autodoc_get_properties()
res["size"] = self.size
res["Crypt context"] = self.crypt_context
return res
class TextType(types.TypeDecorator):
impl = types.Text
cache_ok = True
def process_bind_param(self, value, engine):
if value is False:
value = ""
return value
def process_result_value(self, value, dialect):
return value
[docs]class Text(Column):
"""Text column
::
from anyblok.declarations import Declarations
from anyblok.column import Text
@Declarations.register(Declarations.Model)
class Test:
x = Text(default='test')
"""
sqlalchemy_type = TextType
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ["MySQL", "MariaDB"]):
sqlalchemy_type.impl = types.Text()
return sqlalchemy_type
[docs]class StrSelection(str):
"""Class representing the data of one column Selection"""
selections = dumps({})
registry = None
namespace = None
[docs] def get_selections(self):
"""Return a dict of selections
:return: selections dict
"""
selections = loads(self.selections)
if isinstance(selections, dict):
return selections
if isinstance(selections, str):
m = self.registry.get(self.namespace)
return dict(getattr(m, selections)())
[docs] def validate(self):
"""Validate if the key is in the selections
:return: True or False
"""
a = super(StrSelection, self).__str__()
return a in self.get_selections().keys()
@property
def label(self):
"""Return the label corresponding to the selection key
:return:
"""
a = super(StrSelection, self).__str__()
return self.get_selections()[a]
[docs]class SelectionType(types.TypeDecorator):
"""Generic type for Column Selection"""
impl = types.String
cache_ok = True
def __init__(self, selections, size, registry=None, namespace=None):
super(SelectionType, self).__init__(length=size)
self.size = size
if isinstance(selections, (dict, str)):
self.selections = selections
elif isinstance(selections, (list, tuple)):
self.selections = dict(selections)
else:
raise FieldException( # pragma: no cover
"selection wainting 'dict', get %r" % type(selections)
)
if isinstance(self.selections, dict):
for k in self.selections.keys():
if not isinstance(k, str):
raise FieldException("The key must be a str")
if len(k) > 64:
raise Exception( # pragma: no cover
"%r is too long %r, waiting max %s or use size arg"
% (k, len(k), size)
)
self.selections = dumps(self.selections)
self._StrSelection = type(
"StrSelection",
(StrSelection,),
{
"selections": self.selections,
"registry": registry,
"namespace": namespace,
},
)
@property
def python_type(self):
return self._StrSelection
[docs] def process_bind_param(self, value, engine):
if value is not None:
value = self.python_type(value)
return value
[docs] def process_result_value(self, value, dialect):
return value
[docs]class Selection(Column):
"""Selection column
::
from anyblok.declarations import Declarations
from anyblok.column import Selection
@Declarations.register(Declarations.Model)
class Test:
STATUS = (
(u'draft', u'Draft'),
(u'done', u'Done'),
)
x = Selection(selections=STATUS, size=64, default=u'draft')
"""
def __init__(self, *args, **kwargs):
self.selections = tuple()
if "selections" in kwargs:
self.selections = kwargs.pop("selections")
self.size = kwargs.pop("size", 64)
self.sqlalchemy_type = "tmp value for assert"
super(Selection, self).__init__(*args, **kwargs)
[docs] def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(Selection, self).autodoc_get_properties()
res["selections"] = self.selections
res["size"] = self.size
return res
[docs] def get_sqlalchemy_mapping(
self, registry, namespace, fieldname, properties
):
"""Return sqlalchmy mapping
:param registry: the current registry
:param namespace: the namespace of the model
:param fieldname: the fieldname of the model
:param properties: the properties of the model
:return: instance of the real field
"""
self.sqlalchemy_type = SelectionType(
self.selections, self.size, registry=registry, namespace=namespace
)
return super(Selection, self).get_sqlalchemy_mapping(
registry, namespace, fieldname, properties
)
[docs] def update_description(self, registry, model, res):
"""Update model description
:param registry: the current registry
:param model:
:param res:
"""
super(Selection, self).update_description(registry, model, res)
sqlalchemy_type = SelectionType(
self.selections, self.size, registry=registry, namespace=model
)
values = sqlalchemy_type._StrSelection().get_selections()
res["selections"] = [(k, v) for k, v in values.items()]
[docs] def must_be_copied_before_declaration(self):
"""Return True if selections is an instance of str.
In the case of the field selection is a mixin, it must be copied or the
selection method can fail
"""
if isinstance(self.selections, str):
return True
else:
return False
[docs] def update_properties(self, registry, namespace, fieldname, properties):
"""Update column properties
:param registry: the current registry
:param namespace: the namespace of the model
:param fieldname: the fieldname of the model
:param properties: the properties of the model
"""
super(Selection, self).update_properties(
registry, namespace, fieldname, properties
)
self.fieldname = fieldname
properties["add_in_table_args"].append(self)
[docs] def update_table_args(self, registry, Model):
"""Return check constraints to limit the value
:param registry:
:param Model:
:return: list of checkConstraint
"""
if self.encrypt_key:
# dont add constraint because the state is crypted and nobody
# can add new entry
return []
if sgdb_in(registry.engine, ["MariaDB", "MsSQL", "MySQL"]):
# No check constraint in MariaDB
return []
selections = loads(self.sqlalchemy_type.selections)
if isinstance(selections, dict):
enum = selections.keys()
else:
enum = getattr(Model, selections)()
if isinstance(enum, (list, tuple)):
enum = dict(enum)
enum = enum.keys()
if len(enum) > 1:
constraint = """"%s" in ('%s')""" % (
self.fieldname,
"', '".join(enum),
)
elif enum:
constraint = """"%s" = '%s'""" % (self.fieldname, list(enum)[0])
else:
constraint = None
if constraint:
enum = list(enum)
enum.sort()
key = md5()
key.update(str(enum).encode("utf-8"))
name = self.fieldname + "_" + key.hexdigest() + "_types"
return [CheckConstraint(constraint, name=name)]
return []
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ["MySQL", "MariaDB"]):
sqlalchemy_type.impl = types.String(max(self.size, 64))
return sqlalchemy_type
"""
Added *process_result_value* at the class *JSON*, because
this method is necessary for encrypt the column
"""
types.JSON.process_result_value = lambda self, value, dialect: value
[docs]class Json(Column):
"""JSON column
::
from anyblok.declarations import Declarations
from anyblok.column import Json
@Declarations.register(Declarations.Model)
class Test:
x = Json()
"""
sqlalchemy_type = types.JSON(none_as_null=True)
def setter_format_value(self, value):
if self.encrypt_key:
value = dumps(value)
return value
def getter_format_value(self, value):
if value is None:
return None
if self.encrypt_key:
value = loads(value)
return value
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ["MySQL", "MariaDB"]):
sqlalchemy_type.impl = types.Text()
return sqlalchemy_type
[docs]class LargeBinary(Column):
"""Large binary column
::
from os import urandom
from anyblok.declarations import Declarations
from anyblok.column import LargeBinary
blob = urandom(100000)
@Declarations.register(Declarations.Model)
class Test:
x = LargeBinary(default=blob)
"""
sqlalchemy_type = types.LargeBinary
[docs] def native_type(self, registry):
if self.encrypt_key:
return types.Text
return self.sqlalchemy_type
def setter_format_value(self, value):
if self.encrypt_key:
value = b64encode(value).decode("utf-8")
return value
def getter_format_value(self, value):
if self.encrypt_key:
value = b64decode(value.encode("utf-8"))
return value
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ["MySQL", "MariaDB"]):
sqlalchemy_type.impl = types.Text()
return sqlalchemy_type
class Sequence(String):
"""Sequence column
::
from anyblok.column import Sequence
@Declarations.register(Declarations.Model)
class Test:
x = Sequence()
If you wish ensure no gap in the sequence::
from anyblok.column import Sequence
@Declarations.register(Declarations.Model)
class Test:
x = Sequence(no_gap=True, code="SO", formater="{code}-{seq:06d}")
.. warning::
Keep in mind `no_gap=True` will raise an
`sqlalchemy.exc.OperationalError: (psycopg2.errors.LockNotAvailable)`
exception in case a concurrent transaction do not release the lock
while getting the next value.
usage with `no_gap=True`::
>>> Test.insert().x
"SO-000001"
>>> Test.insert().x
"SO-000002"
>>> registry.rollback()
>>> Test.insert().x
"SO-000001"
"""
def __init__(self, *args, **kwargs):
if "foreign_key" in kwargs:
raise FieldException(
"Sequence column can not define a foreign key"
" %r" % kwargs["foreign_key"]
)
if "default" in kwargs:
raise FieldException(
"Sequence column can not define a default " "value"
)
kwargs["default"] = ColumnDefaultValue(self.wrap_default)
self.code = kwargs.pop("code") if "code" in kwargs else None
self.start = kwargs.pop("start", 1)
self.formater = kwargs.pop("formater") if "formater" in kwargs else None
self.no_gap = kwargs.pop("no_gap") if "no_gap" in kwargs else None
super(Sequence, self).__init__(*args, **kwargs)
def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(Sequence, self).autodoc_get_properties()
res["formater"] = self.formater
res["no_gap"] = self.no_gap
return res
def wrap_default(self, registry, namespace, fieldname, properties):
"""Return default wrapper
:param registry: the current registry
:param namespace: the namespace of the model
:param fieldname: the fieldname of the model
:return:
"""
if not hasattr(
registry, "_need_sequence_to_create_if_not_exist"
): # pragma: no cover
registry._need_sequence_to_create_if_not_exist = []
elif registry._need_sequence_to_create_if_not_exist is None:
registry._need_sequence_to_create_if_not_exist = []
code = self.code if self.code else "%s=>%s" % (namespace, fieldname)
registry._need_sequence_to_create_if_not_exist.append(
{"code": code, "formater": self.formater, "no_gap": self.no_gap}
)
# {'code': code, 'formater': self.formater, 'start': self.start})
def default_value(self, *args, **kwargs):
"""Return next sequence value
:return:
"""
return registry.System.Sequence.nextvalBy(code=code)
return default_value
[docs]class Color(Column):
"""Color column.
`See colour package on pypi <https://pypi.python.org/pypi/colour/>`_
::
from anyblok.declarations import Declarations
from anyblok.column import Color
@Declarations.register(Declarations.Model)
class Test:
x = Color(default='green')
"""
def __init__(self, *args, **kwargs):
self.max_length = max_length = kwargs.pop("size", 20)
kwargs.pop("type_", None)
self.sqlalchemy_type = ColorType(max_length)
super(Color, self).__init__(*args, **kwargs)
[docs] def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(Color, self).autodoc_get_properties()
res["size"] = self.max_length
return res
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ["MySQL", "MariaDB"]):
sqlalchemy_type.impl = types.String(max(self.max_length, 64))
return sqlalchemy_type
class UUID(Column):
"""UUID column
::
from anyblok.column import UUID
@Declarations.register(Declarations.Model)
class Test:
x = UUID()
"""
def __init__(self, *args, **kwargs):
uuid_kwargs = {}
self.binary = None
self.native = None
for kwarg in ("binary", "native"):
if kwarg in kwargs:
uuid_kwargs[kwarg] = kwargs.pop(kwarg)
setattr(self, kwarg, uuid_kwargs[kwarg])
self.sqlalchemy_type = UUIDType(**uuid_kwargs)
super(UUID, self).__init__(*args, **kwargs)
def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(UUID, self).autodoc_get_properties()
res["binary"] = self.binary
res["native"] = self.native
return res
URLType.cache_ok = True # waiting fix from sqlalchemy_utils
class URL(Column):
"""URL column
::
from anyblok.declarations import Declarations
from anyblok.column import URL
@Declarations.register(Declarations.Model)
class Test:
x = URL(default='doc.anyblok.org')
"""
sqlalchemy_type = URLType
def setter_format_value(self, value):
"""Return formatted url value
:param value:
:return:
"""
from furl import furl
if value is not None:
if isinstance(value, str):
value = furl(value)
return value
[docs]class PhoneNumber(Column):
"""PhoneNumber column
::
from anyblok.declarations import Declarations
from anyblok.column import PhoneNumber
@Declarations.register(Declarations.Model)
class Test:
x = PhoneNumber(default='+120012301')
.. note:: ``phonenumbers`` >= **8.9.5** distribution is required
"""
def __init__(self, region="FR", max_length=20, *args, **kwargs):
self.region = region
self.max_length = max_length
kwargs.pop("type_", None)
self.sqlalchemy_type = PhoneNumberType(
region=region, max_length=max_length
)
super(PhoneNumber, self).__init__(*args, **kwargs)
[docs] def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(PhoneNumber, self).autodoc_get_properties()
res["region"] = self.region
res["max_length"] = self.max_length
return res
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ["MySQL", "MariaDB"]):
sqlalchemy_type.impl = types.String(max(self.max_length, 64))
return sqlalchemy_type
"""
Added *process_result_value* at the class *EmailType*, because
this method is necessary for encrypt the column
"""
EmailType.process_result_value = lambda self, value, dialect: value
EmailType.cache_ok = True # waiting fix from sqlalchemy_utils
[docs]class Email(Column):
"""Email column
::
from anyblok.column import Email
@Declarations.register(Declarations.Model)
class Test:
x = Email()
"""
sqlalchemy_type = EmailType
class CountryType(types.TypeDecorator, ScalarCoercible):
"""Generic type for Column Country"""
impl = types.Unicode(3)
cache_ok = True
python_type = python_pycountry_type
def process_bind_param(self, value, dialect):
if value and isinstance(value, self.python_type):
return value.alpha_3
return value
def process_result_value(self, value, dialect):
return self._coerce(value)
def _coerce(self, value):
if value is not None and not isinstance(value, self.python_type):
return pycountry.countries.get(alpha_3=value)
return value # pragma: no cover
[docs]class Country(Column):
"""Country column.
::
from anyblok.declarations import Declarations
from anyblok.column import Country
from pycountry import countries
@Declarations.register(Declarations.Model)
class Test:
x = Country(default=countries.get(alpha_2='FR'))
"""
sqlalchemy_type = CountryType
def __init__(self, mode="alpha_2", *args, **kwargs):
self.mode = mode
if pycountry is None:
raise FieldException( # pragma: no cover
"'pycountry' package is required for use 'CountryType'"
)
self.choices = {
getattr(country, mode): country.name
for country in pycountry.countries
}
super(Country, self).__init__(*args, **kwargs)
[docs] def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(Country, self).autodoc_get_properties()
res["mode"] = self.mode
res["choices"] = self.choices
return res
[docs] def update_properties(self, registry, namespace, fieldname, properties):
"""Update column properties
:param registry: the current registry
:param namespace: the namespace of the model
:param fieldname: the fieldname of the model
:param properties: the properties of the model
"""
super(Country, self).update_properties(
registry, namespace, fieldname, properties
)
self.fieldname = fieldname
properties["add_in_table_args"].append(self)
[docs] def update_table_args(self, registry, Model):
"""Return check constraints to limit the value
:param registry:
:param Model:
:return: list of checkConstraint
"""
if self.encrypt_key:
# dont add constraint because the state is crypted and nobody
# can add new entry
return []
if sgdb_in(registry.engine, ["MariaDB", "MsSQL"]):
# No Check constraint in MariaDB
return []
enum = [country.alpha_3 for country in pycountry.countries]
constraint = """"%s" in ('%s')""" % (self.fieldname, "', '".join(enum))
enum.sort()
key = md5()
key.update(str(enum).encode("utf-8"))
name = self.fieldname + "_" + key.hexdigest() + "_types"
return [CheckConstraint(constraint, name=name)]
def model_validator_all(Model):
return True
def model_validator_is_sql(Model):
return Model.is_sql is True
def model_validator_is_not_sql(Model):
return not model_validator_is_sql(Model)
def model_validator_is_view(Model):
return hasattr(Model, "__view__")
def model_validator_is_not_view(Model):
return not model_validator_is_view(Model)
def model_validator_in_namespace(namespace):
if hasattr(namespace, "__registry_name__"):
namespace = f"{namespace.__registry_name__}."
def validator(Model):
return Model.__registry_name__.startswith(namespace)
return validator
def merge_validators(*validators):
def validator(obj):
return all(v(obj) for v in validators)
return validator
class StrModelSelection(str):
"""Class representing the data of one column ModelSelection"""
validator = None
registry = None
selections = None
def get_selections(self):
"""Return a dict of selections
:return: selections dict
"""
if not self.selections:
self.__class__.selections = {
k: v.__doc__ and v.__doc__.split("\n")[0] or k
for k, v in self.registry.loaded_namespaces.items()
if self.validator(v)
}
return self.selections
def validate(self):
"""Validate if the key is in the selections
:return: True or False
"""
a = super(StrModelSelection, self).__str__()
return a in self.get_selections().keys()
@property
def Model(self):
"""Return the class corresponding to the selection key
:return:
"""
a = super(StrModelSelection, self).__str__()
if a:
return self.registry.get(a)
return None # pragma: no cover
class ModelSelectionType(ScalarCoercible, types.TypeDecorator):
"""Generic type for Column ModelSelection"""
impl = types.String
cache_ok = True
def __init__(self, validator, registry=None, namespace=None):
super(ModelSelectionType, self).__init__(length=256)
if validator is None:
validator = model_validator_all
def _validator(obj, Model):
if isinstance(validator, str):
return getattr(registry.get(namespace), validator)(Model)
return validator(Model)
self._StrModelSelection = type(
"StrModelSelection",
(StrModelSelection,),
{
"validator": _validator,
"registry": registry,
"selections": None,
},
)
@property
def python_type(self):
return self._StrModelSelection
def process_bind_param(self, value, engine):
if value is not None:
if hasattr(value, "__registry_name__"):
value = value.__registry_name__
value = self.python_type(value)
return value
def process_result_value(self, value, dialect):
return self._coerce(value)
def _coerce(self, value):
if value is not None and not isinstance(value, self._StrModelSelection):
value = self.python_type(value)
return value
class ModelSelection(Column):
"""ModelSelection column
Allow to Reference an AnyBlok Model
::
from anyblok.declarations import Declarations
from anyblok.column import ModelSelection
@Declarations.register(Declarations.Model)
class Test:
x = ModelSelection(
default='Model.System.Blok',
validator="_x_validator"
)
@classmethod
def _x_validator(cls, Model):
return True or False
"""
def __init__(self, *args, **kwargs):
self.size = 256 # used by comparator
self.validator = None
if "validator" in kwargs:
self.validator = kwargs.pop("validator")
self.sqlalchemy_type = "tmp value for assert"
super(ModelSelection, self).__init__(*args, **kwargs)
if self.default_val and hasattr(self.default_val, "__registry_name__"):
self.default_val = self.default_val.__registry_name__
def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(ModelSelection, self).autodoc_get_properties()
res["validator"] = str(self.validator)
return res
def getter_format_value(self, value):
"""Return formatted value
:param value:
:return:
"""
if value is None:
return None
return self.sqlalchemy_type.python_type(value)
def setter_format_value(self, value):
"""Return value or raise exception if the given value is invalid
:param value:
:exception FieldException
:return:
"""
if value is not None:
if hasattr(value, "__registry_name__"):
value = value.__registry_name__
val = self.sqlalchemy_type.python_type(value)
if not val.validate():
raise FieldException(
"%r is not in the selections (%s)"
% (value, ", ".join(val.get_selections()))
)
return value
def get_sqlalchemy_mapping(
self, registry, namespace, fieldname, properties
):
"""Return sqlalchmy mapping
:param registry: the current registry
:param namespace: the namespace of the model
:param fieldname: the fieldname of the model
:param properties: the properties of the model
:return: instance of the real field
"""
self.sqlalchemy_type = ModelSelectionType(
self.validator, registry=registry, namespace=namespace
)
return super(ModelSelection, self).get_sqlalchemy_mapping(
registry, namespace, fieldname, properties
)
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ["MySQL", "MariaDB"]):
sqlalchemy_type.impl = types.String(256)
return sqlalchemy_type
def update_description(self, registry, model, res):
"""Update model description
:param registry: the current registry
:param model:
:param res:
"""
super(ModelSelection, self).update_description(registry, model, res)
sqlalchemy_type = ModelSelectionType(
self.validator, registry=registry, namespace=model
)
values = sqlalchemy_type._StrModelSelection().get_selections()
res["selections"] = [(k, v) for k, v in values.items()]
def fieldToModelAttribute(field):
if hasattr(field, "anyblok_field_name"):
field = ModelAttribute(
field.anyblok_registry_name,
field.anyblok_field_name,
)
return field
def field_validator_all(field):
return True
def field_validator_is_field(field):
return field_validator_is_not_column(
field
) and field_validator_is_not_relationship(field)
def field_validator_is_not_field(field):
return not field_validator_is_field(field)
def field_validator_is_column(field):
return isinstance(
fieldToModelAttribute(field).get_type(field.from_model().anyblok),
Column,
)
def field_validator_is_not_column(field):
return not field_validator_is_column(field)
def field_validator_is_relationship(field):
from .relationship import RelationShip
return isinstance(
fieldToModelAttribute(field).get_type(field.from_model().anyblok),
RelationShip,
)
def field_validator_is_not_relationship(field):
return not field_validator_is_relationship(field)
def field_validator_is_named(*names):
def validator(field):
return field.anyblok_field_name in names
return validator
def field_validator_is_from_types(*types):
def validator(field):
return isinstance(
fieldToModelAttribute(field).get_type(field.from_model().anyblok),
types,
)
return validator
class StrModelFieldSelection(str):
"""Class representing the data of one column ModelFieldSelection"""
model_validator = None
field_validator = None
registry = None
selections = None
def get_selections(self):
"""Return a dict of selections
:return: selections dict
"""
if not self.selections:
self.__class__.selections = {
str(ModelAttribute(namespace, field)): (
Model.__doc__ and Model.__doc__.split("\n")[0] or namespace
)
+ " : "
+ field
for namespace, Model in self.registry.loaded_namespaces.items()
if self.model_validator(Model)
for field in Model.fields_name()
if self.field_validator(getattr(Model, field))
}
return self.selections
def validate(self):
"""Validate if the key is in the selections
:return: True or False
"""
a = super(StrModelFieldSelection, self).__str__()
return a in self.get_selections().keys()
@property
def field(self):
"""Return the class corresponding to the selection key
:return:
"""
a = super(StrModelFieldSelection, self).__str__()
if a:
return ModelAttributeAdapter(a).get_attribute(self.registry)
return None # pragma: no cover
class ModelFieldSelectionType(ScalarCoercible, types.TypeDecorator):
"""Generic type for Column ModelFieldSelection"""
impl = types.String
cache_ok = True
def __init__(
self, model_validator, field_validator, registry=None, namespace=None
):
super(ModelFieldSelectionType, self).__init__(length=256)
if model_validator is None:
model_validator = model_validator_is_sql
if field_validator is None:
field_validator = field_validator_all
def _model_validator(obj, Model):
if isinstance(model_validator, str):
return getattr(registry.get(namespace), model_validator)(Model)
return model_validator(Model)
def _field_validator(obj, field):
if isinstance(field_validator, str):
return getattr(registry.get(namespace), field_validator)(field)
return field_validator(field)
self._StrModelFieldSelection = type(
"StrModelFieldSelection",
(StrModelFieldSelection,),
{
"model_validator": _model_validator,
"field_validator": _field_validator,
"registry": registry,
"selections": None,
},
)
@property
def python_type(self):
return self._StrModelFieldSelection
def process_bind_param(self, value, engine):
if value is not None:
value = self.python_type(fieldToModelAttribute(value))
return value
def process_result_value(self, value, dialect):
return self._coerce(value)
def _coerce(self, value):
if value is not None and not isinstance(
value, self._StrModelFieldSelection
):
value = self.python_type(value)
return value
class ModelFieldSelection(Column):
"""ModelFieldSelection column
Allow to Reference an AnyBlok Model
::
from anyblok.declarations import Declarations
from anyblok.column import ModelFieldSelection
@Declarations.register(Declarations.Model)
class Test:
x = ModelFieldSelection(
default='Model.System.Blok => name',
model_validator="_x_model_validator",
field_validator="_x_field_validator",
)
@classmethod
def _x_model_validator(cls, Model):
return True or False
@classmethod
def _x_field_validator(cls, field):
return True or False
"""
def __init__(self, *args, **kwargs):
self.size = 256 # used by comparator
self.model_validator = None
if "model_validator" in kwargs:
self.model_validator = kwargs.pop("model_validator")
self.field_validator = None
if "field_validator" in kwargs:
self.field_validator = kwargs.pop("field_validator")
self.sqlalchemy_type = "tmp value for assert"
super(ModelFieldSelection, self).__init__(*args, **kwargs)
def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(ModelFieldSelection, self).autodoc_get_properties()
res["model_validator"] = str(self.model_validator)
res["field_validator"] = str(self.field_validator)
return res
def getter_format_value(self, value):
"""Return formatted value
:param value:
:return:
"""
if value is None:
return None
return self.sqlalchemy_type.python_type(value)
def setter_format_value(self, value):
"""Return value or raise exception if the given value is invalid
:param value:
:exception FieldException
:return:
"""
if value is not None:
if hasattr(value, "anyblok_field_name"):
value = str(
ModelAttribute(
value.anyblok_registry_name,
value.anyblok_field_name,
)
)
val = self.sqlalchemy_type.python_type(value)
if not val.validate():
raise FieldException(
"%r is not in the selections (%s)"
% (value, ", ".join(val.get_selections()))
)
return value
def get_sqlalchemy_mapping(
self, registry, namespace, fieldname, properties
):
"""Return sqlalchmy mapping
:param registry: the current registry
:param namespace: the namespace of the model
:param fieldname: the fieldname of the model
:param properties: the properties of the model
:return: instance of the real field
"""
self.sqlalchemy_type = ModelFieldSelectionType(
self.model_validator,
self.field_validator,
registry=registry,
namespace=namespace,
)
return super(ModelFieldSelection, self).get_sqlalchemy_mapping(
registry, namespace, fieldname, properties
)
def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ["MySQL", "MariaDB"]):
sqlalchemy_type.impl = types.String(256)
return sqlalchemy_type
def update_description(self, registry, model, res):
"""Update model description
:param registry: the current registry
:param model:
:param res:
"""
super(ModelFieldSelection, self).update_description(
registry, model, res
)
sqlalchemy_type = ModelFieldSelectionType(
self.model_validator,
self.field_validator,
registry=registry,
namespace=model,
)
values = sqlalchemy_type._StrModelFieldSelection().get_selections()
res["selections"] = [(k, v) for k, v in values.items()]
def instanceToDict(instance):
if not isinstance(instance, dict):
return dict(
model=instance.__registry_name__,
primary_keys=instance.to_primary_keys(),
)
elif "model" not in instance:
raise FieldException("The model entry is required : %r" % instance)
elif "primary_keys" not in instance:
raise FieldException(
"The primary_keys entry is required : %r" % instance
)
return instance
def instance_validator_all(instance):
return True
class ModelReferenceType(types.JSON):
"""Generic type for Column ModelFieldSelection"""
cache_ok = True
class comparator_factory(SA_JSON.Comparator):
def is_(self, instance):
value = instanceToDict(instance)
filters = [
self.expr["model"].as_string() == value["model"],
]
for pk, value in value.get("primary_keys", {}).items():
filters.append(
self.expr["primary_keys"][pk].as_string() == str(value)
)
return and_(*filters)
def with_models(self, *namespaces):
filters = []
for namespace in namespaces:
if hasattr(namespace, "__registry_name__"):
namespace = namespace.__registry_name__
filters.append(self.expr["model"].as_string() == namespace)
if len(filters) == 1:
return filters[0]
return or_(*filters)
def __init__(
self,
model_validator,
instance_validator,
registry=None,
namespace=None,
):
if model_validator is None:
model_validator = model_validator_is_sql
if instance_validator is None:
instance_validator = instance_validator_all
def validate_model(value):
Model = registry.get(value)
if isinstance(model_validator, str):
return getattr(registry.get(namespace), model_validator)(Model)
return model_validator(Model)
def validate_instance(value):
instance = registry.get(value["model"]).from_primary_keys(
**value["primary_keys"]
)
if isinstance(instance_validator, str):
return getattr(instance, instance_validator)()
return instance_validator(instance)
self.validate_model = validate_model
self.validate_instance = validate_instance
self.model_selections = None
self.model_validator = model_validator
self.instance_validator = instance_validator
self.registry = registry
self.namespace = namespace
super(ModelReferenceType, self).__init__(none_as_null=True)
def get_instance(self, value):
if value:
value = instanceToDict(value)
value = self.registry.get(value["model"]).from_primary_keys(
**value["primary_keys"]
)
return value
def get_model_selections(self):
"""Return a dict of selections
:return: selections dict
"""
if not self.model_selections:
self.model_selections = {
namespace: (
Model.__doc__ and Model.__doc__.split("\n")[0] or namespace
)
for namespace, Model in self.registry.loaded_namespaces.items()
if self.validate_model(namespace)
}
return self.model_selections
class ModelReference(Json):
"""ModelReference column
Allow to Reference an AnyBlok instance of Model
::
from anyblok.declarations import Declarations
from anyblok.column import ModelReference
@Declarations.register(Declarations.Model)
class Test:
x = ModelReference(
model_validator="_x_model_validator",
instance_validator="_x_instance_validator",
)
@classmethod
def _x_model_validator(cls, Model):
return True or False
@classmethod
def _x_instance_validator(cls, field):
return True or False
==========
anyblok.Test.query().filter(Test.x.is_(instance))
anyblok.Test.query().filter(Test.x.with_models(anyblok.System.Blok))
"""
def __init__(self, *args, **kwargs):
self.model_validator = None
if "model_validator" in kwargs:
self.model_validator = kwargs.pop("model_validator")
self.instance_validator = None
if "instance_validator" in kwargs:
self.instance_validator = kwargs.pop("instance_validator")
if "default" in kwargs:
self.real_default_value = kwargs.pop("default")
kwargs["default"] = ColumnDefaultValue(self.wrap_default)
super(ModelReference, self).__init__(*args, **kwargs)
def autodoc_get_properties(self):
"""Return properties for autodoc
:return: autodoc properties
"""
res = super(ModelReference, self).autodoc_get_properties()
res["model_validator"] = str(self.model_validator)
res["instance_validator"] = str(self.instance_validator)
return res
def getter_format_value(self, value):
"""Return formatted value
:param value:
:return:
"""
value = super(ModelReference, self).getter_format_value(value)
if value is None:
return None
if isinstance(value, dict):
value = self._sqlalchemy_type.get_instance(value)
return value
def validate_dict(self, value):
if value is not None:
if not self._sqlalchemy_type.validate_model(value["model"]):
raise FieldException(
"The model of %r is not in the selections" % value
)
if not self._sqlalchemy_type.get_instance(value):
raise FieldException("%r is not existing" % value)
if not self._sqlalchemy_type.validate_instance(value):
raise FieldException("%r is not a valid choice" % value)
return value
def setter_format_value(self, value):
"""Return value or raise exception if the given value is invalid
:param value:
:exception FieldException
:return:
"""
if value is not None:
value = self.validate_dict(instanceToDict(value))
return value
def get_sqlalchemy_mapping(
self, registry, namespace, fieldname, properties
):
"""Return sqlalchmy mapping
:param registry: the current registry
:param namespace: the namespace of the model
:param fieldname: the fieldname of the model
:param properties: the properties of the model
:return: instance of the real field
"""
self._sqlalchemy_type = self.sqlalchemy_type = ModelReferenceType(
self.model_validator,
self.instance_validator,
registry=registry,
namespace=namespace,
)
return super(ModelReference, self).get_sqlalchemy_mapping(
registry, namespace, fieldname, properties
)
def update_description(self, registry, model, res):
"""Update model description
:param registry: the current registry
:param model:
:param res:
"""
super(ModelReference, self).update_description(registry, model, res)
sqlalchemy_type = ModelReferenceType(
self.model_validator,
self.instance_validator,
registry=registry,
namespace=model,
)
values = sqlalchemy_type.get_model_selections()
res["model_selections"] = [(k, v) for k, v in values.items()]
def wrap_default(self, registry, namespace, fieldname, properties):
"""Return default wrapper
:param registry: the current registry
:param namespace: the namespace of the model
:param fieldname: the fieldname of the model
:return:
"""
def default_value():
if isinstance(self.real_default_value, str):
return self.validate_dict(
instanceToDict(
wrap_default(
registry, namespace, self.real_default_value
)()
)
)
return self.validate_dict(instanceToDict(self.real_default_value))
return default_value
@CompareType.add_comparator(String, String)
@CompareType.add_comparator(String, Selection)
@CompareType.add_comparator(String, Sequence)
@CompareType.add_comparator(String, ModelSelection)
@CompareType.add_comparator(String, ModelFieldSelection)
def compare_strings(col1, type1, col2, type2):
if type1.size != type2.size:
raise FieldException(
"You can't add a foreign key using based String columns with "
"different size `{model1!s}.{col1!s}` pointing to "
"`{model2!s}.{col2!s}` have different sizes {type1!r}({size1:d}) "
"-> {type2!r}({size2:d})".format(
model1=col1.model_name,
col1=col1.attribute_name,
model2=col2.model_name,
col2=col2.attribute_name,
type1=type1.__class__,
type2=type2.__class__,
size1=type1.size,
size2=type2.size,
)
)
@CompareType.add_comparator(String, Color)
def compare_string_to_color(col1, type1, col2, type2):
if type1.size != type2.max_length:
raise FieldException(
"You can't add a foreign key using based String columns with "
"different size `{model1!s}.{col1!s}` pointing to "
"`{model2!s}.{col2!s}` have different sizes {type1!r}({size1:d}) "
"-> {type2!r}({size2:d})".format(
model1=col1.model_name,
col1=col1.attribute_name,
model2=col2.model_name,
col2=col2.attribute_name,
type1=type1.__class__,
type2=type2.__class__,
size1=type1.size,
size2=type2.max_length,
)
)