Merge "NSXv3: Solve race condition in DB cert provider"
This commit is contained in:
commit
74baadec61
@ -13,6 +13,8 @@
|
|||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
|
||||||
from oslo_config import cfg
|
from oslo_config import cfg
|
||||||
from oslo_log import log as logging
|
from oslo_log import log as logging
|
||||||
@ -39,11 +41,32 @@ class DbCertProvider(client_cert.ClientCertProvider):
|
|||||||
this class maintains refcount to write/delete the file only once
|
this class maintains refcount to write/delete the file only once
|
||||||
"""
|
"""
|
||||||
EXPIRATION_ALERT_DAYS = 30 # days prior to expiration
|
EXPIRATION_ALERT_DAYS = 30 # days prior to expiration
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
def __init__(self, filename):
|
def __init__(self):
|
||||||
super(DbCertProvider, self).__init__(filename)
|
# Note: we cannot initialize filename here, because this call
|
||||||
|
# happens before neutron fork, meaning variable initialized here
|
||||||
|
# will be shared between all neutron processes (which will cause file
|
||||||
|
# collisions).
|
||||||
|
# The file can be shared between different connections within same
|
||||||
|
# process, if they happen to do the SSL handshake simultaneously.
|
||||||
|
# Such collisions are handled with refcount and locking.
|
||||||
|
super(DbCertProvider, self).__init__(None)
|
||||||
|
random.seed()
|
||||||
self.refcount = 0
|
self.refcount = 0
|
||||||
|
|
||||||
|
def _increase_and_test_first(self):
|
||||||
|
with self.lock:
|
||||||
|
self.refcount += 1
|
||||||
|
|
||||||
|
return (self.refcount == 1)
|
||||||
|
|
||||||
|
def _decrease_and_test_last(self):
|
||||||
|
with self.lock:
|
||||||
|
self.refcount -= 1
|
||||||
|
|
||||||
|
return (self.refcount == 0)
|
||||||
|
|
||||||
def _check_expiration(self, expires_in_days):
|
def _check_expiration(self, expires_in_days):
|
||||||
if expires_in_days > self.EXPIRATION_ALERT_DAYS:
|
if expires_in_days > self.EXPIRATION_ALERT_DAYS:
|
||||||
return
|
return
|
||||||
@ -57,13 +80,12 @@ class DbCertProvider(client_cert.ClientCertProvider):
|
|||||||
expires_in_days)
|
expires_in_days)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.refcount += 1
|
if not self._increase_and_test_first():
|
||||||
|
# The file was already created and not yet deleted, use it
|
||||||
if self.refcount > 1:
|
|
||||||
# The file was prepared for another connection
|
|
||||||
# and was not removed yet
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
# Choose a random filename to contain cert for the current connection
|
||||||
|
self._filename = '/tmp/.' + str(random.randint(1, 10000000))
|
||||||
try:
|
try:
|
||||||
context = q_context.get_admin_context()
|
context = q_context.get_admin_context()
|
||||||
db_storage_driver = cert_utils.DbCertificateStorageDriver(context)
|
db_storage_driver = cert_utils.DbCertificateStorageDriver(context)
|
||||||
@ -90,10 +112,13 @@ class DbCertProvider(client_cert.ClientCertProvider):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def _on_exit(self):
|
def _on_exit(self):
|
||||||
self.refcount -= 1
|
# I am the last user of this file
|
||||||
if self.refcount == 0 and os.path.isfile(self._filename):
|
if self._decrease_and_test_last():
|
||||||
os.remove(self._filename)
|
if os.path.isfile(self._filename):
|
||||||
LOG.debug("Deleted client certificate file")
|
os.remove(self._filename)
|
||||||
|
LOG.debug("Deleted client certificate file")
|
||||||
|
|
||||||
|
self._filename = None
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
self._on_exit()
|
self._on_exit()
|
||||||
@ -116,7 +141,7 @@ def get_client_cert_provider():
|
|||||||
# Cert data is stored in DB, and written to file system only
|
# Cert data is stored in DB, and written to file system only
|
||||||
# when new connection is opened, and deleted immediately after.
|
# when new connection is opened, and deleted immediately after.
|
||||||
# Pid is appended to avoid file collisions between neutron servers
|
# Pid is appended to avoid file collisions between neutron servers
|
||||||
return DbCertProvider('/tmp/.' + str(os.getpid()))
|
return DbCertProvider()
|
||||||
|
|
||||||
|
|
||||||
def get_nsxlib_wrapper(nsx_username=None, nsx_password=None, basic_auth=False):
|
def get_nsxlib_wrapper(nsx_username=None, nsx_password=None, basic_auth=False):
|
||||||
|
@ -24,7 +24,7 @@ from vmware_nsx.plugins.nsx_v3 import cert_utils
|
|||||||
from vmware_nsx.plugins.nsx_v3 import utils
|
from vmware_nsx.plugins.nsx_v3 import utils
|
||||||
|
|
||||||
|
|
||||||
class NsxV3iClientCertProviderTestCase(unittest.TestCase):
|
class NsxV3ClientCertProviderTestCase(unittest.TestCase):
|
||||||
|
|
||||||
CERT = "-----BEGIN CERTIFICATE-----\n" \
|
CERT = "-----BEGIN CERTIFICATE-----\n" \
|
||||||
"MIIDJTCCAg0CBFh36j0wDQYJKoZIhvcNAQELBQAwVzELMAkGA1UEBhMCVVMxEzAR\n" \
|
"MIIDJTCCAg0CBFh36j0wDQYJKoZIhvcNAQELBQAwVzELMAkGA1UEBhMCVVMxEzAR\n" \
|
||||||
|
Loading…
Reference in New Issue
Block a user