Protocols are now stored in a list so that the order is considered when selecting the right protocol for a request.

This commit is contained in:
Christophe de Vienne 2011-10-14 18:14:43 +02:00
parent 23b2874cdb
commit cfd98b8ee4

View File

@ -223,7 +223,7 @@ class WSRoot(object):
def __init__(self, protocols=[], webpath=''): def __init__(self, protocols=[], webpath=''):
self._debug = True self._debug = True
self._webpath = webpath self._webpath = webpath
self.protocols = {} self.protocols = []
for protocol in protocols: for protocol in protocols:
self.addprotocol(protocol) self.addprotocol(protocol)
@ -238,7 +238,7 @@ class WSRoot(object):
""" """
if isinstance(protocol, str): if isinstance(protocol, str):
protocol = getprotocol(protocol, **options) protocol = getprotocol(protocol, **options)
self.protocols[protocol.name] = protocol self.protocols.append(protocol)
protocol.root = weakref.proxy(self) protocol.root = weakref.proxy(self)
def getapi(self): def getapi(self):
@ -251,6 +251,11 @@ class WSRoot(object):
self._api = [i for i in scan_api(self)] self._api = [i for i in scan_api(self)]
return self._api return self._api
def _get_protocol(self, name):
for protocol in self.protocols:
if protocol.name == name:
return protocol
def _select_protocol(self, request): def _select_protocol(self, request):
log.debug("Selecting a protocol for the following request :\n" log.debug("Selecting a protocol for the following request :\n"
"headers: %s\nbody: %s", request.headers, "headers: %s\nbody: %s", request.headers,
@ -259,10 +264,10 @@ class WSRoot(object):
or request.body) or request.body)
protocol = None protocol = None
if 'wsmeproto' in request.params: if 'wsmeproto' in request.params:
protocol = self.protocols[request.params['wsmeproto']] return self._get_protocol(request.params['wsmeproto'])
else: else:
for p in self.protocols.values(): for p in self.protocols:
if p.accept(request): if p.accept(request):
protocol = p protocol = p
break break
@ -275,7 +280,8 @@ class WSRoot(object):
protocol = self._select_protocol(request) protocol = self._select_protocol(request)
if protocol is None: if protocol is None:
msg = ("None of the following protocols can handle this " msg = ("None of the following protocols can handle this "
"request : %s" % ','.join(self.protocols.keys())) "request : %s" % ','.join(
(p.name for p in self.protocols)))
res.status = 500 res.status = 500
res.content_type = 'text/plain' res.content_type = 'text/plain'
res.body = msg res.body = msg
@ -298,7 +304,10 @@ class WSRoot(object):
res.status = 200 res.status = 200
if funcdef.protocol_specific and funcdef.return_type is None: if funcdef.protocol_specific and funcdef.return_type is None:
res.body = result if isinstance(result, unicode):
res.unicode_body = result
else:
res.body = result
else: else:
# TODO make sure result type == a._wsme_definition.return_type # TODO make sure result type == a._wsme_definition.return_type
res.body = protocol.encode_result(funcdef, result) res.body = protocol.encode_result(funcdef, result)
@ -341,7 +350,7 @@ class WSRoot(object):
isprotocol_specific = path[0] == '_protocol' isprotocol_specific = path[0] == '_protocol'
if isprotocol_specific: if isprotocol_specific:
a = self.protocols[path[1]] a = self._get_protocol(path[1])
path = path[2:] path = path[2:]
for name in path: for name in path: