From 4761084e6535f0b18d80529fcb6467c39db4b0dd Mon Sep 17 00:00:00 2001 From: Michael Krotscheck Date: Tue, 21 Apr 2015 15:08:35 -0700 Subject: [PATCH] Added session parameter to various db_api methods. Various entity methods were automatically generating their session. This permits us to pass those session instances through, so it can be used in non-api plugins. Change-Id: Iebfd0d20b02d18317e8340dfe104c742ffa65d46 --- storyboard/db/api/base.py | 13 +++++++------ storyboard/db/api/subscriptions.py | 28 ++++++++++++++++------------ 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/storyboard/db/api/base.py b/storyboard/db/api/base.py index 63822d54..990976fb 100644 --- a/storyboard/db/api/base.py +++ b/storyboard/db/api/base.py @@ -188,7 +188,7 @@ def entity_get(kls, entity_id, filter_non_public=False, session=None): def entity_get_all(kls, filter_non_public=False, marker=None, limit=None, - sort_field='id', sort_dir='asc', **kwargs): + sort_field='id', sort_dir='asc', session=None, **kwargs): # Sanity checks, in case someone accidentally explicitly passes in 'None' if not sort_field: sort_field = 'id' @@ -196,7 +196,7 @@ def entity_get_all(kls, filter_non_public=False, marker=None, limit=None, sort_dir = 'asc' # Construct the query - query = model_query(kls) + query = model_query(kls, session) # Sanity check on input parameters query = apply_query_filters(query=query, model=kls, **kwargs) @@ -229,9 +229,9 @@ def entity_get_all(kls, filter_non_public=False, marker=None, limit=None, return entities -def entity_get_count(kls, **kwargs): +def entity_get_count(kls, session=None, **kwargs): # Construct the query - query = model_query(kls) + query = model_query(kls, session) # Sanity check on input parameters query = apply_query_filters(query=query, model=kls, **kwargs) @@ -329,8 +329,9 @@ def entity_update(kls, entity_id, values, session=None): return entity -def entity_hard_delete(kls, entity_id): - session = get_session() +def entity_hard_delete(kls, entity_id, session=None): + if not session: + session = get_session() try: with session.begin(subtransactions=True): diff --git a/storyboard/db/api/subscriptions.py b/storyboard/db/api/subscriptions.py index aee561ba..ec1c4cb4 100644 --- a/storyboard/db/api/subscriptions.py +++ b/storyboard/db/api/subscriptions.py @@ -16,8 +16,8 @@ from sqlalchemy import distinct from storyboard.db.api import base as api_base -import storyboard.db.api.timeline_events as timeline_api from storyboard.db import models +from storyboard.db.models import TimeLineEvent SUPPORTED_TYPES = { 'project': models.Project, @@ -64,7 +64,7 @@ def subscription_delete(subscription_id): api_base.entity_hard_delete(models.Subscription, subscription_id) -def subscription_get_all_subscriber_ids(resource, resource_id): +def subscription_get_all_subscriber_ids(resource, resource_id, session=None): '''Test subscription discovery. The tested algorithm is as follows: If you're subscribed to a project_group, you will be notified about @@ -93,7 +93,9 @@ def subscription_get_all_subscriber_ids(resource, resource_id): # If we accidentally pass a timeline_event, we're actually going to treat # it as a story. if resource == 'timeline_event': - event = timeline_api.event_get(resource_id) + event = api_base.entity_get(TimeLineEvent, + resource_id, + session=session) if event: resource = 'story' resource_id = event.story_id @@ -111,14 +113,14 @@ def subscription_get_all_subscriber_ids(resource, resource_id): # resource id remains pristine. if resource == 'story': # Get this story's tasks - query = api_base.model_query(models.Task.id) \ + query = api_base.model_query(models.Task.id, session=session) \ .filter(models.Task.story_id.in_(affected['story'])) affected['task'] = affected['task'] \ - .union(r for (r, ) in query.all()) + .union(r for (r,) in query.all()) elif resource == 'task': # Get this tasks's story - query = api_base.model_query(models.Task.story_id) \ + query = api_base.model_query(models.Task.story_id, session=session) \ .filter(models.Task.id == resource_id) affected['story'].add(query.first().story_id) @@ -126,32 +128,34 @@ def subscription_get_all_subscriber_ids(resource, resource_id): # If there are tasks, there will also be projects. if affected['task']: # Get all the tasks's projects - query = api_base.model_query(distinct(models.Task.project_id)) \ + query = api_base.model_query(distinct(models.Task.project_id), + session=session) \ .filter(models.Task.id.in_(affected['task'])) affected['project'] = affected['project'] \ - .union(r for (r, ) in query.all()) + .union(r for (r,) in query.all()) # If there are projects, there will also be project groups. if affected['project']: # Get all the projects' groups. query = api_base.model_query( - distinct(models.project_group_mapping.c.project_group_id)) \ + distinct(models.project_group_mapping.c.project_group_id), + session=session) \ .filter(models.project_group_mapping.c.project_id .in_(affected['project'])) affected['project_group'] = affected['project_group'] \ - .union(r for (r, ) in query.all()) + .union(r for (r,) in query.all()) # Load all subscribers. subscribers = set() for affected_type in affected: query = api_base.model_query(distinct( - models.Subscription.user_id)) \ + models.Subscription.user_id), session=session) \ .filter(models.Subscription.target_type == affected_type) \ .filter(models.Subscription.target_id.in_(affected[affected_type])) results = query.all() - subscribers = subscribers.union(r for (r, ) in results) + subscribers = subscribers.union(r for (r,) in results) return subscribers