Add many test cases for the RPC protocol and start making a Task structure.
This commit is contained in:
parent
7ead822716
commit
697f98b929
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,3 +11,5 @@ subunit-output.txt
|
|||||||
test-report.xml
|
test-report.xml
|
||||||
twisted/plugins/dropin.cache
|
twisted/plugins/dropin.cache
|
||||||
twistd.log
|
twistd.log
|
||||||
|
.coverage
|
||||||
|
_trial_coverage/
|
||||||
|
3
Makefile
3
Makefile
@ -15,6 +15,9 @@ else
|
|||||||
trial --random 0 ${UNITTESTS}
|
trial --random 0 ${UNITTESTS}
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
coverage:
|
||||||
|
coverage run --source=${CODEDIR} --branch `which trial` ${UNITTESTS} && coverage html -d _trial_coverage --omit="*/tests/*"
|
||||||
|
|
||||||
env:
|
env:
|
||||||
./scripts/bootstrap-virtualenv.sh
|
./scripts/bootstrap-virtualenv.sh
|
||||||
|
|
||||||
|
14
README.md
14
README.md
@ -37,21 +37,21 @@ Fatal Error:
|
|||||||
|
|
||||||
* `log`: (Agent->Server) Log a structured message from the Agent.
|
* `log`: (Agent->Server) Log a structured message from the Agent.
|
||||||
* `status`: (Server->Agent) Uptime, version, and other fields reported.
|
* `status`: (Server->Agent) Uptime, version, and other fields reported.
|
||||||
* `task_status`: (Agent->Server) Update status of a task. Task has a `.state`, which is `running`, `error` or `complete`. `running` will additionally contain `.eta` and `.percent`, a measure of how much work estimated to remain in seconds and how much work is done. Once `error` or `complete` is sent, no more updates will be sent. `error` state includes an additional human readable `.msg` field.
|
* `task_status`: (Agent->Server) Update status of a task. Task has an `.task_id` which was previously designated. Task has a `.state`, which is `running`, `error` or `complete`. `running` will additionally contain `.eta` and `.percent`, a measure of how much work estimated to remain in seconds and how much work is done. Once `error` or `complete` is sent, no more updates will be sent. `error` state includes an additional human readable `.msg` field.
|
||||||
|
|
||||||
|
|
||||||
#### Decommission
|
#### Decommission
|
||||||
|
|
||||||
* `decom.disk_erase`: (Server->Agent) Erase all attached block devices securely. Returns a Task ID.
|
* `decom.disk_erase`: (Server->Agent) Erase all attached block devices securely. Takes a `task_id`.
|
||||||
* `decom.firmware_secure`: (Server->Agent) Update Firmwares/BIOS versions and settings. Returns a Task ID.
|
* `decom.firmware_secure`: (Server->Agent) Update Firmwares/BIOS versions and settings. Takes a `task_id`.
|
||||||
* `decom.qc`: (Server->Agent) Run quality control checks on chassis model. Includes sending specifications of chassis (cpu types, disks, etc). Returns a Task ID.
|
* `decom.qc`: (Server->Agent) Run quality control checks on chassis model. Includes sending specifications of chassis (cpu types, disks, etc). Takes a `task_id`.
|
||||||
|
|
||||||
|
|
||||||
#### Standbye
|
#### Standbye
|
||||||
|
|
||||||
* `standbye.cache_images`: (Server->Agent) Cache an set of image UUID on local storage. Ordered in priority, chassis may only cache a subset depending on local storage. Returns a Task ID.
|
* `standbye.cache_images`: (Server->Agent) Cache an set of image UUID on local storage. Ordered in priority, chassis may only cache a subset depending on local storage. Takes a `task_id`.
|
||||||
* `standbye.prepare_image`: (Server->Agent) Prepare a image UUID to be ran. Returns a Task ID.
|
* `standbye.prepare_image`: (Server->Agent) Prepare a image UUID to be ran. Takes a `task_id`.
|
||||||
* `standbye.run_image`: (Server->Agent) Run an image UUID. Must include Config Drive Settings. Agent will write config drive, and setup grub. If the Agent can detect a viable kexec target it will kexec into it, otherwise reboot. Returns a Task ID.
|
* `standbye.run_image`: (Server->Agent) Run an image UUID. Must include Config Drive Settings. Agent will write config drive, and setup grub. If the Agent can detect a viable kexec target it will kexec into it, otherwise reboot. Takes a `task_id`.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,3 +4,5 @@ pep8==1.4.6
|
|||||||
pyflakes==0.7.3
|
pyflakes==0.7.3
|
||||||
junitxml==0.7
|
junitxml==0.7
|
||||||
python-subunit==0.0.15
|
python-subunit==0.0.15
|
||||||
|
mock==1.0.1
|
||||||
|
coverage==3.6
|
@ -27,7 +27,7 @@ class StandbyAgent(TeethClient):
|
|||||||
def __init__(self, addrs):
|
def __init__(self, addrs):
|
||||||
super(StandbyAgent, self).__init__(addrs)
|
super(StandbyAgent, self).__init__(addrs)
|
||||||
self._addHandler('v1', 'prepare_image', self.prepare_image)
|
self._addHandler('v1', 'prepare_image', self.prepare_image)
|
||||||
log.info('Starting agent', addrs=addrs)
|
log.msg('Starting agent', addrs=addrs)
|
||||||
|
|
||||||
def prepare_image(self, image_id):
|
def prepare_image(self, image_id):
|
||||||
"""Prepare an Image."""
|
"""Prepare an Image."""
|
||||||
|
@ -98,7 +98,7 @@ class TeethClient(MultiService, object):
|
|||||||
self._running = False
|
self._running = False
|
||||||
dl = []
|
dl = []
|
||||||
for client in self._clients:
|
for client in self._clients:
|
||||||
dl.append(client.loseConnectionSoon(timeout=0.05))
|
dl.append(client.abortConnection())
|
||||||
return DeferredList(dl)
|
return DeferredList(dl)
|
||||||
|
|
||||||
def remove_endpoint(self, host, port):
|
def remove_endpoint(self, host, port):
|
||||||
|
@ -32,6 +32,8 @@ def configure():
|
|||||||
|
|
||||||
structlog.configure(
|
structlog.configure(
|
||||||
context_class=dict,
|
context_class=dict,
|
||||||
|
logger_factory=structlog.twisted.LoggerFactory(),
|
||||||
|
wrapper_class=structlog.twisted.BoundLogger,
|
||||||
cache_logger_on_first_use=True)
|
cache_logger_on_first_use=True)
|
||||||
|
|
||||||
|
|
||||||
@ -39,4 +41,6 @@ def get_logger():
|
|||||||
"""
|
"""
|
||||||
Get a logger instance.
|
Get a logger instance.
|
||||||
"""
|
"""
|
||||||
|
configure()
|
||||||
|
|
||||||
return structlog.get_logger()
|
return structlog.get_logger()
|
||||||
|
@ -18,10 +18,8 @@ import simplejson as json
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet import task
|
from twisted.protocols import policies
|
||||||
from twisted.internet import reactor
|
|
||||||
from twisted.protocols.basic import LineReceiver
|
from twisted.protocols.basic import LineReceiver
|
||||||
from twisted.python.failure import Failure
|
|
||||||
from teeth_agent import __version__ as AGENT_VERSION
|
from teeth_agent import __version__ as AGENT_VERSION
|
||||||
from teeth_agent.events import EventEmitter
|
from teeth_agent.events import EventEmitter
|
||||||
from teeth_agent.logging import get_logger
|
from teeth_agent.logging import get_logger
|
||||||
@ -73,7 +71,9 @@ class RPCError(RPCMessage, RuntimeError):
|
|||||||
self._raw_message = message
|
self._raw_message = message
|
||||||
|
|
||||||
|
|
||||||
class RPCProtocol(LineReceiver, EventEmitter):
|
class RPCProtocol(LineReceiver,
|
||||||
|
EventEmitter,
|
||||||
|
policies.TimeoutMixin):
|
||||||
"""
|
"""
|
||||||
Twisted Protocol handler for the RPC Protocol of the Teeth
|
Twisted Protocol handler for the RPC Protocol of the Teeth
|
||||||
Agent <-> Endpoint communication.
|
Agent <-> Endpoint communication.
|
||||||
@ -96,29 +96,34 @@ class RPCProtocol(LineReceiver, EventEmitter):
|
|||||||
self._pending_command_deferreds = {}
|
self._pending_command_deferreds = {}
|
||||||
self._fatal_error = False
|
self._fatal_error = False
|
||||||
self._log = log.bind(host=address.host, port=address.port)
|
self._log = log.bind(host=address.host, port=address.port)
|
||||||
|
self._timeOut = 60
|
||||||
|
|
||||||
def loseConnectionSoon(self, timeout=10):
|
def timeoutConnection(self):
|
||||||
"""Attempt to disconnect from the transport as 'nicely' as possible. """
|
"""Action called when the connection has hit a timeout."""
|
||||||
self._log.info('Trying to disconnect.')
|
self.transport.abortConnection()
|
||||||
self.transport.loseConnection()
|
|
||||||
return task.deferLater(reactor, timeout, self.transport.abortConnection)
|
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
"""TCP hard. We made it. Maybe."""
|
"""TCP hard. We made it. Maybe."""
|
||||||
super(RPCProtocol, self).connectionMade()
|
super(RPCProtocol, self).connectionMade()
|
||||||
self._log.info('Connection established.')
|
self._log.msg('Connection established.')
|
||||||
self.transport.setTcpKeepAlive(True)
|
self.transport.setTcpKeepAlive(True)
|
||||||
self.transport.setTcpNoDelay(True)
|
self.transport.setTcpNoDelay(True)
|
||||||
self.emit('connect')
|
self.emit('connect')
|
||||||
|
|
||||||
|
def sendLine(self, line):
|
||||||
|
"""Send a line of content to our peer."""
|
||||||
|
self.resetTimeout()
|
||||||
|
super(RPCProtocol, self).sendLine(line)
|
||||||
|
|
||||||
def lineReceived(self, line):
|
def lineReceived(self, line):
|
||||||
"""Process a line of data."""
|
"""Process a line of data."""
|
||||||
|
self.resetTimeout()
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
|
|
||||||
if not line:
|
if not line:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._log.debug('Got Line', line=line)
|
self._log.msg('Got Line', line=line)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message = json.loads(line)
|
message = json.loads(line)
|
||||||
@ -127,16 +132,16 @@ class RPCProtocol(LineReceiver, EventEmitter):
|
|||||||
|
|
||||||
if 'fatal_error' in message:
|
if 'fatal_error' in message:
|
||||||
# TODO: Log what happened?
|
# TODO: Log what happened?
|
||||||
self.loseConnectionSoon()
|
self.transport.abortConnection()
|
||||||
return
|
return
|
||||||
|
|
||||||
if not message.get('id', None):
|
|
||||||
return self.fatal_error("protocol violation: missing message id.")
|
|
||||||
|
|
||||||
if not message.get('version', None):
|
if not message.get('version', None):
|
||||||
return self.fatal_error("protocol violation: missing message version.")
|
return self.fatal_error("protocol violation: missing message version.")
|
||||||
|
|
||||||
elif 'method' in message:
|
if not message.get('id', None):
|
||||||
|
return self.fatal_error("protocol violation: missing message id.")
|
||||||
|
|
||||||
|
if 'method' in message:
|
||||||
if not message.get('params', None):
|
if not message.get('params', None):
|
||||||
return self.fatal_error("protocol violation: missing message params.")
|
return self.fatal_error("protocol violation: missing message params.")
|
||||||
|
|
||||||
@ -145,24 +150,24 @@ class RPCProtocol(LineReceiver, EventEmitter):
|
|||||||
|
|
||||||
elif 'error' in message:
|
elif 'error' in message:
|
||||||
msg = RPCError(self, message)
|
msg = RPCError(self, message)
|
||||||
self._handle_response(message)
|
self._handle_response(msg)
|
||||||
|
|
||||||
elif 'result' in message:
|
elif 'result' in message:
|
||||||
|
|
||||||
msg = RPCResponse(self, message)
|
msg = RPCResponse(self, message)
|
||||||
self._handle_response(message)
|
self._handle_response(msg)
|
||||||
else:
|
else:
|
||||||
return self.fatal_error('protocol error: malformed message.')
|
return self.fatal_error('protocol error: malformed message.')
|
||||||
|
|
||||||
def fatal_error(self, message):
|
def fatal_error(self, message):
|
||||||
"""Send a fatal error message, and disconnect."""
|
"""Send a fatal error message, and disconnect."""
|
||||||
self._log.error('sending a fatal error', message=message)
|
self._log.msg('sending a fatal error', message=message)
|
||||||
if not self._fatal_error:
|
if not self._fatal_error:
|
||||||
self._fatal_error = True
|
self._fatal_error = True
|
||||||
self.sendLine(self.encoder.encode({
|
self.sendLine(self.encoder.encode({
|
||||||
'fatal_error': message
|
'fatal_error': message
|
||||||
}))
|
}))
|
||||||
self.loseConnectionSoon()
|
self.transport.abortConnection()
|
||||||
|
|
||||||
def send_command(self, method, params, timeout=60):
|
def send_command(self, method, params, timeout=60):
|
||||||
"""Send a new command."""
|
"""Send a new command."""
|
||||||
@ -196,11 +201,13 @@ class RPCProtocol(LineReceiver, EventEmitter):
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
def _handle_response(self, message):
|
def _handle_response(self, message):
|
||||||
d = self.pending_command_deferreds.pop(message['id'])
|
d = self._pending_command_deferreds.pop(message.id, None)
|
||||||
|
|
||||||
|
if not d:
|
||||||
|
return self.fatal_error("protocol violation: unknown message id referenced.")
|
||||||
|
|
||||||
if isinstance(message, RPCError):
|
if isinstance(message, RPCError):
|
||||||
f = Failure(message)
|
d.errback(message)
|
||||||
d.errback(f)
|
|
||||||
else:
|
else:
|
||||||
d.callback(message)
|
d.callback(message)
|
||||||
|
|
||||||
@ -208,7 +215,7 @@ class RPCProtocol(LineReceiver, EventEmitter):
|
|||||||
d = self.emit('command', message)
|
d = self.emit('command', message)
|
||||||
|
|
||||||
if len(d) == 0:
|
if len(d) == 0:
|
||||||
return self.fatal_error("protocol violation: unsupported command")
|
return self.fatal_error("protocol violation: unsupported command.")
|
||||||
|
|
||||||
# TODO: do we need to wait on anything here?
|
# TODO: do we need to wait on anything here?
|
||||||
pass
|
pass
|
||||||
@ -224,9 +231,10 @@ class TeethAgentProtocol(RPCProtocol):
|
|||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.address = address
|
self.address = address
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.on('connect', self._on_connect)
|
self.once('connect', self._once_connect)
|
||||||
|
|
||||||
|
def _once_connect(self, event):
|
||||||
|
|
||||||
def _on_connect(self, event):
|
|
||||||
def _response(result):
|
def _response(result):
|
||||||
self._log.msg('Handshake successful', connection_id=result['id'])
|
self._log.msg('Handshake successful', connection_id=result['id'])
|
||||||
|
|
||||||
|
100
teeth_agent/task.py
Normal file
100
teeth_agent/task.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2013 Rackspace, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from twisted.application.service import MultiService
|
||||||
|
from twisted.application.internet import TimerService
|
||||||
|
from teeth_agent.logging import get_logger
|
||||||
|
log = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['Task', 'PrepareImageTask']
|
||||||
|
|
||||||
|
|
||||||
|
class Task(MultiService, object):
|
||||||
|
"""
|
||||||
|
Task to execute, reporting status periodically to TeethClient instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_name = 'task_undefined'
|
||||||
|
|
||||||
|
def __init__(self, client, task_id, task_name, reporting_interval=10):
|
||||||
|
super(Task, self).__init__()
|
||||||
|
self.setName(self.task_name)
|
||||||
|
self._client = client
|
||||||
|
self._id = task_id
|
||||||
|
self._percent = 0
|
||||||
|
self._reporting_interval = reporting_interval
|
||||||
|
self._state = 'starting'
|
||||||
|
self._timer = TimerService(self._reporting_interval, self._tick)
|
||||||
|
self._timer.setServiceParent(self)
|
||||||
|
self._error_msg = None
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""Run the Task."""
|
||||||
|
# setServiceParent actually starts the task if it is already running
|
||||||
|
# so we run it in start.
|
||||||
|
self.setServiceParent(self._client)
|
||||||
|
|
||||||
|
def _tick(self):
|
||||||
|
if not self.running:
|
||||||
|
# log.debug("_tick called while not running :()")
|
||||||
|
return
|
||||||
|
return self._client.update_task_status(self)
|
||||||
|
|
||||||
|
def error(self, message):
|
||||||
|
"""Error out running of the task."""
|
||||||
|
self._error_msg = message
|
||||||
|
self._state = 'error'
|
||||||
|
self.stopService()
|
||||||
|
|
||||||
|
def complete(self):
|
||||||
|
"""Complete running of the task."""
|
||||||
|
self._state = 'complete'
|
||||||
|
self.stopService()
|
||||||
|
|
||||||
|
def startService(self):
|
||||||
|
"""Start the Service."""
|
||||||
|
super(Task, self).startService()
|
||||||
|
self._state = 'running'
|
||||||
|
|
||||||
|
def stopService(self):
|
||||||
|
"""Stop the Service."""
|
||||||
|
super(Task, self).stopService()
|
||||||
|
|
||||||
|
if not self._client.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._state not in ['error', 'complete']:
|
||||||
|
log.err("told to shutdown before task could complete, marking as error.")
|
||||||
|
self._error_msg = 'service being shutdown'
|
||||||
|
self._state = 'error'
|
||||||
|
|
||||||
|
self._client.finish_task(self)
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareImageTask(Task):
|
||||||
|
|
||||||
|
"""Prepare an image to be ran on the machine."""
|
||||||
|
|
||||||
|
task_name = 'prepare_image'
|
||||||
|
|
||||||
|
def __init__(self, client, task_id, image_info, reporting_interval=10):
|
||||||
|
super(PrepareImageTask, self).__init__(client, task_id)
|
||||||
|
self._image_info = image_info
|
||||||
|
|
||||||
|
def run():
|
||||||
|
"""Run the Prepare Image task."""
|
||||||
|
pass
|
190
teeth_agent/tests/test_protocol.py
Normal file
190
teeth_agent/tests/test_protocol.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2013 Rackspace, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from twisted.internet import main
|
||||||
|
from twisted.internet.address import IPv4Address
|
||||||
|
from twisted.python import failure
|
||||||
|
from twisted.test.proto_helpers import StringTransportWithDisconnection
|
||||||
|
from twisted.trial import unittest
|
||||||
|
import simplejson as json
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
|
from teeth_agent.protocol import RPCError, RPCProtocol, TeethAgentProtocol
|
||||||
|
from teeth_agent import __version__ as AGENT_VERSION
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTCPTransport(StringTransportWithDisconnection, object):
|
||||||
|
_aborting = False
|
||||||
|
disconnected = False
|
||||||
|
|
||||||
|
setTcpKeepAlive = Mock(return_value=None)
|
||||||
|
setTcpNoDelay = Mock(return_value=None)
|
||||||
|
setTcpNoDelay = Mock(return_value=None)
|
||||||
|
|
||||||
|
def connectionLost(self, reason):
|
||||||
|
self.protocol.connectionLost(reason)
|
||||||
|
|
||||||
|
def abortConnection(self):
|
||||||
|
if self.disconnected or self._aborting:
|
||||||
|
return
|
||||||
|
self._aborting = True
|
||||||
|
self.connectionLost(failure.Failure(main.CONNECTION_DONE))
|
||||||
|
|
||||||
|
|
||||||
|
class RPCProtocolTest(unittest.TestCase):
|
||||||
|
"""RPC Protocol tests."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tr = FakeTCPTransport()
|
||||||
|
self.proto = RPCProtocol(json.JSONEncoder(), IPv4Address('TCP', '127.0.0.1', 0))
|
||||||
|
self.proto.makeConnection(self.tr)
|
||||||
|
self.tr.protocol = self.proto
|
||||||
|
|
||||||
|
def test_timeout(self):
|
||||||
|
d = defer.Deferred()
|
||||||
|
called = []
|
||||||
|
orig = self.proto.connectionLost
|
||||||
|
|
||||||
|
def lost(arg):
|
||||||
|
orig()
|
||||||
|
called.append(True)
|
||||||
|
d.callback(True)
|
||||||
|
|
||||||
|
self.proto.connectionLost = lost
|
||||||
|
self.proto.timeoutConnection()
|
||||||
|
|
||||||
|
def check(ignore):
|
||||||
|
self.assertEqual(called, [True])
|
||||||
|
|
||||||
|
d.addCallback(check)
|
||||||
|
return d
|
||||||
|
|
||||||
|
def test_recv_command_no_params(self):
|
||||||
|
self.tr.clear()
|
||||||
|
self.proto.lineReceived(json.dumps({'id': '1', 'version': 'v1', 'method': 'BOGUS_STUFF'}))
|
||||||
|
obj = json.loads(self.tr.io.getvalue().strip())
|
||||||
|
self.assertEqual(obj['fatal_error'], 'protocol violation: missing message params.')
|
||||||
|
|
||||||
|
def test_recv_bogus_command(self):
|
||||||
|
self.tr.clear()
|
||||||
|
self.proto.lineReceived(
|
||||||
|
json.dumps({'id': '1', 'version': 'v1', 'method': 'BOGUS_STUFF', 'params': {'d': '1'}}))
|
||||||
|
obj = json.loads(self.tr.io.getvalue().strip())
|
||||||
|
self.assertEqual(obj['fatal_error'], 'protocol violation: unsupported command.')
|
||||||
|
|
||||||
|
def test_recv_valid_json_no_id(self):
|
||||||
|
self.tr.clear()
|
||||||
|
self.proto.lineReceived(json.dumps({'version': 'v913'}))
|
||||||
|
obj = json.loads(self.tr.io.getvalue().strip())
|
||||||
|
self.assertEqual(obj['fatal_error'], 'protocol violation: missing message id.')
|
||||||
|
|
||||||
|
def test_recv_valid_json_no_version(self):
|
||||||
|
self.tr.clear()
|
||||||
|
self.proto.lineReceived(json.dumps({'version': None, 'id': 'foo'}))
|
||||||
|
obj = json.loads(self.tr.io.getvalue().strip())
|
||||||
|
self.assertEqual(obj['fatal_error'], 'protocol violation: missing message version.')
|
||||||
|
|
||||||
|
def test_recv_invalid_data(self):
|
||||||
|
self.tr.clear()
|
||||||
|
self.proto.lineReceived('')
|
||||||
|
self.proto.lineReceived('invalid json!')
|
||||||
|
obj = json.loads(self.tr.io.getvalue().strip())
|
||||||
|
self.assertEqual(obj['fatal_error'], 'protocol error: unable to decode message.')
|
||||||
|
|
||||||
|
def test_recv_missing_key_parts(self):
|
||||||
|
self.tr.clear()
|
||||||
|
self.proto.lineReceived(json.dumps(
|
||||||
|
{'id': '1', 'version': 'v1'}))
|
||||||
|
obj = json.loads(self.tr.io.getvalue().strip())
|
||||||
|
self.assertEqual(obj['fatal_error'], 'protocol error: malformed message.')
|
||||||
|
|
||||||
|
def test_recv_error_to_unknown_id(self):
|
||||||
|
self.tr.clear()
|
||||||
|
self.proto.lineReceived(json.dumps(
|
||||||
|
{'id': '1', 'version': 'v1', 'error': {'msg': 'something is wrong'}}))
|
||||||
|
obj = json.loads(self.tr.io.getvalue().strip())
|
||||||
|
self.assertEqual(obj['fatal_error'], 'protocol violation: unknown message id referenced.')
|
||||||
|
|
||||||
|
def _send_command(self):
|
||||||
|
self.tr.clear()
|
||||||
|
d = self.proto.send_command('test_command', {'body': 42})
|
||||||
|
req = json.loads(self.tr.io.getvalue().strip())
|
||||||
|
self.tr.clear()
|
||||||
|
return (d, req)
|
||||||
|
|
||||||
|
def test_recv_result(self):
|
||||||
|
dout = defer.Deferred()
|
||||||
|
d, req = self._send_command()
|
||||||
|
self.proto.lineReceived(json.dumps(
|
||||||
|
{'id': req['id'], 'version': 'v1', 'result': {'duh': req['params']['body']}}))
|
||||||
|
self.assertEqual(len(self.tr.io.getvalue()), 0)
|
||||||
|
|
||||||
|
def check(resp):
|
||||||
|
self.assertEqual(resp.result['duh'], 42)
|
||||||
|
dout.callback(True)
|
||||||
|
|
||||||
|
d.addCallback(check)
|
||||||
|
|
||||||
|
return dout
|
||||||
|
|
||||||
|
def test_recv_error(self):
|
||||||
|
d, req = self._send_command()
|
||||||
|
self.proto.lineReceived(json.dumps(
|
||||||
|
{'id': req['id'], 'version': 'v1', 'error': {'msg': 'something is wrong'}}))
|
||||||
|
self.assertEqual(len(self.tr.io.getvalue()), 0)
|
||||||
|
return self.assertFailure(d, RPCError)
|
||||||
|
|
||||||
|
def test_recv_fatal_error(self):
|
||||||
|
d = defer.Deferred()
|
||||||
|
called = []
|
||||||
|
orig = self.proto.connectionLost
|
||||||
|
|
||||||
|
def lost(arg):
|
||||||
|
self.failUnless(isinstance(arg, failure.Failure))
|
||||||
|
orig()
|
||||||
|
called.append(True)
|
||||||
|
d.callback(True)
|
||||||
|
|
||||||
|
self.proto.connectionLost = lost
|
||||||
|
|
||||||
|
def check(ignore):
|
||||||
|
self.assertEqual(called, [True])
|
||||||
|
|
||||||
|
d.addCallback(check)
|
||||||
|
|
||||||
|
self.tr.clear()
|
||||||
|
self.proto.lineReceived(json.dumps({'fatal_error': 'you be broken'}))
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
class TeethAgentProtocolTest(unittest.TestCase):
|
||||||
|
"""Teeth Agent Protocol tests."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tr = FakeTCPTransport()
|
||||||
|
self.proto = TeethAgentProtocol(json.JSONEncoder(), IPv4Address('TCP', '127.0.0.1', 0), None)
|
||||||
|
self.proto.makeConnection(self.tr)
|
||||||
|
self.tr.protocol = self.proto
|
||||||
|
|
||||||
|
def test_on_connect(self):
|
||||||
|
obj = json.loads(self.tr.io.getvalue().strip())
|
||||||
|
self.assertEqual(obj['version'], 'v1')
|
||||||
|
self.assertEqual(obj['method'], 'handshake')
|
||||||
|
self.assertEqual(obj['method'], 'handshake')
|
||||||
|
self.assertEqual(obj['params']['id'], 'a:b:c:d')
|
||||||
|
self.assertEqual(obj['params']['version'], AGENT_VERSION)
|
53
teeth_agent/tests/test_task.py
Normal file
53
teeth_agent/tests/test_task.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2013 Rackspace, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from twisted.trial import unittest
|
||||||
|
from teeth_agent.task import Task
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
|
|
||||||
|
class FakeClient(object):
|
||||||
|
addService = Mock(return_value=None)
|
||||||
|
running = Mock(return_value=0)
|
||||||
|
update_task_status = Mock(return_value=None)
|
||||||
|
finish_task = Mock(return_value=None)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskTest(unittest.TestCase):
|
||||||
|
"""Event Emitter tests."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.task_id = str(uuid.uuid4())
|
||||||
|
self.client = FakeClient()
|
||||||
|
self.task = Task(self.client, self.task_id, 'test_task')
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
del self.task_id
|
||||||
|
del self.task
|
||||||
|
del self.client
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
self.assertEqual(self.task._state, 'starting')
|
||||||
|
self.assertEqual(self.task._id, self.task_id)
|
||||||
|
self.task.run()
|
||||||
|
self.client.addService.assert_called_once_with(self.task)
|
||||||
|
self.task.startService()
|
||||||
|
self.client.update_task_status.assert_called_once_with(self.task)
|
||||||
|
self.task.complete()
|
||||||
|
self.assertEqual(self.task._state, 'complete')
|
||||||
|
self.client.finish_task.assert_called_once_with(self.task)
|
Loading…
x
Reference in New Issue
Block a user