--- a +++ b/qiita_db/sql_connection.py @@ -0,0 +1,523 @@ +r""" +SQL Connection object (:mod:`qiita_db.sql_connection`) +====================================================== + +.. currentmodule:: qiita_db.sql_connection + +This modules provides wrappers for the psycopg2 module to allow easy use of +transaction blocks and SQL execution/data retrieval. + +This module provides the variable TRN, which is the transaction available +to use in the system. The singleton pattern is applied and this works as long +as the system remains single-threaded. + +Classes +------- + +.. autosummary:: + :toctree: generated/ + + Transaction +""" +# ----------------------------------------------------------------------------- +# Copyright (c) 2014--, The Qiita Development Team. +# +# Distributed under the terms of the BSD 3-clause License. +# +# The full license is in the file LICENSE, distributed with this software. +# ----------------------------------------------------------------------------- +from contextlib import contextmanager +from itertools import chain +from functools import wraps + +from psycopg2 import (connect, ProgrammingError, Error as PostgresError, + OperationalError, errorcodes) +from psycopg2.extras import DictCursor +from psycopg2.extensions import TRANSACTION_STATUS_IDLE + +from qiita_core.qiita_settings import qiita_config + + +def _checker(func): + """Decorator to check that methods are executed inside the context""" + @wraps(func) + def wrapper(self, *args, **kwargs): + if self._contexts_entered == 0: + raise RuntimeError( + "Operation not permitted. Transaction methods can only be " + "invoked within the context manager.") + return func(self, *args, **kwargs) + return wrapper + + +class Transaction(object): + """A context manager that encapsulates a DB transaction + + A transaction is defined by a series of consecutive queries that need to + be applied to the database as a single block. + + Raises + ------ + RuntimeError + If the transaction methods are invoked outside a context. + + Notes + ----- + When the execution leaves the context manager, any remaining queries in + the transaction will be executed and committed. + """ + def __init__(self, admin=False): + self._queries = [] + self._results = [] + self._contexts_entered = 0 + self._connection = None + self._post_commit_funcs = [] + self._post_rollback_funcs = [] + self.admin = admin + + def _open_connection(self): + # If the connection already exists and is not closed, don't do anything + if self._connection is not None and self._connection.closed == 0: + return + + try: + if self.admin: + self._connection = connect( + user=qiita_config.admin_user, + password=qiita_config.admin_password, + host=qiita_config.host, + port=qiita_config.port) + self._connection.autocommit = True + else: + self._connection = connect(user=qiita_config.user, + password=qiita_config.password, + database=qiita_config.database, + host=qiita_config.host, + port=qiita_config.port) + except OperationalError as e: + # catch three known common exceptions and raise runtime errors + try: + etype = str(e).split(':')[1].split()[0] + except IndexError: + # we recieved a really unanticipated error without a colon + etype = '' + if etype == 'database': + etext = ('This is likely because the database `%s` has not ' + 'been created or has been dropped.' % + qiita_config.database) + elif etype == 'role': + etext = ('This is likely because the user string `%s` ' + 'supplied in your configuration file `%s` is ' + 'incorrect or not an authorized postgres user.' % + (qiita_config.user, qiita_config.conf_fp)) + elif etype == 'Connection': + etext = ('This is likely because postgres isn\'t ' + 'running. Check that postgres is correctly ' + 'installed and is running.') + else: + # we recieved a really unanticipated error with a colon + etext = '' + ebase = ('An OperationalError with the following message occured' + '\n\n\t%s\n%s For more information, review `INSTALL.md`' + ' in the Qiita installation base directory.') + raise RuntimeError(ebase % (str(e), etext)) + + def close(self): + if self._connection is not None: + self._connection.close() + + @contextmanager + def _get_cursor(self): + """Returns a postgres cursor + + Returns + ------- + psycopg2.cursor + The psycopg2 cursor + + Raises + ------ + RuntimeError + if the cursor cannot be created + """ + self._open_connection() + + try: + with self._connection.cursor(cursor_factory=DictCursor) as cur: + yield cur + except PostgresError as e: + raise RuntimeError("Cannot get postgres cursor: %s" % e) + + def __enter__(self): + self._open_connection() + self._contexts_entered += 1 + return self + + def _clean_up(self, exc_type): + if exc_type is not None: + # An exception occurred during the execution of the transaction + # Make sure that we leave the DB w/o any modification + self.rollback() + elif self._queries: + # There are still queries to be executed, execute them + # It is safe to use the execute method here, as internally is + # wrapped in a try/except and rollbacks in case of failure + self.execute() + self.commit() + elif self._connection.get_transaction_status() != \ + TRANSACTION_STATUS_IDLE: + # There are no queries to be executed, however, the transaction + # is still not committed. Commit it so the changes are not lost + self.commit() + + def __exit__(self, exc_type, exc_value, traceback): + # We only need to perform some action if this is the last context + # that we are entering + if self._contexts_entered == 1: + # We need to wrap the entire function in a try/finally because + # at the end we need to decrement _contexts_entered + try: + self._clean_up(exc_type) + finally: + self._contexts_entered -= 1 + else: + self._contexts_entered -= 1 + + def _raise_execution_error(self, sql, sql_args, error): + """Rollbacks the current transaction and raises a useful error + The error message contains the name of the transaction, the failed + query, the arguments of the failed query and the error generated. + + Raises + ------ + ValueError + """ + self.rollback() + + try: + ec_lu = errorcodes.lookup(error.pgcode) + raise ValueError( + "Error running SQL: %s. MSG: %s\n" % (ec_lu, str(error))) + # the order of except statements is important, do not change + except (KeyError, AttributeError, TypeError) as error: + raise ValueError("Error running SQL query: %s" % str(error)) + except ValueError as error: + raise ValueError("Error running SQL query: %s" % str(error)) + + @_checker + def add(self, sql, sql_args=None, many=False): + """Add a sql query to the transaction + + Parameters + ---------- + sql : str + The sql query + sql_args : list, tuple or dict of objects, optional + The arguments to the sql query + many : bool, optional + Whether or not we should add the query multiple times to the + transaction + + Raises + ------ + TypeError + If `sql_args` is provided and is not a list, tuple or dict + RuntimeError + If invoked outside a context + + Notes + ----- + If `many` is true, `sql_args` should be a list of lists, tuples or + dicts, in which each element of the list contains the parameters for + one SQL query of the many. Each element on the list is all the + parameters for a single one of the many queries added. The amount of + SQL queries added to the list is len(sql_args). + """ + if not many: + sql_args = [sql_args] + + for args in sql_args: + if args: + if not isinstance(args, (list, tuple, dict)): + raise TypeError("sql_args should be a list, tuple or dict." + " Found %s" % type(args)) + self._queries.append((sql, args)) + + def _execute(self): + """Internal function that actually executes the transaction + The `execute` function exposed in the API wraps this one to make sure + that we catch any exception that happens in here and we rollback the + transaction + """ + with self._get_cursor() as cur: + for sql, sql_args in self._queries: + # Execute the current SQL command + try: + cur.execute(sql, sql_args) + except Exception as e: + # We catch any exception as we want to make sure that we + # rollback every time that something went wrong + self._raise_execution_error(sql, sql_args, e) + + try: + res = cur.fetchall() + except ProgrammingError: + # At this execution point, we don't know if the sql query + # that we executed should retrieve values from the database + # If the query was not supposed to retrieve any value + # (e.g. an INSERT without a RETURNING clause), it will + # raise a ProgrammingError. Otherwise it will just return + # an empty list + res = None + except PostgresError as e: + # Some other error happened during the execution of the + # query, so we need to rollback + self._raise_execution_error(sql, sql_args, e) + + # Store the results of the current query + self._results.append(res) + + # wipe out the already executed queries + self._queries = [] + + return self._results + + @_checker + def execute(self): + """Executes the transaction + + Returns + ------- + list of DictCursor + The results of all the SQL queries in the transaction + + Raises + ------ + RuntimeError + If invoked outside a context + + Notes + ----- + If any exception occurs during the execution transaction, a rollback + is executed and no changes are reflected in the database. + When calling execute, the transaction will never be committed, it will + be automatically committed when leaving the context + + See Also + -------- + execute_fetchlast + execute_fetchindex + execute_fetchflatten + """ + try: + return self._execute() + except Exception: + self.rollback() + raise + + @_checker + def execute_fetchlast(self): + """Executes the transaction and returns the last result + + This is a convenient function that is equivalent to + `self.execute()[-1][0][0]` + + Returns + ------- + object + The first value of the last SQL query executed + + See Also + -------- + execute + execute_fetchindex + execute_fetchflatten + """ + return self.execute()[-1][0][0] + + @_checker + def execute_fetchindex(self, idx=-1): + """Executes the transaction and returns the results of the `idx` query + + This is a convenient function that is equivalent to + `self.execute()[idx] + + Parameters + ---------- + idx : int, optional + The index of the query to return the result. It defaults to -1, the + last query. + + Returns + ------- + DictCursor + The results of the `idx` query in the transaction + + See Also + -------- + execute + execute_fetchlast + execute_fetchflatten + """ + return self.execute()[idx] + + @_checker + def execute_fetchflatten(self, idx=-1): + """Executes the transaction and returns the flattened results of the + `idx` query + + This is a convenient function that is equivalent to + `chain.from_iterable(self.execute()[idx])` + + Parameters + ---------- + idx : int, optional + The index of the query to return the result. It defaults to -1, the + last query. + + Returns + ------- + list of objects + The flattened results of the `idx` query + + See Also + -------- + execute + execute_fetchlast + execute_fetchindex + """ + return list(chain.from_iterable(self.execute()[idx])) + + def _funcs_executor(self, funcs, func_str): + error_msg = [] + for f, args, kwargs in funcs: + try: + f(*args, **kwargs) + except Exception as e: + error_msg.append(str(e)) + # The functions in these two lines are mutually exclusive. When one of + # them is executed, we can restore both of them. + self._post_commit_funcs = [] + self._post_rollback_funcs = [] + if error_msg: + raise RuntimeError( + "An error occurred during the post %s commands:\n%s" + % (func_str, "\n".join(error_msg))) + + @_checker + def commit(self): + """Commits the transaction and reset the queries + + Raises + ------ + RuntimeError + If invoked outside a context + """ + # Reset the queries, the results and the index + self._queries = [] + self._results = [] + try: + self._connection.commit() + except Exception: + self._connection.close() + raise + # Execute the post commit functions + self._funcs_executor(self._post_commit_funcs, "commit") + + @_checker + def rollback(self): + """Rollbacks the transaction and reset the queries + + Raises + ------ + RuntimeError + If invoked outside a context + """ + # Reset the queries, the results and the index + self._queries = [] + self._results = [] + + if self._connection is not None and self._connection.closed == 0: + try: + self._connection.rollback() + except Exception: + self._connection.close() + raise + # Execute the post rollback functions + self._funcs_executor(self._post_rollback_funcs, "rollback") + + @property + def index(self): + return len(self._queries) + len(self._results) + + @_checker + def add_post_commit_func(self, func, *args, **kwargs): + """Adds a post commit function + + The function added will be executed after the next commit in the + transaction, unless a rollback is executed. This is useful, for + example, to perform some filesystem clean up once the transaction is + committed. + + Parameters + ---------- + func : function + The function to add for the post commit functions + args : tuple + The arguments of the function + kwargs : dict + The keyword arguments of the function + """ + self._post_commit_funcs.append((func, args, kwargs)) + + @_checker + def add_post_rollback_func(self, func, *args, **kwargs): + """Adds a post rollback function + + The function added will be executed after the next rollback in the + transaction, unless a commit is executed. This is useful, for example, + to restore the filesystem in case a rollback occurs, avoiding leaving + the database and the filesystem in an out of sync state. + + Parameters + ---------- + func : function + The function to add for the post rollback functions + args : tuple + The arguments of the function + kwargs : dict + The keyword arguments of the function + """ + self._post_rollback_funcs.append((func, args, kwargs)) + + +# Singleton pattern, create the transaction for the entire system +TRN = Transaction() +TRNADMIN = Transaction(admin=True) + + +def perform_as_transaction(sql, parameters=None): + """Opens, adds and executes sql as a single transaction + + Parameters + ---------- + sql : str + The SQL to execute + parameters: object, optional + The object of parameters to pass to the TRN.add command + """ + with TRN: + if parameters: + TRN.add(sql, parameters) + else: + TRN.add(sql) + TRN.execute() + + +def create_new_transaction(): + """Creates a new global transaction + + This is needed when using multiprocessing + """ + global TRN + TRN = Transaction()