Merge "NSXv3: Rewrite client certificate provider"

This commit is contained in:
Jenkins 2017-09-07 17:58:58 +00:00 committed by Gerrit Code Review
commit e0b64848a5
2 changed files with 15 additions and 29 deletions

View File

@ -14,7 +14,6 @@
# under the License. # under the License.
import os import os
import random import random
import threading
from oslo_config import cfg from oslo_config import cfg
from oslo_log import log as logging from oslo_log import log as logging
@ -37,24 +36,19 @@ LOG = logging.getLogger(__name__)
class DbCertProvider(client_cert.ClientCertProvider): class DbCertProvider(client_cert.ClientCertProvider):
"""Write cert data from DB to file and delete after use """Write cert data from DB to file and delete after use
New file with random filename is created for each thread. This New provider object with random filename is created for each request.
is not most efficient, but the safest way to avoid race conditions, This is not most efficient, but the safest way to avoid race conditions,
since backend connections can occur both before and after neutron since backend connections can occur both before and after neutron
fork. fork, and several concurrent requests can occupy the same thread.
Note that new cert filename for each request does not result in new
connection for each request (at least for now..)
""" """
EXPIRATION_ALERT_DAYS = 30 # days prior to expiration EXPIRATION_ALERT_DAYS = 30 # days prior to expiration
_thread_local = threading.local()
def __init__(self): def __init__(self):
# Note: we cannot initialize filename here, because this call
# happens before neutron fork, meaning variable initialized here
# will be shared between all neutron processes (which will cause file
# collisions).
# The file can be shared between different connections within same
# process, if they happen to do the SSL handshake simultaneously.
# Such collisions are handled with refcount and locking.
super(DbCertProvider, self).__init__(None) super(DbCertProvider, self).__init__(None)
random.seed() random.seed()
self._filename = '/tmp/.' + str(random.randint(1, 10000000))
def _check_expiration(self, expires_in_days): def _check_expiration(self, expires_in_days):
if expires_in_days > self.EXPIRATION_ALERT_DAYS: if expires_in_days > self.EXPIRATION_ALERT_DAYS:
@ -69,11 +63,6 @@ class DbCertProvider(client_cert.ClientCertProvider):
expires_in_days) expires_in_days)
def __enter__(self): def __enter__(self):
# No certificate file available - need to create one
# Choose a random filename to contain the certificate
self._thread_local._filename = '/tmp/.' + str(
random.randint(1, 10000000))
try: try:
context = q_context.get_admin_context() context = q_context.get_admin_context()
db_storage_driver = cert_utils.DbCertificateStorageDriver( db_storage_driver = cert_utils.DbCertificateStorageDriver(
@ -86,7 +75,7 @@ class DbCertProvider(client_cert.ClientCertProvider):
msg = _("Unable to load from nsx-db") msg = _("Unable to load from nsx-db")
raise nsx_exc.ClientCertificateException(err_msg=msg) raise nsx_exc.ClientCertificateException(err_msg=msg)
filename = self._thread_local._filename filename = self._filename
if not os.path.exists(os.path.dirname(filename)): if not os.path.exists(os.path.dirname(filename)):
if len(os.path.dirname(filename)) > 0: if len(os.path.dirname(filename)) > 0:
os.makedirs(os.path.dirname(filename)) os.makedirs(os.path.dirname(filename))
@ -98,21 +87,19 @@ class DbCertProvider(client_cert.ClientCertProvider):
self._on_exit() self._on_exit()
raise e raise e
LOG.debug("Prepared client certificate file")
return self return self
def _on_exit(self): def _on_exit(self):
if os.path.isfile(self._thread_local._filename): if os.path.isfile(self._filename):
os.remove(self._thread_local._filename) os.remove(self._filename)
LOG.debug("Deleted client certificate file")
self._thread_local._filename = None self._filename = None
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
self._on_exit() self._on_exit()
def filename(self): def filename(self):
return self._thread_local._filename return self._filename
def get_client_cert_provider(): def get_client_cert_provider():
@ -128,8 +115,7 @@ def get_client_cert_provider():
if cfg.CONF.nsx_v3.nsx_client_cert_storage.lower() == 'nsx-db': if cfg.CONF.nsx_v3.nsx_client_cert_storage.lower() == 'nsx-db':
# Cert data is stored in DB, and written to file system only # Cert data is stored in DB, and written to file system only
# when new connection is opened, and deleted immediately after. # when new connection is opened, and deleted immediately after.
# Pid is appended to avoid file collisions between neutron servers return DbCertProvider
return DbCertProvider()
def get_nsxlib_wrapper(nsx_username=None, nsx_password=None, basic_auth=False): def get_nsxlib_wrapper(nsx_username=None, nsx_password=None, basic_auth=False):

View File

@ -92,7 +92,7 @@ class NsxV3ClientCertProviderTestCase(unittest.TestCase):
def validate_db_provider(self, expected_cert_data): def validate_db_provider(self, expected_cert_data):
fname = None fname = None
with self._provider as p: with self._provider() as p:
# verify cert data was exported to CERTFILE # verify cert data was exported to CERTFILE
fname = p.filename() fname = p.filename()
with open(fname, 'r') as f: with open(fname, 'r') as f:
@ -125,7 +125,7 @@ class NsxV3ClientCertProviderTestCase(unittest.TestCase):
"vmware_nsx.db.db.get_certificate", "vmware_nsx.db.db.get_certificate",
return_value=(None, None)).start() return_value=(None, None)).start()
self.assertRaises(nsx_exc.ClientCertificateException, self.assertRaises(nsx_exc.ClientCertificateException,
self._provider.__enter__) self._provider().__enter__)
# now verify return to normal after failure # now verify return to normal after failure
mock.patch( mock.patch(
@ -174,7 +174,7 @@ class NsxV3ClientCertProviderTestCase(unittest.TestCase):
# since PK in DB is not encrypted, we should fail to decrypt it on load # since PK in DB is not encrypted, we should fail to decrypt it on load
self.assertRaises(nsx_exc.ClientCertificateException, self.assertRaises(nsx_exc.ClientCertificateException,
self._provider.__enter__) self._provider().__enter__)
def test_basic_provider(self): def test_basic_provider(self):