catching invalid urls and adding tests

This commit is contained in:
David Goetz 2010-10-29 13:30:34 -07:00
parent 937554c85e
commit 1fc40d6c29
4 changed files with 36 additions and 1 deletions

View File

@ -14,6 +14,7 @@
import time
import eventlet
from webob import Request, Response
from webob.exc import HTTPNotFound
from swift.common.utils import split_path, cache_from_env, get_logger
from swift.proxy.server import get_container_memcache_key
@ -204,7 +205,10 @@ class RateLimitMiddleware(object):
req = Request(env)
if self.memcache_client is None:
self.memcache_client = cache_from_env(env)
version, account, container, obj = split_path(req.path, 1, 4, True)
try:
version, account, container, obj = split_path(req.path, 1, 4, True)
except ValueError:
return HTTPNotFound()(env, start_response)
ratelimit_resp = self.handle_ratelimit(req, account, container, obj)
if ratelimit_resp is None:
return self.app(env, start_response)

View File

@ -208,6 +208,7 @@ def split_path(path, minsegs=1, maxsegs=None, rest_with_last=False):
trailing data, raises ValueError.
:returns: list of segments with a length of maxsegs (non-existant
segments will return as None)
:raises: ValueError if given an invalid path
"""
if not maxsegs:
maxsegs = minsegs
@ -622,6 +623,7 @@ def write_pickle(obj, dest, tmp):
os.fsync(fd)
renamer(tmppath, dest)
def audit_location_generator(devices, datadir, mount_check=True, logger=None):
'''
Given a devices path and a data directory, yield (path, device,

View File

@ -170,6 +170,15 @@ class TestAccount(Base):
self.assert_status(412)
self.assert_body('Bad URL')
def testInvalidPath(self):
was_url = self.env.account.conn.storage_url
self.env.account.conn.storage_url = "/%s" % was_url
self.env.account.conn.make_request('GET')
try:
self.assert_status(404)
finally:
self.env.account.conn.storage_url = was_url
def testPUT(self):
self.env.account.conn.make_request('PUT')
self.assert_status([403, 405])

View File

@ -366,6 +366,26 @@ class TestRateLimit(unittest.TestCase):
time_took = time.time() - begin
self.assert_(round(time_took, 1) == .4)
def test_call_invalid_path(self):
env = {'REQUEST_METHOD': 'GET',
'SCRIPT_NAME': '',
'PATH_INFO': '//v1/AUTH_1234567890',
'SERVER_NAME': '127.0.0.1',
'SERVER_PORT': '80',
'swift.cache': FakeMemcache(),
'SERVER_PROTOCOL': 'HTTP/1.0'}
app = lambda *args, **kwargs: None
rate_mid = ratelimit.RateLimitMiddleware(app, {},
logger=FakeLogger())
class a_callable(object):
def __call__(self, *args, **kwargs):
pass
resp = rate_mid.__call__(env, a_callable())
self.assert_('404 Not Found' in resp[0])
if __name__ == '__main__':
unittest.main()