Source code for anyblok.model

# This file is a part of the AnyBlok project
#    Copyright (C) 2014 Jean-Sebastien SUZANNE <>
#    Copyright (C) 2017 Jean-Sebastien SUZANNE <>
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file,You can
# obtain one at
from anyblok.registry import RegistryManager
from anyblok import Declarations
from anyblok.field import Field, FieldException
from anyblok.relationship import RelationShip
from anyblok.column import Column
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import DDLElement
from sqlalchemy.sql import table
from sqlalchemy.orm import Query, mapper, synonym
from sqlalchemy import inspection
from anyblok.common import TypeList
from copy import deepcopy
from sqlalchemy.ext.declarative import declared_attr
from anyblok.mapper import ModelAttribute
from anyblok.common import anyblok_column_prefix
from texttable import Texttable
from .plugins import get_model_plugins
from .exceptions import ModelException, ViewException

class CreateView(DDLElement):
    def __init__(self, name, selectable): = name
        self.selectable = selectable

class DropView(DDLElement):
    def __init__(self, name): = name

def compile_create_view(element, compiler, **kw):
    return "CREATE VIEW %s AS %s" % (, compiler.sql_compiler.process(element.selectable))

def compile_drop_view(element, compiler, **kw):
    return "DROP VIEW IF EXISTS %s" % (

def has_sql_fields(bases):
    """ Tells whether the model as field or not

    :param bases: list of Model's Class
    :rtype: boolean
    for base in bases:
        for p in base.__dict__.keys():
                if hasattr(getattr(base, p), '__class__'):
                    if Field in getattr(base, p).__class__.__mro__:
                        return True
            except FieldException:
                # field function case already computed
                return True

    return False

def has_sqlalchemy_fields(base):
    for p in base.__dict__.keys():
        attr = base.__dict__[p]
        if inspection.inspect(attr, raiseerr=False) is not None:
            return True

    return False

def is_in_mro(cls, base, attr):
    return cls in getattr(base, attr).__class__.__mro__

def get_fields(base, without_relationship=False, only_relationship=False,
    """ Return the fields for a model

    :param base: Model Class
    :param without_relationship: Do not return the relationship field
    :param only_relationship: return only the relationship field
    :param without_column: Do not return the column field
    :rtype: dict with name of the field in key and instance of Field in value
    fields = {}
    for p in base.__dict__.keys():
            if hasattr(getattr(base, p), '__class__'):
                if without_relationship and is_in_mro(RelationShip, base, p):

                if without_column and is_in_mro(Column, base, p):

                if only_relationship and not is_in_mro(RelationShip, base, p):

                if is_in_mro(Field, base, p):
                    fields[p] = getattr(base, p)

        except FieldException:

    return fields

def autodoc_fields(declaration_cls, model_cls):
    """Produces autodocumentation table for the fields.

    Exposed as a function in order to be reusable by a simple export,
    e.g., from anyblok.mixin.
    if not has_sql_fields([model_cls]):
        return ''

    rows = [['Fields', '']]
    rows.extend([x, y.autodoc()]
                for x, y in get_fields(model_cls).items())
    table = Texttable(max_width=0)
    table.set_cols_valign(["m", "t"])
    return table.draw() + '\n\n'

class Model:
    """ The Model class is used to define or inherit an SQL table.

    Add new model class::

        class MyModelclass:

    Remove a model class::


    There are three Model families:

    * No SQL Model: These models have got any field, so any table
    * SQL Model:
    * SQL View Model: it is a model mapped with a SQL View, the insert, update
      delete method are forbidden by the database

    Each model has a:

    * registry name: compose by the parent + . + class model name
    * table name: compose by the parent + '_' + class model name

    The table name can be overloaded by the attribute tablename. the wanted
    value are a string (name of the table) of a model in the declaration.


        Two models can have the same table name, both models are mapped on
        the table. But they must have the same column.

    autodoc_anyblok_kwargs = True

    autodoc_anyblok_bases = True

    autodoc_anyblok_fields = True

    def pre_assemble_callback(cls, registry):
        plugins = get_model_plugins(registry)

        def call_plugins(method, *args, **kwargs):
            """call the method on each plugin"""
            for plugin in plugins:
                if hasattr(plugin, method):
                    getattr(plugin, method)(*args, **kwargs)

        registry.call_plugins = call_plugins

    def register(self, parent, name, cls_, **kwargs):
        """ add new sub registry in the registry

        :param parent: Existing global registry
        :param name: Name of the new registry to add it
        :param cls_: Class Interface to add in registry
        _registryname = parent.__registry_name__ + '.' + name
        if 'tablename' in kwargs:
            tablename = kwargs.pop('tablename')
            if not isinstance(tablename, str):
                tablename = tablename.__tablename__

        elif hasattr(parent, name):
            tablename = getattr(parent, name).__tablename__
            if parent is Declarations or parent is Declarations.Model:
                tablename = name.lower()
            elif hasattr(parent, '__tablename__'):
                tablename = parent.__tablename__
                tablename += '_' + name.lower()

        if not hasattr(parent, name):
            p = {
                '__tablename__': tablename,
                '__registry_name__': _registryname,
                'use': lambda x: ModelAttribute(_registryname, x),
            ns = type(name, tuple(), p)
            setattr(parent, name, ns)

        if parent is Declarations:

        kwargs['__registry_name__'] = _registryname
        kwargs['__tablename__'] = tablename

            'Model', _registryname, cls_, **kwargs)
        setattr(cls_, '__anyblok_kwargs__', kwargs)

    def unregister(self, entry, cls_):
        """ Remove the Interface from the registry

        :param entry: entry declaration of the model where the ``cls_``
            must be removed
        :param cls_: Class Interface to remove in registry

    def declare_field(cls, registry, name, field, namespace, properties,
        """ Declare the field/column/relationship to put in the properties
        of the model

        :param registry: the current  registry
        :param name: name of the field / column or relationship
        :param field: the declaration field / column or relationship
        :param namespace: the namespace of the model
        :param properties: the properties of the model
        if name in properties['loaded_columns']:

        if field.must_be_duplicate_before_added():
            field = deepcopy(field)

        attr_name = name
        if field.use_hybrid_property:
            attr_name = anyblok_column_prefix + name

        if field.must_be_declared_as_attr():
            # All the declaration are seen as mixin for sqlalchemy
            # some of them need de be defered for the initialisation
            # cause of the mixin as relation ship and column with foreign key
            def wrapper(cls):
                return field.get_sqlalchemy_mapping(
                    registry, namespace, name, properties)

            properties[attr_name] = declared_attr(wrapper)
            properties[attr_name].anyblok_field = field
            properties[attr_name] = field.get_sqlalchemy_mapping(
                registry, namespace, name, properties)

        if field.use_hybrid_property:
            properties[name] = field.get_property(
                registry, namespace, name, properties)

        registry.call_plugins('declare_field', name, field, namespace,
                              properties, transformation_properties)

        field.update_properties(registry, namespace, name, properties)

    def transform_base(cls, registry, namespace, base, properties):
        """ Detect specific declaration which must define by registry

        :param registry: the current registry
        :param namespace: the namespace of the model
        :param base: One of the base of the model
        :param properties: the properties of the model
        :rtype: new base
        new_type_properties = {}
        for attr in dir(base):
            method = getattr(base, attr)
                attr, method, namespace, base, properties, new_type_properties)

            'transform_base', namespace, base, properties, new_type_properties)

        if new_type_properties:
            return [type(namespace, (), new_type_properties), base]

        return [base]

    def insert_in_bases(cls, registry, namespace, bases,
                        transformation_properties, properties):
        """ Add in the declared namespaces new base.

        :param registry: the current registry
        :param namespace: the namespace of the model
        :param base: One of the base of the model
        :param transformation_properties: the properties of the model
        :param properties: assembled attributes of the namespace
        new_base = type(namespace, (), {})
        bases.insert(0, new_base)
        registry.call_plugins('insert_in_bases', new_base, namespace,
                              properties, transformation_properties)

    def raise_if_has_sqlalchemy(cls, base):
        if has_sqlalchemy_fields(base):
            raise ModelException(
                "the base %r have an SQLAlchemy attribute" % base)

    def load_namespace_first_step(cls, registry, namespace):
        """ Return the properties of the declared bases for a namespace.
        This is the first step because some actions need to known all the

        :param registry: the current registry
        :param namespace: the namespace of the model
        :rtype: dict of the known properties
        if namespace in registry.loaded_namespaces_first_step:
            return registry.loaded_namespaces_first_step[namespace]

        bases = []
        properties = {'__depends__': set()}
        ns = registry.loaded_registries[namespace]

        for b in ns['bases']:

            for b_ns in b.__anyblok_bases__:
                if b_ns.__registry_name__.startswith('Model.'):

                ps = cls.load_namespace_first_step(registry,

        for b in bases:
            fields = get_fields(b)
            for p, f in fields.items():
                if p not in properties:
                    properties[p] = f

        if '__tablename__' in ns['properties']:
            properties['__tablename__'] = ns['properties']['__tablename__']

        registry.loaded_namespaces_first_step[namespace] = properties
        return properties

    def apply_inheritance_base(cls, registry, namespace, ns, bases,
                               realregistryname, properties,

        # remove doublon
        for b in ns['bases']:
            if b in bases:

            kwargs = {
                'namespace': realregistryname} if realregistryname else {}
            bases.append(b, **kwargs)

            if b.__doc__ and '__doc__' not in properties:
                properties['__doc__'] = b.__doc__

            for b_ns in b.__anyblok_bases__:
                brn = b_ns.__registry_name__
                if brn in registry.loaded_registries['Mixin_names']:
                    tp = transformation_properties
                    if realregistryname:
                        bs, ps = cls.load_namespace_second_step(
                            registry, brn, realregistryname=realregistryname,
                        bs, ps = cls.load_namespace_second_step(
                            registry, brn, realregistryname=namespace,
                elif brn in registry.loaded_registries['Model_names']:
                    bs, ps = cls.load_namespace_second_step(registry, brn)
                    raise ModelException(
                        "You have not to inherit the %r "
                        "Only the 'Mixin' and %r types are allowed" % (
                            brn, cls.__name__))

                bases += bs

    def init_core_properties_and_bases(cls, registry, bases, properties):
        properties['loaded_columns'] = []
        properties['hybrid_property_columns'] = []
        properties['loaded_fields'] = {}
        if properties['is_sql_view']:
            bases.extend([x for x in registry.loaded_cores['SqlViewBase']])
        elif has_sql_fields(bases):
            bases.extend([x for x in registry.loaded_cores['SqlBase']])
            # remove tablename to inherit from a sqlmodel
            del properties['__tablename__']

        bases.extend([x for x in registry.loaded_cores['Base']])

    def declare_all_fields(cls, registry, namespace, bases, properties,
        # do in the first time the fields and columns
        # because for the relationship on the same model
        # the primary keys must exist before the relationship
        # load all the base before do relationship because primary key
        # can be come from inherit
        for b in bases:
            for p, f in get_fields(b,
                    registry, p, f, namespace, properties,

        for b in bases:
            for p, f in get_fields(b, only_relationship=True).items():
                    registry, p, f, namespace, properties,

    def apply_existing_table(cls, registry, namespace, tablename, properties,
                             bases, transformation_properties):
        if '__tablename__' in properties:
            del properties['__tablename__']

        for t in registry.loaded_namespaces.keys():
            m = registry.loaded_namespaces[t]
            if m.is_sql:
                if getattr(m, '__tablename__'):
                    if m.__tablename__ == tablename:
                        properties['__table__'] = m.__table__
                        tablename = namespace.replace('.', '_').lower()

        for b in bases:
            for p, f in get_fields(b,
                    registry, p, f, namespace, properties,

    def load_namespace_second_step(cls, registry, namespace,
        """ Return the bases and the properties of the namespace

        :param registry: the current registry
        :param namespace: the namespace of the model
        :param realregistryname: the name of the model if the namespace is a
        :rtype: the list od the bases and the properties
        :exception: ModelException
        if namespace in registry.loaded_namespaces:
            return [registry.loaded_namespaces[namespace]], {}

        if transformation_properties is None:
            transformation_properties = {}

        bases = TypeList(cls, registry, namespace, transformation_properties)
        ns = registry.loaded_registries[namespace]
        properties = ns['properties'].copy()

                              properties, transformation_properties)

        if 'is_sql_view' not in properties:
            properties['is_sql_view'] = False

        cls.apply_inheritance_base(registry, namespace, ns, bases,
                                   realregistryname, properties,

        if namespace in registry.loaded_registries['Model_names']:
            tablename = properties['__tablename__']
            modelname = namespace.replace('.', '')
            cls.init_core_properties_and_bases(registry, bases, properties)

            if tablename in registry.declarativebase.metadata.tables:
                    registry, namespace, tablename, properties,
                    bases, transformation_properties)
                cls.declare_all_fields(registry, namespace, bases, properties,

            cls.insert_in_bases(registry, namespace, bases,
                                transformation_properties, properties)
            if properties['is_sql_view']:
                bases = [type(modelname, tuple(bases), properties)]
                if properties['is_sql_view']:
                    cls.apply_view(namespace, tablename, bases[0], registry,
                bases = [type(modelname, tuple(bases), properties)]

            properties = {}
            registry.add_in_registry(namespace, bases[0])
            registry.loaded_namespaces[namespace] = bases[0]
            registry.call_plugins('after_model_construction', bases[0],
                                  namespace, transformation_properties)

        return bases, properties

    def replace_properties_by_synonym(cls, properties):
        for field in properties['loaded_columns']:
            properties[field] = synonym(anyblok_column_prefix + field)

    def apply_view(cls, namespace, tablename, base, registry, properties):
        """ Transform the sqlmodel to view model

        :param namespace: Namespace of the model
        :param tablename: Name od the table of the model
        :param base: Model cls
        :param registry: current registry
        :param properties: properties of the model
        :exception: MigrationException
        :exception: ViewException
        if hasattr(base, '__view__'):
            view = base.__view__
        elif tablename in registry.loaded_views:
            view = registry.loaded_views[tablename]
            if not hasattr(base, 'sqlalchemy_view_declaration'):
                raise ViewException(
                    "%r.'sqlalchemy_view_declaration' is required to "
                    "define the query to apply of the view" % namespace)

            view = table(tablename)
            registry.loaded_views[tablename] = view
            selectable = getattr(base, 'sqlalchemy_view_declaration')()

            if isinstance(selectable, Query):
                selectable = selectable.subquery()

            for c in selectable.c:

                'before-create', registry.declarativebase.metadata)
            CreateView(tablename, selectable).execute_at(
                'after-create', registry.declarativebase.metadata)
                'before-drop', registry.declarativebase.metadata)

        pks = [col for col in properties['loaded_columns']
               if getattr(base, anyblok_column_prefix + col).primary_key]

        if not pks:
            raise ViewException(
                "%r have any primary key defined" % namespace)

        pks = [getattr(view.c, x) for x in pks]
        mapper(base, view, primary_key=pks)
        setattr(base, '__view__', view)

    def assemble_callback(cls, registry):
        """ Assemble callback is called to assemble all the Model
        from the installed bloks

        :param registry: registry to update
        registry.loaded_namespaces_first_step = {}
        registry.loaded_views = {}

        # get all the information to create a namespace
        for namespace in registry.loaded_registries['Model_names']:
            cls.load_namespace_first_step(registry, namespace)

        # create the namespace with all the information come from first
        # step
        for namespace in registry.loaded_registries['Model_names']:
            cls.load_namespace_second_step(registry, namespace)

    def initialize_callback(cls, registry):
        """ initialize callback is called after assembling all entries

        This callback updates the database information about

        * Model
        * Column
        * RelationShip

        :param registry: registry to update
        for Model in registry.loaded_namespaces.values():

        Blok = registry.System.Blok
        if not registry.withoutautomigration:
            Model = registry.System.Model

        bloks = Blok.list_by_state('touninstall')
        return Blok.apply_state(*registry.ordered_loaded_bloks)