Added session parameter to various db_api methods.

Various entity methods were automatically generating their session.
This permits us to pass those session instances through, so it can be
used in non-api plugins.

Change-Id: Iebfd0d20b02d18317e8340dfe104c742ffa65d46
This commit is contained in:
Michael Krotscheck 2015-04-21 15:08:35 -07:00
parent 128191f6e5
commit 4761084e65
2 changed files with 23 additions and 18 deletions

View File

@ -188,7 +188,7 @@ def entity_get(kls, entity_id, filter_non_public=False, session=None):
def entity_get_all(kls, filter_non_public=False, marker=None, limit=None,
sort_field='id', sort_dir='asc', **kwargs):
sort_field='id', sort_dir='asc', session=None, **kwargs):
# Sanity checks, in case someone accidentally explicitly passes in 'None'
if not sort_field:
sort_field = 'id'
@ -196,7 +196,7 @@ def entity_get_all(kls, filter_non_public=False, marker=None, limit=None,
sort_dir = 'asc'
# Construct the query
query = model_query(kls)
query = model_query(kls, session)
# Sanity check on input parameters
query = apply_query_filters(query=query, model=kls, **kwargs)
@ -229,9 +229,9 @@ def entity_get_all(kls, filter_non_public=False, marker=None, limit=None,
return entities
def entity_get_count(kls, **kwargs):
def entity_get_count(kls, session=None, **kwargs):
# Construct the query
query = model_query(kls)
query = model_query(kls, session)
# Sanity check on input parameters
query = apply_query_filters(query=query, model=kls, **kwargs)
@ -329,8 +329,9 @@ def entity_update(kls, entity_id, values, session=None):
return entity
def entity_hard_delete(kls, entity_id):
session = get_session()
def entity_hard_delete(kls, entity_id, session=None):
if not session:
session = get_session()
try:
with session.begin(subtransactions=True):

View File

@ -16,8 +16,8 @@
from sqlalchemy import distinct
from storyboard.db.api import base as api_base
import storyboard.db.api.timeline_events as timeline_api
from storyboard.db import models
from storyboard.db.models import TimeLineEvent
SUPPORTED_TYPES = {
'project': models.Project,
@ -64,7 +64,7 @@ def subscription_delete(subscription_id):
api_base.entity_hard_delete(models.Subscription, subscription_id)
def subscription_get_all_subscriber_ids(resource, resource_id):
def subscription_get_all_subscriber_ids(resource, resource_id, session=None):
'''Test subscription discovery. The tested algorithm is as follows:
If you're subscribed to a project_group, you will be notified about
@ -93,7 +93,9 @@ def subscription_get_all_subscriber_ids(resource, resource_id):
# If we accidentally pass a timeline_event, we're actually going to treat
# it as a story.
if resource == 'timeline_event':
event = timeline_api.event_get(resource_id)
event = api_base.entity_get(TimeLineEvent,
resource_id,
session=session)
if event:
resource = 'story'
resource_id = event.story_id
@ -111,14 +113,14 @@ def subscription_get_all_subscriber_ids(resource, resource_id):
# resource id remains pristine.
if resource == 'story':
# Get this story's tasks
query = api_base.model_query(models.Task.id) \
query = api_base.model_query(models.Task.id, session=session) \
.filter(models.Task.story_id.in_(affected['story']))
affected['task'] = affected['task'] \
.union(r for (r, ) in query.all())
.union(r for (r,) in query.all())
elif resource == 'task':
# Get this tasks's story
query = api_base.model_query(models.Task.story_id) \
query = api_base.model_query(models.Task.story_id, session=session) \
.filter(models.Task.id == resource_id)
affected['story'].add(query.first().story_id)
@ -126,32 +128,34 @@ def subscription_get_all_subscriber_ids(resource, resource_id):
# If there are tasks, there will also be projects.
if affected['task']:
# Get all the tasks's projects
query = api_base.model_query(distinct(models.Task.project_id)) \
query = api_base.model_query(distinct(models.Task.project_id),
session=session) \
.filter(models.Task.id.in_(affected['task']))
affected['project'] = affected['project'] \
.union(r for (r, ) in query.all())
.union(r for (r,) in query.all())
# If there are projects, there will also be project groups.
if affected['project']:
# Get all the projects' groups.
query = api_base.model_query(
distinct(models.project_group_mapping.c.project_group_id)) \
distinct(models.project_group_mapping.c.project_group_id),
session=session) \
.filter(models.project_group_mapping.c.project_id
.in_(affected['project']))
affected['project_group'] = affected['project_group'] \
.union(r for (r, ) in query.all())
.union(r for (r,) in query.all())
# Load all subscribers.
subscribers = set()
for affected_type in affected:
query = api_base.model_query(distinct(
models.Subscription.user_id)) \
models.Subscription.user_id), session=session) \
.filter(models.Subscription.target_type == affected_type) \
.filter(models.Subscription.target_id.in_(affected[affected_type]))
results = query.all()
subscribers = subscribers.union(r for (r, ) in results)
subscribers = subscribers.union(r for (r,) in results)
return subscribers