diff --git a/storyboard/db/models.py b/storyboard/db/models.py index eb0f8103..3b5085c8 100644 --- a/storyboard/db/models.py +++ b/storyboard/db/models.py @@ -313,8 +313,10 @@ class Story(FullText, ModelBuilder, Base): description = Column(UnicodeText()) is_bug = Column(Boolean, default=True) private = Column(Boolean, default=False) - tasks = relationship('Task', backref='story') - events = relationship('TimeLineEvent', backref='story') + tasks = relationship('Task', backref='story', + cascade="all, delete-orphan") + events = relationship('TimeLineEvent', backref='story', + cascade="all, delete-orphan") tags = relationship('StoryTag', secondary='story_storytags') permissions = relationship('Permission', secondary='story_permissions') diff --git a/storyboard/tests/db/api/test_stories.py b/storyboard/tests/db/api/test_stories.py index 6ddddbaa..93898ac8 100644 --- a/storyboard/tests/db/api/test_stories.py +++ b/storyboard/tests/db/api/test_stories.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from storyboard.db.api import stories +from storyboard.db.api import stories as stories_api +from storyboard.db.api import tasks as tasks_api +from storyboard.db.api import timeline_events as events_api from storyboard.tests.db import base @@ -28,7 +30,7 @@ class StoriesTest(base.BaseDbTestCase): } def test_create_story(self): - self._test_create(self.story_01, stories.story_create) + self._test_create(self.story_01, stories_api.story_create) def test_update_story(self): delta = { @@ -36,4 +38,34 @@ class StoriesTest(base.BaseDbTestCase): 'description': u'New Description' } self._test_update(self.story_01, delta, - stories.story_create, stories.story_update) + stories_api.story_create, stories_api.story_update) + + def test_delete_story(self): + # This test uses mock_data + story_id = 1 + # Verify that we can look up a story with tasks and events + story = stories_api.story_get_simple(story_id) + self.assertIsNotNone(story) + tasks = tasks_api.task_get_all(story_id=story_id) + self.assertEqual(len(tasks), 3) + task_ids = [t.id for t in tasks] + events = events_api.events_get_all(story_id=story_id) + self.assertEqual(len(events), 3) + event_ids = [e.id for e in events] + + # Delete the story + stories_api.story_delete(story_id) + story = stories_api.story_get_simple(story_id) + self.assertIsNone(story) + # Verify that the story's tasks were deleted + tasks = tasks_api.task_get_all(story_id=story_id) + self.assertEqual(len(tasks), 0) + for tid in task_ids: + task = tasks_api.task_get(task_id=tid) + self.assertIsNone(task) + # And the events + events = events_api.events_get_all(story_id=story_id) + self.assertEqual(len(events), 0) + for eid in event_ids: + event = events_api.event_get(event_id=eid) + self.assertIsNone(event)