Merge "NSXv3: Rewrite client certificate provider"
This commit is contained in:
commit
e0b64848a5
@ -14,7 +14,6 @@
|
||||
# under the License.
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
|
||||
from oslo_config import cfg
|
||||
from oslo_log import log as logging
|
||||
@ -37,24 +36,19 @@ LOG = logging.getLogger(__name__)
|
||||
class DbCertProvider(client_cert.ClientCertProvider):
|
||||
"""Write cert data from DB to file and delete after use
|
||||
|
||||
New file with random filename is created for each thread. This
|
||||
is not most efficient, but the safest way to avoid race conditions,
|
||||
New provider object with random filename is created for each request.
|
||||
This is not most efficient, but the safest way to avoid race conditions,
|
||||
since backend connections can occur both before and after neutron
|
||||
fork.
|
||||
fork, and several concurrent requests can occupy the same thread.
|
||||
Note that new cert filename for each request does not result in new
|
||||
connection for each request (at least for now..)
|
||||
"""
|
||||
EXPIRATION_ALERT_DAYS = 30 # days prior to expiration
|
||||
_thread_local = threading.local()
|
||||
|
||||
def __init__(self):
|
||||
# Note: we cannot initialize filename here, because this call
|
||||
# happens before neutron fork, meaning variable initialized here
|
||||
# will be shared between all neutron processes (which will cause file
|
||||
# collisions).
|
||||
# The file can be shared between different connections within same
|
||||
# process, if they happen to do the SSL handshake simultaneously.
|
||||
# Such collisions are handled with refcount and locking.
|
||||
super(DbCertProvider, self).__init__(None)
|
||||
random.seed()
|
||||
self._filename = '/tmp/.' + str(random.randint(1, 10000000))
|
||||
|
||||
def _check_expiration(self, expires_in_days):
|
||||
if expires_in_days > self.EXPIRATION_ALERT_DAYS:
|
||||
@ -69,11 +63,6 @@ class DbCertProvider(client_cert.ClientCertProvider):
|
||||
expires_in_days)
|
||||
|
||||
def __enter__(self):
|
||||
# No certificate file available - need to create one
|
||||
# Choose a random filename to contain the certificate
|
||||
self._thread_local._filename = '/tmp/.' + str(
|
||||
random.randint(1, 10000000))
|
||||
|
||||
try:
|
||||
context = q_context.get_admin_context()
|
||||
db_storage_driver = cert_utils.DbCertificateStorageDriver(
|
||||
@ -86,7 +75,7 @@ class DbCertProvider(client_cert.ClientCertProvider):
|
||||
msg = _("Unable to load from nsx-db")
|
||||
raise nsx_exc.ClientCertificateException(err_msg=msg)
|
||||
|
||||
filename = self._thread_local._filename
|
||||
filename = self._filename
|
||||
if not os.path.exists(os.path.dirname(filename)):
|
||||
if len(os.path.dirname(filename)) > 0:
|
||||
os.makedirs(os.path.dirname(filename))
|
||||
@ -98,21 +87,19 @@ class DbCertProvider(client_cert.ClientCertProvider):
|
||||
self._on_exit()
|
||||
raise e
|
||||
|
||||
LOG.debug("Prepared client certificate file")
|
||||
return self
|
||||
|
||||
def _on_exit(self):
|
||||
if os.path.isfile(self._thread_local._filename):
|
||||
os.remove(self._thread_local._filename)
|
||||
LOG.debug("Deleted client certificate file")
|
||||
if os.path.isfile(self._filename):
|
||||
os.remove(self._filename)
|
||||
|
||||
self._thread_local._filename = None
|
||||
self._filename = None
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self._on_exit()
|
||||
|
||||
def filename(self):
|
||||
return self._thread_local._filename
|
||||
return self._filename
|
||||
|
||||
|
||||
def get_client_cert_provider():
|
||||
@ -128,8 +115,7 @@ def get_client_cert_provider():
|
||||
if cfg.CONF.nsx_v3.nsx_client_cert_storage.lower() == 'nsx-db':
|
||||
# Cert data is stored in DB, and written to file system only
|
||||
# when new connection is opened, and deleted immediately after.
|
||||
# Pid is appended to avoid file collisions between neutron servers
|
||||
return DbCertProvider()
|
||||
return DbCertProvider
|
||||
|
||||
|
||||
def get_nsxlib_wrapper(nsx_username=None, nsx_password=None, basic_auth=False):
|
||||
|
@ -92,7 +92,7 @@ class NsxV3ClientCertProviderTestCase(unittest.TestCase):
|
||||
|
||||
def validate_db_provider(self, expected_cert_data):
|
||||
fname = None
|
||||
with self._provider as p:
|
||||
with self._provider() as p:
|
||||
# verify cert data was exported to CERTFILE
|
||||
fname = p.filename()
|
||||
with open(fname, 'r') as f:
|
||||
@ -125,7 +125,7 @@ class NsxV3ClientCertProviderTestCase(unittest.TestCase):
|
||||
"vmware_nsx.db.db.get_certificate",
|
||||
return_value=(None, None)).start()
|
||||
self.assertRaises(nsx_exc.ClientCertificateException,
|
||||
self._provider.__enter__)
|
||||
self._provider().__enter__)
|
||||
|
||||
# now verify return to normal after failure
|
||||
mock.patch(
|
||||
@ -174,7 +174,7 @@ class NsxV3ClientCertProviderTestCase(unittest.TestCase):
|
||||
|
||||
# since PK in DB is not encrypted, we should fail to decrypt it on load
|
||||
self.assertRaises(nsx_exc.ClientCertificateException,
|
||||
self._provider.__enter__)
|
||||
self._provider().__enter__)
|
||||
|
||||
def test_basic_provider(self):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user