diff --git a/nodepool/config.py b/nodepool/config.py index fd2241816..d59d6a5ce 100755 --- a/nodepool/config.py +++ b/nodepool/config.py @@ -204,7 +204,7 @@ def get_provider_config(provider): # Ensure legacy configuration still works when using fake cloud if provider.get('name', '').startswith('fake'): provider['driver'] = 'fake' - driver = Drivers._get(provider['driver']) + driver = Drivers.get(provider['driver']) return driver.getProviderConfig(provider) @@ -234,7 +234,7 @@ def loadConfig(config_path): config = openConfig(config_path) # Call driver config reset now to clean global hooks like os_client_config - for driver in Drivers._drivers.values(): + for driver in Drivers.drivers.values(): driver.reset() newconfig = Config() diff --git a/nodepool/driver/__init__.py b/nodepool/driver/__init__.py index 436104a26..aef5bf089 100644 --- a/nodepool/driver/__init__.py +++ b/nodepool/driver/__init__.py @@ -32,7 +32,6 @@ class Drivers: log = logging.getLogger("nodepool.driver.Drivers") drivers = {} - _drivers = {} # TODO: replace drivers drivers_paths = None @staticmethod @@ -72,22 +71,6 @@ class Drivers: if not os.path.isdir(driver_path) or \ "__init__.py" not in os.listdir(driver_path): continue - Drivers.log.debug("%s: loading driver", driver_path) - driver_obj = {} - for name, parent_class in ( - ("provider", Provider), - ): - driver_obj[name] = Drivers._load_class( - driver, os.path.join(driver_path, "%s.py" % name), - parent_class) - if not driver_obj[name]: - break - if not driver_obj[name]: - Drivers.log.error( - "%s: skipping incorrect driver from %s.py", - driver_path, name) - continue - Drivers.drivers[driver] = driver_obj driver_obj = Drivers._load_class( driver, os.path.join(driver_path, "__init__.py"), Driver) @@ -96,7 +79,7 @@ class Drivers: "%s: skipping incorrect driver from __init__.py", driver_path) continue - Drivers._drivers[driver] = driver_obj() + Drivers.drivers[driver] = driver_obj() Drivers.drivers_paths = drivers_paths @@ -109,16 +92,6 @@ class Drivers: except KeyError: raise RuntimeError("%s: unknown driver" % name) - # TODO: replace get - @staticmethod - def _get(name): - if not Drivers._drivers: - Drivers.load() - try: - return Drivers._drivers[name] - except KeyError: - raise RuntimeError("%s: unknown driver" % name) - class Driver(object, metaclass=abc.ABCMeta): """The Driver interface @@ -145,6 +118,18 @@ class Driver(object, metaclass=abc.ABCMeta): """ pass + @abc.abstractmethod + def getProvider(self, provider_config, use_taskmanager): + """Return a Provider instance + + :arg dict provider_config: A ProviderConfig instance + + :arg bool use_taskmanager: Whether this provider should use a + task manager (i.e., perform synchronous or asynchronous + operations). + """ + pass + class Provider(object, metaclass=abc.ABCMeta): """The Provider interface diff --git a/nodepool/driver/fake/__init__.py b/nodepool/driver/fake/__init__.py index f7cb7a4d4..2d4b6fadc 100644 --- a/nodepool/driver/fake/__init__.py +++ b/nodepool/driver/fake/__init__.py @@ -16,6 +16,7 @@ import os_client_config from nodepool.driver import Driver from nodepool.driver.fake.config import FakeProviderConfig +from nodepool.driver.fake.provider import FakeProvider class FakeDriver(Driver): @@ -28,3 +29,6 @@ class FakeDriver(Driver): def getProviderConfig(self, provider): return FakeProviderConfig(self, provider) + + def getProvider(self, provider_config, use_taskmanager): + return FakeProvider(provider_config, use_taskmanager) diff --git a/nodepool/driver/openstack/__init__.py b/nodepool/driver/openstack/__init__.py index c7648edf9..d335380be 100644 --- a/nodepool/driver/openstack/__init__.py +++ b/nodepool/driver/openstack/__init__.py @@ -16,6 +16,7 @@ import os_client_config from nodepool.driver import Driver from nodepool.driver.openstack.config import OpenStackProviderConfig +from nodepool.driver.openstack.provider import OpenStackProvider class OpenStackDriver(Driver): @@ -28,3 +29,6 @@ class OpenStackDriver(Driver): def getProviderConfig(self, provider): return OpenStackProviderConfig(self, provider) + + def getProvider(self, provider_config, use_taskmanager): + return OpenStackProvider(provider_config, use_taskmanager) diff --git a/nodepool/driver/static/__init__.py b/nodepool/driver/static/__init__.py index c29c0b85d..39ab8868b 100644 --- a/nodepool/driver/static/__init__.py +++ b/nodepool/driver/static/__init__.py @@ -13,9 +13,13 @@ # limitations under the License. from nodepool.driver import Driver -from nodepool.driver.static.config import StaticProviderConfig +from nodepool.driver.static import config +from nodepool.driver.static import provider class StaticDriver(Driver): def getProviderConfig(self, provider): - return StaticProviderConfig(provider) + return config.StaticProviderConfig(provider) + + def getProvider(self, provider_config, use_taskmanager): + return provider.StaticNodeProvider(provider_config, use_taskmanager) diff --git a/nodepool/driver/test/__init__.py b/nodepool/driver/test/__init__.py index 25a4bf1ec..34a112a13 100644 --- a/nodepool/driver/test/__init__.py +++ b/nodepool/driver/test/__init__.py @@ -13,9 +13,13 @@ # limitations under the License. from nodepool.driver import Driver -from nodepool.driver.test.config import TestConfig +from nodepool.driver.test import config +from nodepool.driver.test import provider class TestDriver(Driver): def getProviderConfig(self, provider): - return TestConfig(provider) + return config.TestConfig(provider) + + def getProvider(self, provider_config, use_taskmanager): + return provider.TestProvider(provider_config) diff --git a/nodepool/driver/test/provider.py b/nodepool/driver/test/provider.py index 50761b030..88373c59c 100644 --- a/nodepool/driver/test/provider.py +++ b/nodepool/driver/test/provider.py @@ -19,7 +19,7 @@ from nodepool.driver.test import handler class TestProvider(Provider): - def __init__(self, provider, *args): + def __init__(self, provider): self.provider = provider def start(self): diff --git a/nodepool/provider_manager.py b/nodepool/provider_manager.py index 2c4e75cae..907fc86a3 100755 --- a/nodepool/provider_manager.py +++ b/nodepool/provider_manager.py @@ -23,7 +23,7 @@ from nodepool.driver import Drivers def get_provider(provider, use_taskmanager): driver = Drivers.get(provider.driver.name) - return driver['provider'](provider, use_taskmanager) + return driver.getProvider(provider, use_taskmanager) class ProviderManager(object): diff --git a/nodepool/tests/test_builder.py b/nodepool/tests/test_builder.py index 3d18cf824..f08b03e43 100644 --- a/nodepool/tests/test_builder.py +++ b/nodepool/tests/test_builder.py @@ -18,7 +18,6 @@ import uuid import fixtures from nodepool import builder, exceptions, tests -from nodepool.driver import Drivers from nodepool.driver.fake import provider as fakeprovider from nodepool import zk @@ -121,7 +120,7 @@ class TestNodePoolBuilder(tests.DBTestCase): return fake_client self.useFixture(fixtures.MockPatchObject( - Drivers.get('fake')['provider'], '_getClient', + fakeprovider.FakeProvider, '_getClient', get_fake_client)) configfile = self.setup_config('node.yaml') diff --git a/nodepool/tests/test_launcher.py b/nodepool/tests/test_launcher.py index 2809c8b59..0ce334bf6 100644 --- a/nodepool/tests/test_launcher.py +++ b/nodepool/tests/test_launcher.py @@ -21,7 +21,7 @@ import mock from nodepool import tests from nodepool import zk -from nodepool.driver import Drivers +from nodepool.driver.fake import provider as fakeprovider import nodepool.launcher from kazoo import exceptions as kze @@ -130,7 +130,7 @@ class TestLauncher(tests.DBTestCase): def fake_get_quota(): return (max_cores, max_instances, max_ram) self.useFixture(fixtures.MockPatchObject( - Drivers.get('fake')['provider'].fake_cloud, '_get_quota', + fakeprovider.FakeProvider.fake_cloud, '_get_quota', fake_get_quota )) @@ -265,7 +265,7 @@ class TestLauncher(tests.DBTestCase): def fake_get_quota(): return (max_cores, max_instances, max_ram) self.useFixture(fixtures.MockPatchObject( - Drivers.get('fake')['provider'].fake_cloud, '_get_quota', + fakeprovider.FakeProvider.fake_cloud, '_get_quota', fake_get_quota )) @@ -653,7 +653,7 @@ class TestLauncher(tests.DBTestCase): raise RuntimeError('Fake Error') self.useFixture(fixtures.MockPatchObject( - Drivers.get('fake')['provider'], 'deleteServer', fail_delete)) + fakeprovider.FakeProvider, 'deleteServer', fail_delete)) configfile = self.setup_config('node.yaml') pool = self.useNodepool(configfile, watermark_sleep=1)