[973924]: / qiita_db / sql_connection.py

Download this file

524 lines (445 with data), 17.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
r"""
SQL Connection object (:mod:`qiita_db.sql_connection`)
======================================================
.. currentmodule:: qiita_db.sql_connection
This modules provides wrappers for the psycopg2 module to allow easy use of
transaction blocks and SQL execution/data retrieval.
This module provides the variable TRN, which is the transaction available
to use in the system. The singleton pattern is applied and this works as long
as the system remains single-threaded.
Classes
-------
.. autosummary::
:toctree: generated/
Transaction
"""
# -----------------------------------------------------------------------------
# Copyright (c) 2014--, The Qiita Development Team.
#
# Distributed under the terms of the BSD 3-clause License.
#
# The full license is in the file LICENSE, distributed with this software.
# -----------------------------------------------------------------------------
from contextlib import contextmanager
from itertools import chain
from functools import wraps
from psycopg2 import (connect, ProgrammingError, Error as PostgresError,
OperationalError, errorcodes)
from psycopg2.extras import DictCursor
from psycopg2.extensions import TRANSACTION_STATUS_IDLE
from qiita_core.qiita_settings import qiita_config
def _checker(func):
"""Decorator to check that methods are executed inside the context"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if self._contexts_entered == 0:
raise RuntimeError(
"Operation not permitted. Transaction methods can only be "
"invoked within the context manager.")
return func(self, *args, **kwargs)
return wrapper
class Transaction(object):
"""A context manager that encapsulates a DB transaction
A transaction is defined by a series of consecutive queries that need to
be applied to the database as a single block.
Raises
------
RuntimeError
If the transaction methods are invoked outside a context.
Notes
-----
When the execution leaves the context manager, any remaining queries in
the transaction will be executed and committed.
"""
def __init__(self, admin=False):
self._queries = []
self._results = []
self._contexts_entered = 0
self._connection = None
self._post_commit_funcs = []
self._post_rollback_funcs = []
self.admin = admin
def _open_connection(self):
# If the connection already exists and is not closed, don't do anything
if self._connection is not None and self._connection.closed == 0:
return
try:
if self.admin:
self._connection = connect(
user=qiita_config.admin_user,
password=qiita_config.admin_password,
host=qiita_config.host,
port=qiita_config.port)
self._connection.autocommit = True
else:
self._connection = connect(user=qiita_config.user,
password=qiita_config.password,
database=qiita_config.database,
host=qiita_config.host,
port=qiita_config.port)
except OperationalError as e:
# catch three known common exceptions and raise runtime errors
try:
etype = str(e).split(':')[1].split()[0]
except IndexError:
# we recieved a really unanticipated error without a colon
etype = ''
if etype == 'database':
etext = ('This is likely because the database `%s` has not '
'been created or has been dropped.' %
qiita_config.database)
elif etype == 'role':
etext = ('This is likely because the user string `%s` '
'supplied in your configuration file `%s` is '
'incorrect or not an authorized postgres user.' %
(qiita_config.user, qiita_config.conf_fp))
elif etype == 'Connection':
etext = ('This is likely because postgres isn\'t '
'running. Check that postgres is correctly '
'installed and is running.')
else:
# we recieved a really unanticipated error with a colon
etext = ''
ebase = ('An OperationalError with the following message occured'
'\n\n\t%s\n%s For more information, review `INSTALL.md`'
' in the Qiita installation base directory.')
raise RuntimeError(ebase % (str(e), etext))
def close(self):
if self._connection is not None:
self._connection.close()
@contextmanager
def _get_cursor(self):
"""Returns a postgres cursor
Returns
-------
psycopg2.cursor
The psycopg2 cursor
Raises
------
RuntimeError
if the cursor cannot be created
"""
self._open_connection()
try:
with self._connection.cursor(cursor_factory=DictCursor) as cur:
yield cur
except PostgresError as e:
raise RuntimeError("Cannot get postgres cursor: %s" % e)
def __enter__(self):
self._open_connection()
self._contexts_entered += 1
return self
def _clean_up(self, exc_type):
if exc_type is not None:
# An exception occurred during the execution of the transaction
# Make sure that we leave the DB w/o any modification
self.rollback()
elif self._queries:
# There are still queries to be executed, execute them
# It is safe to use the execute method here, as internally is
# wrapped in a try/except and rollbacks in case of failure
self.execute()
self.commit()
elif self._connection.get_transaction_status() != \
TRANSACTION_STATUS_IDLE:
# There are no queries to be executed, however, the transaction
# is still not committed. Commit it so the changes are not lost
self.commit()
def __exit__(self, exc_type, exc_value, traceback):
# We only need to perform some action if this is the last context
# that we are entering
if self._contexts_entered == 1:
# We need to wrap the entire function in a try/finally because
# at the end we need to decrement _contexts_entered
try:
self._clean_up(exc_type)
finally:
self._contexts_entered -= 1
else:
self._contexts_entered -= 1
def _raise_execution_error(self, sql, sql_args, error):
"""Rollbacks the current transaction and raises a useful error
The error message contains the name of the transaction, the failed
query, the arguments of the failed query and the error generated.
Raises
------
ValueError
"""
self.rollback()
try:
ec_lu = errorcodes.lookup(error.pgcode)
raise ValueError(
"Error running SQL: %s. MSG: %s\n" % (ec_lu, str(error)))
# the order of except statements is important, do not change
except (KeyError, AttributeError, TypeError) as error:
raise ValueError("Error running SQL query: %s" % str(error))
except ValueError as error:
raise ValueError("Error running SQL query: %s" % str(error))
@_checker
def add(self, sql, sql_args=None, many=False):
"""Add a sql query to the transaction
Parameters
----------
sql : str
The sql query
sql_args : list, tuple or dict of objects, optional
The arguments to the sql query
many : bool, optional
Whether or not we should add the query multiple times to the
transaction
Raises
------
TypeError
If `sql_args` is provided and is not a list, tuple or dict
RuntimeError
If invoked outside a context
Notes
-----
If `many` is true, `sql_args` should be a list of lists, tuples or
dicts, in which each element of the list contains the parameters for
one SQL query of the many. Each element on the list is all the
parameters for a single one of the many queries added. The amount of
SQL queries added to the list is len(sql_args).
"""
if not many:
sql_args = [sql_args]
for args in sql_args:
if args:
if not isinstance(args, (list, tuple, dict)):
raise TypeError("sql_args should be a list, tuple or dict."
" Found %s" % type(args))
self._queries.append((sql, args))
def _execute(self):
"""Internal function that actually executes the transaction
The `execute` function exposed in the API wraps this one to make sure
that we catch any exception that happens in here and we rollback the
transaction
"""
with self._get_cursor() as cur:
for sql, sql_args in self._queries:
# Execute the current SQL command
try:
cur.execute(sql, sql_args)
except Exception as e:
# We catch any exception as we want to make sure that we
# rollback every time that something went wrong
self._raise_execution_error(sql, sql_args, e)
try:
res = cur.fetchall()
except ProgrammingError:
# At this execution point, we don't know if the sql query
# that we executed should retrieve values from the database
# If the query was not supposed to retrieve any value
# (e.g. an INSERT without a RETURNING clause), it will
# raise a ProgrammingError. Otherwise it will just return
# an empty list
res = None
except PostgresError as e:
# Some other error happened during the execution of the
# query, so we need to rollback
self._raise_execution_error(sql, sql_args, e)
# Store the results of the current query
self._results.append(res)
# wipe out the already executed queries
self._queries = []
return self._results
@_checker
def execute(self):
"""Executes the transaction
Returns
-------
list of DictCursor
The results of all the SQL queries in the transaction
Raises
------
RuntimeError
If invoked outside a context
Notes
-----
If any exception occurs during the execution transaction, a rollback
is executed and no changes are reflected in the database.
When calling execute, the transaction will never be committed, it will
be automatically committed when leaving the context
See Also
--------
execute_fetchlast
execute_fetchindex
execute_fetchflatten
"""
try:
return self._execute()
except Exception:
self.rollback()
raise
@_checker
def execute_fetchlast(self):
"""Executes the transaction and returns the last result
This is a convenient function that is equivalent to
`self.execute()[-1][0][0]`
Returns
-------
object
The first value of the last SQL query executed
See Also
--------
execute
execute_fetchindex
execute_fetchflatten
"""
return self.execute()[-1][0][0]
@_checker
def execute_fetchindex(self, idx=-1):
"""Executes the transaction and returns the results of the `idx` query
This is a convenient function that is equivalent to
`self.execute()[idx]
Parameters
----------
idx : int, optional
The index of the query to return the result. It defaults to -1, the
last query.
Returns
-------
DictCursor
The results of the `idx` query in the transaction
See Also
--------
execute
execute_fetchlast
execute_fetchflatten
"""
return self.execute()[idx]
@_checker
def execute_fetchflatten(self, idx=-1):
"""Executes the transaction and returns the flattened results of the
`idx` query
This is a convenient function that is equivalent to
`chain.from_iterable(self.execute()[idx])`
Parameters
----------
idx : int, optional
The index of the query to return the result. It defaults to -1, the
last query.
Returns
-------
list of objects
The flattened results of the `idx` query
See Also
--------
execute
execute_fetchlast
execute_fetchindex
"""
return list(chain.from_iterable(self.execute()[idx]))
def _funcs_executor(self, funcs, func_str):
error_msg = []
for f, args, kwargs in funcs:
try:
f(*args, **kwargs)
except Exception as e:
error_msg.append(str(e))
# The functions in these two lines are mutually exclusive. When one of
# them is executed, we can restore both of them.
self._post_commit_funcs = []
self._post_rollback_funcs = []
if error_msg:
raise RuntimeError(
"An error occurred during the post %s commands:\n%s"
% (func_str, "\n".join(error_msg)))
@_checker
def commit(self):
"""Commits the transaction and reset the queries
Raises
------
RuntimeError
If invoked outside a context
"""
# Reset the queries, the results and the index
self._queries = []
self._results = []
try:
self._connection.commit()
except Exception:
self._connection.close()
raise
# Execute the post commit functions
self._funcs_executor(self._post_commit_funcs, "commit")
@_checker
def rollback(self):
"""Rollbacks the transaction and reset the queries
Raises
------
RuntimeError
If invoked outside a context
"""
# Reset the queries, the results and the index
self._queries = []
self._results = []
if self._connection is not None and self._connection.closed == 0:
try:
self._connection.rollback()
except Exception:
self._connection.close()
raise
# Execute the post rollback functions
self._funcs_executor(self._post_rollback_funcs, "rollback")
@property
def index(self):
return len(self._queries) + len(self._results)
@_checker
def add_post_commit_func(self, func, *args, **kwargs):
"""Adds a post commit function
The function added will be executed after the next commit in the
transaction, unless a rollback is executed. This is useful, for
example, to perform some filesystem clean up once the transaction is
committed.
Parameters
----------
func : function
The function to add for the post commit functions
args : tuple
The arguments of the function
kwargs : dict
The keyword arguments of the function
"""
self._post_commit_funcs.append((func, args, kwargs))
@_checker
def add_post_rollback_func(self, func, *args, **kwargs):
"""Adds a post rollback function
The function added will be executed after the next rollback in the
transaction, unless a commit is executed. This is useful, for example,
to restore the filesystem in case a rollback occurs, avoiding leaving
the database and the filesystem in an out of sync state.
Parameters
----------
func : function
The function to add for the post rollback functions
args : tuple
The arguments of the function
kwargs : dict
The keyword arguments of the function
"""
self._post_rollback_funcs.append((func, args, kwargs))
# Singleton pattern, create the transaction for the entire system
TRN = Transaction()
TRNADMIN = Transaction(admin=True)
def perform_as_transaction(sql, parameters=None):
"""Opens, adds and executes sql as a single transaction
Parameters
----------
sql : str
The SQL to execute
parameters: object, optional
The object of parameters to pass to the TRN.add command
"""
with TRN:
if parameters:
TRN.add(sql, parameters)
else:
TRN.add(sql)
TRN.execute()
def create_new_transaction():
"""Creates a new global transaction
This is needed when using multiprocessing
"""
global TRN
TRN = Transaction()