Merge "NSXv3: Solve race condition in DB cert provider"

This commit is contained in:
Jenkins 2017-05-26 09:27:33 +00:00 committed by Gerrit Code Review
commit 74baadec61
2 changed files with 38 additions and 13 deletions

View File

@ -13,6 +13,8 @@
# License for the specific language governing permissions and limitations
# under the License.
import os
import random
import threading
from oslo_config import cfg
from oslo_log import log as logging
@ -39,11 +41,32 @@ class DbCertProvider(client_cert.ClientCertProvider):
this class maintains refcount to write/delete the file only once
"""
EXPIRATION_ALERT_DAYS = 30 # days prior to expiration
lock = threading.Lock()
def __init__(self, filename):
super(DbCertProvider, self).__init__(filename)
def __init__(self):
# Note: we cannot initialize filename here, because this call
# happens before neutron fork, meaning variable initialized here
# will be shared between all neutron processes (which will cause file
# collisions).
# The file can be shared between different connections within same
# process, if they happen to do the SSL handshake simultaneously.
# Such collisions are handled with refcount and locking.
super(DbCertProvider, self).__init__(None)
random.seed()
self.refcount = 0
def _increase_and_test_first(self):
with self.lock:
self.refcount += 1
return (self.refcount == 1)
def _decrease_and_test_last(self):
with self.lock:
self.refcount -= 1
return (self.refcount == 0)
def _check_expiration(self, expires_in_days):
if expires_in_days > self.EXPIRATION_ALERT_DAYS:
return
@ -57,13 +80,12 @@ class DbCertProvider(client_cert.ClientCertProvider):
expires_in_days)
def __enter__(self):
self.refcount += 1
if self.refcount > 1:
# The file was prepared for another connection
# and was not removed yet
if not self._increase_and_test_first():
# The file was already created and not yet deleted, use it
return self
# Choose a random filename to contain cert for the current connection
self._filename = '/tmp/.' + str(random.randint(1, 10000000))
try:
context = q_context.get_admin_context()
db_storage_driver = cert_utils.DbCertificateStorageDriver(context)
@ -90,11 +112,14 @@ class DbCertProvider(client_cert.ClientCertProvider):
return self
def _on_exit(self):
self.refcount -= 1
if self.refcount == 0 and os.path.isfile(self._filename):
# I am the last user of this file
if self._decrease_and_test_last():
if os.path.isfile(self._filename):
os.remove(self._filename)
LOG.debug("Deleted client certificate file")
self._filename = None
def __exit__(self, type, value, traceback):
self._on_exit()
@ -116,7 +141,7 @@ def get_client_cert_provider():
# Cert data is stored in DB, and written to file system only
# when new connection is opened, and deleted immediately after.
# Pid is appended to avoid file collisions between neutron servers
return DbCertProvider('/tmp/.' + str(os.getpid()))
return DbCertProvider()
def get_nsxlib_wrapper(nsx_username=None, nsx_password=None, basic_auth=False):

View File

@ -24,7 +24,7 @@ from vmware_nsx.plugins.nsx_v3 import cert_utils
from vmware_nsx.plugins.nsx_v3 import utils
class NsxV3iClientCertProviderTestCase(unittest.TestCase):
class NsxV3ClientCertProviderTestCase(unittest.TestCase):
CERT = "-----BEGIN CERTIFICATE-----\n" \
"MIIDJTCCAg0CBFh36j0wDQYJKoZIhvcNAQELBQAwVzELMAkGA1UEBhMCVVMxEzAR\n" \