Source code for anyblok.model.factory

# This file is a part of the AnyBlok project
#    Copyright (C) 2018 Jean-Sebastien SUZANNE <[email protected]>
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file,You can
# obtain one at
from .exceptions import ModelFactoryException
from anyblok.field import Field, FieldException
from sqlalchemy.sql import table, and_
from sqlalchemy.orm import Query, mapper, relationship
from .exceptions import ViewException
from anyblok.common import anyblok_column_prefix
from .view import CreateView, DropView

def has_sql_fields(bases):
    """ Tells whether the model as field or not

    :param bases: list of Model's Class
    :rtype: boolean
    for base in bases:
        for p in base.__dict__.keys():
                if hasattr(getattr(base, p), '__class__'):
                    if Field in getattr(base, p).__class__.__mro__:
                        return True
            except FieldException:
                # field function case already computed
                return True

    return False

[docs]class BaseFactory: def __init__(self, registry): self.registry = registry def insert_core_bases(self, bases, properties): raise ModelFactoryException('Must be overwritten') def build_model(self, modelname, bases, properties): raise ModelFactoryException('Must be overwritten')
[docs]class ModelFactory(BaseFactory): def insert_core_bases(self, bases, properties): if has_sql_fields(bases): bases.extend( [x for x in self.registry.loaded_cores['SqlBase']]) bases.append(self.registry.declarativebase) else: # remove tablename to inherit from a sqlmodel del properties['__tablename__'] bases.extend([x for x in self.registry.loaded_cores['Base']]) def build_model(self, modelname, bases, properties): return type(modelname, tuple(bases), properties)
def get_columns(view, columns): if not isinstance(columns, list): if ', ' in columns: columns = columns.split(', ') else: columns = [columns] return [getattr(view.c, x.split(' => ')[1]) for x in columns]
[docs]class ViewFactory(BaseFactory): def insert_core_bases(self, bases, properties): bases.extend( [x for x in self.registry.loaded_cores['SqlViewBase']]) bases.extend([x for x in self.registry.loaded_cores['Base']]) def build_model(self, modelname, bases, properties): Model = type(modelname, tuple(bases), properties) self.apply_view(Model, properties) return Model
[docs] def apply_view(self, base, properties): """ Transform the sqlmodel to view model :param base: Model cls :param properties: properties of the model :exception: MigrationException :exception: ViewException """ tablename = base.__tablename__ if hasattr(base, '__view__'): view = base.__view__ elif tablename in self.registry.loaded_views: view = self.registry.loaded_views[tablename] else: if not hasattr(base, 'sqlalchemy_view_declaration'): raise ViewException( "%r.'sqlalchemy_view_declaration' is required to " "define the query to apply of the view" % base) view = table(tablename) self.registry.loaded_views[tablename] = view selectable = getattr(base, 'sqlalchemy_view_declaration')() if isinstance(selectable, Query): selectable = selectable.subquery() for c in selectable.c: c._make_proxy(view) DropView(tablename).execute_at( 'before-create', self.registry.declarativebase.metadata) CreateView(tablename, selectable).execute_at( 'after-create', self.registry.declarativebase.metadata) DropView(tablename).execute_at( 'before-drop', self.registry.declarativebase.metadata) pks = [col for col in properties['loaded_columns'] if getattr(getattr(base, anyblok_column_prefix + col), 'primary_key', False)] if not pks: raise ViewException( "%r have any primary key defined" % base) pks = [getattr(view.c, x) for x in pks] mapper_properties = self.get_mapper_properties(base, view, properties) setattr(base, '__view__', view) __mapper__ = mapper( base, view, primary_key=pks, properties=mapper_properties) setattr(base, '__mapper__', __mapper__)
def get_mapper_properties(self, base, view, properties): mapper_properties = base.define_mapper_args() for field in properties['loaded_columns']: if not hasattr(properties[anyblok_column_prefix + field], 'anyblok_field'): mapper_properties[field] = getattr(view.c, field) continue anyblok_field = properties[ anyblok_column_prefix + field].anyblok_field kwargs = anyblok_field.kwargs.copy() if 'foreign_keys' in kwargs: foreign_keys = kwargs['foreign_keys'][1:][:-1].split(', ') foreign_keys = [getattr(view.c, x.split('.')[1]) for x in foreign_keys] kwargs['foreign_keys'] = foreign_keys if anyblok_field.model.model_name == base.__registry_name__: remote_columns = get_columns( view, kwargs['info']['remote_columns']) local_columns = get_columns( view, kwargs['info']['local_columns']) assert len(remote_columns) == len(local_columns) primaryjoin = [] for i in range(len(local_columns)): primaryjoin.append(remote_columns[i] == local_columns[i]) if len(primaryjoin) == 1: primaryjoin = primaryjoin[0] else: primaryjoin = and_(*primaryjoin) kwargs['remote_side'] = remote_columns kwargs['primaryjoin'] = primaryjoin Model = base else: Model = self.registry.get(anyblok_field.model.model_name) mapper_properties[field] = relationship(Model, **kwargs) return mapper_properties