diff --git a/ceilometer/storage/impl_sqlalchemy.py b/ceilometer/storage/impl_sqlalchemy.py index 5c1e4fbc2..107e4fcd0 100644 --- a/ceilometer/storage/impl_sqlalchemy.py +++ b/ceilometer/storage/impl_sqlalchemy.py @@ -141,15 +141,16 @@ class Connection(base.Connection): def __init__(self, conf): url = conf.database.connection if url == 'sqlite://': - url = os.environ.get('CEILOMETER_TEST_SQL_URL', url) - LOG.info('connecting to %s', url) - self.session = sqlalchemy_session.get_session() + conf.database.connection = \ + os.environ.get('CEILOMETER_TEST_SQL_URL', url) def upgrade(self, version=None): - migration.db_sync(self.session.get_bind(), version=version) + session = sqlalchemy_session.get_session() + migration.db_sync(session.get_bind(), version=version) def clear(self): - engine = self.session.get_bind() + session = sqlalchemy_session.get_session() + engine = session.get_bind() for table in reversed(Base.metadata.sorted_tables): engine.execute(table.delete()) @@ -159,56 +160,57 @@ class Connection(base.Connection): :param data: a dictionary such as returned by ceilometer.meter.meter_message_from_counter """ - if data['source']: - source = self.session.query(Source).get(data['source']) - if not source: - source = Source(id=data['source']) - self.session.add(source) - else: - source = None + session = sqlalchemy_session.get_session() + with session.begin(): + if data['source']: + source = session.query(Source).get(data['source']) + if not source: + source = Source(id=data['source']) + session.add(source) + else: + source = None - # create/update user && project, add/update their sources list - if data['user_id']: - user = self.session.merge(User(id=str(data['user_id']))) - if not filter(lambda x: x.id == source.id, user.sources): - user.sources.append(source) - else: - user = None + # create/update user && project, add/update their sources list + if data['user_id']: + user = session.merge(User(id=str(data['user_id']))) + if not filter(lambda x: x.id == source.id, user.sources): + user.sources.append(source) + else: + user = None - if data['project_id']: - project = self.session.merge(Project(id=str(data['project_id']))) - if not filter(lambda x: x.id == source.id, project.sources): - project.sources.append(source) - else: - project = None + if data['project_id']: + project = session.merge(Project(id=str(data['project_id']))) + if not filter(lambda x: x.id == source.id, project.sources): + project.sources.append(source) + else: + project = None - # Record the updated resource metadata - rmetadata = data['resource_metadata'] + # Record the updated resource metadata + rmetadata = data['resource_metadata'] - resource = self.session.merge(Resource(id=str(data['resource_id']))) - if not filter(lambda x: x.id == source.id, resource.sources): - resource.sources.append(source) - resource.project = project - resource.user = user - # Current metadata being used and when it was last updated. - resource.resource_metadata = rmetadata - # Autoflush didn't catch this one, requires manual flush. - self.session.flush() + resource = session.merge(Resource(id=str(data['resource_id']))) + if not filter(lambda x: x.id == source.id, resource.sources): + resource.sources.append(source) + resource.project = project + resource.user = user + # Current metadata being used and when it was last updated. + resource.resource_metadata = rmetadata - # Record the raw data for the meter. - meter = Meter(counter_type=data['counter_type'], - counter_unit=data['counter_unit'], - counter_name=data['counter_name'], resource=resource) - self.session.add(meter) - if not filter(lambda x: x.id == source.id, meter.sources): - meter.sources.append(source) - meter.project = project - meter.user = user - meter.timestamp = data['timestamp'] - meter.resource_metadata = rmetadata - meter.counter_volume = data['counter_volume'] - meter.message_signature = data['message_signature'] - meter.message_id = data['message_id'] + # Record the raw data for the meter. + meter = Meter(counter_type=data['counter_type'], + counter_unit=data['counter_unit'], + counter_name=data['counter_name'], resource=resource) + session.add(meter) + if not filter(lambda x: x.id == source.id, meter.sources): + meter.sources.append(source) + meter.project = project + meter.user = user + meter.timestamp = data['timestamp'] + meter.resource_metadata = rmetadata + meter.counter_volume = data['counter_volume'] + meter.message_signature = data['message_signature'] + meter.message_id = data['message_id'] + session.flush() return @@ -217,7 +219,8 @@ class Connection(base.Connection): :param source: Optional source filter. """ - query = self.session.query(User.id) + session = sqlalchemy_session.get_session() + query = session.query(User.id) if source is not None: query = query.filter(User.sources.any(id=source)) return (x[0] for x in query.all()) @@ -227,7 +230,8 @@ class Connection(base.Connection): :param source: Optional source filter. """ - query = self.session.query(Project.id) + session = sqlalchemy_session.get_session() + query = session.query(Project.id) if source: query = query.filter(Project.sources.any(id=source)) return (x[0] for x in query.all()) @@ -245,7 +249,8 @@ class Connection(base.Connection): :param metaquery: Optional dict with metadata to match on. :param resource: Optional resource filter. """ - query = self.session.query(Meter,).group_by(Meter.resource_id) + session = sqlalchemy_session.get_session() + query = session.query(Meter,).group_by(Meter.resource_id) if user is not None: query = query.filter(Meter.user_id == user) if source is not None: @@ -288,7 +293,8 @@ class Connection(base.Connection): :param source: Optional source filter. :param metaquery: Optional dict with metadata to match on. """ - query = self.session.query(Resource) + session = sqlalchemy_session.get_session() + query = session.query(Resource) if user is not None: query = query.filter(Resource.user_id == user) if source is not None: @@ -327,7 +333,8 @@ class Connection(base.Connection): if limit == 0: return - query = self.session.query(Meter) + session = sqlalchemy_session.get_session() + query = session.query(Meter) query = make_query_from_filter(query, sample_filter, require_meter=False) if limit: @@ -358,15 +365,17 @@ class Connection(base.Connection): def _make_volume_query(self, sample_filter, counter_volume_func): """Returns complex Meter counter_volume query for max and sum.""" - subq = self.session.query(Meter.id) + session = sqlalchemy_session.get_session() + subq = session.query(Meter.id) subq = make_query_from_filter(subq, sample_filter, require_meter=False) subq = subq.subquery() - mainq = self.session.query(Resource.id, counter_volume_func) + mainq = session.query(Resource.id, counter_volume_func) mainq = mainq.join(Meter).group_by(Resource.id) return mainq.filter(Meter.id.in_(subq)) def _make_stats_query(self, sample_filter): - query = self.session.query( + session = sqlalchemy_session.get_session() + query = session.query( func.min(Meter.timestamp).label('tsmin'), func.max(Meter.timestamp).label('tsmax'), func.avg(Meter.counter_volume).label('avg'), @@ -469,7 +478,8 @@ class Connection(base.Connection): :param enabled: Optional boolean to list disable alarm. :param alarm_id: Optional alarm_id to return one alarm. """ - query = self.session.query(Alarm) + session = sqlalchemy_session.get_session() + query = session.query(Alarm) if name is not None: query = query.filter(Alarm.name == name) if enabled is not None: @@ -488,17 +498,19 @@ class Connection(base.Connection): :param alarm: the new Alarm to update """ - if alarm.alarm_id: - alarm_row = self.session.merge(Alarm(id=alarm.alarm_id)) - self._alarm_model_to_row(alarm, alarm_row) - else: - self.session.merge(User(id=alarm.user_id)) - self.session.merge(Project(id=alarm.project_id)) + session = sqlalchemy_session.get_session() + with session.begin(): + if alarm.alarm_id: + alarm_row = session.merge(Alarm(id=alarm.alarm_id)) + self._alarm_model_to_row(alarm, alarm_row) + else: + session.merge(User(id=alarm.user_id)) + session.merge(Project(id=alarm.project_id)) - alarm_row = self._alarm_model_to_row(alarm) - self.session.add(alarm_row) + alarm_row = self._alarm_model_to_row(alarm) + session.add(alarm_row) - self.session.flush() + session.flush() return self._row_to_alarm_model(alarm_row) def delete_alarm(self, alarm_id): @@ -506,32 +518,37 @@ class Connection(base.Connection): :param alarm_id: ID of the alarm to delete """ - self.session.query(Alarm).filter(Alarm.id == alarm_id).delete() - self.session.flush() + session = sqlalchemy_session.get_session() + with session.begin(): + session.query(Alarm).filter(Alarm.id == alarm_id).delete() + session.flush() - def _get_unique(self, key): - return self.session.query(UniqueName)\ - .filter(UniqueName.key == key).first() + def _get_unique(self, session, key): + return session.query(UniqueName).filter(UniqueName.key == key).first() - def _get_or_create_unique_name(self, key): + def _get_or_create_unique_name(self, key, session=None): """Find the UniqueName entry for a given key, creating one if necessary. This may result in a flush. """ - unique = self._get_unique(key) - if not unique: - unique = UniqueName(key=key) - self.session.add(unique) - self.session.flush() + if session is None: + session = sqlalchemy_session.get_session() + with session.begin(subtransactions=True): + unique = self._get_unique(session, key) + if not unique: + unique = UniqueName(key=key) + session.add(unique) + session.flush() return unique - def _make_trait(self, trait_model, event): + def _make_trait(self, trait_model, event, session=None): """Make a new Trait from a Trait model. Doesn't flush or add to session. """ - name = self._get_or_create_unique_name(trait_model.name) + name = self._get_or_create_unique_name(trait_model.name, + session=session) value_map = Trait._value_map values = {'t_string': None, 't_float': None, 't_int': None, 't_datetime': None} @@ -541,21 +558,23 @@ class Connection(base.Connection): values[value_map[trait_model.dtype]] = value return Trait(name, event, trait_model.dtype, **values) - def _record_event(self, event_model): + def _record_event(self, session, event_model): """Store a single Event, including related Traits. """ - unique = self._get_or_create_unique_name(event_model.event_name) + with session.begin(subtransactions=True): + unique = self._get_or_create_unique_name(event_model.event_name, + session=session) - generated = utils.dt_to_decimal(event_model.generated) - event = Event(unique, generated) - self.session.add(event) + generated = utils.dt_to_decimal(event_model.generated) + event = Event(unique, generated) + session.add(event) - new_traits = [] - if event_model.traits: - for trait in event_model.traits: - t = self._make_trait(trait, event) - self.session.add(t) - new_traits.append(t) + new_traits = [] + if event_model.traits: + for trait in event_model.traits: + t = self._make_trait(trait, event, session=session) + session.add(t) + new_traits.append(t) # Note: we don't flush here, explicitly (unless a new uniquename # does it). Otherwise, just wait until all the Events are staged. @@ -569,10 +588,11 @@ class Connection(base.Connection): Flush when they're all added, unless new UniqueNames are added along the way. """ - events = [self._record_event(event_model) - for event_model in event_models] - - self.session.flush() + session = sqlalchemy_session.get_session() + with session.begin(): + events = [self._record_event(session, event_model) + for event_model in event_models] + session.flush() # Update the models with the underlying DB ID. for model, actual in zip(event_models, events): @@ -590,46 +610,49 @@ class Connection(base.Connection): start = utils.dt_to_decimal(event_filter.start) end = utils.dt_to_decimal(event_filter.end) - sub_query = self.session.query(Event.id)\ - .join(Trait, Trait.event_id == Event.id)\ - .filter(Event.generated >= start, Event.generated <= end) + session = sqlalchemy_session.get_session() + with session.begin(): + sub_query = session.query(Event.id)\ + .join(Trait, Trait.event_id == Event.id)\ + .filter(Event.generated >= start, Event.generated <= end) - if event_filter.event_name: - event_name = self._get_unique(event_filter.event_name) - sub_query = sub_query.filter(Event.unique_name == event_name) + if event_filter.event_name: + event_name = self._get_unique(session, event_filter.event_name) + sub_query = sub_query.filter(Event.unique_name == event_name) - if event_filter.traits: - for key, value in event_filter.traits.iteritems(): - if key == 'key': - key = self._get_unique(value) - sub_query = sub_query.filter(Trait.name == key) - elif key == 't_string': - sub_query = sub_query.filter(Trait.t_string == value) - elif key == 't_int': - sub_query = sub_query.filter(Trait.t_int == value) - elif key == 't_datetime': - dt = utils.dt_to_decimal(value) - sub_query = sub_query.filter(Trait.t_datetime == dt) - elif key == 't_float': - sub_query = sub_query.filter(Trait.t_datetime == value) + if event_filter.traits: + for key, value in event_filter.traits.iteritems(): + if key == 'key': + key = self._get_unique(session, value) + sub_query = sub_query.filter(Trait.name == key) + elif key == 't_string': + sub_query = sub_query.filter(Trait.t_string == value) + elif key == 't_int': + sub_query = sub_query.filter(Trait.t_int == value) + elif key == 't_datetime': + dt = utils.dt_to_decimal(value) + sub_query = sub_query.filter(Trait.t_datetime == dt) + elif key == 't_float': + sub_query = sub_query.filter(Trait.t_datetime == value) - sub_query = sub_query.subquery() + sub_query = sub_query.subquery() - all_data = self.session.query(Trait)\ - .join(sub_query, Trait.event_id == sub_query.c.id) + all_data = session.query(Trait)\ + .join(sub_query, Trait.event_id == sub_query.c.id) - # Now convert the sqlalchemy objects back into Models ... - event_models_dict = {} - for trait in all_data.all(): - event = event_models_dict.get(trait.event_id) - if not event: - generated = utils.decimal_to_dt(trait.event.generated) - event = api_models.Event(trait.event.unique_name.key, - generated, []) - event_models_dict[trait.event_id] = event - value = trait.get_value() - trait_model = api_models.Trait(trait.name.key, trait.t_type, value) - event.append_trait(trait_model) + # Now convert the sqlalchemy objects back into Models ... + event_models_dict = {} + for trait in all_data.all(): + event = event_models_dict.get(trait.event_id) + if not event: + generated = utils.decimal_to_dt(trait.event.generated) + event = api_models.Event(trait.event.unique_name.key, + generated, []) + event_models_dict[trait.event_id] = event + value = trait.get_value() + trait_model = api_models.Trait(trait.name.key, trait.t_type, + value) + event.append_trait(trait_model) event_models = event_models_dict.values() return sorted(event_models, key=operator.attrgetter('generated')) diff --git a/tests/storage/test_impl_sqlalchemy.py b/tests/storage/test_impl_sqlalchemy.py index 9d095f5c2..027490def 100644 --- a/tests/storage/test_impl_sqlalchemy.py +++ b/tests/storage/test_impl_sqlalchemy.py @@ -80,13 +80,15 @@ class UniqueNameTest(base.EventTest, EventTestBase): u1 = self.conn._get_or_create_unique_name("foo") self.assertTrue(u1.id >= 0) u2 = self.conn._get_or_create_unique_name("foo") - self.assertEqual(u1, u2) + self.assertEqual(u1.id, u2.id) + self.assertEqual(u1.key, u2.key) def test_new_unique(self): u1 = self.conn._get_or_create_unique_name("foo") self.assertTrue(u1.id >= 0) u2 = self.conn._get_or_create_unique_name("blah") - self.assertNotEqual(u1, u2) + self.assertNotEqual(u1.id, u2.id) + self.assertNotEqual(u1.key, u2.key) class EventTest(base.EventTest, EventTestBase):