Protocols are now stored in a list so that the order is considered when selecting the right protocol for a request.

This commit is contained in:
Christophe de Vienne 2011-10-14 18:14:43 +02:00
parent 23b2874cdb
commit cfd98b8ee4

View File

@ -223,7 +223,7 @@ class WSRoot(object):
def __init__(self, protocols=[], webpath=''):
self._debug = True
self._webpath = webpath
self.protocols = {}
self.protocols = []
for protocol in protocols:
self.addprotocol(protocol)
@ -238,7 +238,7 @@ class WSRoot(object):
"""
if isinstance(protocol, str):
protocol = getprotocol(protocol, **options)
self.protocols[protocol.name] = protocol
self.protocols.append(protocol)
protocol.root = weakref.proxy(self)
def getapi(self):
@ -251,6 +251,11 @@ class WSRoot(object):
self._api = [i for i in scan_api(self)]
return self._api
def _get_protocol(self, name):
for protocol in self.protocols:
if protocol.name == name:
return protocol
def _select_protocol(self, request):
log.debug("Selecting a protocol for the following request :\n"
"headers: %s\nbody: %s", request.headers,
@ -259,10 +264,10 @@ class WSRoot(object):
or request.body)
protocol = None
if 'wsmeproto' in request.params:
protocol = self.protocols[request.params['wsmeproto']]
return self._get_protocol(request.params['wsmeproto'])
else:
for p in self.protocols.values():
for p in self.protocols:
if p.accept(request):
protocol = p
break
@ -275,7 +280,8 @@ class WSRoot(object):
protocol = self._select_protocol(request)
if protocol is None:
msg = ("None of the following protocols can handle this "
"request : %s" % ','.join(self.protocols.keys()))
"request : %s" % ','.join(
(p.name for p in self.protocols)))
res.status = 500
res.content_type = 'text/plain'
res.body = msg
@ -298,6 +304,9 @@ class WSRoot(object):
res.status = 200
if funcdef.protocol_specific and funcdef.return_type is None:
if isinstance(result, unicode):
res.unicode_body = result
else:
res.body = result
else:
# TODO make sure result type == a._wsme_definition.return_type
@ -341,7 +350,7 @@ class WSRoot(object):
isprotocol_specific = path[0] == '_protocol'
if isprotocol_specific:
a = self.protocols[path[1]]
a = self._get_protocol(path[1])
path = path[2:]
for name in path: