From ae398e612cbb247c9537eb65266a25488fb466a9 Mon Sep 17 00:00:00 2001 From: Ilya Etingof Date: Wed, 21 Feb 2018 14:45:50 +0100 Subject: [PATCH] SNMP agent rebased on the high-level pysnmp API This change makes it possible to support all SNMP versions (including strong crypto features) as well as IPv6. The followup patches are expected to implement these new features. New `debug_snmp` option added to configuration file to facilitate troubleshooting of the SNMP issues. Closes-Bug: #1751163 Change-Id: Iec6452201f85f2f4c486e94abe8c6b6bba68840d --- doc/example.ini | 13 + virtualpdu/main.py | 31 ++- virtualpdu/pdu/pysnmp_handler.py | 223 ++++++++++-------- virtualpdu/tests/integration/pdu/test_pdu.py | 4 +- .../integration/pdu/test_pysnmp_handler.py | 35 +-- virtualpdu/tests/unit/__init__.py | 6 +- virtualpdu/tests/unit/test_pysnmp_handler.py | 207 ++++++++++------ 7 files changed, 315 insertions(+), 204 deletions(-) create mode 100644 doc/example.ini diff --git a/doc/example.ini b/doc/example.ini new file mode 100644 index 0000000..48a9ef3 --- /dev/null +++ b/doc/example.ini @@ -0,0 +1,13 @@ +[global] +libvirt_uri=test:///default +debug_snmp=no +[my_pdu] +listen_address=127.0.0.1 +listen_port=9998 +community=public +ports=5:test +[my_second_pdu] +listen_address=127.0.0.1 +listen_port=9997 +community=public +ports=2:test diff --git a/virtualpdu/main.py b/virtualpdu/main.py index 89a2ddf..037021a 100644 --- a/virtualpdu/main.py +++ b/virtualpdu/main.py @@ -34,15 +34,18 @@ def main(): config_file = sys.argv[1] except IndexError: sys.stderr.write(MISSING_CONFIG_MESSAGE) - sys.exit(1) + return 1 else: - config = configparser.RawConfigParser() + config = configparser.RawConfigParser({'debug_snmp': 'no'}) config.read(config_file) driver = get_driver_from_config(config) mapping = get_mapping_for_config(config) outlet_default_state = get_default_state_from_config(config) + debug_snmp = config.get('global', 'debug_snmp') + core = virtualpdu.core.Core(driver=driver, mapping=mapping, store={}, default_state=outlet_default_state) + pdu_threads = [] for pdu in [s for s in config.sections() if s != 'global']: @@ -52,12 +55,15 @@ def main(): apc_pdu = apc_rackpdu.APCRackPDU(pdu, core) - pdu_threads.append(pysnmp_handler.SNMPPDUHarness( - apc_pdu, - listen_address, - port, - community - )) + pdu_threads.append( + pysnmp_handler.SNMPPDUHarness( + apc_pdu, + listen_address, + port, + community, + debug_snmp=debug_snmp in ('yes', 'true', '1') + ) + ) for t in pdu_threads: t.start() @@ -66,10 +72,13 @@ def main(): for t in pdu_threads: while t.isAlive(): t.join(1) + except KeyboardInterrupt: for t in pdu_threads: t.stop() - sys.exit() + return 1 + + return 0 def parse_default_state_config(default_state): @@ -120,3 +129,7 @@ def get_default_state_from_config(conf): class UnableToParseConfig(Exception): pass + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/virtualpdu/pdu/pysnmp_handler.py b/virtualpdu/pdu/pysnmp_handler.py index dbb9741..bd67e75 100644 --- a/virtualpdu/pdu/pysnmp_handler.py +++ b/virtualpdu/pdu/pysnmp_handler.py @@ -15,117 +15,151 @@ import logging import threading -from pyasn1.codec.ber import decoder -from pyasn1.codec.ber import encoder -from pysnmp.carrier.asyncore.dgram import udp -from pysnmp.carrier.asyncore.dispatch import AsyncoreDispatcher -from pysnmp.proto import api +from pysnmp.carrier.asynsock.dgram import udp +from pysnmp import debug +from pysnmp.entity import config +from pysnmp.entity import engine +from pysnmp.entity.rfc3413 import cmdrsp +from pysnmp.entity.rfc3413 import context +from pysnmp.proto.api import v2c -# NOTE(mmitchell): Roughly from implementing-scalar-mib-objects.py in pysnmp. -# Unfortunately, that file is not part of the pysnmp package and re-use is -# not possible. # pysnmp is distributed under the BSD license. from virtualpdu.pdu import TraversableOidMapping -class SNMPPDUHandler(object): - def __init__(self, pdu, community): - self.pdu = pdu - self.community = community - self.logger = logging.getLogger(__name__) +class GetCommandResponder(cmdrsp.GetCommandResponder): - def message_handler(self, transportDispatcher, transportDomain, - transportAddress, whole_message): - while whole_message: - message_version = api.decodeMessageVersion(whole_message) - if message_version in api.protoModules: - protocol = api.protoModules[message_version] - else: - self.logger.warn( - 'Unsupported SNMP version "{}"'.format(message_version)) - return + def __init__(self, snmpEngine, snmpContext, power_unit): + super(GetCommandResponder, self).__init__(snmpEngine, snmpContext) + self.__power_unit = power_unit - request, whole_message = decoder.decode( - whole_message, asn1Spec=protocol.Message() + def handleMgmtOperation(self, snmpEngine, stateReference, + contextName, req_pdu, acInfo): + + var_binds = [] + + for oid, val in v2c.apiPDU.getVarBinds(req_pdu): + var_binds.append( + (oid, (self.__power_unit.oid_mapping[oid].value + if oid in self.__power_unit.oid_mapping + else v2c.NoSuchInstance(''))) ) - response = protocol.apiMessage.getResponse(request) - request_pdus = protocol.apiMessage.getPDU(request) - community = protocol.apiMessage.getCommunity(request) + self.sendRsp(snmpEngine, stateReference, 0, 0, var_binds) - if not self.valid_community(community): - self.logger.warn('Invalid community "{}"'.format(community)) - return + self.releaseStateInformation(stateReference) - response_pdus = protocol.apiMessage.getPDU(response) - var_binds = [] - pending_errors = [] - error_index = 0 - if request_pdus.isSameTypeWith(protocol.GetRequestPDU()): - for oid, val in protocol.apiPDU.getVarBinds(request_pdus): - if oid in self.pdu.oid_mapping: - var_binds.append( - (oid, self.pdu.oid_mapping[oid].value)) - else: - return - elif request_pdus.isSameTypeWith(protocol.GetNextRequestPDU()): - for oid, val in protocol.apiPDU.getVarBinds(request_pdus): - error_index += 1 - try: - oid = TraversableOidMapping(self.pdu.oid_mapping)\ - .next(to=oid) - val = self.pdu.oid_mapping[oid].value - except (KeyError, IndexError): - pending_errors.append( - (protocol.apiPDU.setNoSuchInstanceError, - error_index) - ) - var_binds.append((oid, val)) - elif request_pdus.isSameTypeWith(protocol.SetRequestPDU()): - for oid, val in protocol.apiPDU.getVarBinds(request_pdus): - error_index += 1 - if oid in self.pdu.oid_mapping: - self.pdu.oid_mapping[oid].value = val - var_binds.append((oid, val)) - else: - var_binds.append((oid, val)) - pending_errors.append( - (protocol.apiPDU.setNoSuchInstanceError, - error_index) +class NextCommandResponder(cmdrsp.NextCommandResponder): + + def __init__(self, snmpEngine, snmpContext, power_unit): + super(NextCommandResponder, self).__init__(snmpEngine, snmpContext) + self.__power_unit = power_unit + + def handleMgmtOperation(self, snmpEngine, stateReference, + contextName, req_pdu, acInfo): + + oid_map = TraversableOidMapping(self.__power_unit.oid_mapping) + + var_binds = [] + + for oid, val in v2c.apiPDU.getVarBinds(req_pdu): + + try: + oid = oid_map.next(to=oid) + val = self.__power_unit.oid_mapping[oid].value + + except (KeyError, IndexError): + val = v2c.NoSuchInstance('') + + var_binds.append((oid, val)) + + self.sendRsp(snmpEngine, stateReference, 0, 0, var_binds) + + self.releaseStateInformation(stateReference) + + +class SetCommandResponder(cmdrsp.SetCommandResponder): + + def __init__(self, snmpEngine, snmpContext, power_unit): + super(SetCommandResponder, self).__init__(snmpEngine, snmpContext) + self.__power_unit = power_unit + + self.__logger = logging.getLogger(__name__) + + def handleMgmtOperation(self, snmpEngine, stateReference, + contextName, req_pdu, acInfo): + + var_binds = [] + + for oid, val in v2c.apiPDU.getVarBinds(req_pdu): + if oid in self.__power_unit.oid_mapping: + try: + self.__power_unit.oid_mapping[oid].value = val + + except Exception as ex: + self.__logger.info( + 'Set value {} on power unit {} failed: {}'.format( + val, self.__power_unit.name, ex ) + ) + val = v2c.NoSuchInstance('') else: - protocol.apiPDU.setErrorStatus(response_pdus, 'genErr') + val = v2c.NoSuchInstance('') - protocol.apiPDU.setVarBinds(response_pdus, var_binds) + var_binds.append((oid, val)) - # Commit possible error indices to response PDU - for f, i in pending_errors: - f(response_pdus, i) + self.sendRsp(snmpEngine, stateReference, 0, 0, var_binds) - transportDispatcher.sendMessage( - encoder.encode(response), transportDomain, transportAddress - ) + self.releaseStateInformation(stateReference) - return whole_message - def valid_community(self, community): - return str(community) == self.community +def create_snmp_engine(power_unit, listen_address, listen_port, + community="public"): + snmp_engine = engine.SnmpEngine() + + config.addSocketTransport( + snmp_engine, + udp.domainName, + udp.UdpTransport().openServerMode((listen_address, listen_port)) + ) + + config.addV1System(snmp_engine, community, community) + + # Allow read MIB access for this user / securityModels at SNMP VACM + for snmp_version in (1, 2): + config.addVacmUser(snmp_engine, snmp_version, + community, 'noAuthNoPriv', (1,), (1,)) + + snmp_context = context.SnmpContext(snmp_engine) + + # Register SNMP Apps at the SNMP engine for particular SNMP context + GetCommandResponder(snmp_engine, snmp_context, power_unit=power_unit) + NextCommandResponder(snmp_engine, snmp_context, power_unit=power_unit) + SetCommandResponder(snmp_engine, snmp_context, power_unit=power_unit) + + return snmp_engine class SNMPPDUHarness(threading.Thread): - def __init__(self, pdu, listen_address, listen_port, community="public"): + def __init__(self, power_unit, + listen_address, listen_port, + community="public", + debug_snmp=False): super(SNMPPDUHarness, self).__init__() - self.logger = logging.getLogger(__name__) - self.pdu = pdu + self._logger = logging.getLogger(__name__) - self.snmp_handler = SNMPPDUHandler(self.pdu, community=community) + if debug_snmp: + debug.setLogger(debug.Debug('all')) + + self.snmp_engine = create_snmp_engine(power_unit, listen_address, + listen_port, community) self.listen_address = listen_address self.listen_port = listen_port - self.transportDispatcher = AsyncoreDispatcher() + self.power_unit = power_unit self._lock = threading.Lock() self._stop_requested = False @@ -135,31 +169,24 @@ class SNMPPDUHarness(threading.Thread): if self._stop_requested: return - self.logger.info("Starting PDU '{}' on {}:{}".format( - self.pdu.name, self.listen_address, self.listen_port) - ) - self.transportDispatcher.registerRecvCbFun( - self.snmp_handler.message_handler) + self._logger.info("Starting SNMP agent at {}:{} serving '{}'" + .format(self.listen_address, self.listen_port, + self.power_unit.name)) - # UDP/IPv4 - self.transportDispatcher.registerTransport( - udp.domainName, - udp.UdpSocketTransport().openServerMode( - (self.listen_address, self.listen_port)) - ) - - self.transportDispatcher.jobStarted(1) + self.snmp_engine.transportDispatcher.jobStarted(1) try: # Dispatcher will never finish as job#1 never reaches zero - self.transportDispatcher.runDispatcher() + self.snmp_engine.transportDispatcher.runDispatcher() + except Exception: - self.transportDispatcher.closeDispatcher() + self.snmp_engine.transportDispatcher.closeDispatcher() def stop(self): with self._lock: self._stop_requested = True try: - self.transportDispatcher.jobFinished(1) + self.snmp_engine.transportDispatcher.jobFinished(1) + except KeyError: pass # The job is not started yet and will not start diff --git a/virtualpdu/tests/integration/pdu/test_pdu.py b/virtualpdu/tests/integration/pdu/test_pdu.py index 6a00dca..4b1599c 100644 --- a/virtualpdu/tests/integration/pdu/test_pdu.py +++ b/virtualpdu/tests/integration/pdu/test_pdu.py @@ -28,8 +28,8 @@ class TestPDU(PDUTestCase): outlet_control_class = pdu.PDUOutletControl def test_get_unknown_oid(self): - self.assertRaises(RequestTimedOut, - self.snmp_get, enterprises + (42,)) + self.assertEqual(NoSuchInstance(''), + self.snmp_get(enterprises + (42,))) def test_set_unknown_oid(self): self.assertEqual(NoSuchInstance(''), diff --git a/virtualpdu/tests/integration/pdu/test_pysnmp_handler.py b/virtualpdu/tests/integration/pdu/test_pysnmp_handler.py index 45fc7f4..a1cd5aa 100644 --- a/virtualpdu/tests/integration/pdu/test_pysnmp_handler.py +++ b/virtualpdu/tests/integration/pdu/test_pysnmp_handler.py @@ -26,9 +26,10 @@ class TestSNMPPDUHarness(base.TestCase): def test_harness_get(self): - mock_pdu = mock.Mock() + mock_power_unit = mock.Mock() port = randint(20000, 30000) - harness = pysnmp_handler.SNMPPDUHarness(pdu=mock_pdu, + + harness = pysnmp_handler.SNMPPDUHarness(power_unit=mock_power_unit, listen_address='127.0.0.1', listen_port=port, community='bleh') @@ -42,9 +43,9 @@ class TestSNMPPDUHarness(base.TestCase): timeout=1, retries=1) - mock_pdu.oid_mapping = dict() - mock_pdu.oid_mapping[(1, 3, 6, 99)] = mock.Mock() - mock_pdu.oid_mapping[(1, 3, 6, 99)].value = univ.Integer(42) + mock_power_unit.oid_mapping = dict() + mock_power_unit.oid_mapping[(1, 3, 6, 99)] = mock.Mock() + mock_power_unit.oid_mapping[(1, 3, 6, 99)].value = univ.Integer(42) self.assertEqual(42, client.get_one((1, 3, 6, 99))) @@ -52,9 +53,9 @@ class TestSNMPPDUHarness(base.TestCase): def test_harness_set(self): - mock_pdu = mock.Mock() + mock_power_unit = mock.Mock() port = randint(20000, 30000) - harness = pysnmp_handler.SNMPPDUHarness(pdu=mock_pdu, + harness = pysnmp_handler.SNMPPDUHarness(power_unit=mock_power_unit, listen_address='127.0.0.1', listen_port=port, community='bleh') @@ -68,20 +69,20 @@ class TestSNMPPDUHarness(base.TestCase): timeout=1, retries=1) - mock_pdu.oid_mapping = dict() - mock_pdu.oid_mapping[(1, 3, 6, 98)] = mock.Mock() + mock_power_unit.oid_mapping = dict() + mock_power_unit.oid_mapping[(1, 3, 6, 98)] = mock.Mock() client.set((1, 3, 6, 98), univ.Integer(99)) self.assertEqual(univ.Integer(99), - mock_pdu.oid_mapping[(1, 3, 6, 98)].value) + mock_power_unit.oid_mapping[(1, 3, 6, 98)].value) harness.stop() def test_harness_get_next(self): - mock_pdu = mock.Mock() + mock_power_unit = mock.Mock() port = randint(20000, 30000) - harness = pysnmp_handler.SNMPPDUHarness(pdu=mock_pdu, + harness = pysnmp_handler.SNMPPDUHarness(power_unit=mock_power_unit, listen_address='127.0.0.1', listen_port=port, community='bleh') @@ -95,9 +96,9 @@ class TestSNMPPDUHarness(base.TestCase): timeout=1, retries=1) - mock_pdu.oid_mapping = dict() - mock_pdu.oid_mapping[(1, 3, 6, 1, 5)] = mock.Mock() - mock_pdu.oid_mapping[(1, 3, 6, 1, 5)].value = univ.Integer(42) + mock_power_unit.oid_mapping = dict() + mock_power_unit.oid_mapping[(1, 3, 6, 1, 5)] = mock.Mock() + mock_power_unit.oid_mapping[(1, 3, 6, 1, 5)].value = univ.Integer(42) oid, val = client.get_next((1, 3, 6, 1)) @@ -107,9 +108,9 @@ class TestSNMPPDUHarness(base.TestCase): harness.stop() def test_start_stop_threadsafety(self): - mock_pdu = mock.Mock() + mock_power_unit = mock.Mock() port = randint(20000, 30000) - harness = pysnmp_handler.SNMPPDUHarness(pdu=mock_pdu, + harness = pysnmp_handler.SNMPPDUHarness(power_unit=mock_power_unit, listen_address='127.0.0.1', listen_port=port, community='bleh') diff --git a/virtualpdu/tests/unit/__init__.py b/virtualpdu/tests/unit/__init__.py index bc0d518..b23a1d8 100644 --- a/virtualpdu/tests/unit/__init__.py +++ b/virtualpdu/tests/unit/__init__.py @@ -20,12 +20,14 @@ class TraversableMessage(object): def __getitem__(self, type_class): ret = None try: - for component in self.value.values(): + # this is required for ancient pyasn1 to work + for idx in range(len(self.value)): + component = self.value.getComponentByPosition(idx) if isinstance(component, type_class): if ret: raise KeyError() ret = component - except AttributeError: + except (TypeError, AttributeError): index = type_class ret = self.value[index] return TraversableMessage(ret) diff --git a/virtualpdu/tests/unit/test_pysnmp_handler.py b/virtualpdu/tests/unit/test_pysnmp_handler.py index 3537098..beccd84 100644 --- a/virtualpdu/tests/unit/test_pysnmp_handler.py +++ b/virtualpdu/tests/unit/test_pysnmp_handler.py @@ -16,8 +16,8 @@ import unittest from mock import Mock from mock import patch -from mock import sentinel +from pysnmp.proto.errind import UnknownPDUHandler from pysnmp.proto.rfc1902 import Integer from pysnmp.proto.rfc1902 import ObjectName from pysnmp.proto.rfc1902 import ObjectSyntax @@ -26,20 +26,20 @@ from pysnmp.proto.rfc1902 import SimpleSyntax from pysnmp.proto.rfc1905 import _BindValue from pysnmp.proto.rfc1905 import NoSuchInstance -from pysnmp.proto.rfc1905 import PDUs -from pysnmp.proto.rfc1905 import ResponsePDU from pysnmp.proto.rfc1905 import VarBindList -from virtualpdu.pdu.pysnmp_handler import SNMPPDUHandler +from virtualpdu.pdu.pysnmp_handler import create_snmp_engine from virtualpdu.tests.unit import TraversableMessage -SNMP_ERR_noSuchName = 2 -SNMP_ERR_genErr = 5 - # snmpget -v2c -c community localhost:10610 .1.1 MSG_SNMP_GET = (b'0%\x02\x01\x01\x04\tcommunity\xa0\x15\x02\x04$=W\xfd\x02\x01' b'\x00\x02\x01\x000\x070\x05\x06\x01)\x05\x00') +# snmpget -v2c -c community localhost:10610 .1.0 +MSG_SNMP_GET_UNKNOWN_OID = (b'0%\x02\x01\x01\x04\tcommunity\xa0\x15\x02\x04' + b'$=W\xfd\x02\x01\x00\x02\x01\x000\x070\x05\x06' + b'\x01(\x05\x00') + # snmpset -v2c -c community localhost:10610 .1.1 i 5 MSG_SNMP_SET = (b'0&\x02\x01\x01\x04\tcommunity\xa3\x16\x02\x04ce\xd84\x02\x01' b'\x00\x02\x01' @@ -62,103 +62,158 @@ MSG_SNMP_WRONG_COMM = (b'0+\x02\x01\x01\x04\x0fwrong_community\xa0\x15\x02' class SnmpServiceMessageReceivedTest(unittest.TestCase): def setUp(self): - self.pdu_mock = Mock() - self.pdu_mock.oid_mapping = {} - self.transport_dispatcher = Mock() - self.pdu_handler = SNMPPDUHandler(self.pdu_mock, 'community') - self.encoder_patcher = patch('virtualpdu.pdu.pysnmp_handler.encoder') - self.encoder_mock = self.encoder_patcher.start() - self.encoder_mock.return_value = sentinel.encoded_message + self.power_unit_mock = Mock() + self.power_unit_mock.oid_mapping = {} + + for pysnmp_package in ('asyncore', 'asynsock'): + try: + self.socket_patcher = patch('pysnmp.carrier.%s.dgram' + '.base.DgramSocketTransport' + '.openServerMode' % pysnmp_package) + self.socket_patcher.start() + + break + + except ImportError: + continue + + else: + raise ImportError('Monkeys failed at pysnmp patching!') + + self.snmp_engine = create_snmp_engine(self.power_unit_mock, + '127.0.0.1', 161, + 'community') def tearDown(self): - self.encoder_patcher.stop() + self.snmp_engine.transportDispatcher.closeDispatcher() + self.socket_patcher.stop() def test_set_calls_pdu_mock(self): - self.pdu_mock.oid_mapping[(1, 1)] = Mock() + self.power_unit_mock.oid_mapping[(1, 1)] = Mock() - self.pdu_handler.message_handler(self.transport_dispatcher, - sentinel.transport_domain, - sentinel.transport_address, - MSG_SNMP_SET) - self.assertEqual(self.pdu_mock.oid_mapping[(1, 1)].value, 5) + self.snmp_engine.msgAndPduDsp.receiveMessage( + self.snmp_engine, (1, 3, 6, 1), ('127.0.0.1', 12345), + MSG_SNMP_SET + ) + + self.assertEqual(self.power_unit_mock.oid_mapping[(1, 1)].value, 5) def test_set_response(self): - self.pdu_mock.oid_mapping[(1, 1)] = Mock() + self.power_unit_mock.oid_mapping[(1, 1)] = Mock() - self.pdu_handler.message_handler(self.transport_dispatcher, - sentinel.transport_domain, - sentinel.transport_address, - MSG_SNMP_SET) + patcher = patch('virtualpdu.pdu.pysnmp_handler' + '.SetCommandResponder.handleMgmtOperation') + mock = patcher.start() - message = TraversableMessage(self.encoder_mock.encode.call_args[0][0]) - varbindlist = message[PDUs][ResponsePDU][VarBindList] + self.snmp_engine.msgAndPduDsp.receiveMessage( + self.snmp_engine, (1, 3, 6, 1), ('127.0.0.1', 12345), + MSG_SNMP_SET + ) + + message = TraversableMessage(mock.call_args[0][3]) + + patcher.stop() + + varbindlist = message[VarBindList] self.assertEqual(varbindlist[0][ObjectName].value, (1, 1)) self.assertEqual(varbindlist[0][_BindValue][ObjectSyntax] [SimpleSyntax][Integer].value, Integer(5)) - def test_set_with_unknown_oid_replies_nosuchinstance(self): - self.pdu_handler.message_handler(self.transport_dispatcher, - sentinel.transport_domain, - sentinel.transport_address, - MSG_SNMP_SET) + def test_get_with_unknown_oid_replies_nosuchinstance(self): - message = TraversableMessage(self.encoder_mock.encode.call_args[0][0]) - varbindlist = message[PDUs][ResponsePDU][VarBindList] - self.assertEqual(varbindlist[0][ObjectName].value, (1, 1)) - self.assertEqual(varbindlist[0][NoSuchInstance].value, - NoSuchInstance('')) + patcher = patch('virtualpdu.pdu.pysnmp_handler' + '.GetCommandResponder.sendRsp') + mock = patcher.start() + + self.snmp_engine.msgAndPduDsp.receiveMessage( + self.snmp_engine, (1, 3, 6, 1), ('127.0.0.1', 12345), + MSG_SNMP_GET_UNKNOWN_OID + ) + + varbindlist = mock.call_args[0][4] + + patcher.stop() + + self.assertEqual(varbindlist[0][0], (1, 0)) + self.assertIsInstance(varbindlist[0][1], NoSuchInstance) def test_get(self): - self.pdu_mock.oid_mapping[(1, 1)] = Mock() - self.pdu_mock.oid_mapping[(1, 1)].value = OctetString('test') + self.power_unit_mock.oid_mapping[(1, 1)] = Mock() + self.power_unit_mock.oid_mapping[(1, 1)].value = OctetString('test') - self.pdu_handler.message_handler(self.transport_dispatcher, - sentinel.transport_domain, - sentinel.transport_address, - MSG_SNMP_GET) + patcher = patch('virtualpdu.pdu.pysnmp_handler' + '.GetCommandResponder.sendRsp') + mock = patcher.start() - message = TraversableMessage(self.encoder_mock.encode.call_args[0][0]) - varbindlist = message[PDUs][ResponsePDU][VarBindList] + self.snmp_engine.msgAndPduDsp.receiveMessage( + self.snmp_engine, (1, 3, 6, 1), ('127.0.0.1', 12345), + MSG_SNMP_GET + ) - self.assertEqual(varbindlist[0][ObjectName].value, (1, 1)) - self.assertEqual(varbindlist[0][_BindValue][ObjectSyntax] - [SimpleSyntax][OctetString].value, - OctetString("test")) + varbindlist = mock.call_args[0][4] + + patcher.stop() + + self.assertEqual(varbindlist[0][0], (1, 1)) + self.assertEqual(varbindlist[0][1], OctetString("test")) def test_get_next(self): - self.pdu_mock.oid_mapping[(1, 1)] = Mock() - self.pdu_mock.oid_mapping[(1, 2)] = Mock() - self.pdu_mock.oid_mapping[(1, 2)].value = Integer(5) + self.power_unit_mock.oid_mapping[(1, 1)] = Mock() + self.power_unit_mock.oid_mapping[(1, 2)] = Mock() + self.power_unit_mock.oid_mapping[(1, 2)].value = Integer(5) - self.pdu_handler.message_handler(self.transport_dispatcher, - sentinel.transport_domain, - sentinel.transport_address, - MSG_SNMP_WALK) + patcher = patch('virtualpdu.pdu.pysnmp_handler' + '.NextCommandResponder.sendRsp') + mock = patcher.start() - message = TraversableMessage(self.encoder_mock.encode.call_args[0][0]) - varbindlist = message[PDUs][ResponsePDU][VarBindList] - self.assertEqual(varbindlist[0][ObjectName].value, (1, 2)) - self.assertEqual(varbindlist[0][_BindValue][ObjectSyntax] - [SimpleSyntax][Integer].value, - Integer(5)) + self.snmp_engine.msgAndPduDsp.receiveMessage( + self.snmp_engine, (1, 3, 6, 1), ('127.0.0.1', 12345), + MSG_SNMP_WALK + ) - def test_unsupported_command_returns_genError(self): - self.pdu_handler.message_handler(self.transport_dispatcher, - sentinel.transport_domain, - sentinel.transport_address, - MSG_SNMP_BULK_GET) + varbindlist = mock.call_args[0][4] - message = TraversableMessage(self.encoder_mock.encode.call_args[0][0]) + patcher.stop() - self.assertEqual(message[PDUs][ResponsePDU].get_by_index(1).value, - Integer(SNMP_ERR_genErr)) + self.assertEqual(varbindlist[0][0], (1, 2)) + self.assertEqual(varbindlist[0][1], Integer(5)) + + def test_unsupported_command_returns_error(self): + patcher = patch('pysnmp.proto.mpmod.rfc2576' + '.SnmpV2cMessageProcessingModel' + '.prepareResponseMessage') + mock = patcher.start() + mock.return_value = ( + (1, 3, 6, 1), ('127.0.0.1', 12345), b'' + ) + + self.snmp_engine.msgAndPduDsp.receiveMessage( + self.snmp_engine, (1, 3, 6, 1), ('127.0.0.1', 12345), + MSG_SNMP_BULK_GET + ) + + status_info = mock.call_args[0][11] + self.assertIsInstance(status_info['errorIndication'], + UnknownPDUHandler) + + patcher.stop() def test_doesnt_reply_with_wrong_community(self): - self.pdu_handler.message_handler(self.transport_dispatcher, - sentinel.transport_domain, - sentinel.transport_address, - MSG_SNMP_WRONG_COMM) + patcher = patch('pysnmp.proto.mpmod.rfc2576' + '.SnmpV2cMessageProcessingModel' + '.prepareResponseMessage') + mock = patcher.start() + mock.return_value = ( + (1, 3, 6, 1), ('127.0.0.1', 12345), b'' + ) - self.assertFalse(self.transport_dispatcher.sendMessage.called) + self.snmp_engine.msgAndPduDsp.receiveMessage( + self.snmp_engine, (1, 3, 6, 1), ('127.0.0.1', 12345), + MSG_SNMP_WRONG_COMM + ) + + self.assertEqual(mock.call_count, 0) + + patcher.stop()