Merge pull request #2342 from minrk/expire-auth

allow auth to expire
This commit is contained in:
Min RK
2019-02-05 13:05:00 +01:00
committed by GitHub
12 changed files with 508 additions and 147 deletions

View File

@@ -19,7 +19,7 @@ except Exception as e:
from tornado.concurrent import run_on_executor from tornado.concurrent import run_on_executor
from traitlets.config import LoggingConfigurable from traitlets.config import LoggingConfigurable
from traitlets import Bool, Set, Unicode, Dict, Any, default, observe from traitlets import Bool, Integer, Set, Unicode, Dict, Any, default, observe
from .handlers.login import LoginHandler from .handlers.login import LoginHandler
from .utils import maybe_future, url_path_join from .utils import maybe_future, url_path_join
@@ -50,6 +50,35 @@ class Authenticator(LoggingConfigurable):
""", """,
) )
auth_refresh_age = Integer(
300,
config=True,
help="""The max age (in seconds) of authentication info
before forcing a refresh of user auth info.
Refreshing auth info allows, e.g. requesting/re-validating auth tokens.
See :meth:`.refresh_user` for what happens when user auth info is refreshed
(nothing by default).
"""
)
refresh_pre_spawn = Bool(
False,
config=True,
help="""Force refresh of auth prior to spawn.
This forces :meth:`.refresh_user` to be called prior to launching
a server, to ensure that auth state is up-to-date.
This can be important when e.g. auth tokens that may have expired
are passed to the spawner via environment variables from auth_state.
If refresh_user cannot refresh the user auth data,
launch will fail until the user logs in again.
"""
)
admin_users = Set( admin_users = Set(
help=""" help="""
Set of users that will have admin rights on this JupyterHub. Set of users that will have admin rights on this JupyterHub.

View File

@@ -28,6 +28,7 @@ from .. import __version__
from .. import orm from .. import orm
from ..objects import Server from ..objects import Server
from ..spawner import LocalProcessSpawner from ..spawner import LocalProcessSpawner
from ..user import User
from ..utils import maybe_future, url_path_join from ..utils import maybe_future, url_path_join
from ..metrics import ( from ..metrics import (
SERVER_SPAWN_DURATION_SECONDS, ServerSpawnStatus, SERVER_SPAWN_DURATION_SECONDS, ServerSpawnStatus,
@@ -240,7 +241,7 @@ class BaseHandler(RequestHandler):
self.db.commit() self.db.commit()
return self._user_from_orm(orm_token.user) return self._user_from_orm(orm_token.user)
async def refresh_user_auth(self, user, force=False): async def refresh_auth(self, user, force=False):
"""Refresh user authentication info """Refresh user authentication info
Calls `authenticator.refresh_user(user)` Calls `authenticator.refresh_user(user)`
@@ -254,7 +255,12 @@ class BaseHandler(RequestHandler):
user (User): the user having been refreshed, user (User): the user having been refreshed,
or None if the user must login again to refresh auth info. or None if the user must login again to refresh auth info.
""" """
if not force: # TODO: and it's sufficiently recent refresh_age = self.authenticator.auth_refresh_age
if not refresh_age:
return user
now = time.monotonic()
if not force and user._auth_refreshed and (now - user._auth_refreshed < refresh_age):
# auth up-to-date
return user return user
# refresh a user at most once per request # refresh a user at most once per request
@@ -275,6 +281,8 @@ class BaseHandler(RequestHandler):
) )
return return
user._auth_refreshed = now
if auth_info == True: if auth_info == True:
# refresh_user confirmed that it's up-to-date, # refresh_user confirmed that it's up-to-date,
# nothing to refresh # nothing to refresh
@@ -355,8 +363,8 @@ class BaseHandler(RequestHandler):
user = self.get_current_user_token() user = self.get_current_user_token()
if user is None: if user is None:
user = self.get_current_user_cookie() user = self.get_current_user_cookie()
if user: if user and isinstance(user, User):
user = await self.refresh_user_auth(user) user = await self.refresh_auth(user)
self._jupyterhub_user = user self._jupyterhub_user = user
except Exception: except Exception:
# don't let errors here raise more than once # don't let errors here raise more than once
@@ -610,6 +618,7 @@ class BaseHandler(RequestHandler):
self.statsd.incr('login.success') self.statsd.incr('login.success')
self.statsd.timing('login.authenticate.success', auth_timer.ms) self.statsd.timing('login.authenticate.success', auth_timer.ms)
self.log.info("User logged in: %s", user.name) self.log.info("User logged in: %s", user.name)
user._auth_refreshed = time.monotonic()
return user return user
else: else:
self.statsd.incr('login.failure') self.statsd.incr('login.failure')
@@ -643,6 +652,11 @@ class BaseHandler(RequestHandler):
async def spawn_single_user(self, user, server_name='', options=None): async def spawn_single_user(self, user, server_name='', options=None):
# in case of error, include 'try again from /hub/home' message # in case of error, include 'try again from /hub/home' message
if self.authenticator.refresh_pre_spawn:
auth_user = await self.refresh_auth(user, force=True)
if auth_user is None:
raise web.HTTPError(403, "auth has expired for %s, login again", user.name)
spawn_start_time = time.perf_counter() spawn_start_time = time.perf_counter()
self.extra_error_html = self.spawn_home_error self.extra_error_html = self.spawn_home_error

View File

@@ -258,7 +258,6 @@ class Service(LoggingConfigurable):
def _default_redirect_uri(self): def _default_redirect_uri(self):
if self.server is None: if self.server is None:
return '' return ''
print(self.domain, self.host, self.server)
return self.host + url_path_join(self.prefix, 'oauth_callback') return self.host + url_path_join(self.prefix, 'oauth_callback')
@property @property

View File

@@ -47,7 +47,7 @@ from ..utils import random_port
from . import mocking from . import mocking
from .mocking import MockHub from .mocking import MockHub
from .utils import ssl_setup from .utils import ssl_setup, add_user
from .test_services import mockservice_cmd from .test_services import mockservice_cmd
import jupyterhub.services.service import jupyterhub.services.service
@@ -185,6 +185,43 @@ def cleanup_after(request, io_loop):
app.db.commit() app.db.commit()
_username_counter = 0
def new_username(prefix='testuser'):
"""Return a new unique username"""
global _username_counter
_username_counter += 1
return '{}-{}'.format(prefix, _username_counter)
@fixture
def username():
"""allocate a temporary username
unique each time the fixture is used
"""
yield new_username()
@fixture
def user(app):
"""Fixture for creating a temporary user
Each time the fixture is used, a new user is created
"""
user = add_user(app.db, app, name=new_username())
yield user
@fixture
def admin_user(app, username):
"""Fixture for creating a temporary admin user"""
user = add_user(app.db, app, name=new_username('testadmin'), admin=True)
yield user
class MockServiceSpawner(jupyterhub.services.service._ServiceSpawner): class MockServiceSpawner(jupyterhub.services.service._ServiceSpawner):
"""mock services for testing. """mock services for testing.

View File

@@ -49,7 +49,7 @@ from ..objects import Server
from ..spawner import LocalProcessSpawner, SimpleLocalProcessSpawner from ..spawner import LocalProcessSpawner, SimpleLocalProcessSpawner
from ..singleuser import SingleUserNotebookApp from ..singleuser import SingleUserNotebookApp
from ..utils import random_port, url_path_join from ..utils import random_port, url_path_join
from .utils import async_requests, ssl_setup from .utils import async_requests, ssl_setup, public_host, public_url
from pamela import PAMError from pamela import PAMError
@@ -223,7 +223,10 @@ class MockHub(JupyterHub):
last_activity_interval = 2 last_activity_interval = 2
log_datefmt = '%M:%S' log_datefmt = '%M:%S'
external_certs = None external_certs = None
log_level = 10
@default('log_level')
def _default_log_level(self):
return 10
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if 'internal_certs_location' in kwargs: if 'internal_certs_location' in kwargs:
@@ -351,31 +354,6 @@ class MockHub(JupyterHub):
return r.cookies return r.cookies
def public_host(app):
"""Return the public *host* (no URL prefix) of the given JupyterHub instance."""
if app.subdomain_host:
return app.subdomain_host
else:
return Server.from_url(app.proxy.public_url).host
def public_url(app, user_or_service=None, path=''):
"""Return the full, public base URL (including prefix) of the given JupyterHub instance."""
if user_or_service:
if app.subdomain_host:
host = user_or_service.host
else:
host = public_host(app)
prefix = user_or_service.prefix
else:
host = public_host(app)
prefix = Server.from_url(app.proxy.public_url).base_url
if path:
return host + url_path_join(prefix, path)
else:
return host + prefix
# single-user-server mocking: # single-user-server mocking:
class MockSingleUserServer(SingleUserNotebookApp): class MockSingleUserServer(SingleUserNotebookApp):

View File

@@ -18,97 +18,13 @@ import jupyterhub
from .. import orm from .. import orm
from ..utils import url_path_join as ujoin from ..utils import url_path_join as ujoin
from .mocking import public_host, public_url from .mocking import public_host, public_url
from .utils import async_requests from .utils import (
add_user,
api_request,
def check_db_locks(func): async_requests,
"""Decorator that verifies no locks are held on database upon exit. auth_header,
find_user,
This decorator for test functions verifies no locks are held on the )
application's database upon exit by creating and dropping a dummy table.
The decorator relies on an instance of JupyterHubApp being the first
argument to the decorated function.
Example
-------
@check_db_locks
def api_request(app, *api_path, **kwargs):
"""
def new_func(app, *args, **kwargs):
retval = func(app, *args, **kwargs)
temp_session = app.session_factory()
temp_session.execute('CREATE TABLE dummy (foo INT)')
temp_session.execute('DROP TABLE dummy')
temp_session.close()
return retval
return new_func
def find_user(db, name, app=None):
"""Find user in database."""
orm_user = db.query(orm.User).filter(orm.User.name == name).first()
if app is None:
return orm_user
else:
return app.users[orm_user.id]
def add_user(db, app=None, **kwargs):
"""Add a user to the database."""
orm_user = find_user(db, name=kwargs.get('name'))
if orm_user is None:
orm_user = orm.User(**kwargs)
db.add(orm_user)
else:
for attr, value in kwargs.items():
setattr(orm_user, attr, value)
db.commit()
if app:
return app.users[orm_user.id]
else:
return orm_user
def auth_header(db, name):
"""Return header with user's API authorization token."""
user = find_user(db, name)
if user is None:
user = add_user(db, name=name)
token = user.new_api_token()
return {'Authorization': 'token %s' % token}
@check_db_locks
async def api_request(app, *api_path, **kwargs):
"""Make an API request"""
base_url = app.hub.url
headers = kwargs.setdefault('headers', {})
if 'Authorization' not in headers and not kwargs.pop('noauth', False):
# make a copy to avoid modifying arg in-place
kwargs['headers'] = h = {}
h.update(headers)
h.update(auth_header(app.db, 'admin'))
url = ujoin(base_url, 'api', *api_path)
method = kwargs.pop('method', 'get')
f = getattr(async_requests, method)
if app.internal_ssl:
kwargs['cert'] = (app.internal_ssl_cert, app.internal_ssl_key)
kwargs["verify"] = app.internal_ssl_ca
resp = await f(url, **kwargs)
assert "frame-ancestors 'self'" in resp.headers['Content-Security-Policy']
assert ujoin(app.hub.base_url, "security/csp-report") in resp.headers['Content-Security-Policy']
assert 'http' not in resp.headers['Content-Security-Policy']
if not kwargs.get('stream', False) and resp.content:
assert resp.headers.get('content-type') == 'application/json'
return resp
# -------------------- # --------------------
@@ -197,7 +113,7 @@ def normalize_timestamp(ts):
""" """
if ts is None: if ts is None:
return return
return re.sub('\d(\.\d+)?', '0', ts) return re.sub(r'\d(\.\d+)?', '0', ts)
def normalize_user(user): def normalize_user(user):

View File

@@ -12,7 +12,8 @@ from requests import HTTPError
from jupyterhub import auth, crypto, orm from jupyterhub import auth, crypto, orm
from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd
from .test_api import add_user from .utils import add_user
async def test_pam_auth(): async def test_pam_auth():
authenticator = MockPAMAuthenticator() authenticator = MockPAMAuthenticator()

View File

@@ -0,0 +1,179 @@
"""
test authentication expiry
authentication can expire in a number of ways:
- needs refresh and can be refreshed
- doesn't need refresh
- needs refresh and cannot be refreshed without new login
"""
import asyncio
from contextlib import contextmanager
from unittest import mock
from urllib.parse import parse_qs, urlparse
import pytest
from .utils import api_request, get_page
async def refresh_expired(authenticator, user):
return None
@pytest.fixture
def disable_refresh(app):
"""Fixture disabling auth refresh"""
with mock.patch.object(app.authenticator, 'refresh_user', refresh_expired):
yield
@pytest.fixture
def refresh_pre_spawn(app):
"""Fixture enabling auth refresh pre spawn"""
app.authenticator.refresh_pre_spawn = True
try:
yield
finally:
app.authenticator.refresh_pre_spawn = False
async def test_auth_refresh_at_login(app, user):
# auth_refreshed starts unset:
assert not user._auth_refreshed
# login sets auth_refreshed timestamp
await app.login_user(user.name)
assert user._auth_refreshed
user._auth_refreshed -= 10
before = user._auth_refreshed
# login again updates auth_refreshed timestamp
# even when auth is fresh
await app.login_user(user.name)
assert user._auth_refreshed > before
async def test_auth_refresh_page(app, user):
cookies = await app.login_user(user.name)
assert user._auth_refreshed
user._auth_refreshed -= 10
before = user._auth_refreshed
# get a page with auth not expired
# doesn't trigger refresh
r = await get_page('home', app, cookies=cookies)
assert r.status_code == 200
assert user._auth_refreshed == before
# get a page with stale auth, refreshes auth
user._auth_refreshed -= app.authenticator.auth_refresh_age
r = await get_page('home', app, cookies=cookies)
assert r.status_code == 200
assert user._auth_refreshed > before
async def test_auth_expired_page(app, user, disable_refresh):
cookies = await app.login_user(user.name)
assert user._auth_refreshed
user._auth_refreshed -= 10
before = user._auth_refreshed
# auth is fresh, doesn't trigger expiry
r = await get_page('home', app, cookies=cookies)
assert user._auth_refreshed == before
assert r.status_code == 200
# get a page with stale auth, triggers expiry
user._auth_refreshed -= app.authenticator.auth_refresh_age
before = user._auth_refreshed
r = await get_page('home', app, cookies=cookies, allow_redirects=False)
# verify that we redirect to login with ?next=requested page
assert r.status_code == 302
redirect_url = urlparse(r.headers['Location'])
assert redirect_url.path.endswith('/login')
query = parse_qs(redirect_url.query)
assert query['next']
next_url = urlparse(query['next'][0])
assert next_url.path == urlparse(r.url).path
# make sure refresh didn't get updated
assert user._auth_refreshed == before
async def test_auth_expired_api(app, user, disable_refresh):
cookies = await app.login_user(user.name)
assert user._auth_refreshed
user._auth_refreshed -= 10
before = user._auth_refreshed
# auth is fresh, doesn't trigger expiry
r = await api_request(app, 'users/' + user.name, name=user.name)
assert user._auth_refreshed == before
assert r.status_code == 200
# get a page with stale auth, triggers expiry
user._auth_refreshed -= app.authenticator.auth_refresh_age
r = await api_request(app, 'users/' + user.name, name=user.name)
# api requests can't do login redirects
assert r.status_code == 403
async def test_refresh_pre_spawn(app, user, refresh_pre_spawn):
cookies = await app.login_user(user.name)
assert user._auth_refreshed
user._auth_refreshed -= 10
before = user._auth_refreshed
# auth is fresh, but should be forced to refresh by spawn
r = await api_request(
app, 'users/{}/server'.format(user.name), method='post', name=user.name
)
assert 200 <= r.status_code < 300
assert user._auth_refreshed > before
async def test_refresh_pre_spawn_expired(app, user, refresh_pre_spawn, disable_refresh):
cookies = await app.login_user(user.name)
assert user._auth_refreshed
user._auth_refreshed -= 10
before = user._auth_refreshed
# auth is fresh, doesn't trigger expiry
r = await api_request(
app, 'users/{}/server'.format(user.name), method='post', name=user.name
)
assert r.status_code == 403
assert user._auth_refreshed == before
async def test_refresh_pre_spawn_admin_request(
app, user, admin_user, refresh_pre_spawn
):
await app.login_user(user.name)
await app.login_user(admin_user.name)
user._auth_refreshed -= 10
before = user._auth_refreshed
# admin request, auth is fresh. Should still refresh user auth.
r = await api_request(
app, 'users', user.name, 'server', method='post', name=admin_user.name
)
assert 200 <= r.status_code < 300
assert user._auth_refreshed > before
async def test_refresh_pre_spawn_expired_admin_request(
app, user, admin_user, refresh_pre_spawn, disable_refresh
):
await app.login_user(user.name)
await app.login_user(admin_user.name)
user._auth_refreshed -= 10
# auth needs refresh but can't without a new login; spawn should fail
user._auth_refreshed -= app.authenticator.auth_refresh_age
r = await api_request(
app, 'users', user.name, 'server', method='post', name=admin_user.name
)
# api requests can't do login redirects
assert r.status_code == 403

View File

@@ -2,6 +2,7 @@
import asyncio import asyncio
import sys import sys
from unittest import mock
from urllib.parse import urlencode, urlparse from urllib.parse import urlencode, urlparse
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
@@ -13,21 +14,17 @@ from ..utils import url_path_join as ujoin
from .. import orm from .. import orm
from ..auth import Authenticator from ..auth import Authenticator
import mock
import pytest import pytest
from .mocking import FormSpawner, public_url, public_host from .mocking import FormSpawner
from .test_api import api_request, add_user from .utils import (
from .utils import async_requests async_requests,
api_request,
add_user,
def get_page(path, app, hub=True, **kw): get_page,
if hub: public_url,
prefix = app.hub.base_url public_host,
else: )
prefix = app.base_url
base_url = ujoin(public_host(app), prefix)
return async_requests.get(ujoin(base_url, path), **kw)
async def test_root_no_auth(app): async def test_root_no_auth(app):

View File

@@ -84,7 +84,8 @@ async def test_external_proxy(request):
# add user to the db and start a single user server # add user to the db and start a single user server
name = 'river' name = 'river'
add_user(app.db, app, name=name) add_user(app.db, app, name=name)
r = await api_request(app, 'users', name, 'server', method='post') r = await api_request(app, 'users', name, 'server', method='post',
bypass_proxy=True)
r.raise_for_status() r.raise_for_status()
routes = await app.proxy.get_all_routes() routes = await app.proxy.get_all_routes()
@@ -108,7 +109,7 @@ async def test_external_proxy(request):
assert list(routes.keys()) == [] assert list(routes.keys()) == []
# poke the server to update the proxy # poke the server to update the proxy
r = await api_request(app, 'proxy', method='post') r = await api_request(app, 'proxy', method='post', bypass_proxy=True)
r.raise_for_status() r.raise_for_status()
# check that the routes are correct # check that the routes are correct
@@ -135,10 +136,16 @@ async def test_external_proxy(request):
# tell the hub where the new proxy is # tell the hub where the new proxy is
new_api_url = 'http://{}:{}'.format(proxy_ip, proxy_port) new_api_url = 'http://{}:{}'.format(proxy_ip, proxy_port)
r = await api_request(app, 'proxy', method='patch', data=json.dumps({ r = await api_request(
'api_url': new_api_url, app,
'auth_token': new_auth_token, 'proxy',
})) method='patch',
data=json.dumps({
'api_url': new_api_url,
'auth_token': new_auth_token,
}),
bypass_proxy=True,
)
r.raise_for_status() r.raise_for_status()
assert app.proxy.api_url == new_api_url assert app.proxy.api_url == new_api_url

View File

@@ -4,6 +4,10 @@ from concurrent.futures import ThreadPoolExecutor
from certipy import Certipy from certipy import Certipy
import requests import requests
from jupyterhub import orm
from jupyterhub.objects import Server
from jupyterhub.utils import url_path_join as ujoin
class _AsyncRequests: class _AsyncRequests:
"""Wrapper around requests to return a Future from request methods """Wrapper around requests to return a Future from request methods
@@ -46,3 +50,144 @@ def ssl_setup(cert_dir, authority_name):
"external", authority_name, overwrite=True, alt_names=alt_names "external", authority_name, overwrite=True, alt_names=alt_names
) )
return external_certs return external_certs
def check_db_locks(func):
"""Decorator that verifies no locks are held on database upon exit.
This decorator for test functions verifies no locks are held on the
application's database upon exit by creating and dropping a dummy table.
The decorator relies on an instance of JupyterHubApp being the first
argument to the decorated function.
Example
-------
@check_db_locks
def api_request(app, *api_path, **kwargs):
"""
def new_func(app, *args, **kwargs):
retval = func(app, *args, **kwargs)
temp_session = app.session_factory()
temp_session.execute('CREATE TABLE dummy (foo INT)')
temp_session.execute('DROP TABLE dummy')
temp_session.close()
return retval
return new_func
def find_user(db, name, app=None):
"""Find user in database."""
orm_user = db.query(orm.User).filter(orm.User.name == name).first()
if app is None:
return orm_user
else:
return app.users[orm_user.id]
def add_user(db, app=None, **kwargs):
"""Add a user to the database."""
orm_user = find_user(db, name=kwargs.get('name'))
if orm_user is None:
orm_user = orm.User(**kwargs)
db.add(orm_user)
else:
for attr, value in kwargs.items():
setattr(orm_user, attr, value)
db.commit()
if app:
return app.users[orm_user.id]
else:
return orm_user
def auth_header(db, name):
"""Return header with user's API authorization token."""
user = find_user(db, name)
if user is None:
user = add_user(db, name=name)
token = user.new_api_token()
return {'Authorization': 'token %s' % token}
@check_db_locks
async def api_request(app, *api_path, method='get',
noauth=False, bypass_proxy=False,
**kwargs):
"""Make an API request"""
if bypass_proxy:
# make a direct request to the hub,
# skipping the proxy
base_url = app.hub.url
else:
base_url = public_url(app, path='hub')
headers = kwargs.setdefault('headers', {})
if (
'Authorization' not in headers
and not noauth
and 'cookies' not in kwargs
):
# make a copy to avoid modifying arg in-place
kwargs['headers'] = h = {}
h.update(headers)
h.update(auth_header(app.db, kwargs.pop('name', 'admin')))
if 'cookies' in kwargs:
# for cookie-authenticated requests,
# set Referer so it looks like the request originated
# from a Hub-served page
headers.setdefault('Referer', ujoin(base_url, 'test'))
url = ujoin(base_url, 'api', *api_path)
f = getattr(async_requests, method)
if app.internal_ssl:
kwargs['cert'] = (app.internal_ssl_cert, app.internal_ssl_key)
kwargs["verify"] = app.internal_ssl_ca
resp = await f(url, **kwargs)
assert "frame-ancestors 'self'" in resp.headers['Content-Security-Policy']
assert ujoin(app.hub.base_url, "security/csp-report") in resp.headers['Content-Security-Policy']
assert 'http' not in resp.headers['Content-Security-Policy']
if not kwargs.get('stream', False) and resp.content:
assert resp.headers.get('content-type') == 'application/json'
return resp
def get_page(path, app, hub=True, **kw):
if hub:
prefix = app.hub.base_url
else:
prefix = app.base_url
base_url = ujoin(public_host(app), prefix)
return async_requests.get(ujoin(base_url, path), **kw)
def public_host(app):
"""Return the public *host* (no URL prefix) of the given JupyterHub instance."""
if app.subdomain_host:
return app.subdomain_host
else:
return Server.from_url(app.proxy.public_url).host
def public_url(app, user_or_service=None, path=''):
"""Return the full, public base URL (including prefix) of the given JupyterHub instance."""
if user_or_service:
if app.subdomain_host:
host = user_or_service.host
else:
host = public_host(app)
prefix = user_or_service.prefix
else:
host = public_host(app)
prefix = Server.from_url(app.proxy.public_url).base_url
if path:
return host + ujoin(prefix, path)
else:
return host + prefix

View File

@@ -8,7 +8,9 @@ import warnings
from sqlalchemy import inspect from sqlalchemy import inspect
from tornado import gen from tornado import gen
from tornado.httputil import urlencode
from tornado.log import app_log from tornado.log import app_log
from tornado import web
from .utils import maybe_future, url_path_join, make_ssl_context from .utils import maybe_future, url_path_join, make_ssl_context
@@ -136,6 +138,7 @@ class User:
orm_user = None orm_user = None
log = app_log log = app_log
settings = None settings = None
_auth_refreshed = None
def __init__(self, orm_user, settings=None, db=None): def __init__(self, orm_user, settings=None, db=None):
self.db = db or inspect(orm_user).session self.db = db or inspect(orm_user).session
@@ -380,6 +383,59 @@ class User:
url_parts.extend(['server/progress']) url_parts.extend(['server/progress'])
return url_path_join(*url_parts) return url_path_join(*url_parts)
async def refresh_auth(self, handler):
"""Refresh authentication if needed
Checks authentication expiry and refresh it if needed.
See Spawner.
If the auth is expired and cannot be refreshed
without forcing a new login, a few things can happen:
1. if this is a normal user spawn,
the user should be redirected to login
and back to spawn after login.
2. if this is a spawn via API or other user,
spawn will fail until the user logs in again.
Args:
handler (RequestHandler):
The handler for the request triggering the spawn.
May be None
"""
authenticator = self.authenticator
if authenticator is None or not authenticator.refresh_pre_spawn:
# nothing to do
return
# refresh auth
auth_user = await handler.refresh_auth(self, force=True)
if auth_user:
# auth refreshed, all done
return
# if we got to here, auth is expired and couldn't be refreshed
self.log.error(
"Auth expired for %s; cannot spawn until they login again",
self.name,
)
# auth expired, cannot spawn without a fresh login
# it's the current user *and* spawn via GET, trigger login redirect
if handler.request.method == 'GET' and handler.current_user is self:
self.log.info("Redirecting %s to login to refresh auth", self.name)
url = self.get_login_url()
next_url = self.request.uri
sep = '&' if '?' in url else '?'
url += sep + urlencode(dict(next=next_url))
self.redirect(url)
raise web.Finish()
else:
# spawn via POST or on behalf of another user.
# nothing we can do here but fail
raise web.HTTPError(400, "{}'s authentication has expired".format(self.name))
async def spawn(self, server_name='', options=None, handler=None): async def spawn(self, server_name='', options=None, handler=None):
"""Start the user's spawner """Start the user's spawner
@@ -395,6 +451,9 @@ class User:
""" """
db = self.db db = self.db
if handler:
await self.refresh_auth(handler)
base_url = url_path_join(self.base_url, server_name) + '/' base_url = url_path_join(self.base_url, server_name) + '/'
orm_server = orm.Server( orm_server = orm.Server(
@@ -436,7 +495,7 @@ class User:
# trigger pre-spawn hook on authenticator # trigger pre-spawn hook on authenticator
authenticator = self.authenticator authenticator = self.authenticator
if (authenticator): if authenticator:
await maybe_future(authenticator.pre_spawn_start(self, spawner)) await maybe_future(authenticator.pre_spawn_start(self, spawner))
spawner._start_pending = True spawner._start_pending = True