mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-19 07:53:00 +00:00
@@ -19,7 +19,7 @@ except Exception as e:
|
|||||||
from tornado.concurrent import run_on_executor
|
from tornado.concurrent import run_on_executor
|
||||||
|
|
||||||
from traitlets.config import LoggingConfigurable
|
from traitlets.config import LoggingConfigurable
|
||||||
from traitlets import Bool, Set, Unicode, Dict, Any, default, observe
|
from traitlets import Bool, Integer, Set, Unicode, Dict, Any, default, observe
|
||||||
|
|
||||||
from .handlers.login import LoginHandler
|
from .handlers.login import LoginHandler
|
||||||
from .utils import maybe_future, url_path_join
|
from .utils import maybe_future, url_path_join
|
||||||
@@ -50,6 +50,35 @@ class Authenticator(LoggingConfigurable):
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
auth_refresh_age = Integer(
|
||||||
|
300,
|
||||||
|
config=True,
|
||||||
|
help="""The max age (in seconds) of authentication info
|
||||||
|
before forcing a refresh of user auth info.
|
||||||
|
|
||||||
|
Refreshing auth info allows, e.g. requesting/re-validating auth tokens.
|
||||||
|
|
||||||
|
See :meth:`.refresh_user` for what happens when user auth info is refreshed
|
||||||
|
(nothing by default).
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_pre_spawn = Bool(
|
||||||
|
False,
|
||||||
|
config=True,
|
||||||
|
help="""Force refresh of auth prior to spawn.
|
||||||
|
|
||||||
|
This forces :meth:`.refresh_user` to be called prior to launching
|
||||||
|
a server, to ensure that auth state is up-to-date.
|
||||||
|
|
||||||
|
This can be important when e.g. auth tokens that may have expired
|
||||||
|
are passed to the spawner via environment variables from auth_state.
|
||||||
|
|
||||||
|
If refresh_user cannot refresh the user auth data,
|
||||||
|
launch will fail until the user logs in again.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
admin_users = Set(
|
admin_users = Set(
|
||||||
help="""
|
help="""
|
||||||
Set of users that will have admin rights on this JupyterHub.
|
Set of users that will have admin rights on this JupyterHub.
|
||||||
|
@@ -28,6 +28,7 @@ from .. import __version__
|
|||||||
from .. import orm
|
from .. import orm
|
||||||
from ..objects import Server
|
from ..objects import Server
|
||||||
from ..spawner import LocalProcessSpawner
|
from ..spawner import LocalProcessSpawner
|
||||||
|
from ..user import User
|
||||||
from ..utils import maybe_future, url_path_join
|
from ..utils import maybe_future, url_path_join
|
||||||
from ..metrics import (
|
from ..metrics import (
|
||||||
SERVER_SPAWN_DURATION_SECONDS, ServerSpawnStatus,
|
SERVER_SPAWN_DURATION_SECONDS, ServerSpawnStatus,
|
||||||
@@ -240,7 +241,7 @@ class BaseHandler(RequestHandler):
|
|||||||
self.db.commit()
|
self.db.commit()
|
||||||
return self._user_from_orm(orm_token.user)
|
return self._user_from_orm(orm_token.user)
|
||||||
|
|
||||||
async def refresh_user_auth(self, user, force=False):
|
async def refresh_auth(self, user, force=False):
|
||||||
"""Refresh user authentication info
|
"""Refresh user authentication info
|
||||||
|
|
||||||
Calls `authenticator.refresh_user(user)`
|
Calls `authenticator.refresh_user(user)`
|
||||||
@@ -254,7 +255,12 @@ class BaseHandler(RequestHandler):
|
|||||||
user (User): the user having been refreshed,
|
user (User): the user having been refreshed,
|
||||||
or None if the user must login again to refresh auth info.
|
or None if the user must login again to refresh auth info.
|
||||||
"""
|
"""
|
||||||
if not force: # TODO: and it's sufficiently recent
|
refresh_age = self.authenticator.auth_refresh_age
|
||||||
|
if not refresh_age:
|
||||||
|
return user
|
||||||
|
now = time.monotonic()
|
||||||
|
if not force and user._auth_refreshed and (now - user._auth_refreshed < refresh_age):
|
||||||
|
# auth up-to-date
|
||||||
return user
|
return user
|
||||||
|
|
||||||
# refresh a user at most once per request
|
# refresh a user at most once per request
|
||||||
@@ -275,6 +281,8 @@ class BaseHandler(RequestHandler):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
user._auth_refreshed = now
|
||||||
|
|
||||||
if auth_info == True:
|
if auth_info == True:
|
||||||
# refresh_user confirmed that it's up-to-date,
|
# refresh_user confirmed that it's up-to-date,
|
||||||
# nothing to refresh
|
# nothing to refresh
|
||||||
@@ -355,8 +363,8 @@ class BaseHandler(RequestHandler):
|
|||||||
user = self.get_current_user_token()
|
user = self.get_current_user_token()
|
||||||
if user is None:
|
if user is None:
|
||||||
user = self.get_current_user_cookie()
|
user = self.get_current_user_cookie()
|
||||||
if user:
|
if user and isinstance(user, User):
|
||||||
user = await self.refresh_user_auth(user)
|
user = await self.refresh_auth(user)
|
||||||
self._jupyterhub_user = user
|
self._jupyterhub_user = user
|
||||||
except Exception:
|
except Exception:
|
||||||
# don't let errors here raise more than once
|
# don't let errors here raise more than once
|
||||||
@@ -610,6 +618,7 @@ class BaseHandler(RequestHandler):
|
|||||||
self.statsd.incr('login.success')
|
self.statsd.incr('login.success')
|
||||||
self.statsd.timing('login.authenticate.success', auth_timer.ms)
|
self.statsd.timing('login.authenticate.success', auth_timer.ms)
|
||||||
self.log.info("User logged in: %s", user.name)
|
self.log.info("User logged in: %s", user.name)
|
||||||
|
user._auth_refreshed = time.monotonic()
|
||||||
return user
|
return user
|
||||||
else:
|
else:
|
||||||
self.statsd.incr('login.failure')
|
self.statsd.incr('login.failure')
|
||||||
@@ -643,6 +652,11 @@ class BaseHandler(RequestHandler):
|
|||||||
|
|
||||||
async def spawn_single_user(self, user, server_name='', options=None):
|
async def spawn_single_user(self, user, server_name='', options=None):
|
||||||
# in case of error, include 'try again from /hub/home' message
|
# in case of error, include 'try again from /hub/home' message
|
||||||
|
if self.authenticator.refresh_pre_spawn:
|
||||||
|
auth_user = await self.refresh_auth(user, force=True)
|
||||||
|
if auth_user is None:
|
||||||
|
raise web.HTTPError(403, "auth has expired for %s, login again", user.name)
|
||||||
|
|
||||||
spawn_start_time = time.perf_counter()
|
spawn_start_time = time.perf_counter()
|
||||||
self.extra_error_html = self.spawn_home_error
|
self.extra_error_html = self.spawn_home_error
|
||||||
|
|
||||||
|
@@ -258,7 +258,6 @@ class Service(LoggingConfigurable):
|
|||||||
def _default_redirect_uri(self):
|
def _default_redirect_uri(self):
|
||||||
if self.server is None:
|
if self.server is None:
|
||||||
return ''
|
return ''
|
||||||
print(self.domain, self.host, self.server)
|
|
||||||
return self.host + url_path_join(self.prefix, 'oauth_callback')
|
return self.host + url_path_join(self.prefix, 'oauth_callback')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@@ -47,7 +47,7 @@ from ..utils import random_port
|
|||||||
|
|
||||||
from . import mocking
|
from . import mocking
|
||||||
from .mocking import MockHub
|
from .mocking import MockHub
|
||||||
from .utils import ssl_setup
|
from .utils import ssl_setup, add_user
|
||||||
from .test_services import mockservice_cmd
|
from .test_services import mockservice_cmd
|
||||||
|
|
||||||
import jupyterhub.services.service
|
import jupyterhub.services.service
|
||||||
@@ -185,6 +185,43 @@ def cleanup_after(request, io_loop):
|
|||||||
app.db.commit()
|
app.db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
_username_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def new_username(prefix='testuser'):
|
||||||
|
"""Return a new unique username"""
|
||||||
|
global _username_counter
|
||||||
|
_username_counter += 1
|
||||||
|
return '{}-{}'.format(prefix, _username_counter)
|
||||||
|
|
||||||
|
|
||||||
|
@fixture
|
||||||
|
def username():
|
||||||
|
"""allocate a temporary username
|
||||||
|
|
||||||
|
unique each time the fixture is used
|
||||||
|
"""
|
||||||
|
yield new_username()
|
||||||
|
|
||||||
|
|
||||||
|
@fixture
|
||||||
|
def user(app):
|
||||||
|
"""Fixture for creating a temporary user
|
||||||
|
|
||||||
|
Each time the fixture is used, a new user is created
|
||||||
|
"""
|
||||||
|
user = add_user(app.db, app, name=new_username())
|
||||||
|
yield user
|
||||||
|
|
||||||
|
|
||||||
|
@fixture
|
||||||
|
def admin_user(app, username):
|
||||||
|
"""Fixture for creating a temporary admin user"""
|
||||||
|
user = add_user(app.db, app, name=new_username('testadmin'), admin=True)
|
||||||
|
yield user
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MockServiceSpawner(jupyterhub.services.service._ServiceSpawner):
|
class MockServiceSpawner(jupyterhub.services.service._ServiceSpawner):
|
||||||
"""mock services for testing.
|
"""mock services for testing.
|
||||||
|
|
||||||
|
@@ -49,7 +49,7 @@ from ..objects import Server
|
|||||||
from ..spawner import LocalProcessSpawner, SimpleLocalProcessSpawner
|
from ..spawner import LocalProcessSpawner, SimpleLocalProcessSpawner
|
||||||
from ..singleuser import SingleUserNotebookApp
|
from ..singleuser import SingleUserNotebookApp
|
||||||
from ..utils import random_port, url_path_join
|
from ..utils import random_port, url_path_join
|
||||||
from .utils import async_requests, ssl_setup
|
from .utils import async_requests, ssl_setup, public_host, public_url
|
||||||
|
|
||||||
from pamela import PAMError
|
from pamela import PAMError
|
||||||
|
|
||||||
@@ -223,7 +223,10 @@ class MockHub(JupyterHub):
|
|||||||
last_activity_interval = 2
|
last_activity_interval = 2
|
||||||
log_datefmt = '%M:%S'
|
log_datefmt = '%M:%S'
|
||||||
external_certs = None
|
external_certs = None
|
||||||
log_level = 10
|
|
||||||
|
@default('log_level')
|
||||||
|
def _default_log_level(self):
|
||||||
|
return 10
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
if 'internal_certs_location' in kwargs:
|
if 'internal_certs_location' in kwargs:
|
||||||
@@ -351,31 +354,6 @@ class MockHub(JupyterHub):
|
|||||||
return r.cookies
|
return r.cookies
|
||||||
|
|
||||||
|
|
||||||
def public_host(app):
|
|
||||||
"""Return the public *host* (no URL prefix) of the given JupyterHub instance."""
|
|
||||||
if app.subdomain_host:
|
|
||||||
return app.subdomain_host
|
|
||||||
else:
|
|
||||||
return Server.from_url(app.proxy.public_url).host
|
|
||||||
|
|
||||||
|
|
||||||
def public_url(app, user_or_service=None, path=''):
|
|
||||||
"""Return the full, public base URL (including prefix) of the given JupyterHub instance."""
|
|
||||||
if user_or_service:
|
|
||||||
if app.subdomain_host:
|
|
||||||
host = user_or_service.host
|
|
||||||
else:
|
|
||||||
host = public_host(app)
|
|
||||||
prefix = user_or_service.prefix
|
|
||||||
else:
|
|
||||||
host = public_host(app)
|
|
||||||
prefix = Server.from_url(app.proxy.public_url).base_url
|
|
||||||
if path:
|
|
||||||
return host + url_path_join(prefix, path)
|
|
||||||
else:
|
|
||||||
return host + prefix
|
|
||||||
|
|
||||||
|
|
||||||
# single-user-server mocking:
|
# single-user-server mocking:
|
||||||
|
|
||||||
class MockSingleUserServer(SingleUserNotebookApp):
|
class MockSingleUserServer(SingleUserNotebookApp):
|
||||||
|
@@ -18,97 +18,13 @@ import jupyterhub
|
|||||||
from .. import orm
|
from .. import orm
|
||||||
from ..utils import url_path_join as ujoin
|
from ..utils import url_path_join as ujoin
|
||||||
from .mocking import public_host, public_url
|
from .mocking import public_host, public_url
|
||||||
from .utils import async_requests
|
from .utils import (
|
||||||
|
add_user,
|
||||||
|
api_request,
|
||||||
def check_db_locks(func):
|
async_requests,
|
||||||
"""Decorator that verifies no locks are held on database upon exit.
|
auth_header,
|
||||||
|
find_user,
|
||||||
This decorator for test functions verifies no locks are held on the
|
)
|
||||||
application's database upon exit by creating and dropping a dummy table.
|
|
||||||
|
|
||||||
The decorator relies on an instance of JupyterHubApp being the first
|
|
||||||
argument to the decorated function.
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
|
|
||||||
@check_db_locks
|
|
||||||
def api_request(app, *api_path, **kwargs):
|
|
||||||
|
|
||||||
"""
|
|
||||||
def new_func(app, *args, **kwargs):
|
|
||||||
retval = func(app, *args, **kwargs)
|
|
||||||
|
|
||||||
temp_session = app.session_factory()
|
|
||||||
temp_session.execute('CREATE TABLE dummy (foo INT)')
|
|
||||||
temp_session.execute('DROP TABLE dummy')
|
|
||||||
temp_session.close()
|
|
||||||
|
|
||||||
return retval
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
|
|
||||||
|
|
||||||
def find_user(db, name, app=None):
|
|
||||||
"""Find user in database."""
|
|
||||||
orm_user = db.query(orm.User).filter(orm.User.name == name).first()
|
|
||||||
if app is None:
|
|
||||||
return orm_user
|
|
||||||
else:
|
|
||||||
return app.users[orm_user.id]
|
|
||||||
|
|
||||||
|
|
||||||
def add_user(db, app=None, **kwargs):
|
|
||||||
"""Add a user to the database."""
|
|
||||||
orm_user = find_user(db, name=kwargs.get('name'))
|
|
||||||
if orm_user is None:
|
|
||||||
orm_user = orm.User(**kwargs)
|
|
||||||
db.add(orm_user)
|
|
||||||
else:
|
|
||||||
for attr, value in kwargs.items():
|
|
||||||
setattr(orm_user, attr, value)
|
|
||||||
db.commit()
|
|
||||||
if app:
|
|
||||||
return app.users[orm_user.id]
|
|
||||||
else:
|
|
||||||
return orm_user
|
|
||||||
|
|
||||||
|
|
||||||
def auth_header(db, name):
|
|
||||||
"""Return header with user's API authorization token."""
|
|
||||||
user = find_user(db, name)
|
|
||||||
if user is None:
|
|
||||||
user = add_user(db, name=name)
|
|
||||||
token = user.new_api_token()
|
|
||||||
return {'Authorization': 'token %s' % token}
|
|
||||||
|
|
||||||
|
|
||||||
@check_db_locks
|
|
||||||
async def api_request(app, *api_path, **kwargs):
|
|
||||||
"""Make an API request"""
|
|
||||||
base_url = app.hub.url
|
|
||||||
headers = kwargs.setdefault('headers', {})
|
|
||||||
|
|
||||||
if 'Authorization' not in headers and not kwargs.pop('noauth', False):
|
|
||||||
# make a copy to avoid modifying arg in-place
|
|
||||||
kwargs['headers'] = h = {}
|
|
||||||
h.update(headers)
|
|
||||||
h.update(auth_header(app.db, 'admin'))
|
|
||||||
|
|
||||||
url = ujoin(base_url, 'api', *api_path)
|
|
||||||
method = kwargs.pop('method', 'get')
|
|
||||||
f = getattr(async_requests, method)
|
|
||||||
if app.internal_ssl:
|
|
||||||
kwargs['cert'] = (app.internal_ssl_cert, app.internal_ssl_key)
|
|
||||||
kwargs["verify"] = app.internal_ssl_ca
|
|
||||||
resp = await f(url, **kwargs)
|
|
||||||
assert "frame-ancestors 'self'" in resp.headers['Content-Security-Policy']
|
|
||||||
assert ujoin(app.hub.base_url, "security/csp-report") in resp.headers['Content-Security-Policy']
|
|
||||||
assert 'http' not in resp.headers['Content-Security-Policy']
|
|
||||||
if not kwargs.get('stream', False) and resp.content:
|
|
||||||
assert resp.headers.get('content-type') == 'application/json'
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------
|
# --------------------
|
||||||
@@ -197,7 +113,7 @@ def normalize_timestamp(ts):
|
|||||||
"""
|
"""
|
||||||
if ts is None:
|
if ts is None:
|
||||||
return
|
return
|
||||||
return re.sub('\d(\.\d+)?', '0', ts)
|
return re.sub(r'\d(\.\d+)?', '0', ts)
|
||||||
|
|
||||||
|
|
||||||
def normalize_user(user):
|
def normalize_user(user):
|
||||||
|
@@ -12,7 +12,8 @@ from requests import HTTPError
|
|||||||
from jupyterhub import auth, crypto, orm
|
from jupyterhub import auth, crypto, orm
|
||||||
|
|
||||||
from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd
|
from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd
|
||||||
from .test_api import add_user
|
from .utils import add_user
|
||||||
|
|
||||||
|
|
||||||
async def test_pam_auth():
|
async def test_pam_auth():
|
||||||
authenticator = MockPAMAuthenticator()
|
authenticator = MockPAMAuthenticator()
|
||||||
|
179
jupyterhub/tests/test_auth_expiry.py
Normal file
179
jupyterhub/tests/test_auth_expiry.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
"""
|
||||||
|
test authentication expiry
|
||||||
|
|
||||||
|
authentication can expire in a number of ways:
|
||||||
|
|
||||||
|
- needs refresh and can be refreshed
|
||||||
|
- doesn't need refresh
|
||||||
|
- needs refresh and cannot be refreshed without new login
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from unittest import mock
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .utils import api_request, get_page
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_expired(authenticator, user):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def disable_refresh(app):
|
||||||
|
"""Fixture disabling auth refresh"""
|
||||||
|
with mock.patch.object(app.authenticator, 'refresh_user', refresh_expired):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def refresh_pre_spawn(app):
|
||||||
|
"""Fixture enabling auth refresh pre spawn"""
|
||||||
|
app.authenticator.refresh_pre_spawn = True
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
app.authenticator.refresh_pre_spawn = False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_refresh_at_login(app, user):
|
||||||
|
# auth_refreshed starts unset:
|
||||||
|
assert not user._auth_refreshed
|
||||||
|
# login sets auth_refreshed timestamp
|
||||||
|
await app.login_user(user.name)
|
||||||
|
assert user._auth_refreshed
|
||||||
|
user._auth_refreshed -= 10
|
||||||
|
before = user._auth_refreshed
|
||||||
|
# login again updates auth_refreshed timestamp
|
||||||
|
# even when auth is fresh
|
||||||
|
await app.login_user(user.name)
|
||||||
|
assert user._auth_refreshed > before
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_refresh_page(app, user):
|
||||||
|
cookies = await app.login_user(user.name)
|
||||||
|
assert user._auth_refreshed
|
||||||
|
user._auth_refreshed -= 10
|
||||||
|
before = user._auth_refreshed
|
||||||
|
|
||||||
|
# get a page with auth not expired
|
||||||
|
# doesn't trigger refresh
|
||||||
|
r = await get_page('home', app, cookies=cookies)
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert user._auth_refreshed == before
|
||||||
|
|
||||||
|
# get a page with stale auth, refreshes auth
|
||||||
|
user._auth_refreshed -= app.authenticator.auth_refresh_age
|
||||||
|
r = await get_page('home', app, cookies=cookies)
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert user._auth_refreshed > before
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_expired_page(app, user, disable_refresh):
|
||||||
|
cookies = await app.login_user(user.name)
|
||||||
|
assert user._auth_refreshed
|
||||||
|
user._auth_refreshed -= 10
|
||||||
|
before = user._auth_refreshed
|
||||||
|
|
||||||
|
# auth is fresh, doesn't trigger expiry
|
||||||
|
r = await get_page('home', app, cookies=cookies)
|
||||||
|
assert user._auth_refreshed == before
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
# get a page with stale auth, triggers expiry
|
||||||
|
user._auth_refreshed -= app.authenticator.auth_refresh_age
|
||||||
|
before = user._auth_refreshed
|
||||||
|
r = await get_page('home', app, cookies=cookies, allow_redirects=False)
|
||||||
|
|
||||||
|
# verify that we redirect to login with ?next=requested page
|
||||||
|
assert r.status_code == 302
|
||||||
|
redirect_url = urlparse(r.headers['Location'])
|
||||||
|
assert redirect_url.path.endswith('/login')
|
||||||
|
query = parse_qs(redirect_url.query)
|
||||||
|
assert query['next']
|
||||||
|
next_url = urlparse(query['next'][0])
|
||||||
|
assert next_url.path == urlparse(r.url).path
|
||||||
|
|
||||||
|
# make sure refresh didn't get updated
|
||||||
|
assert user._auth_refreshed == before
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_expired_api(app, user, disable_refresh):
|
||||||
|
cookies = await app.login_user(user.name)
|
||||||
|
assert user._auth_refreshed
|
||||||
|
user._auth_refreshed -= 10
|
||||||
|
before = user._auth_refreshed
|
||||||
|
|
||||||
|
# auth is fresh, doesn't trigger expiry
|
||||||
|
r = await api_request(app, 'users/' + user.name, name=user.name)
|
||||||
|
assert user._auth_refreshed == before
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
# get a page with stale auth, triggers expiry
|
||||||
|
user._auth_refreshed -= app.authenticator.auth_refresh_age
|
||||||
|
r = await api_request(app, 'users/' + user.name, name=user.name)
|
||||||
|
# api requests can't do login redirects
|
||||||
|
assert r.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_pre_spawn(app, user, refresh_pre_spawn):
|
||||||
|
cookies = await app.login_user(user.name)
|
||||||
|
assert user._auth_refreshed
|
||||||
|
user._auth_refreshed -= 10
|
||||||
|
before = user._auth_refreshed
|
||||||
|
|
||||||
|
# auth is fresh, but should be forced to refresh by spawn
|
||||||
|
r = await api_request(
|
||||||
|
app, 'users/{}/server'.format(user.name), method='post', name=user.name
|
||||||
|
)
|
||||||
|
assert 200 <= r.status_code < 300
|
||||||
|
assert user._auth_refreshed > before
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_pre_spawn_expired(app, user, refresh_pre_spawn, disable_refresh):
|
||||||
|
cookies = await app.login_user(user.name)
|
||||||
|
assert user._auth_refreshed
|
||||||
|
user._auth_refreshed -= 10
|
||||||
|
before = user._auth_refreshed
|
||||||
|
|
||||||
|
# auth is fresh, doesn't trigger expiry
|
||||||
|
r = await api_request(
|
||||||
|
app, 'users/{}/server'.format(user.name), method='post', name=user.name
|
||||||
|
)
|
||||||
|
assert r.status_code == 403
|
||||||
|
assert user._auth_refreshed == before
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_pre_spawn_admin_request(
|
||||||
|
app, user, admin_user, refresh_pre_spawn
|
||||||
|
):
|
||||||
|
await app.login_user(user.name)
|
||||||
|
await app.login_user(admin_user.name)
|
||||||
|
user._auth_refreshed -= 10
|
||||||
|
before = user._auth_refreshed
|
||||||
|
|
||||||
|
# admin request, auth is fresh. Should still refresh user auth.
|
||||||
|
r = await api_request(
|
||||||
|
app, 'users', user.name, 'server', method='post', name=admin_user.name
|
||||||
|
)
|
||||||
|
assert 200 <= r.status_code < 300
|
||||||
|
assert user._auth_refreshed > before
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_pre_spawn_expired_admin_request(
|
||||||
|
app, user, admin_user, refresh_pre_spawn, disable_refresh
|
||||||
|
):
|
||||||
|
await app.login_user(user.name)
|
||||||
|
await app.login_user(admin_user.name)
|
||||||
|
user._auth_refreshed -= 10
|
||||||
|
|
||||||
|
# auth needs refresh but can't without a new login; spawn should fail
|
||||||
|
user._auth_refreshed -= app.authenticator.auth_refresh_age
|
||||||
|
r = await api_request(
|
||||||
|
app, 'users', user.name, 'server', method='post', name=admin_user.name
|
||||||
|
)
|
||||||
|
# api requests can't do login redirects
|
||||||
|
assert r.status_code == 403
|
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
|
from unittest import mock
|
||||||
from urllib.parse import urlencode, urlparse
|
from urllib.parse import urlencode, urlparse
|
||||||
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
@@ -13,21 +14,17 @@ from ..utils import url_path_join as ujoin
|
|||||||
from .. import orm
|
from .. import orm
|
||||||
from ..auth import Authenticator
|
from ..auth import Authenticator
|
||||||
|
|
||||||
import mock
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .mocking import FormSpawner, public_url, public_host
|
from .mocking import FormSpawner
|
||||||
from .test_api import api_request, add_user
|
from .utils import (
|
||||||
from .utils import async_requests
|
async_requests,
|
||||||
|
api_request,
|
||||||
|
add_user,
|
||||||
def get_page(path, app, hub=True, **kw):
|
get_page,
|
||||||
if hub:
|
public_url,
|
||||||
prefix = app.hub.base_url
|
public_host,
|
||||||
else:
|
)
|
||||||
prefix = app.base_url
|
|
||||||
base_url = ujoin(public_host(app), prefix)
|
|
||||||
return async_requests.get(ujoin(base_url, path), **kw)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_root_no_auth(app):
|
async def test_root_no_auth(app):
|
||||||
|
@@ -84,7 +84,8 @@ async def test_external_proxy(request):
|
|||||||
# add user to the db and start a single user server
|
# add user to the db and start a single user server
|
||||||
name = 'river'
|
name = 'river'
|
||||||
add_user(app.db, app, name=name)
|
add_user(app.db, app, name=name)
|
||||||
r = await api_request(app, 'users', name, 'server', method='post')
|
r = await api_request(app, 'users', name, 'server', method='post',
|
||||||
|
bypass_proxy=True)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
routes = await app.proxy.get_all_routes()
|
routes = await app.proxy.get_all_routes()
|
||||||
@@ -108,7 +109,7 @@ async def test_external_proxy(request):
|
|||||||
assert list(routes.keys()) == []
|
assert list(routes.keys()) == []
|
||||||
|
|
||||||
# poke the server to update the proxy
|
# poke the server to update the proxy
|
||||||
r = await api_request(app, 'proxy', method='post')
|
r = await api_request(app, 'proxy', method='post', bypass_proxy=True)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
# check that the routes are correct
|
# check that the routes are correct
|
||||||
@@ -135,10 +136,16 @@ async def test_external_proxy(request):
|
|||||||
|
|
||||||
# tell the hub where the new proxy is
|
# tell the hub where the new proxy is
|
||||||
new_api_url = 'http://{}:{}'.format(proxy_ip, proxy_port)
|
new_api_url = 'http://{}:{}'.format(proxy_ip, proxy_port)
|
||||||
r = await api_request(app, 'proxy', method='patch', data=json.dumps({
|
r = await api_request(
|
||||||
|
app,
|
||||||
|
'proxy',
|
||||||
|
method='patch',
|
||||||
|
data=json.dumps({
|
||||||
'api_url': new_api_url,
|
'api_url': new_api_url,
|
||||||
'auth_token': new_auth_token,
|
'auth_token': new_auth_token,
|
||||||
}))
|
}),
|
||||||
|
bypass_proxy=True,
|
||||||
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
assert app.proxy.api_url == new_api_url
|
assert app.proxy.api_url == new_api_url
|
||||||
|
|
||||||
|
@@ -4,6 +4,10 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from certipy import Certipy
|
from certipy import Certipy
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from jupyterhub import orm
|
||||||
|
from jupyterhub.objects import Server
|
||||||
|
from jupyterhub.utils import url_path_join as ujoin
|
||||||
|
|
||||||
|
|
||||||
class _AsyncRequests:
|
class _AsyncRequests:
|
||||||
"""Wrapper around requests to return a Future from request methods
|
"""Wrapper around requests to return a Future from request methods
|
||||||
@@ -46,3 +50,144 @@ def ssl_setup(cert_dir, authority_name):
|
|||||||
"external", authority_name, overwrite=True, alt_names=alt_names
|
"external", authority_name, overwrite=True, alt_names=alt_names
|
||||||
)
|
)
|
||||||
return external_certs
|
return external_certs
|
||||||
|
|
||||||
|
|
||||||
|
def check_db_locks(func):
|
||||||
|
"""Decorator that verifies no locks are held on database upon exit.
|
||||||
|
|
||||||
|
This decorator for test functions verifies no locks are held on the
|
||||||
|
application's database upon exit by creating and dropping a dummy table.
|
||||||
|
|
||||||
|
The decorator relies on an instance of JupyterHubApp being the first
|
||||||
|
argument to the decorated function.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
|
||||||
|
@check_db_locks
|
||||||
|
def api_request(app, *api_path, **kwargs):
|
||||||
|
|
||||||
|
"""
|
||||||
|
def new_func(app, *args, **kwargs):
|
||||||
|
retval = func(app, *args, **kwargs)
|
||||||
|
|
||||||
|
temp_session = app.session_factory()
|
||||||
|
temp_session.execute('CREATE TABLE dummy (foo INT)')
|
||||||
|
temp_session.execute('DROP TABLE dummy')
|
||||||
|
temp_session.close()
|
||||||
|
|
||||||
|
return retval
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
|
||||||
|
def find_user(db, name, app=None):
|
||||||
|
"""Find user in database."""
|
||||||
|
orm_user = db.query(orm.User).filter(orm.User.name == name).first()
|
||||||
|
if app is None:
|
||||||
|
return orm_user
|
||||||
|
else:
|
||||||
|
return app.users[orm_user.id]
|
||||||
|
|
||||||
|
|
||||||
|
def add_user(db, app=None, **kwargs):
|
||||||
|
"""Add a user to the database."""
|
||||||
|
orm_user = find_user(db, name=kwargs.get('name'))
|
||||||
|
if orm_user is None:
|
||||||
|
orm_user = orm.User(**kwargs)
|
||||||
|
db.add(orm_user)
|
||||||
|
else:
|
||||||
|
for attr, value in kwargs.items():
|
||||||
|
setattr(orm_user, attr, value)
|
||||||
|
db.commit()
|
||||||
|
if app:
|
||||||
|
return app.users[orm_user.id]
|
||||||
|
else:
|
||||||
|
return orm_user
|
||||||
|
|
||||||
|
|
||||||
|
def auth_header(db, name):
|
||||||
|
"""Return header with user's API authorization token."""
|
||||||
|
user = find_user(db, name)
|
||||||
|
if user is None:
|
||||||
|
user = add_user(db, name=name)
|
||||||
|
token = user.new_api_token()
|
||||||
|
return {'Authorization': 'token %s' % token}
|
||||||
|
|
||||||
|
|
||||||
|
@check_db_locks
|
||||||
|
async def api_request(app, *api_path, method='get',
|
||||||
|
noauth=False, bypass_proxy=False,
|
||||||
|
**kwargs):
|
||||||
|
"""Make an API request"""
|
||||||
|
if bypass_proxy:
|
||||||
|
# make a direct request to the hub,
|
||||||
|
# skipping the proxy
|
||||||
|
base_url = app.hub.url
|
||||||
|
else:
|
||||||
|
base_url = public_url(app, path='hub')
|
||||||
|
headers = kwargs.setdefault('headers', {})
|
||||||
|
|
||||||
|
if (
|
||||||
|
'Authorization' not in headers
|
||||||
|
and not noauth
|
||||||
|
and 'cookies' not in kwargs
|
||||||
|
):
|
||||||
|
# make a copy to avoid modifying arg in-place
|
||||||
|
kwargs['headers'] = h = {}
|
||||||
|
h.update(headers)
|
||||||
|
h.update(auth_header(app.db, kwargs.pop('name', 'admin')))
|
||||||
|
|
||||||
|
if 'cookies' in kwargs:
|
||||||
|
# for cookie-authenticated requests,
|
||||||
|
# set Referer so it looks like the request originated
|
||||||
|
# from a Hub-served page
|
||||||
|
headers.setdefault('Referer', ujoin(base_url, 'test'))
|
||||||
|
|
||||||
|
url = ujoin(base_url, 'api', *api_path)
|
||||||
|
f = getattr(async_requests, method)
|
||||||
|
if app.internal_ssl:
|
||||||
|
kwargs['cert'] = (app.internal_ssl_cert, app.internal_ssl_key)
|
||||||
|
kwargs["verify"] = app.internal_ssl_ca
|
||||||
|
resp = await f(url, **kwargs)
|
||||||
|
assert "frame-ancestors 'self'" in resp.headers['Content-Security-Policy']
|
||||||
|
assert ujoin(app.hub.base_url, "security/csp-report") in resp.headers['Content-Security-Policy']
|
||||||
|
assert 'http' not in resp.headers['Content-Security-Policy']
|
||||||
|
if not kwargs.get('stream', False) and resp.content:
|
||||||
|
assert resp.headers.get('content-type') == 'application/json'
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def get_page(path, app, hub=True, **kw):
|
||||||
|
if hub:
|
||||||
|
prefix = app.hub.base_url
|
||||||
|
else:
|
||||||
|
prefix = app.base_url
|
||||||
|
base_url = ujoin(public_host(app), prefix)
|
||||||
|
return async_requests.get(ujoin(base_url, path), **kw)
|
||||||
|
|
||||||
|
|
||||||
|
def public_host(app):
|
||||||
|
"""Return the public *host* (no URL prefix) of the given JupyterHub instance."""
|
||||||
|
if app.subdomain_host:
|
||||||
|
return app.subdomain_host
|
||||||
|
else:
|
||||||
|
return Server.from_url(app.proxy.public_url).host
|
||||||
|
|
||||||
|
|
||||||
|
def public_url(app, user_or_service=None, path=''):
|
||||||
|
"""Return the full, public base URL (including prefix) of the given JupyterHub instance."""
|
||||||
|
if user_or_service:
|
||||||
|
if app.subdomain_host:
|
||||||
|
host = user_or_service.host
|
||||||
|
else:
|
||||||
|
host = public_host(app)
|
||||||
|
prefix = user_or_service.prefix
|
||||||
|
else:
|
||||||
|
host = public_host(app)
|
||||||
|
prefix = Server.from_url(app.proxy.public_url).base_url
|
||||||
|
if path:
|
||||||
|
return host + ujoin(prefix, path)
|
||||||
|
else:
|
||||||
|
return host + prefix
|
||||||
|
|
||||||
|
@@ -8,7 +8,9 @@ import warnings
|
|||||||
|
|
||||||
from sqlalchemy import inspect
|
from sqlalchemy import inspect
|
||||||
from tornado import gen
|
from tornado import gen
|
||||||
|
from tornado.httputil import urlencode
|
||||||
from tornado.log import app_log
|
from tornado.log import app_log
|
||||||
|
from tornado import web
|
||||||
|
|
||||||
from .utils import maybe_future, url_path_join, make_ssl_context
|
from .utils import maybe_future, url_path_join, make_ssl_context
|
||||||
|
|
||||||
@@ -136,6 +138,7 @@ class User:
|
|||||||
orm_user = None
|
orm_user = None
|
||||||
log = app_log
|
log = app_log
|
||||||
settings = None
|
settings = None
|
||||||
|
_auth_refreshed = None
|
||||||
|
|
||||||
def __init__(self, orm_user, settings=None, db=None):
|
def __init__(self, orm_user, settings=None, db=None):
|
||||||
self.db = db or inspect(orm_user).session
|
self.db = db or inspect(orm_user).session
|
||||||
@@ -380,6 +383,59 @@ class User:
|
|||||||
url_parts.extend(['server/progress'])
|
url_parts.extend(['server/progress'])
|
||||||
return url_path_join(*url_parts)
|
return url_path_join(*url_parts)
|
||||||
|
|
||||||
|
async def refresh_auth(self, handler):
|
||||||
|
"""Refresh authentication if needed
|
||||||
|
|
||||||
|
Checks authentication expiry and refresh it if needed.
|
||||||
|
See Spawner.
|
||||||
|
|
||||||
|
If the auth is expired and cannot be refreshed
|
||||||
|
without forcing a new login, a few things can happen:
|
||||||
|
|
||||||
|
1. if this is a normal user spawn,
|
||||||
|
the user should be redirected to login
|
||||||
|
and back to spawn after login.
|
||||||
|
2. if this is a spawn via API or other user,
|
||||||
|
spawn will fail until the user logs in again.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handler (RequestHandler):
|
||||||
|
The handler for the request triggering the spawn.
|
||||||
|
May be None
|
||||||
|
"""
|
||||||
|
authenticator = self.authenticator
|
||||||
|
if authenticator is None or not authenticator.refresh_pre_spawn:
|
||||||
|
# nothing to do
|
||||||
|
return
|
||||||
|
|
||||||
|
# refresh auth
|
||||||
|
auth_user = await handler.refresh_auth(self, force=True)
|
||||||
|
|
||||||
|
if auth_user:
|
||||||
|
# auth refreshed, all done
|
||||||
|
return
|
||||||
|
|
||||||
|
# if we got to here, auth is expired and couldn't be refreshed
|
||||||
|
self.log.error(
|
||||||
|
"Auth expired for %s; cannot spawn until they login again",
|
||||||
|
self.name,
|
||||||
|
)
|
||||||
|
# auth expired, cannot spawn without a fresh login
|
||||||
|
# it's the current user *and* spawn via GET, trigger login redirect
|
||||||
|
if handler.request.method == 'GET' and handler.current_user is self:
|
||||||
|
self.log.info("Redirecting %s to login to refresh auth", self.name)
|
||||||
|
url = self.get_login_url()
|
||||||
|
next_url = self.request.uri
|
||||||
|
sep = '&' if '?' in url else '?'
|
||||||
|
url += sep + urlencode(dict(next=next_url))
|
||||||
|
self.redirect(url)
|
||||||
|
raise web.Finish()
|
||||||
|
else:
|
||||||
|
# spawn via POST or on behalf of another user.
|
||||||
|
# nothing we can do here but fail
|
||||||
|
raise web.HTTPError(400, "{}'s authentication has expired".format(self.name))
|
||||||
|
|
||||||
|
|
||||||
async def spawn(self, server_name='', options=None, handler=None):
|
async def spawn(self, server_name='', options=None, handler=None):
|
||||||
"""Start the user's spawner
|
"""Start the user's spawner
|
||||||
|
|
||||||
@@ -395,6 +451,9 @@ class User:
|
|||||||
"""
|
"""
|
||||||
db = self.db
|
db = self.db
|
||||||
|
|
||||||
|
if handler:
|
||||||
|
await self.refresh_auth(handler)
|
||||||
|
|
||||||
base_url = url_path_join(self.base_url, server_name) + '/'
|
base_url = url_path_join(self.base_url, server_name) + '/'
|
||||||
|
|
||||||
orm_server = orm.Server(
|
orm_server = orm.Server(
|
||||||
@@ -436,7 +495,7 @@ class User:
|
|||||||
|
|
||||||
# trigger pre-spawn hook on authenticator
|
# trigger pre-spawn hook on authenticator
|
||||||
authenticator = self.authenticator
|
authenticator = self.authenticator
|
||||||
if (authenticator):
|
if authenticator:
|
||||||
await maybe_future(authenticator.pre_spawn_start(self, spawner))
|
await maybe_future(authenticator.pre_spawn_start(self, spawner))
|
||||||
|
|
||||||
spawner._start_pending = True
|
spawner._start_pending = True
|
||||||
|
Reference in New Issue
Block a user