# This file is a part of the AnyBlok project
#
# Copyright (C) 2014 Jean-Sebastien SUZANNE <jssuzanne@anybox.fr>
# Copyright (C) 2021 Jean-Sebastien SUZANNE <js.suzanne@gmail.com>
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file,You can
# obtain one at http://mozilla.org/MPL/2.0/.
from logging import getLogger
from sqlalchemy import func, select
from sqlalchemy.orm.exc import NoResultFound
from anyblok import Declarations
from anyblok.common import anyblok_column_prefix
logger = getLogger(__name__)
[docs]@Declarations.register(Declarations.Core)
class Query:
"""Overload the SqlAlchemy Query"""
def __init__(self, Model, *elements, sql_statement=None):
self.Model = Model
self.elements = elements
self.sql_statement = sql_statement
if sql_statement is None:
self.sql_statement = Model.select_sql_statement(*elements)
def __getattr__(self, key, default=None):
sqla_function = getattr(self.sql_statement, key)
def wrapper(*args, **kwargs):
statement = sqla_function(*args, **kwargs)
return self.anyblok.Query(
self.Model, *self.elements, sql_statement=statement
)
return wrapper
def __iter__(self):
for res in self._execute():
yield res
def _execute(self):
res = self.Model.execute(self.sql_statement)
if self.elements:
return res
return res.scalars()
@property
def column_descriptions(self):
return self.sql_statement.column_descriptions
def count(self):
stmt = select(func.count())
stmt = stmt.select_from(self.sql_statement.subquery())
return self.Model.execute(stmt).scalars().first()
def delete(self, *args, **kwargs):
raise NotImplementedError( # pragma: no cover
"You have to use Model.delete_sql_statement()"
)
def update(self, *args, **kwargs):
raise NotImplementedError( # pragma: no cover
"You have to use Model.update_sql_statement()"
)
def first(self):
try:
return self._execute().first()
except NoResultFound as exc: # pragma: no cover
logger.debug(
"On Model %r: exc %s: query %s",
self.Model.__registry_name__,
str(exc),
str(self),
)
raise exc.__class__(
"On Model %r: %s" % (self.Model.__registry_name__, str(exc))
)
[docs] def one(self):
"""Overwrite sqlalchemy one() method to improve exception message
Add model name to query exception message
"""
try:
return self._execute().one()
except NoResultFound as exc:
logger.debug(
"On Model %r: exc %s: query %s",
self.Model.__registry_name__,
str(exc),
str(self),
)
raise exc.__class__( # pragma: no cover
"On Model %r: %s" % (self.Model.__registry_name__, str(exc))
)
def one_or_none(self):
return self._execute().one_or_none()
[docs] def all(self):
"""Return an instrumented list of the result of the query"""
res = self._execute().all()
return self.anyblok.InstrumentedList(res)
[docs] def with_perm(self, principals, permission):
"""Add authorization pre- and post-filtering to query.
This must be last in the construction chain of the query.
Queries too complicated for the authorization system to infer
safely will be refused.
:param principals: list, set or tuple of strings
:param str permission: the permission to filter for
:returns: a query-like object, with only the returning methods, such
as ``all()``, ``count()`` etc. available.
"""
return self.anyblok.wrap_query_permission(self, principals, permission)
def get_field_names_in_column_description(self):
field2get = [
x["name"]
for x in self.column_descriptions
if not hasattr(x["type"], "__table__")
]
field2get = [
(
x[len(anyblok_column_prefix) :]
if x.startswith(anyblok_column_prefix)
else x,
x,
)
for x in field2get
]
return field2get
def dictone(self):
try:
val = self.one()
except NoResultFound as exc:
msg = str(exc).replace("one()", "dictone()")
raise exc.__class__(msg)
field2get = self.get_field_names_in_column_description()
if field2get:
return {x: getattr(val, y) for x, y in field2get}
else:
return val.to_dict()
def dictfirst(self):
val = self.first()
field2get = self.get_field_names_in_column_description()
if field2get:
return {x: getattr(val, y) for x, y in field2get}
else:
return val.to_dict()
def dictall(self):
vals = self.all()
if not vals:
return [] # pragma: no cover
field2get = self.get_field_names_in_column_description()
if field2get:
return [{x: getattr(y, z) for x, z in field2get} for y in vals]
else:
return vals.to_dict()
[docs] def get(self, primary_keys=None, **kwargs):
"""Return instance of the Model
::
instance = Model.query().get(the primary key value)
or
::
instance Model.query().get(pk1 name=pk1 value, ...)
"""
if primary_keys is None:
primary_keys = kwargs
if isinstance(primary_keys, dict):
primary_keys = {
(
anyblok_column_prefix + k
if k in self.Model.hybrid_property_columns
else k
): v
for k, v in primary_keys.items()
}
return self.anyblok.session.get(self.Model, primary_keys)
def subquery(self, *args, **kwargs):
return self.sql_statement.subquery(*args, **kwargs)