mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-18 07:23:00 +00:00
add mock utils to tests
basic testing framework to get starting writing and testing the REST API including tests for the authorizations API, the only API URL defined so far.
This commit is contained in:
@@ -10,7 +10,10 @@ from tornado import ioloop
|
||||
|
||||
from .. import orm
|
||||
|
||||
# global session object
|
||||
from .mocking import MockHubApp
|
||||
|
||||
|
||||
# global db session object
|
||||
_db = None
|
||||
|
||||
@fixture
|
||||
@@ -31,7 +34,18 @@ def db():
|
||||
_db.commit()
|
||||
return _db
|
||||
|
||||
|
||||
@fixture
|
||||
def io_loop():
|
||||
"""Get the current IOLoop"""
|
||||
ioloop.IOLoop.clear_current()
|
||||
return ioloop.IOLoop.current()
|
||||
|
||||
|
||||
|
||||
@fixture
|
||||
def app(request):
|
||||
app = MockHubApp()
|
||||
app.start([])
|
||||
request.addfinalizer(app.stop)
|
||||
return app
|
||||
|
77
jupyterhub/tests/mocking.py
Normal file
77
jupyterhub/tests/mocking.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""mock utilities for testing"""
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
import mock
|
||||
|
||||
import getpass
|
||||
import threading
|
||||
|
||||
from tornado.ioloop import IOLoop
|
||||
|
||||
from IPython.utils.py3compat import unicode_type
|
||||
|
||||
from ..spawner import LocalProcessSpawner
|
||||
from ..app import JupyterHubApp
|
||||
from ..auth import PAMAuthenticator
|
||||
from .. import orm
|
||||
|
||||
def mock_authenticate(username, password, service='login'):
|
||||
# mimic simplepam's failure to handle unicode
|
||||
if isinstance(username, unicode_type):
|
||||
return False
|
||||
if isinstance(password, unicode_type):
|
||||
return False
|
||||
|
||||
# just use equality for testing
|
||||
if password == username:
|
||||
return True
|
||||
|
||||
|
||||
class MockSpawner(LocalProcessSpawner):
|
||||
|
||||
def make_preexec_fn(self):
|
||||
# skip the setuid stuff
|
||||
return
|
||||
|
||||
def _set_user_changed(self, name, old, new):
|
||||
pass
|
||||
|
||||
class MockPAMAuthenticator(PAMAuthenticator):
|
||||
def authenticate(self, *args, **kwargs):
|
||||
with mock.patch('simplepam.authenticate', mock_authenticate):
|
||||
return super(MockPAMAuthenticator, self).authenticate(*args, **kwargs)
|
||||
|
||||
class MockHubApp(JupyterHubApp):
|
||||
"""HubApp with various mock bits"""
|
||||
# def start_proxy(self):
|
||||
# pass
|
||||
def _authenticator_default(self):
|
||||
return '%s.%s' % (__name__, 'MockPAMAuthenticator')
|
||||
|
||||
def _spawner_class_default(self):
|
||||
return '%s.%s' % (__name__, 'MockSpawner')
|
||||
|
||||
def start(self, argv=None):
|
||||
evt = threading.Event()
|
||||
def _start():
|
||||
self.io_loop = IOLoop.current()
|
||||
# put initialize in start for SQLAlchemy threading reasons
|
||||
super(MockHubApp, self).initialize(argv=argv)
|
||||
user = orm.User(name=getpass.getuser())
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
token = user.new_api_token()
|
||||
self.db.add(token)
|
||||
self.db.commit()
|
||||
self.io_loop.add_callback(evt.set)
|
||||
super(MockHubApp, self).start()
|
||||
|
||||
self._thread = threading.Thread(target=_start)
|
||||
self._thread.start()
|
||||
evt.wait(timeout=5)
|
||||
|
||||
def stop(self):
|
||||
self.io_loop.stop()
|
||||
self._thread.join()
|
||||
|
49
jupyterhub/tests/test_api.py
Normal file
49
jupyterhub/tests/test_api.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Tests for the REST API"""
|
||||
|
||||
import requests
|
||||
|
||||
from ..utils import url_path_join as ujoin
|
||||
from .. import orm
|
||||
|
||||
|
||||
def api_request(app, *api_path, **kwargs):
|
||||
"""Make an API request"""
|
||||
base_url = app.hub.server.url
|
||||
token = app.db.query(orm.APIToken).first()
|
||||
kwargs.setdefault('headers', {})
|
||||
kwargs['headers'].setdefault('Authorization', 'token %s' % token.token)
|
||||
|
||||
url = ujoin(base_url, 'api', *api_path)
|
||||
method = kwargs.pop('method', 'get')
|
||||
f = getattr(requests, method)
|
||||
return f(url, **kwargs)
|
||||
|
||||
def test_auth_api(app):
|
||||
db = app.db
|
||||
r = api_request(app, 'authorizations', 'gobbledygook')
|
||||
assert r.status_code == 404
|
||||
|
||||
# make a new cookie token
|
||||
user = db.query(orm.User).first()
|
||||
cookie_token = user.new_cookie_token()
|
||||
db.add(cookie_token)
|
||||
db.commit()
|
||||
|
||||
# check success:
|
||||
r = api_request(app, 'authorizations', cookie_token.token)
|
||||
assert r.status_code == 200
|
||||
reply = r.json()
|
||||
assert reply['user'] == user.name
|
||||
|
||||
# check fail
|
||||
r = api_request(app, 'authorizations', cookie_token.token,
|
||||
headers={'Authorization': 'no sir'},
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
r = api_request(app, 'authorizations', cookie_token.token,
|
||||
headers={'Authorization': 'token: %s' % cookie_token.token},
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
|
@@ -3,61 +3,39 @@
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
try:
|
||||
from unittest import mock # py3
|
||||
except ImportError:
|
||||
import mock
|
||||
|
||||
import simplepam
|
||||
from IPython.utils.py3compat import unicode_type
|
||||
|
||||
from ..auth import PAMAuthenticator
|
||||
|
||||
|
||||
def fake_auth(username, password, service='login'):
|
||||
# mimic simplepam's failure to handle unicode
|
||||
if isinstance(username, unicode_type):
|
||||
return False
|
||||
if isinstance(password, unicode_type):
|
||||
return False
|
||||
|
||||
# just use equality
|
||||
if password == username:
|
||||
return True
|
||||
from .mocking import MockPAMAuthenticator
|
||||
|
||||
|
||||
def test_pam_auth(io_loop):
|
||||
authenticator = PAMAuthenticator()
|
||||
with mock.patch('simplepam.authenticate', fake_auth):
|
||||
authorized = io_loop.run_sync(lambda : authenticator.authenticate(None, {
|
||||
u'username': u'match',
|
||||
u'password': u'match',
|
||||
}))
|
||||
assert authorized == u'match'
|
||||
|
||||
authorized = io_loop.run_sync(lambda : authenticator.authenticate(None, {
|
||||
u'username': u'match',
|
||||
u'password': u'nomatch',
|
||||
}))
|
||||
authenticator = MockPAMAuthenticator()
|
||||
authorized = io_loop.run_sync(lambda : authenticator.authenticate(None, {
|
||||
u'username': u'match',
|
||||
u'password': u'match',
|
||||
}))
|
||||
assert authorized == u'match'
|
||||
|
||||
authorized = io_loop.run_sync(lambda : authenticator.authenticate(None, {
|
||||
u'username': u'match',
|
||||
u'password': u'nomatch',
|
||||
}))
|
||||
assert authorized is None
|
||||
|
||||
def test_pam_auth_whitelist(io_loop):
|
||||
authenticator = PAMAuthenticator(whitelist={'wash', 'kaylee'})
|
||||
with mock.patch('simplepam.authenticate', fake_auth):
|
||||
authorized = io_loop.run_sync(lambda : authenticator.authenticate(None, {
|
||||
u'username': u'kaylee',
|
||||
u'password': u'kaylee',
|
||||
}))
|
||||
assert authorized == u'kaylee'
|
||||
|
||||
authorized = io_loop.run_sync(lambda : authenticator.authenticate(None, {
|
||||
u'username': u'wash',
|
||||
u'password': u'nomatch',
|
||||
}))
|
||||
assert authorized is None
|
||||
|
||||
authorized = io_loop.run_sync(lambda : authenticator.authenticate(None, {
|
||||
u'username': u'mal',
|
||||
u'password': u'mal',
|
||||
}))
|
||||
assert authorized is None
|
||||
authenticator = MockPAMAuthenticator(whitelist={'wash', 'kaylee'})
|
||||
authorized = io_loop.run_sync(lambda : authenticator.authenticate(None, {
|
||||
u'username': u'kaylee',
|
||||
u'password': u'kaylee',
|
||||
}))
|
||||
assert authorized == u'kaylee'
|
||||
|
||||
authorized = io_loop.run_sync(lambda : authenticator.authenticate(None, {
|
||||
u'username': u'wash',
|
||||
u'password': u'nomatch',
|
||||
}))
|
||||
assert authorized is None
|
||||
|
||||
authorized = io_loop.run_sync(lambda : authenticator.authenticate(None, {
|
||||
u'username': u'mal',
|
||||
u'password': u'mal',
|
||||
}))
|
||||
assert authorized is None
|
||||
|
@@ -15,7 +15,7 @@ from .. import orm
|
||||
_echo_sleep = """
|
||||
import sys, time
|
||||
print(sys.argv)
|
||||
time.sleep(10)
|
||||
time.sleep(30)
|
||||
"""
|
||||
|
||||
_uninterruptible = """
|
||||
|
Reference in New Issue
Block a user