Merge "NSXv3: Move away from locking in cert provider"

This commit is contained in:
Jenkins 2017-07-15 09:35:58 +00:00 committed by Gerrit Code Review
commit 878ad90f34

View File

@ -37,11 +37,13 @@ LOG = logging.getLogger(__name__)
class DbCertProvider(client_cert.ClientCertProvider): class DbCertProvider(client_cert.ClientCertProvider):
"""Write cert data from DB to file and delete after use """Write cert data from DB to file and delete after use
Since several connections may use same filename simultaneously, New file with random filename is created for each thread. This
this class maintains refcount to write/delete the file only once is not most efficient, but the safest way to avoid race conditions,
since backend connections can occur both before and after neutron
fork.
""" """
EXPIRATION_ALERT_DAYS = 30 # days prior to expiration EXPIRATION_ALERT_DAYS = 30 # days prior to expiration
lock = threading.Lock() _thread_local = threading.local()
def __init__(self): def __init__(self):
# Note: we cannot initialize filename here, because this call # Note: we cannot initialize filename here, because this call
@ -54,11 +56,6 @@ class DbCertProvider(client_cert.ClientCertProvider):
super(DbCertProvider, self).__init__(None) super(DbCertProvider, self).__init__(None)
random.seed() random.seed()
with self.lock:
# Initialize refcount if other threads did not do it already
if not hasattr(self, 'refcount'):
self.refcount = 0
def _check_expiration(self, expires_in_days): def _check_expiration(self, expires_in_days):
if expires_in_days > self.EXPIRATION_ALERT_DAYS: if expires_in_days > self.EXPIRATION_ALERT_DAYS:
return return
@ -72,16 +69,10 @@ class DbCertProvider(client_cert.ClientCertProvider):
expires_in_days) expires_in_days)
def __enter__(self): def __enter__(self):
with self.lock:
self.refcount += 1
if self.refcount > 1:
# The file was already created and not yet deleted, use it
return self
# No certificate file available - need to create one # No certificate file available - need to create one
# Choose a random filename to contain the certificate # Choose a random filename to contain the certificate
self._filename = '/tmp/.' + str(random.randint(1, 10000000)) self._thread_local._filename = '/tmp/.' + str(
random.randint(1, 10000000))
try: try:
context = q_context.get_admin_context() context = q_context.get_admin_context()
@ -95,15 +86,15 @@ class DbCertProvider(client_cert.ClientCertProvider):
msg = _("Unable to load from nsx-db") msg = _("Unable to load from nsx-db")
raise nsx_exc.ClientCertificateException(err_msg=msg) raise nsx_exc.ClientCertificateException(err_msg=msg)
if not os.path.exists(os.path.dirname(self._filename)): filename = self._thread_local._filename
if len(os.path.dirname(self._filename)) > 0: if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(self._filename)) if len(os.path.dirname(filename)) > 0:
cert_manager.export_pem(self._filename) os.makedirs(os.path.dirname(filename))
cert_manager.export_pem(filename)
expires_in_days = cert_manager.expires_in_days() expires_in_days = cert_manager.expires_in_days()
self._check_expiration(expires_in_days) self._check_expiration(expires_in_days)
except Exception as e: except Exception as e:
# refcount has to be 1 here
self._on_exit() self._on_exit()
raise e raise e
@ -111,23 +102,17 @@ class DbCertProvider(client_cert.ClientCertProvider):
return self return self
def _on_exit(self): def _on_exit(self):
self.refcount -= 1 if os.path.isfile(self._thread_local._filename):
os.remove(self._thread_local._filename)
LOG.debug("Deleted client certificate file")
if self.refcount == 0: self._thread_local._filename = None
# I am the last user of this file
if os.path.isfile(self._filename):
os.remove(self._filename)
LOG.debug("Deleted client certificate file")
self._filename = None
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
with self.lock: self._on_exit()
self._on_exit()
def filename(self): def filename(self):
with self.lock: return self._thread_local._filename
return self._filename
def get_client_cert_provider(): def get_client_cert_provider():