Switch to side-by-side view

--- a
+++ b/qiita_db/sql_connection.py
@@ -0,0 +1,523 @@
+r"""
+SQL Connection object (:mod:`qiita_db.sql_connection`)
+======================================================
+
+.. currentmodule:: qiita_db.sql_connection
+
+This modules provides wrappers for the psycopg2 module to allow easy use of
+transaction blocks and SQL execution/data retrieval.
+
+This module provides the variable TRN, which is the transaction available
+to use in the system. The singleton pattern is applied and this works as long
+as the system remains single-threaded.
+
+Classes
+-------
+
+.. autosummary::
+   :toctree: generated/
+
+   Transaction
+"""
+# -----------------------------------------------------------------------------
+# Copyright (c) 2014--, The Qiita Development Team.
+#
+# Distributed under the terms of the BSD 3-clause License.
+#
+# The full license is in the file LICENSE, distributed with this software.
+# -----------------------------------------------------------------------------
+from contextlib import contextmanager
+from itertools import chain
+from functools import wraps
+
+from psycopg2 import (connect, ProgrammingError, Error as PostgresError,
+                      OperationalError, errorcodes)
+from psycopg2.extras import DictCursor
+from psycopg2.extensions import TRANSACTION_STATUS_IDLE
+
+from qiita_core.qiita_settings import qiita_config
+
+
+def _checker(func):
+    """Decorator to check that methods are executed inside the context"""
+    @wraps(func)
+    def wrapper(self, *args, **kwargs):
+        if self._contexts_entered == 0:
+            raise RuntimeError(
+                "Operation not permitted. Transaction methods can only be "
+                "invoked within the context manager.")
+        return func(self, *args, **kwargs)
+    return wrapper
+
+
+class Transaction(object):
+    """A context manager that encapsulates a DB transaction
+
+    A transaction is defined by a series of consecutive queries that need to
+    be applied to the database as a single block.
+
+    Raises
+    ------
+    RuntimeError
+        If the transaction methods are invoked outside a context.
+
+    Notes
+    -----
+    When the execution leaves the context manager, any remaining queries in
+    the transaction will be executed and committed.
+    """
+    def __init__(self, admin=False):
+        self._queries = []
+        self._results = []
+        self._contexts_entered = 0
+        self._connection = None
+        self._post_commit_funcs = []
+        self._post_rollback_funcs = []
+        self.admin = admin
+
+    def _open_connection(self):
+        # If the connection already exists and is not closed, don't do anything
+        if self._connection is not None and self._connection.closed == 0:
+            return
+
+        try:
+            if self.admin:
+                self._connection = connect(
+                    user=qiita_config.admin_user,
+                    password=qiita_config.admin_password,
+                    host=qiita_config.host,
+                    port=qiita_config.port)
+                self._connection.autocommit = True
+            else:
+                self._connection = connect(user=qiita_config.user,
+                                           password=qiita_config.password,
+                                           database=qiita_config.database,
+                                           host=qiita_config.host,
+                                           port=qiita_config.port)
+        except OperationalError as e:
+            # catch three known common exceptions and raise runtime errors
+            try:
+                etype = str(e).split(':')[1].split()[0]
+            except IndexError:
+                # we recieved a really unanticipated error without a colon
+                etype = ''
+            if etype == 'database':
+                etext = ('This is likely because the database `%s` has not '
+                         'been created or has been dropped.' %
+                         qiita_config.database)
+            elif etype == 'role':
+                etext = ('This is likely because the user string `%s` '
+                         'supplied in your configuration file `%s` is '
+                         'incorrect or not an authorized postgres user.' %
+                         (qiita_config.user, qiita_config.conf_fp))
+            elif etype == 'Connection':
+                etext = ('This is likely because postgres isn\'t '
+                         'running. Check that postgres is correctly '
+                         'installed and is running.')
+            else:
+                # we recieved a really unanticipated error with a colon
+                etext = ''
+            ebase = ('An OperationalError with the following message occured'
+                     '\n\n\t%s\n%s For more information, review `INSTALL.md`'
+                     ' in the Qiita installation base directory.')
+            raise RuntimeError(ebase % (str(e), etext))
+
+    def close(self):
+        if self._connection is not None:
+            self._connection.close()
+
+    @contextmanager
+    def _get_cursor(self):
+        """Returns a postgres cursor
+
+        Returns
+        -------
+        psycopg2.cursor
+            The psycopg2 cursor
+
+        Raises
+        ------
+        RuntimeError
+            if the cursor cannot be created
+        """
+        self._open_connection()
+
+        try:
+            with self._connection.cursor(cursor_factory=DictCursor) as cur:
+                yield cur
+        except PostgresError as e:
+            raise RuntimeError("Cannot get postgres cursor: %s" % e)
+
+    def __enter__(self):
+        self._open_connection()
+        self._contexts_entered += 1
+        return self
+
+    def _clean_up(self, exc_type):
+        if exc_type is not None:
+            # An exception occurred during the execution of the transaction
+            # Make sure that we leave the DB w/o any modification
+            self.rollback()
+        elif self._queries:
+            # There are still queries to be executed, execute them
+            # It is safe to use the execute method here, as internally is
+            # wrapped in a try/except and rollbacks in case of failure
+            self.execute()
+            self.commit()
+        elif self._connection.get_transaction_status() != \
+                TRANSACTION_STATUS_IDLE:
+            # There are no queries to be executed, however, the transaction
+            # is still not committed. Commit it so the changes are not lost
+            self.commit()
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        # We only need to perform some action if this is the last context
+        # that we are entering
+        if self._contexts_entered == 1:
+            # We need to wrap the entire function in a try/finally because
+            # at the end we need to decrement _contexts_entered
+            try:
+                self._clean_up(exc_type)
+            finally:
+                self._contexts_entered -= 1
+        else:
+            self._contexts_entered -= 1
+
+    def _raise_execution_error(self, sql, sql_args, error):
+        """Rollbacks the current transaction and raises a useful error
+        The error message contains the name of the transaction, the failed
+        query, the arguments of the failed query and the error generated.
+
+        Raises
+        ------
+        ValueError
+        """
+        self.rollback()
+
+        try:
+            ec_lu = errorcodes.lookup(error.pgcode)
+            raise ValueError(
+                "Error running SQL: %s. MSG: %s\n" % (ec_lu, str(error)))
+        # the order of except statements is important, do not change
+        except (KeyError, AttributeError, TypeError) as error:
+            raise ValueError("Error running SQL query: %s" % str(error))
+        except ValueError as error:
+            raise ValueError("Error running SQL query: %s" % str(error))
+
+    @_checker
+    def add(self, sql, sql_args=None, many=False):
+        """Add a sql query to the transaction
+
+        Parameters
+        ----------
+        sql : str
+            The sql query
+        sql_args : list, tuple or dict of objects, optional
+            The arguments to the sql query
+        many : bool, optional
+            Whether or not we should add the query multiple times to the
+            transaction
+
+        Raises
+        ------
+        TypeError
+            If `sql_args` is provided and is not a list, tuple or dict
+        RuntimeError
+            If invoked outside a context
+
+        Notes
+        -----
+        If `many` is true, `sql_args` should be a list of lists, tuples or
+        dicts, in which each element of the list contains the parameters for
+        one SQL query of the many. Each element on the list is all the
+        parameters for a single one of the many queries added. The amount of
+        SQL queries added to the list is len(sql_args).
+        """
+        if not many:
+            sql_args = [sql_args]
+
+        for args in sql_args:
+            if args:
+                if not isinstance(args, (list, tuple, dict)):
+                    raise TypeError("sql_args should be a list, tuple or dict."
+                                    " Found %s" % type(args))
+            self._queries.append((sql, args))
+
+    def _execute(self):
+        """Internal function that actually executes the transaction
+        The `execute` function exposed in the API wraps this one to make sure
+        that we catch any exception that happens in here and we rollback the
+        transaction
+        """
+        with self._get_cursor() as cur:
+            for sql, sql_args in self._queries:
+                # Execute the current SQL command
+                try:
+                    cur.execute(sql, sql_args)
+                except Exception as e:
+                    # We catch any exception as we want to make sure that we
+                    # rollback every time that something went wrong
+                    self._raise_execution_error(sql, sql_args, e)
+
+                try:
+                    res = cur.fetchall()
+                except ProgrammingError:
+                    # At this execution point, we don't know if the sql query
+                    # that we executed should retrieve values from the database
+                    # If the query was not supposed to retrieve any value
+                    # (e.g. an INSERT without a RETURNING clause), it will
+                    # raise a ProgrammingError. Otherwise it will just return
+                    # an empty list
+                    res = None
+                except PostgresError as e:
+                    # Some other error happened during the execution of the
+                    # query, so we need to rollback
+                    self._raise_execution_error(sql, sql_args, e)
+
+                # Store the results of the current query
+                self._results.append(res)
+
+        # wipe out the already executed queries
+        self._queries = []
+
+        return self._results
+
+    @_checker
+    def execute(self):
+        """Executes the transaction
+
+        Returns
+        -------
+        list of DictCursor
+            The results of all the SQL queries in the transaction
+
+        Raises
+        ------
+        RuntimeError
+            If invoked outside a context
+
+        Notes
+        -----
+        If any exception occurs during the execution transaction, a rollback
+        is executed and no changes are reflected in the database.
+        When calling execute, the transaction will never be committed, it will
+        be automatically committed when leaving the context
+
+        See Also
+        --------
+        execute_fetchlast
+        execute_fetchindex
+        execute_fetchflatten
+        """
+        try:
+            return self._execute()
+        except Exception:
+            self.rollback()
+            raise
+
+    @_checker
+    def execute_fetchlast(self):
+        """Executes the transaction and returns the last result
+
+        This is a convenient function that is equivalent to
+        `self.execute()[-1][0][0]`
+
+        Returns
+        -------
+        object
+            The first value of the last SQL query executed
+
+        See Also
+        --------
+        execute
+        execute_fetchindex
+        execute_fetchflatten
+        """
+        return self.execute()[-1][0][0]
+
+    @_checker
+    def execute_fetchindex(self, idx=-1):
+        """Executes the transaction and returns the results of the `idx` query
+
+        This is a convenient function that is equivalent to
+        `self.execute()[idx]
+
+        Parameters
+        ----------
+        idx : int, optional
+            The index of the query to return the result. It defaults to -1, the
+            last query.
+
+        Returns
+        -------
+        DictCursor
+            The results of the `idx` query in the transaction
+
+        See Also
+        --------
+        execute
+        execute_fetchlast
+        execute_fetchflatten
+        """
+        return self.execute()[idx]
+
+    @_checker
+    def execute_fetchflatten(self, idx=-1):
+        """Executes the transaction and returns the flattened results of the
+        `idx` query
+
+        This is a convenient function that is equivalent to
+        `chain.from_iterable(self.execute()[idx])`
+
+        Parameters
+        ----------
+        idx : int, optional
+            The index of the query to return the result. It defaults to -1, the
+            last query.
+
+        Returns
+        -------
+        list of objects
+            The flattened results of the `idx` query
+
+        See Also
+        --------
+        execute
+        execute_fetchlast
+        execute_fetchindex
+        """
+        return list(chain.from_iterable(self.execute()[idx]))
+
+    def _funcs_executor(self, funcs, func_str):
+        error_msg = []
+        for f, args, kwargs in funcs:
+            try:
+                f(*args, **kwargs)
+            except Exception as e:
+                error_msg.append(str(e))
+        # The functions in these two lines are mutually exclusive. When one of
+        # them is executed, we can restore both of them.
+        self._post_commit_funcs = []
+        self._post_rollback_funcs = []
+        if error_msg:
+            raise RuntimeError(
+                "An error occurred during the post %s commands:\n%s"
+                % (func_str, "\n".join(error_msg)))
+
+    @_checker
+    def commit(self):
+        """Commits the transaction and reset the queries
+
+        Raises
+        ------
+        RuntimeError
+            If invoked outside a context
+        """
+        # Reset the queries, the results and the index
+        self._queries = []
+        self._results = []
+        try:
+            self._connection.commit()
+        except Exception:
+            self._connection.close()
+            raise
+        # Execute the post commit functions
+        self._funcs_executor(self._post_commit_funcs, "commit")
+
+    @_checker
+    def rollback(self):
+        """Rollbacks the transaction and reset the queries
+
+        Raises
+        ------
+        RuntimeError
+            If invoked outside a context
+        """
+        # Reset the queries, the results and the index
+        self._queries = []
+        self._results = []
+
+        if self._connection is not None and self._connection.closed == 0:
+            try:
+                self._connection.rollback()
+            except Exception:
+                self._connection.close()
+                raise
+        # Execute the post rollback functions
+        self._funcs_executor(self._post_rollback_funcs, "rollback")
+
+    @property
+    def index(self):
+        return len(self._queries) + len(self._results)
+
+    @_checker
+    def add_post_commit_func(self, func, *args, **kwargs):
+        """Adds a post commit function
+
+        The function added will be executed after the next commit in the
+        transaction, unless a rollback is executed. This is useful, for
+        example, to perform some filesystem clean up once the transaction is
+        committed.
+
+        Parameters
+        ----------
+        func : function
+            The function to add for the post commit functions
+        args : tuple
+            The arguments of the function
+        kwargs : dict
+            The keyword arguments of the function
+        """
+        self._post_commit_funcs.append((func, args, kwargs))
+
+    @_checker
+    def add_post_rollback_func(self, func, *args, **kwargs):
+        """Adds a post rollback function
+
+        The function added will be executed after the next rollback in the
+        transaction, unless a commit is executed. This is useful, for example,
+        to restore the filesystem in case a rollback occurs, avoiding leaving
+        the database and the filesystem in an out of sync state.
+
+        Parameters
+        ----------
+        func : function
+            The function to add for the post rollback functions
+        args : tuple
+            The arguments of the function
+        kwargs : dict
+            The keyword arguments of the function
+        """
+        self._post_rollback_funcs.append((func, args, kwargs))
+
+
+# Singleton pattern, create the transaction for the entire system
+TRN = Transaction()
+TRNADMIN = Transaction(admin=True)
+
+
+def perform_as_transaction(sql, parameters=None):
+    """Opens, adds and executes sql as a single transaction
+
+    Parameters
+    ----------
+    sql : str
+        The SQL to execute
+    parameters: object, optional
+        The object of parameters to pass to the TRN.add command
+    """
+    with TRN:
+        if parameters:
+            TRN.add(sql, parameters)
+        else:
+            TRN.add(sql)
+        TRN.execute()
+
+
+def create_new_transaction():
+    """Creates a new global transaction
+
+    This is needed when using multiprocessing
+    """
+    global TRN
+    TRN = Transaction()