diff --git a/vmware_nsx/plugins/nsx_v3/utils.py b/vmware_nsx/plugins/nsx_v3/utils.py index 685b4b7f7d..3f8d5e4a41 100644 --- a/vmware_nsx/plugins/nsx_v3/utils.py +++ b/vmware_nsx/plugins/nsx_v3/utils.py @@ -14,7 +14,6 @@ # under the License. import os import random -import threading from oslo_config import cfg from oslo_log import log as logging @@ -37,24 +36,19 @@ LOG = logging.getLogger(__name__) class DbCertProvider(client_cert.ClientCertProvider): """Write cert data from DB to file and delete after use - New file with random filename is created for each thread. This - is not most efficient, but the safest way to avoid race conditions, + New provider object with random filename is created for each request. + This is not most efficient, but the safest way to avoid race conditions, since backend connections can occur both before and after neutron - fork. + fork, and several concurrent requests can occupy the same thread. + Note that new cert filename for each request does not result in new + connection for each request (at least for now..) """ EXPIRATION_ALERT_DAYS = 30 # days prior to expiration - _thread_local = threading.local() def __init__(self): - # Note: we cannot initialize filename here, because this call - # happens before neutron fork, meaning variable initialized here - # will be shared between all neutron processes (which will cause file - # collisions). - # The file can be shared between different connections within same - # process, if they happen to do the SSL handshake simultaneously. - # Such collisions are handled with refcount and locking. super(DbCertProvider, self).__init__(None) random.seed() + self._filename = '/tmp/.' + str(random.randint(1, 10000000)) def _check_expiration(self, expires_in_days): if expires_in_days > self.EXPIRATION_ALERT_DAYS: @@ -69,11 +63,6 @@ class DbCertProvider(client_cert.ClientCertProvider): expires_in_days) def __enter__(self): - # No certificate file available - need to create one - # Choose a random filename to contain the certificate - self._thread_local._filename = '/tmp/.' + str( - random.randint(1, 10000000)) - try: context = q_context.get_admin_context() db_storage_driver = cert_utils.DbCertificateStorageDriver( @@ -86,7 +75,7 @@ class DbCertProvider(client_cert.ClientCertProvider): msg = _("Unable to load from nsx-db") raise nsx_exc.ClientCertificateException(err_msg=msg) - filename = self._thread_local._filename + filename = self._filename if not os.path.exists(os.path.dirname(filename)): if len(os.path.dirname(filename)) > 0: os.makedirs(os.path.dirname(filename)) @@ -98,21 +87,19 @@ class DbCertProvider(client_cert.ClientCertProvider): self._on_exit() raise e - LOG.debug("Prepared client certificate file") return self def _on_exit(self): - if os.path.isfile(self._thread_local._filename): - os.remove(self._thread_local._filename) - LOG.debug("Deleted client certificate file") + if os.path.isfile(self._filename): + os.remove(self._filename) - self._thread_local._filename = None + self._filename = None def __exit__(self, type, value, traceback): self._on_exit() def filename(self): - return self._thread_local._filename + return self._filename def get_client_cert_provider(): @@ -128,8 +115,7 @@ def get_client_cert_provider(): if cfg.CONF.nsx_v3.nsx_client_cert_storage.lower() == 'nsx-db': # Cert data is stored in DB, and written to file system only # when new connection is opened, and deleted immediately after. - # Pid is appended to avoid file collisions between neutron servers - return DbCertProvider() + return DbCertProvider def get_nsxlib_wrapper(nsx_username=None, nsx_password=None, basic_auth=False): diff --git a/vmware_nsx/tests/unit/nsx_v3/test_client_cert.py b/vmware_nsx/tests/unit/nsx_v3/test_client_cert.py index 3b416c4bdd..ecd79dfcbe 100644 --- a/vmware_nsx/tests/unit/nsx_v3/test_client_cert.py +++ b/vmware_nsx/tests/unit/nsx_v3/test_client_cert.py @@ -92,7 +92,7 @@ class NsxV3ClientCertProviderTestCase(unittest.TestCase): def validate_db_provider(self, expected_cert_data): fname = None - with self._provider as p: + with self._provider() as p: # verify cert data was exported to CERTFILE fname = p.filename() with open(fname, 'r') as f: @@ -125,7 +125,7 @@ class NsxV3ClientCertProviderTestCase(unittest.TestCase): "vmware_nsx.db.db.get_certificate", return_value=(None, None)).start() self.assertRaises(nsx_exc.ClientCertificateException, - self._provider.__enter__) + self._provider().__enter__) # now verify return to normal after failure mock.patch( @@ -174,7 +174,7 @@ class NsxV3ClientCertProviderTestCase(unittest.TestCase): # since PK in DB is not encrypted, we should fail to decrypt it on load self.assertRaises(nsx_exc.ClientCertificateException, - self._provider.__enter__) + self._provider().__enter__) def test_basic_provider(self):