mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-18 15:33:02 +00:00
@@ -19,7 +19,7 @@ except Exception as e:
|
||||
from tornado.concurrent import run_on_executor
|
||||
|
||||
from traitlets.config import LoggingConfigurable
|
||||
from traitlets import Bool, Set, Unicode, Dict, Any, default, observe
|
||||
from traitlets import Bool, Integer, Set, Unicode, Dict, Any, default, observe
|
||||
|
||||
from .handlers.login import LoginHandler
|
||||
from .utils import maybe_future, url_path_join
|
||||
@@ -50,6 +50,35 @@ class Authenticator(LoggingConfigurable):
|
||||
""",
|
||||
)
|
||||
|
||||
auth_refresh_age = Integer(
|
||||
300,
|
||||
config=True,
|
||||
help="""The max age (in seconds) of authentication info
|
||||
before forcing a refresh of user auth info.
|
||||
|
||||
Refreshing auth info allows, e.g. requesting/re-validating auth tokens.
|
||||
|
||||
See :meth:`.refresh_user` for what happens when user auth info is refreshed
|
||||
(nothing by default).
|
||||
"""
|
||||
)
|
||||
|
||||
refresh_pre_spawn = Bool(
|
||||
False,
|
||||
config=True,
|
||||
help="""Force refresh of auth prior to spawn.
|
||||
|
||||
This forces :meth:`.refresh_user` to be called prior to launching
|
||||
a server, to ensure that auth state is up-to-date.
|
||||
|
||||
This can be important when e.g. auth tokens that may have expired
|
||||
are passed to the spawner via environment variables from auth_state.
|
||||
|
||||
If refresh_user cannot refresh the user auth data,
|
||||
launch will fail until the user logs in again.
|
||||
"""
|
||||
)
|
||||
|
||||
admin_users = Set(
|
||||
help="""
|
||||
Set of users that will have admin rights on this JupyterHub.
|
||||
|
@@ -28,6 +28,7 @@ from .. import __version__
|
||||
from .. import orm
|
||||
from ..objects import Server
|
||||
from ..spawner import LocalProcessSpawner
|
||||
from ..user import User
|
||||
from ..utils import maybe_future, url_path_join
|
||||
from ..metrics import (
|
||||
SERVER_SPAWN_DURATION_SECONDS, ServerSpawnStatus,
|
||||
@@ -240,7 +241,7 @@ class BaseHandler(RequestHandler):
|
||||
self.db.commit()
|
||||
return self._user_from_orm(orm_token.user)
|
||||
|
||||
async def refresh_user_auth(self, user, force=False):
|
||||
async def refresh_auth(self, user, force=False):
|
||||
"""Refresh user authentication info
|
||||
|
||||
Calls `authenticator.refresh_user(user)`
|
||||
@@ -254,7 +255,12 @@ class BaseHandler(RequestHandler):
|
||||
user (User): the user having been refreshed,
|
||||
or None if the user must login again to refresh auth info.
|
||||
"""
|
||||
if not force: # TODO: and it's sufficiently recent
|
||||
refresh_age = self.authenticator.auth_refresh_age
|
||||
if not refresh_age:
|
||||
return user
|
||||
now = time.monotonic()
|
||||
if not force and user._auth_refreshed and (now - user._auth_refreshed < refresh_age):
|
||||
# auth up-to-date
|
||||
return user
|
||||
|
||||
# refresh a user at most once per request
|
||||
@@ -275,6 +281,8 @@ class BaseHandler(RequestHandler):
|
||||
)
|
||||
return
|
||||
|
||||
user._auth_refreshed = now
|
||||
|
||||
if auth_info == True:
|
||||
# refresh_user confirmed that it's up-to-date,
|
||||
# nothing to refresh
|
||||
@@ -355,8 +363,8 @@ class BaseHandler(RequestHandler):
|
||||
user = self.get_current_user_token()
|
||||
if user is None:
|
||||
user = self.get_current_user_cookie()
|
||||
if user:
|
||||
user = await self.refresh_user_auth(user)
|
||||
if user and isinstance(user, User):
|
||||
user = await self.refresh_auth(user)
|
||||
self._jupyterhub_user = user
|
||||
except Exception:
|
||||
# don't let errors here raise more than once
|
||||
@@ -610,6 +618,7 @@ class BaseHandler(RequestHandler):
|
||||
self.statsd.incr('login.success')
|
||||
self.statsd.timing('login.authenticate.success', auth_timer.ms)
|
||||
self.log.info("User logged in: %s", user.name)
|
||||
user._auth_refreshed = time.monotonic()
|
||||
return user
|
||||
else:
|
||||
self.statsd.incr('login.failure')
|
||||
@@ -643,6 +652,11 @@ class BaseHandler(RequestHandler):
|
||||
|
||||
async def spawn_single_user(self, user, server_name='', options=None):
|
||||
# in case of error, include 'try again from /hub/home' message
|
||||
if self.authenticator.refresh_pre_spawn:
|
||||
auth_user = await self.refresh_auth(user, force=True)
|
||||
if auth_user is None:
|
||||
raise web.HTTPError(403, "auth has expired for %s, login again", user.name)
|
||||
|
||||
spawn_start_time = time.perf_counter()
|
||||
self.extra_error_html = self.spawn_home_error
|
||||
|
||||
|
@@ -258,7 +258,6 @@ class Service(LoggingConfigurable):
|
||||
def _default_redirect_uri(self):
|
||||
if self.server is None:
|
||||
return ''
|
||||
print(self.domain, self.host, self.server)
|
||||
return self.host + url_path_join(self.prefix, 'oauth_callback')
|
||||
|
||||
@property
|
||||
|
@@ -47,7 +47,7 @@ from ..utils import random_port
|
||||
|
||||
from . import mocking
|
||||
from .mocking import MockHub
|
||||
from .utils import ssl_setup
|
||||
from .utils import ssl_setup, add_user
|
||||
from .test_services import mockservice_cmd
|
||||
|
||||
import jupyterhub.services.service
|
||||
@@ -185,6 +185,43 @@ def cleanup_after(request, io_loop):
|
||||
app.db.commit()
|
||||
|
||||
|
||||
_username_counter = 0
|
||||
|
||||
|
||||
def new_username(prefix='testuser'):
|
||||
"""Return a new unique username"""
|
||||
global _username_counter
|
||||
_username_counter += 1
|
||||
return '{}-{}'.format(prefix, _username_counter)
|
||||
|
||||
|
||||
@fixture
|
||||
def username():
|
||||
"""allocate a temporary username
|
||||
|
||||
unique each time the fixture is used
|
||||
"""
|
||||
yield new_username()
|
||||
|
||||
|
||||
@fixture
|
||||
def user(app):
|
||||
"""Fixture for creating a temporary user
|
||||
|
||||
Each time the fixture is used, a new user is created
|
||||
"""
|
||||
user = add_user(app.db, app, name=new_username())
|
||||
yield user
|
||||
|
||||
|
||||
@fixture
|
||||
def admin_user(app, username):
|
||||
"""Fixture for creating a temporary admin user"""
|
||||
user = add_user(app.db, app, name=new_username('testadmin'), admin=True)
|
||||
yield user
|
||||
|
||||
|
||||
|
||||
class MockServiceSpawner(jupyterhub.services.service._ServiceSpawner):
|
||||
"""mock services for testing.
|
||||
|
||||
|
@@ -49,7 +49,7 @@ from ..objects import Server
|
||||
from ..spawner import LocalProcessSpawner, SimpleLocalProcessSpawner
|
||||
from ..singleuser import SingleUserNotebookApp
|
||||
from ..utils import random_port, url_path_join
|
||||
from .utils import async_requests, ssl_setup
|
||||
from .utils import async_requests, ssl_setup, public_host, public_url
|
||||
|
||||
from pamela import PAMError
|
||||
|
||||
@@ -223,7 +223,10 @@ class MockHub(JupyterHub):
|
||||
last_activity_interval = 2
|
||||
log_datefmt = '%M:%S'
|
||||
external_certs = None
|
||||
log_level = 10
|
||||
|
||||
@default('log_level')
|
||||
def _default_log_level(self):
|
||||
return 10
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if 'internal_certs_location' in kwargs:
|
||||
@@ -351,31 +354,6 @@ class MockHub(JupyterHub):
|
||||
return r.cookies
|
||||
|
||||
|
||||
def public_host(app):
|
||||
"""Return the public *host* (no URL prefix) of the given JupyterHub instance."""
|
||||
if app.subdomain_host:
|
||||
return app.subdomain_host
|
||||
else:
|
||||
return Server.from_url(app.proxy.public_url).host
|
||||
|
||||
|
||||
def public_url(app, user_or_service=None, path=''):
|
||||
"""Return the full, public base URL (including prefix) of the given JupyterHub instance."""
|
||||
if user_or_service:
|
||||
if app.subdomain_host:
|
||||
host = user_or_service.host
|
||||
else:
|
||||
host = public_host(app)
|
||||
prefix = user_or_service.prefix
|
||||
else:
|
||||
host = public_host(app)
|
||||
prefix = Server.from_url(app.proxy.public_url).base_url
|
||||
if path:
|
||||
return host + url_path_join(prefix, path)
|
||||
else:
|
||||
return host + prefix
|
||||
|
||||
|
||||
# single-user-server mocking:
|
||||
|
||||
class MockSingleUserServer(SingleUserNotebookApp):
|
||||
|
@@ -18,97 +18,13 @@ import jupyterhub
|
||||
from .. import orm
|
||||
from ..utils import url_path_join as ujoin
|
||||
from .mocking import public_host, public_url
|
||||
from .utils import async_requests
|
||||
|
||||
|
||||
def check_db_locks(func):
|
||||
"""Decorator that verifies no locks are held on database upon exit.
|
||||
|
||||
This decorator for test functions verifies no locks are held on the
|
||||
application's database upon exit by creating and dropping a dummy table.
|
||||
|
||||
The decorator relies on an instance of JupyterHubApp being the first
|
||||
argument to the decorated function.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
@check_db_locks
|
||||
def api_request(app, *api_path, **kwargs):
|
||||
|
||||
"""
|
||||
def new_func(app, *args, **kwargs):
|
||||
retval = func(app, *args, **kwargs)
|
||||
|
||||
temp_session = app.session_factory()
|
||||
temp_session.execute('CREATE TABLE dummy (foo INT)')
|
||||
temp_session.execute('DROP TABLE dummy')
|
||||
temp_session.close()
|
||||
|
||||
return retval
|
||||
|
||||
return new_func
|
||||
|
||||
|
||||
def find_user(db, name, app=None):
|
||||
"""Find user in database."""
|
||||
orm_user = db.query(orm.User).filter(orm.User.name == name).first()
|
||||
if app is None:
|
||||
return orm_user
|
||||
else:
|
||||
return app.users[orm_user.id]
|
||||
|
||||
|
||||
def add_user(db, app=None, **kwargs):
|
||||
"""Add a user to the database."""
|
||||
orm_user = find_user(db, name=kwargs.get('name'))
|
||||
if orm_user is None:
|
||||
orm_user = orm.User(**kwargs)
|
||||
db.add(orm_user)
|
||||
else:
|
||||
for attr, value in kwargs.items():
|
||||
setattr(orm_user, attr, value)
|
||||
db.commit()
|
||||
if app:
|
||||
return app.users[orm_user.id]
|
||||
else:
|
||||
return orm_user
|
||||
|
||||
|
||||
def auth_header(db, name):
|
||||
"""Return header with user's API authorization token."""
|
||||
user = find_user(db, name)
|
||||
if user is None:
|
||||
user = add_user(db, name=name)
|
||||
token = user.new_api_token()
|
||||
return {'Authorization': 'token %s' % token}
|
||||
|
||||
|
||||
@check_db_locks
|
||||
async def api_request(app, *api_path, **kwargs):
|
||||
"""Make an API request"""
|
||||
base_url = app.hub.url
|
||||
headers = kwargs.setdefault('headers', {})
|
||||
|
||||
if 'Authorization' not in headers and not kwargs.pop('noauth', False):
|
||||
# make a copy to avoid modifying arg in-place
|
||||
kwargs['headers'] = h = {}
|
||||
h.update(headers)
|
||||
h.update(auth_header(app.db, 'admin'))
|
||||
|
||||
url = ujoin(base_url, 'api', *api_path)
|
||||
method = kwargs.pop('method', 'get')
|
||||
f = getattr(async_requests, method)
|
||||
if app.internal_ssl:
|
||||
kwargs['cert'] = (app.internal_ssl_cert, app.internal_ssl_key)
|
||||
kwargs["verify"] = app.internal_ssl_ca
|
||||
resp = await f(url, **kwargs)
|
||||
assert "frame-ancestors 'self'" in resp.headers['Content-Security-Policy']
|
||||
assert ujoin(app.hub.base_url, "security/csp-report") in resp.headers['Content-Security-Policy']
|
||||
assert 'http' not in resp.headers['Content-Security-Policy']
|
||||
if not kwargs.get('stream', False) and resp.content:
|
||||
assert resp.headers.get('content-type') == 'application/json'
|
||||
return resp
|
||||
from .utils import (
|
||||
add_user,
|
||||
api_request,
|
||||
async_requests,
|
||||
auth_header,
|
||||
find_user,
|
||||
)
|
||||
|
||||
|
||||
# --------------------
|
||||
@@ -197,7 +113,7 @@ def normalize_timestamp(ts):
|
||||
"""
|
||||
if ts is None:
|
||||
return
|
||||
return re.sub('\d(\.\d+)?', '0', ts)
|
||||
return re.sub(r'\d(\.\d+)?', '0', ts)
|
||||
|
||||
|
||||
def normalize_user(user):
|
||||
|
@@ -12,7 +12,8 @@ from requests import HTTPError
|
||||
from jupyterhub import auth, crypto, orm
|
||||
|
||||
from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd
|
||||
from .test_api import add_user
|
||||
from .utils import add_user
|
||||
|
||||
|
||||
async def test_pam_auth():
|
||||
authenticator = MockPAMAuthenticator()
|
||||
|
179
jupyterhub/tests/test_auth_expiry.py
Normal file
179
jupyterhub/tests/test_auth_expiry.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
test authentication expiry
|
||||
|
||||
authentication can expire in a number of ways:
|
||||
|
||||
- needs refresh and can be refreshed
|
||||
- doesn't need refresh
|
||||
- needs refresh and cannot be refreshed without new login
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
from unittest import mock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
|
||||
from .utils import api_request, get_page
|
||||
|
||||
|
||||
async def refresh_expired(authenticator, user):
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_refresh(app):
|
||||
"""Fixture disabling auth refresh"""
|
||||
with mock.patch.object(app.authenticator, 'refresh_user', refresh_expired):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def refresh_pre_spawn(app):
|
||||
"""Fixture enabling auth refresh pre spawn"""
|
||||
app.authenticator.refresh_pre_spawn = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
app.authenticator.refresh_pre_spawn = False
|
||||
|
||||
|
||||
async def test_auth_refresh_at_login(app, user):
|
||||
# auth_refreshed starts unset:
|
||||
assert not user._auth_refreshed
|
||||
# login sets auth_refreshed timestamp
|
||||
await app.login_user(user.name)
|
||||
assert user._auth_refreshed
|
||||
user._auth_refreshed -= 10
|
||||
before = user._auth_refreshed
|
||||
# login again updates auth_refreshed timestamp
|
||||
# even when auth is fresh
|
||||
await app.login_user(user.name)
|
||||
assert user._auth_refreshed > before
|
||||
|
||||
|
||||
async def test_auth_refresh_page(app, user):
|
||||
cookies = await app.login_user(user.name)
|
||||
assert user._auth_refreshed
|
||||
user._auth_refreshed -= 10
|
||||
before = user._auth_refreshed
|
||||
|
||||
# get a page with auth not expired
|
||||
# doesn't trigger refresh
|
||||
r = await get_page('home', app, cookies=cookies)
|
||||
assert r.status_code == 200
|
||||
assert user._auth_refreshed == before
|
||||
|
||||
# get a page with stale auth, refreshes auth
|
||||
user._auth_refreshed -= app.authenticator.auth_refresh_age
|
||||
r = await get_page('home', app, cookies=cookies)
|
||||
assert r.status_code == 200
|
||||
assert user._auth_refreshed > before
|
||||
|
||||
|
||||
async def test_auth_expired_page(app, user, disable_refresh):
|
||||
cookies = await app.login_user(user.name)
|
||||
assert user._auth_refreshed
|
||||
user._auth_refreshed -= 10
|
||||
before = user._auth_refreshed
|
||||
|
||||
# auth is fresh, doesn't trigger expiry
|
||||
r = await get_page('home', app, cookies=cookies)
|
||||
assert user._auth_refreshed == before
|
||||
assert r.status_code == 200
|
||||
|
||||
# get a page with stale auth, triggers expiry
|
||||
user._auth_refreshed -= app.authenticator.auth_refresh_age
|
||||
before = user._auth_refreshed
|
||||
r = await get_page('home', app, cookies=cookies, allow_redirects=False)
|
||||
|
||||
# verify that we redirect to login with ?next=requested page
|
||||
assert r.status_code == 302
|
||||
redirect_url = urlparse(r.headers['Location'])
|
||||
assert redirect_url.path.endswith('/login')
|
||||
query = parse_qs(redirect_url.query)
|
||||
assert query['next']
|
||||
next_url = urlparse(query['next'][0])
|
||||
assert next_url.path == urlparse(r.url).path
|
||||
|
||||
# make sure refresh didn't get updated
|
||||
assert user._auth_refreshed == before
|
||||
|
||||
|
||||
async def test_auth_expired_api(app, user, disable_refresh):
|
||||
cookies = await app.login_user(user.name)
|
||||
assert user._auth_refreshed
|
||||
user._auth_refreshed -= 10
|
||||
before = user._auth_refreshed
|
||||
|
||||
# auth is fresh, doesn't trigger expiry
|
||||
r = await api_request(app, 'users/' + user.name, name=user.name)
|
||||
assert user._auth_refreshed == before
|
||||
assert r.status_code == 200
|
||||
|
||||
# get a page with stale auth, triggers expiry
|
||||
user._auth_refreshed -= app.authenticator.auth_refresh_age
|
||||
r = await api_request(app, 'users/' + user.name, name=user.name)
|
||||
# api requests can't do login redirects
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
async def test_refresh_pre_spawn(app, user, refresh_pre_spawn):
|
||||
cookies = await app.login_user(user.name)
|
||||
assert user._auth_refreshed
|
||||
user._auth_refreshed -= 10
|
||||
before = user._auth_refreshed
|
||||
|
||||
# auth is fresh, but should be forced to refresh by spawn
|
||||
r = await api_request(
|
||||
app, 'users/{}/server'.format(user.name), method='post', name=user.name
|
||||
)
|
||||
assert 200 <= r.status_code < 300
|
||||
assert user._auth_refreshed > before
|
||||
|
||||
|
||||
async def test_refresh_pre_spawn_expired(app, user, refresh_pre_spawn, disable_refresh):
|
||||
cookies = await app.login_user(user.name)
|
||||
assert user._auth_refreshed
|
||||
user._auth_refreshed -= 10
|
||||
before = user._auth_refreshed
|
||||
|
||||
# auth is fresh, doesn't trigger expiry
|
||||
r = await api_request(
|
||||
app, 'users/{}/server'.format(user.name), method='post', name=user.name
|
||||
)
|
||||
assert r.status_code == 403
|
||||
assert user._auth_refreshed == before
|
||||
|
||||
|
||||
async def test_refresh_pre_spawn_admin_request(
|
||||
app, user, admin_user, refresh_pre_spawn
|
||||
):
|
||||
await app.login_user(user.name)
|
||||
await app.login_user(admin_user.name)
|
||||
user._auth_refreshed -= 10
|
||||
before = user._auth_refreshed
|
||||
|
||||
# admin request, auth is fresh. Should still refresh user auth.
|
||||
r = await api_request(
|
||||
app, 'users', user.name, 'server', method='post', name=admin_user.name
|
||||
)
|
||||
assert 200 <= r.status_code < 300
|
||||
assert user._auth_refreshed > before
|
||||
|
||||
|
||||
async def test_refresh_pre_spawn_expired_admin_request(
|
||||
app, user, admin_user, refresh_pre_spawn, disable_refresh
|
||||
):
|
||||
await app.login_user(user.name)
|
||||
await app.login_user(admin_user.name)
|
||||
user._auth_refreshed -= 10
|
||||
|
||||
# auth needs refresh but can't without a new login; spawn should fail
|
||||
user._auth_refreshed -= app.authenticator.auth_refresh_age
|
||||
r = await api_request(
|
||||
app, 'users', user.name, 'server', method='post', name=admin_user.name
|
||||
)
|
||||
# api requests can't do login redirects
|
||||
assert r.status_code == 403
|
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from unittest import mock
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
@@ -13,21 +14,17 @@ from ..utils import url_path_join as ujoin
|
||||
from .. import orm
|
||||
from ..auth import Authenticator
|
||||
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from .mocking import FormSpawner, public_url, public_host
|
||||
from .test_api import api_request, add_user
|
||||
from .utils import async_requests
|
||||
|
||||
|
||||
def get_page(path, app, hub=True, **kw):
|
||||
if hub:
|
||||
prefix = app.hub.base_url
|
||||
else:
|
||||
prefix = app.base_url
|
||||
base_url = ujoin(public_host(app), prefix)
|
||||
return async_requests.get(ujoin(base_url, path), **kw)
|
||||
from .mocking import FormSpawner
|
||||
from .utils import (
|
||||
async_requests,
|
||||
api_request,
|
||||
add_user,
|
||||
get_page,
|
||||
public_url,
|
||||
public_host,
|
||||
)
|
||||
|
||||
|
||||
async def test_root_no_auth(app):
|
||||
|
@@ -84,7 +84,8 @@ async def test_external_proxy(request):
|
||||
# add user to the db and start a single user server
|
||||
name = 'river'
|
||||
add_user(app.db, app, name=name)
|
||||
r = await api_request(app, 'users', name, 'server', method='post')
|
||||
r = await api_request(app, 'users', name, 'server', method='post',
|
||||
bypass_proxy=True)
|
||||
r.raise_for_status()
|
||||
|
||||
routes = await app.proxy.get_all_routes()
|
||||
@@ -108,7 +109,7 @@ async def test_external_proxy(request):
|
||||
assert list(routes.keys()) == []
|
||||
|
||||
# poke the server to update the proxy
|
||||
r = await api_request(app, 'proxy', method='post')
|
||||
r = await api_request(app, 'proxy', method='post', bypass_proxy=True)
|
||||
r.raise_for_status()
|
||||
|
||||
# check that the routes are correct
|
||||
@@ -135,10 +136,16 @@ async def test_external_proxy(request):
|
||||
|
||||
# tell the hub where the new proxy is
|
||||
new_api_url = 'http://{}:{}'.format(proxy_ip, proxy_port)
|
||||
r = await api_request(app, 'proxy', method='patch', data=json.dumps({
|
||||
'api_url': new_api_url,
|
||||
'auth_token': new_auth_token,
|
||||
}))
|
||||
r = await api_request(
|
||||
app,
|
||||
'proxy',
|
||||
method='patch',
|
||||
data=json.dumps({
|
||||
'api_url': new_api_url,
|
||||
'auth_token': new_auth_token,
|
||||
}),
|
||||
bypass_proxy=True,
|
||||
)
|
||||
r.raise_for_status()
|
||||
assert app.proxy.api_url == new_api_url
|
||||
|
||||
|
@@ -4,6 +4,10 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from certipy import Certipy
|
||||
import requests
|
||||
|
||||
from jupyterhub import orm
|
||||
from jupyterhub.objects import Server
|
||||
from jupyterhub.utils import url_path_join as ujoin
|
||||
|
||||
|
||||
class _AsyncRequests:
|
||||
"""Wrapper around requests to return a Future from request methods
|
||||
@@ -46,3 +50,144 @@ def ssl_setup(cert_dir, authority_name):
|
||||
"external", authority_name, overwrite=True, alt_names=alt_names
|
||||
)
|
||||
return external_certs
|
||||
|
||||
|
||||
def check_db_locks(func):
|
||||
"""Decorator that verifies no locks are held on database upon exit.
|
||||
|
||||
This decorator for test functions verifies no locks are held on the
|
||||
application's database upon exit by creating and dropping a dummy table.
|
||||
|
||||
The decorator relies on an instance of JupyterHubApp being the first
|
||||
argument to the decorated function.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
@check_db_locks
|
||||
def api_request(app, *api_path, **kwargs):
|
||||
|
||||
"""
|
||||
def new_func(app, *args, **kwargs):
|
||||
retval = func(app, *args, **kwargs)
|
||||
|
||||
temp_session = app.session_factory()
|
||||
temp_session.execute('CREATE TABLE dummy (foo INT)')
|
||||
temp_session.execute('DROP TABLE dummy')
|
||||
temp_session.close()
|
||||
|
||||
return retval
|
||||
|
||||
return new_func
|
||||
|
||||
|
||||
def find_user(db, name, app=None):
|
||||
"""Find user in database."""
|
||||
orm_user = db.query(orm.User).filter(orm.User.name == name).first()
|
||||
if app is None:
|
||||
return orm_user
|
||||
else:
|
||||
return app.users[orm_user.id]
|
||||
|
||||
|
||||
def add_user(db, app=None, **kwargs):
|
||||
"""Add a user to the database."""
|
||||
orm_user = find_user(db, name=kwargs.get('name'))
|
||||
if orm_user is None:
|
||||
orm_user = orm.User(**kwargs)
|
||||
db.add(orm_user)
|
||||
else:
|
||||
for attr, value in kwargs.items():
|
||||
setattr(orm_user, attr, value)
|
||||
db.commit()
|
||||
if app:
|
||||
return app.users[orm_user.id]
|
||||
else:
|
||||
return orm_user
|
||||
|
||||
|
||||
def auth_header(db, name):
|
||||
"""Return header with user's API authorization token."""
|
||||
user = find_user(db, name)
|
||||
if user is None:
|
||||
user = add_user(db, name=name)
|
||||
token = user.new_api_token()
|
||||
return {'Authorization': 'token %s' % token}
|
||||
|
||||
|
||||
@check_db_locks
|
||||
async def api_request(app, *api_path, method='get',
|
||||
noauth=False, bypass_proxy=False,
|
||||
**kwargs):
|
||||
"""Make an API request"""
|
||||
if bypass_proxy:
|
||||
# make a direct request to the hub,
|
||||
# skipping the proxy
|
||||
base_url = app.hub.url
|
||||
else:
|
||||
base_url = public_url(app, path='hub')
|
||||
headers = kwargs.setdefault('headers', {})
|
||||
|
||||
if (
|
||||
'Authorization' not in headers
|
||||
and not noauth
|
||||
and 'cookies' not in kwargs
|
||||
):
|
||||
# make a copy to avoid modifying arg in-place
|
||||
kwargs['headers'] = h = {}
|
||||
h.update(headers)
|
||||
h.update(auth_header(app.db, kwargs.pop('name', 'admin')))
|
||||
|
||||
if 'cookies' in kwargs:
|
||||
# for cookie-authenticated requests,
|
||||
# set Referer so it looks like the request originated
|
||||
# from a Hub-served page
|
||||
headers.setdefault('Referer', ujoin(base_url, 'test'))
|
||||
|
||||
url = ujoin(base_url, 'api', *api_path)
|
||||
f = getattr(async_requests, method)
|
||||
if app.internal_ssl:
|
||||
kwargs['cert'] = (app.internal_ssl_cert, app.internal_ssl_key)
|
||||
kwargs["verify"] = app.internal_ssl_ca
|
||||
resp = await f(url, **kwargs)
|
||||
assert "frame-ancestors 'self'" in resp.headers['Content-Security-Policy']
|
||||
assert ujoin(app.hub.base_url, "security/csp-report") in resp.headers['Content-Security-Policy']
|
||||
assert 'http' not in resp.headers['Content-Security-Policy']
|
||||
if not kwargs.get('stream', False) and resp.content:
|
||||
assert resp.headers.get('content-type') == 'application/json'
|
||||
return resp
|
||||
|
||||
|
||||
def get_page(path, app, hub=True, **kw):
|
||||
if hub:
|
||||
prefix = app.hub.base_url
|
||||
else:
|
||||
prefix = app.base_url
|
||||
base_url = ujoin(public_host(app), prefix)
|
||||
return async_requests.get(ujoin(base_url, path), **kw)
|
||||
|
||||
|
||||
def public_host(app):
|
||||
"""Return the public *host* (no URL prefix) of the given JupyterHub instance."""
|
||||
if app.subdomain_host:
|
||||
return app.subdomain_host
|
||||
else:
|
||||
return Server.from_url(app.proxy.public_url).host
|
||||
|
||||
|
||||
def public_url(app, user_or_service=None, path=''):
|
||||
"""Return the full, public base URL (including prefix) of the given JupyterHub instance."""
|
||||
if user_or_service:
|
||||
if app.subdomain_host:
|
||||
host = user_or_service.host
|
||||
else:
|
||||
host = public_host(app)
|
||||
prefix = user_or_service.prefix
|
||||
else:
|
||||
host = public_host(app)
|
||||
prefix = Server.from_url(app.proxy.public_url).base_url
|
||||
if path:
|
||||
return host + ujoin(prefix, path)
|
||||
else:
|
||||
return host + prefix
|
||||
|
||||
|
@@ -8,7 +8,9 @@ import warnings
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from tornado import gen
|
||||
from tornado.httputil import urlencode
|
||||
from tornado.log import app_log
|
||||
from tornado import web
|
||||
|
||||
from .utils import maybe_future, url_path_join, make_ssl_context
|
||||
|
||||
@@ -136,6 +138,7 @@ class User:
|
||||
orm_user = None
|
||||
log = app_log
|
||||
settings = None
|
||||
_auth_refreshed = None
|
||||
|
||||
def __init__(self, orm_user, settings=None, db=None):
|
||||
self.db = db or inspect(orm_user).session
|
||||
@@ -380,6 +383,59 @@ class User:
|
||||
url_parts.extend(['server/progress'])
|
||||
return url_path_join(*url_parts)
|
||||
|
||||
async def refresh_auth(self, handler):
|
||||
"""Refresh authentication if needed
|
||||
|
||||
Checks authentication expiry and refresh it if needed.
|
||||
See Spawner.
|
||||
|
||||
If the auth is expired and cannot be refreshed
|
||||
without forcing a new login, a few things can happen:
|
||||
|
||||
1. if this is a normal user spawn,
|
||||
the user should be redirected to login
|
||||
and back to spawn after login.
|
||||
2. if this is a spawn via API or other user,
|
||||
spawn will fail until the user logs in again.
|
||||
|
||||
Args:
|
||||
handler (RequestHandler):
|
||||
The handler for the request triggering the spawn.
|
||||
May be None
|
||||
"""
|
||||
authenticator = self.authenticator
|
||||
if authenticator is None or not authenticator.refresh_pre_spawn:
|
||||
# nothing to do
|
||||
return
|
||||
|
||||
# refresh auth
|
||||
auth_user = await handler.refresh_auth(self, force=True)
|
||||
|
||||
if auth_user:
|
||||
# auth refreshed, all done
|
||||
return
|
||||
|
||||
# if we got to here, auth is expired and couldn't be refreshed
|
||||
self.log.error(
|
||||
"Auth expired for %s; cannot spawn until they login again",
|
||||
self.name,
|
||||
)
|
||||
# auth expired, cannot spawn without a fresh login
|
||||
# it's the current user *and* spawn via GET, trigger login redirect
|
||||
if handler.request.method == 'GET' and handler.current_user is self:
|
||||
self.log.info("Redirecting %s to login to refresh auth", self.name)
|
||||
url = self.get_login_url()
|
||||
next_url = self.request.uri
|
||||
sep = '&' if '?' in url else '?'
|
||||
url += sep + urlencode(dict(next=next_url))
|
||||
self.redirect(url)
|
||||
raise web.Finish()
|
||||
else:
|
||||
# spawn via POST or on behalf of another user.
|
||||
# nothing we can do here but fail
|
||||
raise web.HTTPError(400, "{}'s authentication has expired".format(self.name))
|
||||
|
||||
|
||||
async def spawn(self, server_name='', options=None, handler=None):
|
||||
"""Start the user's spawner
|
||||
|
||||
@@ -395,6 +451,9 @@ class User:
|
||||
"""
|
||||
db = self.db
|
||||
|
||||
if handler:
|
||||
await self.refresh_auth(handler)
|
||||
|
||||
base_url = url_path_join(self.base_url, server_name) + '/'
|
||||
|
||||
orm_server = orm.Server(
|
||||
@@ -436,7 +495,7 @@ class User:
|
||||
|
||||
# trigger pre-spawn hook on authenticator
|
||||
authenticator = self.authenticator
|
||||
if (authenticator):
|
||||
if authenticator:
|
||||
await maybe_future(authenticator.pre_spawn_start(self, spawner))
|
||||
|
||||
spawner._start_pending = True
|
||||
|
Reference in New Issue
Block a user