run autoformat

apologies to anyone finding this commit via git blame or log

run the autoformatting by

    pre-commit run --all-files
This commit is contained in:
Min RK
2019-02-14 15:08:59 +01:00
parent ca198e0363
commit 5e60582ef3
118 changed files with 3583 additions and 2934 deletions

View File

@@ -1 +0,0 @@

View File

@@ -1,19 +1,16 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
""" """
bower-lite bower-lite
Since Bower's on its way out, Since Bower's on its way out,
stage frontend dependencies from node_modules into components stage frontend dependencies from node_modules into components
""" """
import json import json
import os import os
from os.path import join
import shutil import shutil
from os.path import join
HERE = os.path.abspath(os.path.dirname(__file__)) HERE = os.path.abspath(os.path.dirname(__file__))

View File

@@ -1,16 +1,16 @@
-r requirements.txt -r requirements.txt
mock # temporary pin of attrs for jsonschema 0.3.0a1
# seems to be a pip bug
attrs>=17.4.0
beautifulsoup4 beautifulsoup4
codecov codecov
coverage coverage
cryptography cryptography
html5lib # needed for beautifulsoup html5lib # needed for beautifulsoup
pytest-cov mock
pytest-asyncio
pytest>=3.3
notebook notebook
pytest-asyncio
pytest-cov
pytest>=3.3
requests-mock requests-mock
virtualenv virtualenv
# temporary pin of attrs for jsonschema 0.3.0a1
# seems to be a pip bug
attrs>=17.4.0

View File

@@ -7,5 +7,3 @@ ENV LANG=en_US.UTF-8
USER nobody USER nobody
CMD ["jupyterhub"] CMD ["jupyterhub"]

View File

@@ -18,4 +18,3 @@ Dockerfile.alpine contains base image for jupyterhub. It does not work independ
* Use dummy authenticator for ease of testing. Update following in jupyterhub_config file * Use dummy authenticator for ease of testing. Update following in jupyterhub_config file
- c.JupyterHub.authenticator_class = 'dummyauthenticator.DummyAuthenticator' - c.JupyterHub.authenticator_class = 'dummyauthenticator.DummyAuthenticator'
- c.DummyAuthenticator.password = "your strong password" - c.DummyAuthenticator.password = "your strong password"

View File

@@ -1,7 +1,7 @@
# ReadTheDocs uses the `environment.yaml` so make sure to update that as well # ReadTheDocs uses the `environment.yaml` so make sure to update that as well
# if you change this file # if you change this file
-r ../requirements.txt -r ../requirements.txt
sphinx>=1.7 alabaster_jupyterhub
recommonmark==0.4.0 recommonmark==0.4.0
sphinx-copybutton sphinx-copybutton
alabaster_jupyterhub sphinx>=1.7

View File

@@ -13,4 +13,3 @@ Module: :mod:`jupyterhub.app`
------------------- -------------------
.. autoconfigurable:: JupyterHub .. autoconfigurable:: JupyterHub

View File

@@ -30,4 +30,3 @@ Module: :mod:`jupyterhub.auth`
--------------------------- ---------------------------
.. autoconfigurable:: DummyAuthenticator .. autoconfigurable:: DummyAuthenticator

View File

@@ -20,4 +20,3 @@ Module: :mod:`jupyterhub.proxy`
.. autoconfigurable:: ConfigurableHTTPProxy .. autoconfigurable:: ConfigurableHTTPProxy
:members: debug, auth_token, check_running_interval, api_url, command :members: debug, auth_token, check_running_interval, api_url, command

View File

@@ -14,4 +14,3 @@ Module: :mod:`jupyterhub.services.service`
.. autoconfigurable:: Service .. autoconfigurable:: Service
:members: name, admin, url, api_token, managed, kind, command, cwd, environment, user, oauth_client_id, server, prefix, proxy_spec :members: name, admin, url, api_token, managed, kind, command, cwd, environment, user, oauth_client_id, server, prefix, proxy_spec

View File

@@ -38,4 +38,3 @@ Module: :mod:`jupyterhub.services.auth`
-------------------------------- --------------------------------
.. autoclass:: HubOAuthCallbackHandler .. autoclass:: HubOAuthCallbackHandler

View File

@@ -19,4 +19,3 @@ Module: :mod:`jupyterhub.spawner`
---------------------------- ----------------------------
.. autoconfigurable:: LocalProcessSpawner .. autoconfigurable:: LocalProcessSpawner

View File

@@ -34,4 +34,3 @@ Module: :mod:`jupyterhub.user`
.. attribute:: spawner .. attribute:: spawner
The user's :class:`~.Spawner` instance. The user's :class:`~.Spawner` instance.

View File

@@ -1,11 +1,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import sys
import os import os
import shlex import shlex
import sys
import recommonmark.parser
# For conversion from markdown to html # For conversion from markdown to html
import recommonmark.parser
# Set paths # Set paths
sys.path.insert(0, os.path.abspath('.')) sys.path.insert(0, os.path.abspath('.'))
@@ -21,7 +22,7 @@ extensions = [
'sphinx.ext.intersphinx', 'sphinx.ext.intersphinx',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'autodoc_traits', 'autodoc_traits',
'sphinx_copybutton' 'sphinx_copybutton',
] ]
templates_path = ['_templates'] templates_path = ['_templates']
@@ -68,6 +69,7 @@ source_suffix = ['.rst', '.md']
# The theme to use for HTML and HTML Help pages. # The theme to use for HTML and HTML Help pages.
import alabaster_jupyterhub import alabaster_jupyterhub
html_theme = 'alabaster_jupyterhub' html_theme = 'alabaster_jupyterhub'
html_theme_path = [alabaster_jupyterhub.get_html_theme_path()] html_theme_path = [alabaster_jupyterhub.get_html_theme_path()]

View File

@@ -210,4 +210,3 @@ jupyterhub_config.py amendments:
--This is the address on which the proxy will bind. Sets protocol, ip, base_url --This is the address on which the proxy will bind. Sets protocol, ip, base_url
c.JupyterHub.bind_url = 'http://127.0.0.1:8000/jhub/' c.JupyterHub.bind_url = 'http://127.0.0.1:8000/jhub/'
``` ```

View File

@@ -1,8 +1,9 @@
"""autodoc extension for configurable traits""" """autodoc extension for configurable traits"""
from traitlets import TraitType, Undefined
from sphinx.domains.python import PyClassmember from sphinx.domains.python import PyClassmember
from sphinx.ext.autodoc import ClassDocumenter, AttributeDocumenter from sphinx.ext.autodoc import AttributeDocumenter
from sphinx.ext.autodoc import ClassDocumenter
from traitlets import TraitType
from traitlets import Undefined
class ConfigurableDocumenter(ClassDocumenter): class ConfigurableDocumenter(ClassDocumenter):

View File

@@ -5,8 +5,10 @@ create a directory for the user before the spawner starts
# pylint: disable=import-error # pylint: disable=import-error
import os import os
import shutil import shutil
from jupyter_client.localinterfaces import public_ips from jupyter_client.localinterfaces import public_ips
def create_dir_hook(spawner): def create_dir_hook(spawner):
""" Create directory """ """ Create directory """
username = spawner.user.name # get the username username = spawner.user.name # get the username
@@ -16,6 +18,7 @@ def create_dir_hook(spawner):
# now do whatever you think your user needs # now do whatever you think your user needs
# ... # ...
def clean_dir_hook(spawner): def clean_dir_hook(spawner):
""" Delete directory """ """ Delete directory """
username = spawner.user.name # get the username username = spawner.user.name # get the username
@@ -23,6 +26,7 @@ def clean_dir_hook(spawner):
if os.path.exists(temp_path) and os.path.isdir(temp_path): if os.path.exists(temp_path) and os.path.isdir(temp_path):
shutil.rmtree(temp_path) shutil.rmtree(temp_path)
# attach the hook functions to the spawner # attach the hook functions to the spawner
# pylint: disable=undefined-variable # pylint: disable=undefined-variable
c.Spawner.pre_spawn_hook = create_dir_hook c.Spawner.pre_spawn_hook = create_dir_hook

View File

@@ -31,11 +31,11 @@ users and servers, you should add this script to the services list
twice, just with different ``name``s, different values, and one with twice, just with different ``name``s, different values, and one with
the ``--cull-users`` option. the ``--cull-users`` option.
""" """
from datetime import datetime, timezone
from functools import partial
import json import json
import os import os
from datetime import datetime
from datetime import timezone
from functools import partial
try: try:
from urllib.parse import quote from urllib.parse import quote
@@ -85,23 +85,21 @@ def format_td(td):
@coroutine @coroutine
def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concurrency=10): def cull_idle(
url, api_token, inactive_limit, cull_users=False, max_age=0, concurrency=10
):
"""Shutdown idle single-user servers """Shutdown idle single-user servers
If cull_users, inactive *users* will be deleted as well. If cull_users, inactive *users* will be deleted as well.
""" """
auth_header = { auth_header = {'Authorization': 'token %s' % api_token}
'Authorization': 'token %s' % api_token, req = HTTPRequest(url=url + '/users', headers=auth_header)
}
req = HTTPRequest(
url=url + '/users',
headers=auth_header,
)
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
client = AsyncHTTPClient() client = AsyncHTTPClient()
if concurrency: if concurrency:
semaphore = Semaphore(concurrency) semaphore = Semaphore(concurrency)
@coroutine @coroutine
def fetch(req): def fetch(req):
"""client.fetch wrapped in a semaphore to limit concurrency""" """client.fetch wrapped in a semaphore to limit concurrency"""
@@ -110,6 +108,7 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
return (yield client.fetch(req)) return (yield client.fetch(req))
finally: finally:
yield semaphore.release() yield semaphore.release()
else: else:
fetch = client.fetch fetch = client.fetch
@@ -129,8 +128,8 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
log_name = '%s/%s' % (user['name'], server_name) log_name = '%s/%s' % (user['name'], server_name)
if server.get('pending'): if server.get('pending'):
app_log.warning( app_log.warning(
"Not culling server %s with pending %s", "Not culling server %s with pending %s", log_name, server['pending']
log_name, server['pending']) )
return False return False
# jupyterhub < 0.9 defined 'server.url' once the server was ready # jupyterhub < 0.9 defined 'server.url' once the server was ready
@@ -142,8 +141,8 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
if not server.get('ready', bool(server['url'])): if not server.get('ready', bool(server['url'])):
app_log.warning( app_log.warning(
"Not culling not-ready not-pending server %s: %s", "Not culling not-ready not-pending server %s: %s", log_name, server
log_name, server) )
return False return False
if server.get('started'): if server.get('started'):
@@ -163,12 +162,13 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
# for running servers # for running servers
inactive = age inactive = age
should_cull = (inactive is not None and should_cull = (
inactive.total_seconds() >= inactive_limit) inactive is not None and inactive.total_seconds() >= inactive_limit
)
if should_cull: if should_cull:
app_log.info( app_log.info(
"Culling server %s (inactive for %s)", "Culling server %s (inactive for %s)", log_name, format_td(inactive)
log_name, format_td(inactive)) )
if max_age and not should_cull: if max_age and not should_cull:
# only check started if max_age is specified # only check started if max_age is specified
@@ -177,32 +177,34 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
if age is not None and age.total_seconds() >= max_age: if age is not None and age.total_seconds() >= max_age:
app_log.info( app_log.info(
"Culling server %s (age: %s, inactive for %s)", "Culling server %s (age: %s, inactive for %s)",
log_name, format_td(age), format_td(inactive)) log_name,
format_td(age),
format_td(inactive),
)
should_cull = True should_cull = True
if not should_cull: if not should_cull:
app_log.debug( app_log.debug(
"Not culling server %s (age: %s, inactive for %s)", "Not culling server %s (age: %s, inactive for %s)",
log_name, format_td(age), format_td(inactive)) log_name,
format_td(age),
format_td(inactive),
)
return False return False
if server_name: if server_name:
# culling a named server # culling a named server
delete_url = url + "/users/%s/servers/%s" % ( delete_url = url + "/users/%s/servers/%s" % (
quote(user['name']), quote(server['name']) quote(user['name']),
quote(server['name']),
) )
else: else:
delete_url = url + '/users/%s/server' % quote(user['name']) delete_url = url + '/users/%s/server' % quote(user['name'])
req = HTTPRequest( req = HTTPRequest(url=delete_url, method='DELETE', headers=auth_header)
url=delete_url, method='DELETE', headers=auth_header,
)
resp = yield fetch(req) resp = yield fetch(req)
if resp.code == 202: if resp.code == 202:
app_log.warning( app_log.warning("Server %s is slow to stop", log_name)
"Server %s is slow to stop",
log_name,
)
# return False to prevent culling user with pending shutdowns # return False to prevent culling user with pending shutdowns
return False return False
return True return True
@@ -245,7 +247,9 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
if still_alive: if still_alive:
app_log.debug( app_log.debug(
"Not culling user %s with %i servers still alive", "Not culling user %s with %i servers still alive",
user['name'], still_alive) user['name'],
still_alive,
)
return False return False
should_cull = False should_cull = False
@@ -265,12 +269,11 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
# which introduces the 'created' field which is never None # which introduces the 'created' field which is never None
inactive = age inactive = age
should_cull = (inactive is not None and should_cull = (
inactive.total_seconds() >= inactive_limit) inactive is not None and inactive.total_seconds() >= inactive_limit
)
if should_cull: if should_cull:
app_log.info( app_log.info("Culling user %s (inactive for %s)", user['name'], inactive)
"Culling user %s (inactive for %s)",
user['name'], inactive)
if max_age and not should_cull: if max_age and not should_cull:
# only check created if max_age is specified # only check created if max_age is specified
@@ -279,19 +282,23 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
if age is not None and age.total_seconds() >= max_age: if age is not None and age.total_seconds() >= max_age:
app_log.info( app_log.info(
"Culling user %s (age: %s, inactive for %s)", "Culling user %s (age: %s, inactive for %s)",
user['name'], format_td(age), format_td(inactive)) user['name'],
format_td(age),
format_td(inactive),
)
should_cull = True should_cull = True
if not should_cull: if not should_cull:
app_log.debug( app_log.debug(
"Not culling user %s (created: %s, last active: %s)", "Not culling user %s (created: %s, last active: %s)",
user['name'], format_td(age), format_td(inactive)) user['name'],
format_td(age),
format_td(inactive),
)
return False return False
req = HTTPRequest( req = HTTPRequest(
url=url + '/users/%s' % user['name'], url=url + '/users/%s' % user['name'], method='DELETE', headers=auth_header
method='DELETE',
headers=auth_header,
) )
yield fetch(req) yield fetch(req)
return True return True
@@ -316,20 +323,30 @@ if __name__ == '__main__':
help="The JupyterHub API URL", help="The JupyterHub API URL",
) )
define('timeout', default=600, help="The idle timeout (in seconds)") define('timeout', default=600, help="The idle timeout (in seconds)")
define('cull_every', default=0, define(
help="The interval (in seconds) for checking for idle servers to cull") 'cull_every',
define('max_age', default=0, default=0,
help="The maximum age (in seconds) of servers that should be culled even if they are active") help="The interval (in seconds) for checking for idle servers to cull",
define('cull_users', default=False, )
define(
'max_age',
default=0,
help="The maximum age (in seconds) of servers that should be culled even if they are active",
)
define(
'cull_users',
default=False,
help="""Cull users in addition to servers. help="""Cull users in addition to servers.
This is for use in temporary-user cases such as tmpnb.""", This is for use in temporary-user cases such as tmpnb.""",
) )
define('concurrency', default=10, define(
'concurrency',
default=10,
help="""Limit the number of concurrent requests made to the Hub. help="""Limit the number of concurrent requests made to the Hub.
Deleting a lot of users at the same time can slow down the Hub, Deleting a lot of users at the same time can slow down the Hub,
so limit the number of API requests we have outstanding at any given time. so limit the number of API requests we have outstanding at any given time.
""" """,
) )
parse_command_line() parse_command_line()
@@ -343,7 +360,8 @@ if __name__ == '__main__':
app_log.warning( app_log.warning(
"Could not load pycurl: %s\n" "Could not load pycurl: %s\n"
"pycurl is recommended if you have a large number of users.", "pycurl is recommended if you have a large number of users.",
e) e,
)
loop = IOLoop.current() loop = IOLoop.current()
cull = partial( cull = partial(

View File

@@ -4,7 +4,9 @@ import os
# this could come from anywhere # this could come from anywhere
api_token = os.getenv("JUPYTERHUB_API_TOKEN") api_token = os.getenv("JUPYTERHUB_API_TOKEN")
if not api_token: if not api_token:
raise ValueError("Make sure to `export JUPYTERHUB_API_TOKEN=$(openssl rand -hex 32)`") raise ValueError(
"Make sure to `export JUPYTERHUB_API_TOKEN=$(openssl rand -hex 32)`"
)
# tell JupyterHub to register the service as an external oauth client # tell JupyterHub to register the service as an external oauth client
@@ -14,5 +16,5 @@ c.JupyterHub.services = [
'oauth_client_id': "whoami-oauth-client-test", 'oauth_client_id': "whoami-oauth-client-test",
'api_token': api_token, 'api_token': api_token,
'oauth_redirect_uri': 'http://127.0.0.1:5555/oauth_callback', 'oauth_redirect_uri': 'http://127.0.0.1:5555/oauth_callback',
}, }
] ]

View File

@@ -3,18 +3,19 @@
Implements OAuth handshake manually Implements OAuth handshake manually
so all URLs and requests necessary for OAuth with JupyterHub should be in one place so all URLs and requests necessary for OAuth with JupyterHub should be in one place
""" """
import json import json
import os import os
import sys import sys
from urllib.parse import urlencode, urlparse from urllib.parse import urlencode
from urllib.parse import urlparse
from tornado.auth import OAuth2Mixin
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
from tornado.httputil import url_concat
from tornado.ioloop import IOLoop
from tornado import log from tornado import log
from tornado import web from tornado import web
from tornado.auth import OAuth2Mixin
from tornado.httpclient import AsyncHTTPClient
from tornado.httpclient import HTTPRequest
from tornado.httputil import url_concat
from tornado.ioloop import IOLoop
class JupyterHubLoginHandler(web.RequestHandler): class JupyterHubLoginHandler(web.RequestHandler):
@@ -32,11 +33,11 @@ class JupyterHubLoginHandler(web.RequestHandler):
code=code, code=code,
redirect_uri=self.settings['redirect_uri'], redirect_uri=self.settings['redirect_uri'],
) )
req = HTTPRequest(self.settings['token_url'], method='POST', req = HTTPRequest(
self.settings['token_url'],
method='POST',
body=urlencode(params).encode('utf8'), body=urlencode(params).encode('utf8'),
headers={ headers={'Content-Type': 'application/x-www-form-urlencoded'},
'Content-Type': 'application/x-www-form-urlencoded',
},
) )
response = await AsyncHTTPClient().fetch(req) response = await AsyncHTTPClient().fetch(req)
data = json.loads(response.body.decode('utf8', 'replace')) data = json.loads(response.body.decode('utf8', 'replace'))
@@ -55,14 +56,16 @@ class JupyterHubLoginHandler(web.RequestHandler):
# we are the login handler, # we are the login handler,
# begin oauth process which will come back later with an # begin oauth process which will come back later with an
# authorization_code # authorization_code
self.redirect(url_concat( self.redirect(
url_concat(
self.settings['authorize_url'], self.settings['authorize_url'],
dict( dict(
redirect_uri=self.settings['redirect_uri'], redirect_uri=self.settings['redirect_uri'],
client_id=self.settings['client_id'], client_id=self.settings['client_id'],
response_type='code', response_type='code',
),
)
) )
))
class WhoAmIHandler(web.RequestHandler): class WhoAmIHandler(web.RequestHandler):
@@ -85,10 +88,7 @@ class WhoAmIHandler(web.RequestHandler):
"""Retrieve the user for a given token, via /hub/api/user""" """Retrieve the user for a given token, via /hub/api/user"""
req = HTTPRequest( req = HTTPRequest(
self.settings['user_url'], self.settings['user_url'], headers={'Authorization': f'token {token}'}
headers={
'Authorization': f'token {token}'
},
) )
response = await AsyncHTTPClient().fetch(req) response = await AsyncHTTPClient().fetch(req)
return json.loads(response.body.decode('utf8', 'replace')) return json.loads(response.body.decode('utf8', 'replace'))
@@ -110,23 +110,23 @@ def main():
token_url = hub_api + '/oauth2/token' token_url = hub_api + '/oauth2/token'
user_url = hub_api + '/user' user_url = hub_api + '/user'
app = web.Application([ app = web.Application(
('/oauth_callback', JupyterHubLoginHandler), [('/oauth_callback', JupyterHubLoginHandler), ('/', WhoAmIHandler)],
('/', WhoAmIHandler),
],
login_url='/oauth_callback', login_url='/oauth_callback',
cookie_secret=os.urandom(32), cookie_secret=os.urandom(32),
api_token=os.environ['JUPYTERHUB_API_TOKEN'], api_token=os.environ['JUPYTERHUB_API_TOKEN'],
client_id=os.environ['JUPYTERHUB_CLIENT_ID'], client_id=os.environ['JUPYTERHUB_CLIENT_ID'],
redirect_uri=os.environ['JUPYTERHUB_SERVICE_URL'].rstrip('/') + '/oauth_callback', redirect_uri=os.environ['JUPYTERHUB_SERVICE_URL'].rstrip('/')
+ '/oauth_callback',
authorize_url=authorize_url, authorize_url=authorize_url,
token_url=token_url, token_url=token_url,
user_url=user_url, user_url=user_url,
) )
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL']) url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
log.app_log.info("Running basic whoami service on %s", log.app_log.info(
os.environ['JUPYTERHUB_SERVICE_URL']) "Running basic whoami service on %s", os.environ['JUPYTERHUB_SERVICE_URL']
)
app.listen(url.port, url.hostname) app.listen(url.port, url.hostname)
IOLoop.current().start() IOLoop.current().start()

View File

@@ -8,10 +8,10 @@ c.Authenticator.whitelist = {'ganymede', 'io', 'rhea'}
# These environment variables are automatically supplied by the linked postgres # These environment variables are automatically supplied by the linked postgres
# container. # container.
import os; import os
pg_pass = os.getenv('POSTGRES_ENV_JPY_PSQL_PASSWORD') pg_pass = os.getenv('POSTGRES_ENV_JPY_PSQL_PASSWORD')
pg_host = os.getenv('POSTGRES_PORT_5432_TCP_ADDR') pg_host = os.getenv('POSTGRES_PORT_5432_TCP_ADDR')
c.JupyterHub.db_url = 'postgresql://jupyterhub:{}@{}:5432/jupyterhub'.format( c.JupyterHub.db_url = 'postgresql://jupyterhub:{}@{}:5432/jupyterhub'.format(
pg_pass, pg_pass, pg_host
pg_host,
) )

View File

@@ -1,11 +1,14 @@
import argparse import argparse
import datetime import datetime
import json import json
import os import os
from tornado import escape
from tornado import gen
from tornado import ioloop
from tornado import web
from jupyterhub.services.auth import HubAuthenticated from jupyterhub.services.auth import HubAuthenticated
from tornado import escape, gen, ioloop, web
class AnnouncementRequestHandler(HubAuthenticated, web.RequestHandler): class AnnouncementRequestHandler(HubAuthenticated, web.RequestHandler):
@@ -53,19 +56,19 @@ def main():
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--api-prefix", "-a", parser.add_argument(
"--api-prefix",
"-a",
default=os.environ.get("JUPYTERHUB_SERVICE_PREFIX", "/"), default=os.environ.get("JUPYTERHUB_SERVICE_PREFIX", "/"),
help="application API prefix") help="application API prefix",
parser.add_argument("--port", "-p", )
default=8888, parser.add_argument(
help="port for API to listen on", "--port", "-p", default=8888, help="port for API to listen on", type=int
type=int) )
return parser.parse_args() return parser.parse_args()
def create_application(api_prefix="/", def create_application(api_prefix="/", handler=AnnouncementRequestHandler, **kwargs):
handler=AnnouncementRequestHandler,
**kwargs):
storage = dict(announcement="", timestamp="", user="") storage = dict(announcement="", timestamp="", user="")
return web.Application([(api_prefix, handler, dict(storage=storage))]) return web.Application([(api_prefix, handler, dict(storage=storage))])

View File

@@ -22,4 +22,3 @@ In the external example, some extra steps are required to set up supervisor:
3. install `shared-notebook-service` somewhere on your system, and update `/path/to/shared-notebook-service` to the absolute path of this destination 3. install `shared-notebook-service` somewhere on your system, and update `/path/to/shared-notebook-service` to the absolute path of this destination
3. copy `shared-notebook.conf` to `/etc/supervisor/conf.d/` 3. copy `shared-notebook.conf` to `/etc/supervisor/conf.d/`
4. `supervisorctl reload` 4. `supervisorctl reload`

View File

@@ -1,18 +1,9 @@
# our user list # our user list
c.Authenticator.whitelist = [ c.Authenticator.whitelist = ['minrk', 'ellisonbg', 'willingc']
'minrk',
'ellisonbg',
'willingc',
]
# ellisonbg and willingc have access to a shared server: # ellisonbg and willingc have access to a shared server:
c.JupyterHub.load_groups = { c.JupyterHub.load_groups = {'shared': ['ellisonbg', 'willingc']}
'shared': [
'ellisonbg',
'willingc',
]
}
# start the notebook server as a service # start the notebook server as a service
c.JupyterHub.services = [ c.JupyterHub.services = [

View File

@@ -1,18 +1,9 @@
# our user list # our user list
c.Authenticator.whitelist = [ c.Authenticator.whitelist = ['minrk', 'ellisonbg', 'willingc']
'minrk',
'ellisonbg',
'willingc',
]
# ellisonbg and willingc have access to a shared server: # ellisonbg and willingc have access to a shared server:
c.JupyterHub.load_groups = { c.JupyterHub.load_groups = {'shared': ['ellisonbg', 'willingc']}
'shared': [
'ellisonbg',
'willingc',
]
}
service_name = 'shared-notebook' service_name = 'shared-notebook'
service_port = 9999 service_port = 9999
@@ -23,10 +14,6 @@ c.JupyterHub.services = [
{ {
'name': service_name, 'name': service_name,
'url': 'http://127.0.0.1:{}'.format(service_port), 'url': 'http://127.0.0.1:{}'.format(service_port),
'command': [ 'command': ['jupyterhub-singleuser', '--group=shared', '--debug'],
'jupyterhub-singleuser',
'--group=shared',
'--debug',
],
} }
] ]

View File

@@ -6,16 +6,12 @@ c.JupyterHub.services = [
'name': 'whoami', 'name': 'whoami',
'url': 'http://127.0.0.1:10101', 'url': 'http://127.0.0.1:10101',
'command': ['flask', 'run', '--port=10101'], 'command': ['flask', 'run', '--port=10101'],
'environment': { 'environment': {'FLASK_APP': 'whoami-flask.py'},
'FLASK_APP': 'whoami-flask.py',
}
}, },
{ {
'name': 'whoami-oauth', 'name': 'whoami-oauth',
'url': 'http://127.0.0.1:10201', 'url': 'http://127.0.0.1:10201',
'command': ['flask', 'run', '--port=10201'], 'command': ['flask', 'run', '--port=10201'],
'environment': { 'environment': {'FLASK_APP': 'whoami-oauth.py'},
'FLASK_APP': 'whoami-oauth.py',
}
}, },
] ]

View File

@@ -2,29 +2,29 @@
""" """
whoami service authentication with the Hub whoami service authentication with the Hub
""" """
from functools import wraps
import json import json
import os import os
from functools import wraps
from urllib.parse import quote from urllib.parse import quote
from flask import Flask, redirect, request, Response from flask import Flask
from flask import redirect
from flask import request
from flask import Response
from jupyterhub.services.auth import HubAuth from jupyterhub.services.auth import HubAuth
prefix = os.environ.get('JUPYTERHUB_SERVICE_PREFIX', '/') prefix = os.environ.get('JUPYTERHUB_SERVICE_PREFIX', '/')
auth = HubAuth( auth = HubAuth(api_token=os.environ['JUPYTERHUB_API_TOKEN'], cache_max_age=60)
api_token=os.environ['JUPYTERHUB_API_TOKEN'],
cache_max_age=60,
)
app = Flask(__name__) app = Flask(__name__)
def authenticated(f): def authenticated(f):
"""Decorator for authenticating with the Hub""" """Decorator for authenticating with the Hub"""
@wraps(f) @wraps(f)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
cookie = request.cookies.get(auth.cookie_name) cookie = request.cookies.get(auth.cookie_name)
@@ -40,6 +40,7 @@ def authenticated(f):
else: else:
# redirect to login url on failed auth # redirect to login url on failed auth
return redirect(auth.login_url + '?next=%s' % quote(request.path)) return redirect(auth.login_url + '?next=%s' % quote(request.path))
return decorated return decorated
@@ -47,7 +48,5 @@ def authenticated(f):
@authenticated @authenticated
def whoami(user): def whoami(user):
return Response( return Response(
json.dumps(user, indent=1, sort_keys=True), json.dumps(user, indent=1, sort_keys=True), mimetype='application/json'
mimetype='application/json',
) )

View File

@@ -2,28 +2,29 @@
""" """
whoami service authentication with the Hub whoami service authentication with the Hub
""" """
from functools import wraps
import json import json
import os import os
from functools import wraps
from flask import Flask, redirect, request, Response, make_response from flask import Flask
from flask import make_response
from flask import redirect
from flask import request
from flask import Response
from jupyterhub.services.auth import HubOAuth from jupyterhub.services.auth import HubOAuth
prefix = os.environ.get('JUPYTERHUB_SERVICE_PREFIX', '/') prefix = os.environ.get('JUPYTERHUB_SERVICE_PREFIX', '/')
auth = HubOAuth( auth = HubOAuth(api_token=os.environ['JUPYTERHUB_API_TOKEN'], cache_max_age=60)
api_token=os.environ['JUPYTERHUB_API_TOKEN'],
cache_max_age=60,
)
app = Flask(__name__) app = Flask(__name__)
def authenticated(f): def authenticated(f):
"""Decorator for authenticating with the Hub via OAuth""" """Decorator for authenticating with the Hub via OAuth"""
@wraps(f) @wraps(f)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
token = request.cookies.get(auth.cookie_name) token = request.cookies.get(auth.cookie_name)
@@ -39,6 +40,7 @@ def authenticated(f):
response = make_response(redirect(auth.login_url + '&state=%s' % state)) response = make_response(redirect(auth.login_url + '&state=%s' % state))
response.set_cookie(auth.state_cookie_name, state) response.set_cookie(auth.state_cookie_name, state)
return response return response
return decorated return decorated
@@ -46,10 +48,10 @@ def authenticated(f):
@authenticated @authenticated
def whoami(user): def whoami(user):
return Response( return Response(
json.dumps(user, indent=1, sort_keys=True), json.dumps(user, indent=1, sort_keys=True), mimetype='application/json'
mimetype='application/json',
) )
@app.route(prefix + 'oauth_callback') @app.route(prefix + 'oauth_callback')
def oauth_callback(): def oauth_callback():
code = request.args.get('code', None) code = request.args.get('code', None)

View File

@@ -4,18 +4,22 @@ This example service serves `/services/whoami/`,
authenticated with the Hub, authenticated with the Hub,
showing the user their own info. showing the user their own info.
""" """
from getpass import getuser
import json import json
import os import os
from getpass import getuser
from urllib.parse import urlparse from urllib.parse import urlparse
from tornado.ioloop import IOLoop
from tornado.httpserver import HTTPServer from tornado.httpserver import HTTPServer
from tornado.web import RequestHandler, Application, authenticated from tornado.ioloop import IOLoop
from tornado.web import Application
from tornado.web import authenticated
from tornado.web import RequestHandler
from jupyterhub.services.auth import HubOAuthenticated, HubOAuthCallbackHandler from jupyterhub.services.auth import HubOAuthCallbackHandler
from jupyterhub.services.auth import HubOAuthenticated
from jupyterhub.utils import url_path_join from jupyterhub.utils import url_path_join
class WhoAmIHandler(HubOAuthenticated, RequestHandler): class WhoAmIHandler(HubOAuthenticated, RequestHandler):
# hub_users can be a set of users who are allowed to access the service # hub_users can be a set of users who are allowed to access the service
# `getuser()` here would mean only the user who started the service # `getuser()` here would mean only the user who started the service
@@ -29,12 +33,21 @@ class WhoAmIHandler(HubOAuthenticated, RequestHandler):
self.set_header('content-type', 'application/json') self.set_header('content-type', 'application/json')
self.write(json.dumps(user_model, indent=1, sort_keys=True)) self.write(json.dumps(user_model, indent=1, sort_keys=True))
def main(): def main():
app = Application([ app = Application(
[
(os.environ['JUPYTERHUB_SERVICE_PREFIX'], WhoAmIHandler), (os.environ['JUPYTERHUB_SERVICE_PREFIX'], WhoAmIHandler),
(url_path_join(os.environ['JUPYTERHUB_SERVICE_PREFIX'], 'oauth_callback'), HubOAuthCallbackHandler), (
url_path_join(
os.environ['JUPYTERHUB_SERVICE_PREFIX'], 'oauth_callback'
),
HubOAuthCallbackHandler,
),
(r'.*', WhoAmIHandler), (r'.*', WhoAmIHandler),
], cookie_secret=os.urandom(32)) ],
cookie_secret=os.urandom(32),
)
http_server = HTTPServer(app) http_server = HTTPServer(app)
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL']) url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
@@ -43,5 +56,6 @@ def main():
IOLoop.current().start() IOLoop.current().start()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@@ -2,14 +2,16 @@
This serves `/services/whoami/`, authenticated with the Hub, showing the user their own info. This serves `/services/whoami/`, authenticated with the Hub, showing the user their own info.
""" """
from getpass import getuser
import json import json
import os import os
from getpass import getuser
from urllib.parse import urlparse from urllib.parse import urlparse
from tornado.ioloop import IOLoop
from tornado.httpserver import HTTPServer from tornado.httpserver import HTTPServer
from tornado.web import RequestHandler, Application, authenticated from tornado.ioloop import IOLoop
from tornado.web import Application
from tornado.web import authenticated
from tornado.web import RequestHandler
from jupyterhub.services.auth import HubAuthenticated from jupyterhub.services.auth import HubAuthenticated
@@ -27,11 +29,14 @@ class WhoAmIHandler(HubAuthenticated, RequestHandler):
self.set_header('content-type', 'application/json') self.set_header('content-type', 'application/json')
self.write(json.dumps(user_model, indent=1, sort_keys=True)) self.write(json.dumps(user_model, indent=1, sort_keys=True))
def main(): def main():
app = Application([ app = Application(
[
(os.environ['JUPYTERHUB_SERVICE_PREFIX'] + '/?', WhoAmIHandler), (os.environ['JUPYTERHUB_SERVICE_PREFIX'] + '/?', WhoAmIHandler),
(r'.*', WhoAmIHandler), (r'.*', WhoAmIHandler),
]) ]
)
http_server = HTTPServer(app) http_server = HTTPServer(app)
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL']) url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
@@ -40,5 +45,6 @@ def main():
IOLoop.current().start() IOLoop.current().start()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@@ -5,6 +5,7 @@ import shlex
from jupyterhub.spawner import LocalProcessSpawner from jupyterhub.spawner import LocalProcessSpawner
class DemoFormSpawner(LocalProcessSpawner): class DemoFormSpawner(LocalProcessSpawner):
def _options_form_default(self): def _options_form_default(self):
default_env = "YOURNAME=%s\n" % self.user.name default_env = "YOURNAME=%s\n" % self.user.name
@@ -13,7 +14,9 @@ class DemoFormSpawner(LocalProcessSpawner):
<input name="args" placeholder="e.g. --debug"></input> <input name="args" placeholder="e.g. --debug"></input>
<label for="env">Environment variables (one per line)</label> <label for="env">Environment variables (one per line)</label>
<textarea name="env">{env}</textarea> <textarea name="env">{env}</textarea>
""".format(env=default_env) """.format(
env=default_env
)
def options_from_form(self, formdata): def options_from_form(self, formdata):
options = {} options = {}
@@ -43,4 +46,5 @@ class DemoFormSpawner(LocalProcessSpawner):
env.update(self.user_options['env']) env.update(self.user_options['env'])
return env return env
c.JupyterHub.spawner_class = DemoFormSpawner c.JupyterHub.spawner_class = DemoFormSpawner

View File

@@ -1 +1,2 @@
from ._version import version_info, __version__ from ._version import __version__
from ._version import version_info

View File

@@ -1,2 +1,3 @@
from .app import main from .app import main
main() main()

View File

@@ -5,6 +5,7 @@ def get_data_files():
"""Walk up until we find share/jupyterhub""" """Walk up until we find share/jupyterhub"""
import sys import sys
from os.path import join, abspath, dirname, exists, split from os.path import join, abspath, dirname, exists, split
path = abspath(dirname(__file__)) path = abspath(dirname(__file__))
starting_points = [path] starting_points = [path]
if not path.startswith(sys.prefix): if not path.startswith(sys.prefix):

View File

@@ -1,5 +1,4 @@
"""JupyterHub version info""" """JupyterHub version info"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
@@ -23,16 +22,23 @@ __version__ = ".".join(map(str, version_info[:3])) + ".".join(version_info[3:])
def _check_version(hub_version, singleuser_version, log): def _check_version(hub_version, singleuser_version, log):
"""Compare Hub and single-user server versions""" """Compare Hub and single-user server versions"""
if not hub_version: if not hub_version:
log.warning("Hub has no version header, which means it is likely < 0.8. Expected %s", __version__) log.warning(
"Hub has no version header, which means it is likely < 0.8. Expected %s",
__version__,
)
return return
if not singleuser_version: if not singleuser_version:
log.warning("Single-user server has no version header, which means it is likely < 0.8. Expected %s", __version__) log.warning(
"Single-user server has no version header, which means it is likely < 0.8. Expected %s",
__version__,
)
return return
# compare minor X.Y versions # compare minor X.Y versions
if hub_version != singleuser_version: if hub_version != singleuser_version:
from distutils.version import LooseVersion as V from distutils.version import LooseVersion as V
hub_major_minor = V(hub_version).version[:2] hub_major_minor = V(hub_version).version[:2]
singleuser_major_minor = V(singleuser_version).version[:2] singleuser_major_minor = V(singleuser_version).version[:2]
extra = "" extra = ""
@@ -50,4 +56,6 @@ def _check_version(hub_version, singleuser_version, log):
singleuser_version, singleuser_version,
) )
else: else:
log.debug("jupyterhub and jupyterhub-singleuser both on version %s" % hub_version) log.debug(
"jupyterhub and jupyterhub-singleuser both on version %s" % hub_version
)

View File

@@ -1,9 +1,10 @@
import logging
import sys import sys
from logging.config import fileConfig
from alembic import context from alembic import context
from sqlalchemy import engine_from_config, pool from sqlalchemy import engine_from_config
import logging from sqlalchemy import pool
from logging.config import fileConfig
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
@@ -14,6 +15,7 @@ config = context.config
if 'jupyterhub' in sys.modules: if 'jupyterhub' in sys.modules:
from traitlets.config import MultipleInstanceError from traitlets.config import MultipleInstanceError
from jupyterhub.app import JupyterHub from jupyterhub.app import JupyterHub
app = None app = None
if JupyterHub.initialized(): if JupyterHub.initialized():
try: try:
@@ -32,6 +34,7 @@ else:
# add your model's MetaData object here for 'autogenerate' support # add your model's MetaData object here for 'autogenerate' support
from jupyterhub import orm from jupyterhub import orm
target_metadata = orm.Base.metadata target_metadata = orm.Base.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
@@ -53,8 +56,7 @@ def run_migrations_offline():
""" """
url = config.get_main_option("sqlalchemy.url") url = config.get_main_option("sqlalchemy.url")
context.configure( context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
url=url, target_metadata=target_metadata, literal_binds=True)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()
@@ -70,17 +72,16 @@ def run_migrations_online():
connectable = engine_from_config( connectable = engine_from_config(
config.get_section(config.config_ini_section), config.get_section(config.config_ini_section),
prefix='sqlalchemy.', prefix='sqlalchemy.',
poolclass=pool.NullPool) poolclass=pool.NullPool,
)
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(connection=connection, target_metadata=target_metadata)
connection=connection,
target_metadata=target_metadata
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()
if context.is_offline_mode(): if context.is_offline_mode():
run_migrations_offline() run_migrations_offline()
else: else:

View File

@@ -5,7 +5,6 @@ Revises:
Create Date: 2016-04-11 16:05:34.873288 Create Date: 2016-04-11 16:05:34.873288
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '19c0846f6344' revision = '19c0846f6344'
down_revision = None down_revision = None

View File

@@ -5,7 +5,6 @@ Revises: 3ec6993fe20c
Create Date: 2017-12-07 14:43:51.500740 Create Date: 2017-12-07 14:43:51.500740
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '1cebaf56856c' revision = '1cebaf56856c'
down_revision = '3ec6993fe20c' down_revision = '3ec6993fe20c'
@@ -13,6 +12,7 @@ branch_labels = None
depends_on = None depends_on = None
import logging import logging
logger = logging.getLogger('alembic') logger = logging.getLogger('alembic')
from alembic import op from alembic import op

View File

@@ -12,7 +12,6 @@ Revises: af4cbdb2d13c
Create Date: 2017-07-28 16:44:40.413648 Create Date: 2017-07-28 16:44:40.413648
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '3ec6993fe20c' revision = '3ec6993fe20c'
down_revision = 'af4cbdb2d13c' down_revision = 'af4cbdb2d13c'
@@ -44,7 +43,9 @@ def upgrade():
except sa.exc.OperationalError: except sa.exc.OperationalError:
# this won't be a problem moving forward, but downgrade will fail # this won't be a problem moving forward, but downgrade will fail
if op.get_context().dialect.name == 'sqlite': if op.get_context().dialect.name == 'sqlite':
logger.warning("sqlite cannot drop columns. Leaving unused old columns in place.") logger.warning(
"sqlite cannot drop columns. Leaving unused old columns in place."
)
else: else:
raise raise
@@ -54,15 +55,13 @@ def upgrade():
def downgrade(): def downgrade():
# drop all the new tables # drop all the new tables
engine = op.get_bind().engine engine = op.get_bind().engine
for table in ('oauth_clients', for table in ('oauth_clients', 'oauth_codes', 'oauth_access_tokens', 'spawners'):
'oauth_codes',
'oauth_access_tokens',
'spawners'):
if engine.has_table(table): if engine.has_table(table):
op.drop_table(table) op.drop_table(table)
op.drop_column('users', 'encrypted_auth_state') op.drop_column('users', 'encrypted_auth_state')
op.add_column('users', sa.Column('auth_state', JSONDict)) op.add_column('users', sa.Column('auth_state', JSONDict))
op.add_column('users', sa.Column('_server_id', sa.Integer, sa.ForeignKey('servers.id'))) op.add_column(
'users', sa.Column('_server_id', sa.Integer, sa.ForeignKey('servers.id'))
)

View File

@@ -5,7 +5,6 @@ Revises: 1cebaf56856c
Create Date: 2017-12-19 15:21:09.300513 Create Date: 2017-12-19 15:21:09.300513
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '56cc5a70207e' revision = '56cc5a70207e'
down_revision = '1cebaf56856c' down_revision = '1cebaf56856c'
@@ -16,22 +15,48 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
import logging import logging
logger = logging.getLogger('alembic') logger = logging.getLogger('alembic')
def upgrade(): def upgrade():
tables = op.get_bind().engine.table_names() tables = op.get_bind().engine.table_names()
op.add_column('api_tokens', sa.Column('created', sa.DateTime(), nullable=True)) op.add_column('api_tokens', sa.Column('created', sa.DateTime(), nullable=True))
op.add_column('api_tokens', sa.Column('last_activity', sa.DateTime(), nullable=True)) op.add_column(
op.add_column('api_tokens', sa.Column('note', sa.Unicode(length=1023), nullable=True)) 'api_tokens', sa.Column('last_activity', sa.DateTime(), nullable=True)
)
op.add_column(
'api_tokens', sa.Column('note', sa.Unicode(length=1023), nullable=True)
)
if 'oauth_access_tokens' in tables: if 'oauth_access_tokens' in tables:
op.add_column('oauth_access_tokens', sa.Column('created', sa.DateTime(), nullable=True)) op.add_column(
op.add_column('oauth_access_tokens', sa.Column('last_activity', sa.DateTime(), nullable=True)) 'oauth_access_tokens', sa.Column('created', sa.DateTime(), nullable=True)
)
op.add_column(
'oauth_access_tokens',
sa.Column('last_activity', sa.DateTime(), nullable=True),
)
if op.get_context().dialect.name == 'sqlite': if op.get_context().dialect.name == 'sqlite':
logger.warning("sqlite cannot use ALTER TABLE to create foreign keys. Upgrade will be incomplete.") logger.warning(
"sqlite cannot use ALTER TABLE to create foreign keys. Upgrade will be incomplete."
)
else: else:
op.create_foreign_key(None, 'oauth_access_tokens', 'oauth_clients', ['client_id'], ['identifier'], ondelete='CASCADE') op.create_foreign_key(
op.create_foreign_key(None, 'oauth_codes', 'oauth_clients', ['client_id'], ['identifier'], ondelete='CASCADE') None,
'oauth_access_tokens',
'oauth_clients',
['client_id'],
['identifier'],
ondelete='CASCADE',
)
op.create_foreign_key(
None,
'oauth_codes',
'oauth_clients',
['client_id'],
['identifier'],
ondelete='CASCADE',
)
def downgrade(): def downgrade():

View File

@@ -5,7 +5,6 @@ Revises: d68c98b66cd4
Create Date: 2018-05-07 11:35:58.050542 Create Date: 2018-05-07 11:35:58.050542
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '896818069c98' revision = '896818069c98'
down_revision = 'd68c98b66cd4' down_revision = 'd68c98b66cd4'

View File

@@ -5,7 +5,6 @@ Revises: 56cc5a70207e
Create Date: 2018-03-21 14:27:17.466841 Create Date: 2018-03-21 14:27:17.466841
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '99a28a4418e1' revision = '99a28a4418e1'
down_revision = '56cc5a70207e' down_revision = '56cc5a70207e'
@@ -18,15 +17,18 @@ import sqlalchemy as sa
from datetime import datetime from datetime import datetime
def upgrade(): def upgrade():
op.add_column('users', sa.Column('created', sa.DateTime, nullable=True)) op.add_column('users', sa.Column('created', sa.DateTime, nullable=True))
c = op.get_bind() c = op.get_bind()
# fill created date with current time # fill created date with current time
now = datetime.utcnow() now = datetime.utcnow()
c.execute(""" c.execute(
"""
UPDATE users UPDATE users
SET created='%s' SET created='%s'
""" % (now,) """
% (now,)
) )
tables = c.engine.table_names() tables = c.engine.table_names()
@@ -34,11 +36,13 @@ def upgrade():
if 'spawners' in tables: if 'spawners' in tables:
op.add_column('spawners', sa.Column('started', sa.DateTime, nullable=True)) op.add_column('spawners', sa.Column('started', sa.DateTime, nullable=True))
# fill started value with now for running servers # fill started value with now for running servers
c.execute(""" c.execute(
"""
UPDATE spawners UPDATE spawners
SET started='%s' SET started='%s'
WHERE server_id IS NOT NULL WHERE server_id IS NOT NULL
""" % (now,) """
% (now,)
) )

View File

@@ -5,7 +5,6 @@ Revises: eeb276e51423
Create Date: 2016-07-28 16:16:38.245348 Create Date: 2016-07-28 16:16:38.245348
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'af4cbdb2d13c' revision = 'af4cbdb2d13c'
down_revision = 'eeb276e51423' down_revision = 'eeb276e51423'

View File

@@ -5,7 +5,6 @@ Revises: 99a28a4418e1
Create Date: 2018-04-13 10:50:17.968636 Create Date: 2018-04-13 10:50:17.968636
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'd68c98b66cd4' revision = 'd68c98b66cd4'
down_revision = '99a28a4418e1' down_revision = '99a28a4418e1'
@@ -20,8 +19,7 @@ def upgrade():
tables = op.get_bind().engine.table_names() tables = op.get_bind().engine.table_names()
if 'oauth_clients' in tables: if 'oauth_clients' in tables:
op.add_column( op.add_column(
'oauth_clients', 'oauth_clients', sa.Column('description', sa.Unicode(length=1023))
sa.Column('description', sa.Unicode(length=1023))
) )

View File

@@ -6,7 +6,6 @@ Revision ID: eeb276e51423
Revises: 19c0846f6344 Revises: 19c0846f6344
Create Date: 2016-04-11 16:06:49.239831 Create Date: 2016-04-11 16:06:49.239831
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'eeb276e51423' revision = 'eeb276e51423'
down_revision = '19c0846f6344' down_revision = '19c0846f6344'
@@ -17,6 +16,7 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from jupyterhub.orm import JSONDict from jupyterhub.orm import JSONDict
def upgrade(): def upgrade():
op.add_column('users', sa.Column('auth_state', JSONDict)) op.add_column('users', sa.Column('auth_state', JSONDict))

View File

@@ -1,5 +1,10 @@
from . import auth
from . import groups
from . import hub
from . import proxy
from . import services
from . import users
from .base import * from .base import *
from . import auth, hub, proxy, users, groups, services
default_handlers = [] default_handlers = []
for mod in (auth, hub, proxy, users, groups, services): for mod in (auth, hub, proxy, users, groups, services):

View File

@@ -1,25 +1,23 @@
"""Authorization handlers""" """Authorization handlers"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from datetime import datetime
import json import json
from urllib.parse import ( from datetime import datetime
parse_qsl, from urllib.parse import parse_qsl
quote, from urllib.parse import quote
urlencode, from urllib.parse import urlencode
urlparse, from urllib.parse import urlparse
urlunparse, from urllib.parse import urlunparse
)
from oauthlib import oauth2 from oauthlib import oauth2
from tornado import web from tornado import web
from .. import orm from .. import orm
from ..user import User from ..user import User
from ..utils import token_authenticated, compare_token from ..utils import compare_token
from .base import BaseHandler, APIHandler from ..utils import token_authenticated
from .base import APIHandler
from .base import BaseHandler
class TokenAPIHandler(APIHandler): class TokenAPIHandler(APIHandler):
@@ -70,7 +68,9 @@ class TokenAPIHandler(APIHandler):
if data and data.get('username'): if data and data.get('username'):
user = self.find_user(data['username']) user = self.find_user(data['username'])
if user is not requester and not requester.admin: if user is not requester and not requester.admin:
raise web.HTTPError(403, "Only admins can request tokens for other users.") raise web.HTTPError(
403, "Only admins can request tokens for other users."
)
if requester.admin and user is None: if requester.admin and user is None:
raise web.HTTPError(400, "No such user '%s'" % data['username']) raise web.HTTPError(400, "No such user '%s'" % data['username'])
@@ -82,11 +82,11 @@ class TokenAPIHandler(APIHandler):
note += " by %s %s" % (kind, requester.name) note += " by %s %s" % (kind, requester.name)
api_token = user.new_api_token(note=note) api_token = user.new_api_token(note=note)
self.write(json.dumps({ self.write(
'token': api_token, json.dumps(
'warning': warn_msg, {'token': api_token, 'warning': warn_msg, 'user': self.user_model(user)}
'user': self.user_model(user), )
})) )
class CookieAPIHandler(APIHandler): class CookieAPIHandler(APIHandler):
@@ -94,7 +94,9 @@ class CookieAPIHandler(APIHandler):
def get(self, cookie_name, cookie_value=None): def get(self, cookie_name, cookie_value=None):
cookie_name = quote(cookie_name, safe='') cookie_name = quote(cookie_name, safe='')
if cookie_value is None: if cookie_value is None:
self.log.warning("Cookie values in request body is deprecated, use `/cookie_name/cookie_value`") self.log.warning(
"Cookie values in request body is deprecated, use `/cookie_name/cookie_value`"
)
cookie_value = self.request.body cookie_value = self.request.body
else: else:
cookie_value = cookie_value.encode('utf8') cookie_value = cookie_value.encode('utf8')
@@ -134,7 +136,9 @@ class OAuthHandler:
return uri return uri
# make absolute local redirects full URLs # make absolute local redirects full URLs
# to satisfy oauthlib's absolute URI requirement # to satisfy oauthlib's absolute URI requirement
redirect_uri = self.request.protocol + "://" + self.request.headers['Host'] + redirect_uri redirect_uri = (
self.request.protocol + "://" + self.request.headers['Host'] + redirect_uri
)
parsed_url = urlparse(uri) parsed_url = urlparse(uri)
query_list = parse_qsl(parsed_url.query, keep_blank_values=True) query_list = parse_qsl(parsed_url.query, keep_blank_values=True)
for idx, item in enumerate(query_list): for idx, item in enumerate(query_list):
@@ -142,10 +146,7 @@ class OAuthHandler:
query_list[idx] = ('redirect_uri', redirect_uri) query_list[idx] = ('redirect_uri', redirect_uri)
break break
return urlunparse( return urlunparse(urlparse(uri)._replace(query=urlencode(query_list)))
urlparse(uri)
._replace(query=urlencode(query_list))
)
def add_credentials(self, credentials=None): def add_credentials(self, credentials=None):
"""Add oauth credentials """Add oauth credentials
@@ -164,11 +165,7 @@ class OAuthHandler:
user = self.current_user user = self.current_user
# Extra credentials we need in the validator # Extra credentials we need in the validator
credentials.update({ credentials.update({'user': user, 'handler': self, 'session_id': session_id})
'user': user,
'handler': self,
'session_id': session_id,
})
return credentials return credentials
def send_oauth_response(self, headers, body, status): def send_oauth_response(self, headers, body, status):
@@ -193,7 +190,8 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
def _complete_login(self, uri, headers, scopes, credentials): def _complete_login(self, uri, headers, scopes, credentials):
try: try:
headers, body, status = self.oauth_provider.create_authorization_response( headers, body, status = self.oauth_provider.create_authorization_response(
uri, 'POST', '', headers, scopes, credentials) uri, 'POST', '', headers, scopes, credentials
)
except oauth2.FatalClientError as e: except oauth2.FatalClientError as e:
# TODO: human error page # TODO: human error page
@@ -213,13 +211,15 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
uri, http_method, body, headers = self.extract_oauth_params() uri, http_method, body, headers = self.extract_oauth_params()
try: try:
scopes, credentials = self.oauth_provider.validate_authorization_request( scopes, credentials = self.oauth_provider.validate_authorization_request(
uri, http_method, body, headers) uri, http_method, body, headers
)
credentials = self.add_credentials(credentials) credentials = self.add_credentials(credentials)
client = self.oauth_provider.fetch_by_client_id(credentials['client_id']) client = self.oauth_provider.fetch_by_client_id(credentials['client_id'])
if client.redirect_uri.startswith(self.current_user.url): if client.redirect_uri.startswith(self.current_user.url):
self.log.debug( self.log.debug(
"Skipping oauth confirmation for %s accessing %s", "Skipping oauth confirmation for %s accessing %s",
self.current_user, client.description, self.current_user,
client.description,
) )
# access to my own server doesn't require oauth confirmation # access to my own server doesn't require oauth confirmation
# this is the pre-1.0 behavior for all oauth # this is the pre-1.0 behavior for all oauth
@@ -228,11 +228,7 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
# Render oauth 'Authorize application...' page # Render oauth 'Authorize application...' page
self.write( self.write(
self.render_template( self.render_template("oauth.html", scopes=scopes, oauth_client=client)
"oauth.html",
scopes=scopes,
oauth_client=client,
)
) )
# Errors that should be shown to the user on the provider website # Errors that should be shown to the user on the provider website
@@ -252,7 +248,9 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
if referer != full_url: if referer != full_url:
# OAuth post must be made to the URL it came from # OAuth post must be made to the URL it came from
self.log.error("OAuth POST from %s != %s", referer, full_url) self.log.error("OAuth POST from %s != %s", referer, full_url)
raise web.HTTPError(403, "Authorization form must be sent from authorization page") raise web.HTTPError(
403, "Authorization form must be sent from authorization page"
)
# The scopes the user actually authorized, i.e. checkboxes # The scopes the user actually authorized, i.e. checkboxes
# that were selected. # that were selected.
@@ -262,7 +260,7 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
try: try:
headers, body, status = self.oauth_provider.create_authorization_response( headers, body, status = self.oauth_provider.create_authorization_response(
uri, http_method, body, headers, scopes, credentials, uri, http_method, body, headers, scopes, credentials
) )
except oauth2.FatalClientError as e: except oauth2.FatalClientError as e:
raise web.HTTPError(e.status_code, e.description) raise web.HTTPError(e.status_code, e.description)
@@ -277,7 +275,8 @@ class OAuthTokenHandler(OAuthHandler, APIHandler):
try: try:
headers, body, status = self.oauth_provider.create_token_response( headers, body, status = self.oauth_provider.create_token_response(
uri, http_method, body, headers, credentials) uri, http_method, body, headers, credentials
)
except oauth2.FatalClientError as e: except oauth2.FatalClientError as e:
raise web.HTTPError(e.status_code, e.description) raise web.HTTPError(e.status_code, e.description)
else: else:

View File

@@ -1,10 +1,8 @@
"""Base API handlers""" """Base API handlers"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from datetime import datetime
import json import json
from datetime import datetime
from http.client import responses from http.client import responses
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@@ -12,7 +10,8 @@ from tornado import web
from .. import orm from .. import orm
from ..handlers import BaseHandler from ..handlers import BaseHandler
from ..utils import isoformat, url_path_join from ..utils import isoformat
from ..utils import url_path_join
class APIHandler(BaseHandler): class APIHandler(BaseHandler):
@@ -55,8 +54,11 @@ class APIHandler(BaseHandler):
host_path = url_path_join(host, self.hub.base_url) host_path = url_path_join(host, self.hub.base_url)
referer_path = referer.split('://', 1)[-1] referer_path = referer.split('://', 1)[-1]
if not (referer_path + '/').startswith(host_path): if not (referer_path + '/').startswith(host_path):
self.log.warning("Blocking Cross Origin API request. Referer: %s, Host: %s", self.log.warning(
referer, host_path) "Blocking Cross Origin API request. Referer: %s, Host: %s",
referer,
host_path,
)
return False return False
return True return True
@@ -105,9 +107,13 @@ class APIHandler(BaseHandler):
if exception and isinstance(exception, SQLAlchemyError): if exception and isinstance(exception, SQLAlchemyError):
try: try:
exception_str = str(exception) exception_str = str(exception)
self.log.warning("Rolling back session due to database error %s", exception_str) self.log.warning(
"Rolling back session due to database error %s", exception_str
)
except Exception: except Exception:
self.log.warning("Rolling back session due to database error %s", type(exception)) self.log.warning(
"Rolling back session due to database error %s", type(exception)
)
self.db.rollback() self.db.rollback()
self.set_header('Content-Type', 'application/json') self.set_header('Content-Type', 'application/json')
@@ -121,10 +127,9 @@ class APIHandler(BaseHandler):
# Content-Length must be recalculated. # Content-Length must be recalculated.
self.clear_header('Content-Length') self.clear_header('Content-Length')
self.write(json.dumps({ self.write(
'status': status_code, json.dumps({'status': status_code, 'message': message or status_message})
'message': message or status_message, )
}))
def server_model(self, spawner, include_state=False): def server_model(self, spawner, include_state=False):
"""Get the JSON model for a Spawner""" """Get the JSON model for a Spawner"""
@@ -144,21 +149,17 @@ class APIHandler(BaseHandler):
expires_at = None expires_at = None
if isinstance(token, orm.APIToken): if isinstance(token, orm.APIToken):
kind = 'api_token' kind = 'api_token'
extra = { extra = {'note': token.note}
'note': token.note,
}
expires_at = token.expires_at expires_at = token.expires_at
elif isinstance(token, orm.OAuthAccessToken): elif isinstance(token, orm.OAuthAccessToken):
kind = 'oauth' kind = 'oauth'
extra = { extra = {'oauth_client': token.client.description or token.client.client_id}
'oauth_client': token.client.description or token.client.client_id,
}
if token.expires_at: if token.expires_at:
expires_at = datetime.fromtimestamp(token.expires_at) expires_at = datetime.fromtimestamp(token.expires_at)
else: else:
raise TypeError( raise TypeError(
"token must be an APIToken or OAuthAccessToken, not %s" "token must be an APIToken or OAuthAccessToken, not %s" % type(token)
% type(token)) )
if token.user: if token.user:
owner_key = 'user' owner_key = 'user'
@@ -219,23 +220,11 @@ class APIHandler(BaseHandler):
def service_model(self, service): def service_model(self, service):
"""Get the JSON model for a Service object""" """Get the JSON model for a Service object"""
return { return {'kind': 'service', 'name': service.name, 'admin': service.admin}
'kind': 'service',
'name': service.name,
'admin': service.admin,
}
_user_model_types = { _user_model_types = {'name': str, 'admin': bool, 'groups': list, 'auth_state': dict}
'name': str,
'admin': bool,
'groups': list,
'auth_state': dict,
}
_group_model_types = { _group_model_types = {'name': str, 'users': list}
'name': str,
'users': list,
}
def _check_model(self, model, model_types, name): def _check_model(self, model, model_types, name):
"""Check a model provided by a REST API request """Check a model provided by a REST API request
@@ -251,24 +240,29 @@ class APIHandler(BaseHandler):
raise web.HTTPError(400, "Invalid JSON keys: %r" % model) raise web.HTTPError(400, "Invalid JSON keys: %r" % model)
for key, value in model.items(): for key, value in model.items():
if not isinstance(value, model_types[key]): if not isinstance(value, model_types[key]):
raise web.HTTPError(400, "%s.%s must be %s, not: %r" % ( raise web.HTTPError(
name, key, model_types[key], type(value) 400,
)) "%s.%s must be %s, not: %r"
% (name, key, model_types[key], type(value)),
)
def _check_user_model(self, model): def _check_user_model(self, model):
"""Check a request-provided user model from a REST API""" """Check a request-provided user model from a REST API"""
self._check_model(model, self._user_model_types, 'user') self._check_model(model, self._user_model_types, 'user')
for username in model.get('users', []): for username in model.get('users', []):
if not isinstance(username, str): if not isinstance(username, str):
raise web.HTTPError(400, ("usernames must be str, not %r", type(username))) raise web.HTTPError(
400, ("usernames must be str, not %r", type(username))
)
def _check_group_model(self, model): def _check_group_model(self, model):
"""Check a request-provided group model from a REST API""" """Check a request-provided group model from a REST API"""
self._check_model(model, self._group_model_types, 'group') self._check_model(model, self._group_model_types, 'group')
for groupname in model.get('groups', []): for groupname in model.get('groups', []):
if not isinstance(groupname, str): if not isinstance(groupname, str):
raise web.HTTPError(400, ("group names must be str, not %r", type(groupname))) raise web.HTTPError(
400, ("group names must be str, not %r", type(groupname))
)
def options(self, *args, **kwargs): def options(self, *args, **kwargs):
self.finish() self.finish()
@@ -279,6 +273,7 @@ class API404(APIHandler):
Ensures JSON 404 errors for malformed URLs Ensures JSON 404 errors for malformed URLs
""" """
async def prepare(self): async def prepare(self):
await super().prepare() await super().prepare()
raise web.HTTPError(404) raise web.HTTPError(404)

View File

@@ -1,11 +1,10 @@
"""Group handlers""" """Group handlers"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import json import json
from tornado import gen, web from tornado import gen
from tornado import web
from .. import orm from .. import orm
from ..utils import admin_only from ..utils import admin_only
@@ -34,6 +33,7 @@ class _GroupAPIHandler(APIHandler):
raise web.HTTPError(404, "No such group: %s", name) raise web.HTTPError(404, "No such group: %s", name)
return group return group
class GroupListAPIHandler(_GroupAPIHandler): class GroupListAPIHandler(_GroupAPIHandler):
@admin_only @admin_only
def get(self): def get(self):
@@ -61,9 +61,7 @@ class GroupListAPIHandler(_GroupAPIHandler):
# check that users exist # check that users exist
users = self._usernames_to_users(usernames) users = self._usernames_to_users(usernames)
# create the group # create the group
self.log.info("Creating new group %s with %i users", self.log.info("Creating new group %s with %i users", name, len(users))
name, len(users),
)
self.log.debug("Users: %s", usernames) self.log.debug("Users: %s", usernames)
group = orm.Group(name=name, users=users) group = orm.Group(name=name, users=users)
self.db.add(group) self.db.add(group)
@@ -99,9 +97,7 @@ class GroupAPIHandler(_GroupAPIHandler):
users = self._usernames_to_users(usernames) users = self._usernames_to_users(usernames)
# create the group # create the group
self.log.info("Creating new group %s with %i users", self.log.info("Creating new group %s with %i users", name, len(users))
name, len(users),
)
self.log.debug("Users: %s", usernames) self.log.debug("Users: %s", usernames)
group = orm.Group(name=name, users=users) group = orm.Group(name=name, users=users)
self.db.add(group) self.db.add(group)
@@ -121,6 +117,7 @@ class GroupAPIHandler(_GroupAPIHandler):
class GroupUsersAPIHandler(_GroupAPIHandler): class GroupUsersAPIHandler(_GroupAPIHandler):
"""Modify a group's user list""" """Modify a group's user list"""
@admin_only @admin_only
def post(self, name): def post(self, name):
"""POST adds users to a group""" """POST adds users to a group"""

View File

@@ -1,21 +1,18 @@
"""API handlers for administering the Hub itself""" """API handlers for administering the Hub itself"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import json import json
import sys import sys
from tornado import web from tornado import web
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from .._version import __version__
from ..utils import admin_only from ..utils import admin_only
from .base import APIHandler from .base import APIHandler
from .._version import __version__
class ShutdownAPIHandler(APIHandler): class ShutdownAPIHandler(APIHandler):
@admin_only @admin_only
def post(self): def post(self):
"""POST /api/shutdown triggers a clean shutdown """POST /api/shutdown triggers a clean shutdown
@@ -26,6 +23,7 @@ class ShutdownAPIHandler(APIHandler):
- proxy: specify whether the proxy should be terminated - proxy: specify whether the proxy should be terminated
""" """
from ..app import JupyterHub from ..app import JupyterHub
app = JupyterHub.instance() app = JupyterHub.instance()
data = self.get_json_body() data = self.get_json_body()
@@ -33,19 +31,21 @@ class ShutdownAPIHandler(APIHandler):
if 'proxy' in data: if 'proxy' in data:
proxy = data['proxy'] proxy = data['proxy']
if proxy not in {True, False}: if proxy not in {True, False}:
raise web.HTTPError(400, "proxy must be true or false, got %r" % proxy) raise web.HTTPError(
400, "proxy must be true or false, got %r" % proxy
)
app.cleanup_proxy = proxy app.cleanup_proxy = proxy
if 'servers' in data: if 'servers' in data:
servers = data['servers'] servers = data['servers']
if servers not in {True, False}: if servers not in {True, False}:
raise web.HTTPError(400, "servers must be true or false, got %r" % servers) raise web.HTTPError(
400, "servers must be true or false, got %r" % servers
)
app.cleanup_servers = servers app.cleanup_servers = servers
# finish the request # finish the request
self.set_status(202) self.set_status(202)
self.finish(json.dumps({ self.finish(json.dumps({"message": "Shutting down Hub"}))
"message": "Shutting down Hub"
}))
# stop the eventloop, which will trigger cleanup # stop the eventloop, which will trigger cleanup
loop = IOLoop.current() loop = IOLoop.current()
@@ -53,7 +53,6 @@ class ShutdownAPIHandler(APIHandler):
class RootAPIHandler(APIHandler): class RootAPIHandler(APIHandler):
def get(self): def get(self):
"""GET /api/ returns info about the Hub and its API. """GET /api/ returns info about the Hub and its API.
@@ -61,14 +60,11 @@ class RootAPIHandler(APIHandler):
For now, it just returns the version of JupyterHub itself. For now, it just returns the version of JupyterHub itself.
""" """
data = { data = {'version': __version__}
'version': __version__,
}
self.finish(json.dumps(data)) self.finish(json.dumps(data))
class InfoAPIHandler(APIHandler): class InfoAPIHandler(APIHandler):
@admin_only @admin_only
def get(self): def get(self):
"""GET /api/info returns detailed info about the Hub and its API. """GET /api/info returns detailed info about the Hub and its API.
@@ -77,10 +73,11 @@ class InfoAPIHandler(APIHandler):
For now, it just returns the version of JupyterHub itself. For now, it just returns the version of JupyterHub itself.
""" """
def _class_info(typ): def _class_info(typ):
"""info about a class (Spawner or Authenticator)""" """info about a class (Spawner or Authenticator)"""
info = { info = {
'class': '{mod}.{name}'.format(mod=typ.__module__, name=typ.__name__), 'class': '{mod}.{name}'.format(mod=typ.__module__, name=typ.__name__)
} }
pkg = typ.__module__.split('.')[0] pkg = typ.__module__.split('.')[0]
try: try:

View File

@@ -1,12 +1,11 @@
"""Proxy handlers""" """Proxy handlers"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import json import json
from urllib.parse import urlparse from urllib.parse import urlparse
from tornado import gen, web from tornado import gen
from tornado import web
from .. import orm from .. import orm
from ..utils import admin_only from ..utils import admin_only
@@ -14,7 +13,6 @@ from .base import APIHandler
class ProxyAPIHandler(APIHandler): class ProxyAPIHandler(APIHandler):
@admin_only @admin_only
async def get(self): async def get(self):
"""GET /api/proxy fetches the routing table """GET /api/proxy fetches the routing table
@@ -58,6 +56,4 @@ class ProxyAPIHandler(APIHandler):
await self.proxy.check_routes(self.users, self.services) await self.proxy.check_routes(self.users, self.services)
default_handlers = [ default_handlers = [(r"/api/proxy", ProxyAPIHandler)]
(r"/api/proxy", ProxyAPIHandler),
]

View File

@@ -2,10 +2,8 @@
Currently GET-only, no actions can be taken to modify services. Currently GET-only, no actions can be taken to modify services.
""" """
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import json import json
from tornado import web from tornado import web
@@ -14,6 +12,7 @@ from .. import orm
from ..utils import admin_only from ..utils import admin_only
from .base import APIHandler from .base import APIHandler
def service_model(service): def service_model(service):
"""Produce the model for a service""" """Produce the model for a service"""
return { return {
@@ -23,9 +22,10 @@ def service_model(service):
'prefix': service.server.base_url if service.server else '', 'prefix': service.server.base_url if service.server else '',
'command': service.command, 'command': service.command,
'pid': service.proc.pid if service.proc else 0, 'pid': service.proc.pid if service.proc else 0,
'info': service.info 'info': service.info,
} }
class ServiceListAPIHandler(APIHandler): class ServiceListAPIHandler(APIHandler):
@admin_only @admin_only
def get(self): def get(self):
@@ -35,6 +35,7 @@ class ServiceListAPIHandler(APIHandler):
def admin_or_self(method): def admin_or_self(method):
"""Decorator for restricting access to either the target service or admin""" """Decorator for restricting access to either the target service or admin"""
def decorated_method(self, name): def decorated_method(self, name):
current = self.current_user current = self.current_user
if current is None: if current is None:
@@ -49,10 +50,11 @@ def admin_or_self(method):
if name not in self.services: if name not in self.services:
raise web.HTTPError(404) raise web.HTTPError(404)
return method(self, name) return method(self, name)
return decorated_method return decorated_method
class ServiceAPIHandler(APIHandler):
class ServiceAPIHandler(APIHandler):
@admin_or_self @admin_or_self
def get(self, name): def get(self, name):
service = self.services[name] service = self.services[name]

View File

@@ -1,11 +1,11 @@
"""User handlers""" """User handlers"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import asyncio import asyncio
from datetime import datetime, timedelta, timezone
import json import json
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from async_generator import aclosing from async_generator import aclosing
from dateutil.parser import parse as parse_date from dateutil.parser import parse as parse_date
@@ -14,7 +14,11 @@ from tornado.iostream import StreamClosedError
from .. import orm from .. import orm
from ..user import User from ..user import User
from ..utils import admin_only, isoformat, iterate_until, maybe_future, url_path_join from ..utils import admin_only
from ..utils import isoformat
from ..utils import iterate_until
from ..utils import maybe_future
from ..utils import url_path_join
from .base import APIHandler from .base import APIHandler
@@ -89,7 +93,9 @@ class UserListAPIHandler(APIHandler):
except Exception as e: except Exception as e:
self.log.error("Failed to create user: %s" % name, exc_info=True) self.log.error("Failed to create user: %s" % name, exc_info=True)
self.users.delete(user) self.users.delete(user)
raise web.HTTPError(400, "Failed to create user %s: %s" % (name, str(e))) raise web.HTTPError(
400, "Failed to create user %s: %s" % (name, str(e))
)
else: else:
created.append(user) created.append(user)
@@ -99,6 +105,7 @@ class UserListAPIHandler(APIHandler):
def admin_or_self(method): def admin_or_self(method):
"""Decorator for restricting access to either the target user or admin""" """Decorator for restricting access to either the target user or admin"""
def m(self, name, *args, **kwargs): def m(self, name, *args, **kwargs):
current = self.current_user current = self.current_user
if current is None: if current is None:
@@ -110,15 +117,17 @@ def admin_or_self(method):
if not self.find_user(name): if not self.find_user(name):
raise web.HTTPError(404) raise web.HTTPError(404)
return method(self, name, *args, **kwargs) return method(self, name, *args, **kwargs)
return m return m
class UserAPIHandler(APIHandler): class UserAPIHandler(APIHandler):
@admin_or_self @admin_or_self
async def get(self, name): async def get(self, name):
user = self.find_user(name) user = self.find_user(name)
model = self.user_model(user, include_servers=True, include_state=self.current_user.admin) model = self.user_model(
user, include_servers=True, include_state=self.current_user.admin
)
# auth state will only be shown if the requester is an admin # auth state will only be shown if the requester is an admin
# this means users can't see their own auth state unless they # this means users can't see their own auth state unless they
# are admins, Hub admins often are also marked as admins so they # are admins, Hub admins often are also marked as admins so they
@@ -161,11 +170,16 @@ class UserAPIHandler(APIHandler):
if user.name == self.current_user.name: if user.name == self.current_user.name:
raise web.HTTPError(400, "Cannot delete yourself!") raise web.HTTPError(400, "Cannot delete yourself!")
if user.spawner._stop_pending: if user.spawner._stop_pending:
raise web.HTTPError(400, "%s's server is in the process of stopping, please wait." % name) raise web.HTTPError(
400, "%s's server is in the process of stopping, please wait." % name
)
if user.running: if user.running:
await self.stop_single_user(user) await self.stop_single_user(user)
if user.spawner._stop_pending: if user.spawner._stop_pending:
raise web.HTTPError(400, "%s's server is in the process of stopping, please wait." % name) raise web.HTTPError(
400,
"%s's server is in the process of stopping, please wait." % name,
)
await maybe_future(self.authenticator.delete_user(user)) await maybe_future(self.authenticator.delete_user(user))
# remove from registry # remove from registry
@@ -183,7 +197,10 @@ class UserAPIHandler(APIHandler):
if 'name' in data and data['name'] != name: if 'name' in data and data['name'] != name:
# check if the new name is already taken inside db # check if the new name is already taken inside db
if self.find_user(data['name']): if self.find_user(data['name']):
raise web.HTTPError(400, "User %s already exists, username must be unique" % data['name']) raise web.HTTPError(
400,
"User %s already exists, username must be unique" % data['name'],
)
for key, value in data.items(): for key, value in data.items():
if key == 'auth_state': if key == 'auth_state':
await user.save_auth_state(value) await user.save_auth_state(value)
@@ -197,6 +214,7 @@ class UserAPIHandler(APIHandler):
class UserTokenListAPIHandler(APIHandler): class UserTokenListAPIHandler(APIHandler):
"""API endpoint for listing/creating tokens""" """API endpoint for listing/creating tokens"""
@admin_or_self @admin_or_self
def get(self, name): def get(self, name):
"""Get tokens for a given user""" """Get tokens for a given user"""
@@ -207,6 +225,7 @@ class UserTokenListAPIHandler(APIHandler):
now = datetime.utcnow() now = datetime.utcnow()
api_tokens = [] api_tokens = []
def sort_key(token): def sort_key(token):
return token.last_activity or token.created return token.last_activity or token.created
@@ -228,10 +247,7 @@ class UserTokenListAPIHandler(APIHandler):
self.db.commit() self.db.commit()
continue continue
oauth_tokens.append(self.token_model(token)) oauth_tokens.append(self.token_model(token))
self.write(json.dumps({ self.write(json.dumps({'api_tokens': api_tokens, 'oauth_tokens': oauth_tokens}))
'api_tokens': api_tokens,
'oauth_tokens': oauth_tokens,
}))
async def post(self, name): async def post(self, name):
body = self.get_json_body() or {} body = self.get_json_body() or {}
@@ -253,8 +269,9 @@ class UserTokenListAPIHandler(APIHandler):
except Exception as e: except Exception as e:
# suppress and log error here in case Authenticator # suppress and log error here in case Authenticator
# isn't prepared to handle auth via this data # isn't prepared to handle auth via this data
self.log.error("Error authenticating request for %s: %s", self.log.error(
self.request.uri, e) "Error authenticating request for %s: %s", self.request.uri, e
)
raise web.HTTPError(403) raise web.HTTPError(403)
requester = self.find_user(name) requester = self.find_user(name)
if requester is None: if requester is None:
@@ -274,9 +291,16 @@ class UserTokenListAPIHandler(APIHandler):
if requester is not user: if requester is not user:
note += " by %s %s" % (kind, requester.name) note += " by %s %s" % (kind, requester.name)
api_token = user.new_api_token(note=note, expires_in=body.get('expires_in', None)) api_token = user.new_api_token(
note=note, expires_in=body.get('expires_in', None)
)
if requester is not user: if requester is not user:
self.log.info("%s %s requested API token for %s", kind.title(), requester.name, user.name) self.log.info(
"%s %s requested API token for %s",
kind.title(),
requester.name,
user.name,
)
else: else:
user_kind = 'user' if isinstance(user, User) else 'service' user_kind = 'user' if isinstance(user, User) else 'service'
self.log.info("%s %s requested new API token", user_kind.title(), user.name) self.log.info("%s %s requested new API token", user_kind.title(), user.name)
@@ -333,8 +357,7 @@ class UserTokenAPIHandler(APIHandler):
if isinstance(token, orm.OAuthAccessToken): if isinstance(token, orm.OAuthAccessToken):
client_id = token.client_id client_id = token.client_id
tokens = [ tokens = [
token for token in user.oauth_tokens token for token in user.oauth_tokens if token.client_id == client_id
if token.client_id == client_id
] ]
else: else:
tokens = [token] tokens = [token]
@@ -354,16 +377,19 @@ class UserServerAPIHandler(APIHandler):
if server_name: if server_name:
if not self.allow_named_servers: if not self.allow_named_servers:
raise web.HTTPError(400, "Named servers are not enabled.") raise web.HTTPError(400, "Named servers are not enabled.")
if self.named_server_limit_per_user > 0 and server_name not in user.orm_spawners: if (
self.named_server_limit_per_user > 0
and server_name not in user.orm_spawners
):
named_spawners = list(user.all_spawners(include_default=False)) named_spawners = list(user.all_spawners(include_default=False))
if self.named_server_limit_per_user <= len(named_spawners): if self.named_server_limit_per_user <= len(named_spawners):
raise web.HTTPError( raise web.HTTPError(
400, 400,
"User {} already has the maximum of {} named servers." "User {} already has the maximum of {} named servers."
" One must be deleted before a new server can be created".format( " One must be deleted before a new server can be created".format(
name, name, self.named_server_limit_per_user
self.named_server_limit_per_user ),
)) )
spawner = user.spawners[server_name] spawner = user.spawners[server_name]
pending = spawner.pending pending = spawner.pending
if pending == 'spawn': if pending == 'spawn':
@@ -396,7 +422,6 @@ class UserServerAPIHandler(APIHandler):
options = self.get_json_body() options = self.get_json_body()
remove = (options or {}).get('remove', False) remove = (options or {}).get('remove', False)
def _remove_spawner(f=None): def _remove_spawner(f=None):
if f and f.exception(): if f and f.exception():
return return
@@ -408,7 +433,9 @@ class UserServerAPIHandler(APIHandler):
if not self.allow_named_servers: if not self.allow_named_servers:
raise web.HTTPError(400, "Named servers are not enabled.") raise web.HTTPError(400, "Named servers are not enabled.")
if server_name not in user.orm_spawners: if server_name not in user.orm_spawners:
raise web.HTTPError(404, "%s has no server named '%s'" % (name, server_name)) raise web.HTTPError(
404, "%s has no server named '%s'" % (name, server_name)
)
elif remove: elif remove:
raise web.HTTPError(400, "Cannot delete the default server") raise web.HTTPError(400, "Cannot delete the default server")
@@ -423,7 +450,8 @@ class UserServerAPIHandler(APIHandler):
if spawner.pending: if spawner.pending:
raise web.HTTPError( raise web.HTTPError(
400, "%s is pending %s, please wait" % (spawner._log_name, spawner.pending) 400,
"%s is pending %s, please wait" % (spawner._log_name, spawner.pending),
) )
stop_future = None stop_future = None
@@ -449,13 +477,16 @@ class UserAdminAccessAPIHandler(APIHandler):
This handler sets the necessary cookie for an admin to login to a single-user server. This handler sets the necessary cookie for an admin to login to a single-user server.
""" """
@admin_only @admin_only
def post(self, name): def post(self, name):
self.log.warning("Deprecated in JupyterHub 0.8." self.log.warning(
" Admin access API is not needed now that we use OAuth.") "Deprecated in JupyterHub 0.8."
" Admin access API is not needed now that we use OAuth."
)
current = self.current_user current = self.current_user
self.log.warning("Admin user %s has requested access to %s's server", self.log.warning(
current.name, name, "Admin user %s has requested access to %s's server", current.name, name
) )
if not self.settings.get('admin_access', False): if not self.settings.get('admin_access', False):
raise web.HTTPError(403, "admin access to user servers disabled") raise web.HTTPError(403, "admin access to user servers disabled")
@@ -501,10 +532,7 @@ class SpawnProgressAPIHandler(APIHandler):
except (StreamClosedError, RuntimeError): except (StreamClosedError, RuntimeError):
return return
await asyncio.wait( await asyncio.wait([self._finish_future], timeout=self.keepalive_interval)
[self._finish_future],
timeout=self.keepalive_interval,
)
@admin_or_self @admin_or_self
async def get(self, username, server_name=''): async def get(self, username, server_name=''):
@@ -535,11 +563,7 @@ class SpawnProgressAPIHandler(APIHandler):
'html_message': 'Server ready at <a href="{0}">{0}</a>'.format(url), 'html_message': 'Server ready at <a href="{0}">{0}</a>'.format(url),
'url': url, 'url': url,
} }
failed_event = { failed_event = {'progress': 100, 'failed': True, 'message': "Spawn failed"}
'progress': 100,
'failed': True,
'message': "Spawn failed",
}
if spawner.ready: if spawner.ready:
# spawner already ready. Trigger progress-completion immediately # spawner already ready. Trigger progress-completion immediately
@@ -561,7 +585,9 @@ class SpawnProgressAPIHandler(APIHandler):
raise web.HTTPError(400, "%s is not starting...", spawner._log_name) raise web.HTTPError(400, "%s is not starting...", spawner._log_name)
# retrieve progress events from the Spawner # retrieve progress events from the Spawner
async with aclosing(iterate_until(spawn_future, spawner._generate_progress())) as events: async with aclosing(
iterate_until(spawn_future, spawner._generate_progress())
) as events:
async for event in events: async for event in events:
# don't allow events to sneakily set the 'ready' flag # don't allow events to sneakily set the 'ready' flag
if 'ready' in event: if 'ready' in event:
@@ -584,7 +610,9 @@ class SpawnProgressAPIHandler(APIHandler):
if f and f.done() and f.exception(): if f and f.done() and f.exception():
failed_event['message'] = "Spawn failed: %s" % f.exception() failed_event['message'] = "Spawn failed: %s" % f.exception()
else: else:
self.log.warning("Server %s didn't start for unknown reason", spawner._log_name) self.log.warning(
"Server %s didn't start for unknown reason", spawner._log_name
)
await self.send_event(failed_event) await self.send_event(failed_event)
@@ -609,13 +637,12 @@ def _parse_timestamp(timestamp):
400, 400,
"Rejecting activity from more than an hour in the future: {}".format( "Rejecting activity from more than an hour in the future: {}".format(
isoformat(dt) isoformat(dt)
) ),
) )
return dt return dt
class ActivityAPIHandler(APIHandler): class ActivityAPIHandler(APIHandler):
def _validate_servers(self, user, servers): def _validate_servers(self, user, servers):
"""Validate servers dict argument """Validate servers dict argument
@@ -632,10 +659,7 @@ class ActivityAPIHandler(APIHandler):
if server_name not in spawners: if server_name not in spawners:
raise web.HTTPError( raise web.HTTPError(
400, 400,
"No such server '{}' for user {}".format( "No such server '{}' for user {}".format(server_name, user.name),
server_name,
user.name,
)
) )
# check that each per-server field is a dict # check that each per-server field is a dict
if not isinstance(server_info, dict): if not isinstance(server_info, dict):
@@ -645,7 +669,9 @@ class ActivityAPIHandler(APIHandler):
raise web.HTTPError(400, msg) raise web.HTTPError(400, msg)
# parse last_activity timestamps # parse last_activity timestamps
# _parse_timestamp above is responsible for raising errors # _parse_timestamp above is responsible for raising errors
server_info['last_activity'] = _parse_timestamp(server_info['last_activity']) server_info['last_activity'] = _parse_timestamp(
server_info['last_activity']
)
return servers return servers
@admin_or_self @admin_or_self
@@ -663,8 +689,7 @@ class ActivityAPIHandler(APIHandler):
servers = body.get('servers') servers = body.get('servers')
if not last_activity_timestamp and not servers: if not last_activity_timestamp and not servers:
raise web.HTTPError( raise web.HTTPError(
400, 400, "body must contain at least one of `last_activity` or `servers`"
"body must contain at least one of `last_activity` or `servers`"
) )
if servers: if servers:
@@ -677,13 +702,9 @@ class ActivityAPIHandler(APIHandler):
# update user.last_activity if specified # update user.last_activity if specified
if last_activity_timestamp: if last_activity_timestamp:
last_activity = _parse_timestamp(last_activity_timestamp) last_activity = _parse_timestamp(last_activity_timestamp)
if ( if (not user.last_activity) or last_activity > user.last_activity:
(not user.last_activity) self.log.debug(
or last_activity > user.last_activity "Activity for user %s: %s", user.name, isoformat(last_activity)
):
self.log.debug("Activity for user %s: %s",
user.name,
isoformat(last_activity),
) )
user.last_activity = last_activity user.last_activity = last_activity
else: else:
@@ -699,11 +720,9 @@ class ActivityAPIHandler(APIHandler):
last_activity = server_info['last_activity'] last_activity = server_info['last_activity']
spawner = user.orm_spawners[server_name] spawner = user.orm_spawners[server_name]
if ( if (not spawner.last_activity) or last_activity > spawner.last_activity:
(not spawner.last_activity) self.log.debug(
or last_activity > spawner.last_activity "Activity on server %s/%s: %s",
):
self.log.debug("Activity on server %s/%s: %s",
user.name, user.name,
server_name, server_name,
isoformat(last_activity), isoformat(last_activity),

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,16 @@
"""Base Authenticator class and the default PAM Authenticator""" """Base Authenticator class and the default PAM Authenticator"""
# Copyright (c) IPython Development Team. # Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from concurrent.futures import ThreadPoolExecutor
import inspect import inspect
import pipes import pipes
import re import re
import sys import sys
from shutil import which
from subprocess import Popen, PIPE, STDOUT
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor
from shutil import which
from subprocess import PIPE
from subprocess import Popen
from subprocess import STDOUT
try: try:
import pamela import pamela
@@ -33,7 +33,9 @@ class Authenticator(LoggingConfigurable):
db = Any() db = Any()
enable_auth_state = Bool(False, config=True, enable_auth_state = Bool(
False,
config=True,
help="""Enable persisting auth_state (if available). help="""Enable persisting auth_state (if available).
auth_state will be encrypted and stored in the Hub's database. auth_state will be encrypted and stored in the Hub's database.
@@ -62,7 +64,7 @@ class Authenticator(LoggingConfigurable):
See :meth:`.refresh_user` for what happens when user auth info is refreshed See :meth:`.refresh_user` for what happens when user auth info is refreshed
(nothing by default). (nothing by default).
""" """,
) )
refresh_pre_spawn = Bool( refresh_pre_spawn = Bool(
@@ -78,7 +80,7 @@ class Authenticator(LoggingConfigurable):
If refresh_user cannot refresh the user auth data, If refresh_user cannot refresh the user auth data,
launch will fail until the user logs in again. launch will fail until the user logs in again.
""" """,
) )
admin_users = Set( admin_users = Set(
@@ -131,8 +133,11 @@ class Authenticator(LoggingConfigurable):
sorted_names = sorted(short_names) sorted_names = sorted(short_names)
single = ''.join(sorted_names) single = ''.join(sorted_names)
string_set_typo = "set('%s')" % single string_set_typo = "set('%s')" % single
self.log.warning("whitelist contains single-character names: %s; did you mean set([%r]) instead of %s?", self.log.warning(
sorted_names[:8], single, string_set_typo, "whitelist contains single-character names: %s; did you mean set([%r]) instead of %s?",
sorted_names[:8],
single,
string_set_typo,
) )
custom_html = Unicode( custom_html = Unicode(
@@ -199,7 +204,8 @@ class Authenticator(LoggingConfigurable):
""" """
).tag(config=True) ).tag(config=True)
delete_invalid_users = Bool(False, delete_invalid_users = Bool(
False,
help="""Delete any users from the database that do not pass validation help="""Delete any users from the database that do not pass validation
When JupyterHub starts, `.add_user` will be called When JupyterHub starts, `.add_user` will be called
@@ -213,10 +219,11 @@ class Authenticator(LoggingConfigurable):
If False (default), invalid users remain in the Hub's database If False (default), invalid users remain in the Hub's database
and a warning will be issued. and a warning will be issued.
This is the default to avoid data loss due to config changes. This is the default to avoid data loss due to config changes.
""" """,
) )
post_auth_hook = Any(config=True, post_auth_hook = Any(
config=True,
help=""" help="""
An optional hook function that you can implement to do some An optional hook function that you can implement to do some
bootstrapping work during authentication. For example, loading user account bootstrapping work during authentication. For example, loading user account
@@ -248,12 +255,16 @@ class Authenticator(LoggingConfigurable):
c.Authenticator.post_auth_hook = my_hook c.Authenticator.post_auth_hook = my_hook
""" """,
) )
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
for method_name in ('check_whitelist', 'check_blacklist', 'check_group_whitelist'): for method_name in (
'check_whitelist',
'check_blacklist',
'check_group_whitelist',
):
original_method = getattr(self, method_name, None) original_method = getattr(self, method_name, None)
if original_method is None: if original_method is None:
# no such method (check_group_whitelist is optional) # no such method (check_group_whitelist is optional)
@@ -273,14 +284,14 @@ class Authenticator(LoggingConfigurable):
Adapting for compatibility. Adapting for compatibility.
""".format( """.format(
self.__class__.__name__, self.__class__.__name__, method_name
method_name,
), ),
DeprecationWarning DeprecationWarning,
) )
def wrapped_method(username, authentication=None, **kwargs): def wrapped_method(username, authentication=None, **kwargs):
return original_method(username, **kwargs) return original_method(username, **kwargs)
setattr(self, method_name, wrapped_method) setattr(self, method_name, wrapped_method)
async def run_post_auth_hook(self, handler, authentication): async def run_post_auth_hook(self, handler, authentication):
@@ -299,11 +310,7 @@ class Authenticator(LoggingConfigurable):
""" """
if self.post_auth_hook is not None: if self.post_auth_hook is not None:
authentication = await maybe_future( authentication = await maybe_future(
self.post_auth_hook( self.post_auth_hook(self, handler, authentication)
self,
handler,
authentication,
)
) )
return authentication return authentication
@@ -380,21 +387,25 @@ class Authenticator(LoggingConfigurable):
if 'name' not in authenticated: if 'name' not in authenticated:
raise ValueError("user missing a name: %r" % authenticated) raise ValueError("user missing a name: %r" % authenticated)
else: else:
authenticated = { authenticated = {'name': authenticated}
'name': authenticated,
}
authenticated.setdefault('auth_state', None) authenticated.setdefault('auth_state', None)
# Leave the default as None, but reevaluate later post-whitelist # Leave the default as None, but reevaluate later post-whitelist
authenticated.setdefault('admin', None) authenticated.setdefault('admin', None)
# normalize the username # normalize the username
authenticated['name'] = username = self.normalize_username(authenticated['name']) authenticated['name'] = username = self.normalize_username(
authenticated['name']
)
if not self.validate_username(username): if not self.validate_username(username):
self.log.warning("Disallowing invalid username %r.", username) self.log.warning("Disallowing invalid username %r.", username)
return return
blacklist_pass = await maybe_future(self.check_blacklist(username, authenticated)) blacklist_pass = await maybe_future(
whitelist_pass = await maybe_future(self.check_whitelist(username, authenticated)) self.check_blacklist(username, authenticated)
)
whitelist_pass = await maybe_future(
self.check_whitelist(username, authenticated)
)
if blacklist_pass: if blacklist_pass:
pass pass
@@ -404,7 +415,9 @@ class Authenticator(LoggingConfigurable):
if whitelist_pass: if whitelist_pass:
if authenticated['admin'] is None: if authenticated['admin'] is None:
authenticated['admin'] = await maybe_future(self.is_admin(handler, authenticated)) authenticated['admin'] = await maybe_future(
self.is_admin(handler, authenticated)
)
authenticated = await self.run_post_auth_hook(handler, authenticated) authenticated = await self.run_post_auth_hook(handler, authenticated)
@@ -534,7 +547,9 @@ class Authenticator(LoggingConfigurable):
""" """
self.whitelist.discard(user.name) self.whitelist.discard(user.name)
auto_login = Bool(False, config=True, auto_login = Bool(
False,
config=True,
help="""Automatically begin the login process help="""Automatically begin the login process
rather than starting with a "Login with..." link at `/hub/login` rather than starting with a "Login with..." link at `/hub/login`
@@ -544,7 +559,7 @@ class Authenticator(LoggingConfigurable):
registered with `.get_handlers()`. registered with `.get_handlers()`.
.. versionadded:: 0.8 .. versionadded:: 0.8
""" """,
) )
def login_url(self, base_url): def login_url(self, base_url):
@@ -592,9 +607,7 @@ class Authenticator(LoggingConfigurable):
list of ``('/url', Handler)`` tuples passed to tornado. list of ``('/url', Handler)`` tuples passed to tornado.
The Hub prefix is added to any URLs. The Hub prefix is added to any URLs.
""" """
return [ return [('/login', LoginHandler)]
('/login', LoginHandler),
]
class LocalAuthenticator(Authenticator): class LocalAuthenticator(Authenticator):
@@ -603,12 +616,13 @@ class LocalAuthenticator(Authenticator):
Checks for local users, and can attempt to create them if they exist. Checks for local users, and can attempt to create them if they exist.
""" """
create_system_users = Bool(False, create_system_users = Bool(
False,
help=""" help="""
If set to True, will attempt to create local system users if they do not exist already. If set to True, will attempt to create local system users if they do not exist already.
Supports Linux and BSD variants only. Supports Linux and BSD variants only.
""" """,
).tag(config=True) ).tag(config=True)
add_user_cmd = Command( add_user_cmd = Command(
@@ -699,8 +713,9 @@ class LocalAuthenticator(Authenticator):
raise KeyError( raise KeyError(
"User {} does not exist on the system." "User {} does not exist on the system."
" Set LocalAuthenticator.create_system_users=True" " Set LocalAuthenticator.create_system_users=True"
" to automatically create system users from jupyterhub users." " to automatically create system users from jupyterhub users.".format(
.format(user.name) user.name
)
) )
await maybe_future(super().add_user(user)) await maybe_future(super().add_user(user))
@@ -711,6 +726,7 @@ class LocalAuthenticator(Authenticator):
on Windows on Windows
""" """
import grp import grp
return grp.getgrnam(name) return grp.getgrnam(name)
@staticmethod @staticmethod
@@ -719,6 +735,7 @@ class LocalAuthenticator(Authenticator):
on Windows on Windows
""" """
import pwd import pwd
return pwd.getpwnam(name) return pwd.getpwnam(name)
@staticmethod @staticmethod
@@ -727,6 +744,7 @@ class LocalAuthenticator(Authenticator):
on Windows on Windows
""" """
import os import os
return os.getgrouplist(name, group) return os.getgrouplist(name, group)
def system_user_exists(self, user): def system_user_exists(self, user):
@@ -758,23 +776,27 @@ class PAMAuthenticator(LocalAuthenticator):
# run PAM in a thread, since it can be slow # run PAM in a thread, since it can be slow
executor = Any() executor = Any()
@default('executor') @default('executor')
def _default_executor(self): def _default_executor(self):
return ThreadPoolExecutor(1) return ThreadPoolExecutor(1)
encoding = Unicode('utf8', encoding = Unicode(
'utf8',
help=""" help="""
The text encoding to use when communicating with PAM The text encoding to use when communicating with PAM
""" """,
).tag(config=True) ).tag(config=True)
service = Unicode('login', service = Unicode(
'login',
help=""" help="""
The name of the PAM service to use for authentication The name of the PAM service to use for authentication
""" """,
).tag(config=True) ).tag(config=True)
open_sessions = Bool(True, open_sessions = Bool(
True,
help=""" help="""
Whether to open a new PAM session when spawners are started. Whether to open a new PAM session when spawners are started.
@@ -784,10 +806,11 @@ class PAMAuthenticator(LocalAuthenticator):
If any errors are encountered when opening/closing PAM sessions, If any errors are encountered when opening/closing PAM sessions,
this is automatically set to False. this is automatically set to False.
""" """,
).tag(config=True) ).tag(config=True)
check_account = Bool(True, check_account = Bool(
True,
help=""" help="""
Whether to check the user's account status via PAM during authentication. Whether to check the user's account status via PAM during authentication.
@@ -797,7 +820,7 @@ class PAMAuthenticator(LocalAuthenticator):
Disabling this can be dangerous as authenticated but unauthorized users may Disabling this can be dangerous as authenticated but unauthorized users may
be granted access and, therefore, arbitrary execution on the system. be granted access and, therefore, arbitrary execution on the system.
""" """,
).tag(config=True) ).tag(config=True)
admin_groups = Set( admin_groups = Set(
@@ -809,14 +832,15 @@ class PAMAuthenticator(LocalAuthenticator):
""" """
).tag(config=True) ).tag(config=True)
pam_normalize_username = Bool(False, pam_normalize_username = Bool(
False,
help=""" help="""
Round-trip the username via PAM lookups to make sure it is unique Round-trip the username via PAM lookups to make sure it is unique
PAM can accept multiple usernames that map to the same user, PAM can accept multiple usernames that map to the same user,
for example DOMAIN\\username in some cases. To prevent this, for example DOMAIN\\username in some cases. To prevent this,
convert username into uid, then back to uid to normalize. convert username into uid, then back to uid to normalize.
""" """,
).tag(config=True) ).tag(config=True)
def __init__(self, **kwargs): def __init__(self, **kwargs):
@@ -844,12 +868,19 @@ class PAMAuthenticator(LocalAuthenticator):
# (returning None instead of just the username) as this indicates some sort of system failure # (returning None instead of just the username) as this indicates some sort of system failure
admin_group_gids = {self._getgrnam(x).gr_gid for x in self.admin_groups} admin_group_gids = {self._getgrnam(x).gr_gid for x in self.admin_groups}
user_group_gids = set(self._getgrouplist(username, self._getpwnam(username).pw_gid)) user_group_gids = set(
self._getgrouplist(username, self._getpwnam(username).pw_gid)
)
admin_status = len(admin_group_gids & user_group_gids) != 0 admin_status = len(admin_group_gids & user_group_gids) != 0
except Exception as e: except Exception as e:
if handler is not None: if handler is not None:
self.log.error("PAM Admin Group Check failed (%s@%s): %s", username, handler.request.remote_ip, e) self.log.error(
"PAM Admin Group Check failed (%s@%s): %s",
username,
handler.request.remote_ip,
e,
)
else: else:
self.log.error("PAM Admin Group Check failed: %s", e) self.log.error("PAM Admin Group Check failed: %s", e)
# re-raise to return a 500 to the user and indicate a problem. We failed, not them. # re-raise to return a 500 to the user and indicate a problem. We failed, not them.
@@ -865,27 +896,40 @@ class PAMAuthenticator(LocalAuthenticator):
""" """
username = data['username'] username = data['username']
try: try:
pamela.authenticate(username, data['password'], service=self.service, encoding=self.encoding) pamela.authenticate(
username, data['password'], service=self.service, encoding=self.encoding
)
except pamela.PAMError as e: except pamela.PAMError as e:
if handler is not None: if handler is not None:
self.log.warning("PAM Authentication failed (%s@%s): %s", username, handler.request.remote_ip, e) self.log.warning(
"PAM Authentication failed (%s@%s): %s",
username,
handler.request.remote_ip,
e,
)
else: else:
self.log.warning("PAM Authentication failed: %s", e) self.log.warning("PAM Authentication failed: %s", e)
return None return None
if self.check_account: if self.check_account:
try: try:
pamela.check_account(username, service=self.service, encoding=self.encoding) pamela.check_account(
username, service=self.service, encoding=self.encoding
)
except pamela.PAMError as e: except pamela.PAMError as e:
if handler is not None: if handler is not None:
self.log.warning("PAM Account Check failed (%s@%s): %s", username, handler.request.remote_ip, e) self.log.warning(
"PAM Account Check failed (%s@%s): %s",
username,
handler.request.remote_ip,
e,
)
else: else:
self.log.warning("PAM Account Check failed: %s", e) self.log.warning("PAM Account Check failed: %s", e)
return None return None
return username return username
@run_on_executor @run_on_executor
def pre_spawn_start(self, user, spawner): def pre_spawn_start(self, user, spawner):
"""Open PAM session for user if so configured""" """Open PAM session for user if so configured"""
@@ -904,7 +948,9 @@ class PAMAuthenticator(LocalAuthenticator):
if not self.open_sessions: if not self.open_sessions:
return return
try: try:
pamela.close_session(user.name, service=self.service, encoding=self.encoding) pamela.close_session(
user.name, service=self.service, encoding=self.encoding
)
except pamela.PAMError as e: except pamela.PAMError as e:
self.log.warning("Failed to close PAM session for %s: %s", user.name, e) self.log.warning("Failed to close PAM session for %s: %s", user.name, e)
self.log.warning("Disabling PAM sessions from now on.") self.log.warning("Disabling PAM sessions from now on.")
@@ -916,12 +962,14 @@ class PAMAuthenticator(LocalAuthenticator):
PAM can accept multiple usernames as the same user, normalize them.""" PAM can accept multiple usernames as the same user, normalize them."""
if self.pam_normalize_username: if self.pam_normalize_username:
import pwd import pwd
uid = pwd.getpwnam(username).pw_uid uid = pwd.getpwnam(username).pw_uid
username = pwd.getpwuid(uid).pw_name username = pwd.getpwuid(uid).pw_name
username = self.username_map.get(username, username) username = self.username_map.get(username, username)
else: else:
return super().normalize_username(username) return super().normalize_username(username)
class DummyAuthenticator(Authenticator): class DummyAuthenticator(Authenticator):
"""Dummy Authenticator for testing """Dummy Authenticator for testing
@@ -938,7 +986,7 @@ class DummyAuthenticator(Authenticator):
Set a global password for all users wanting to log in. Set a global password for all users wanting to log in.
This allows users with any username to log in with the same static password. This allows users with any username to log in with the same static password.
""" """,
) )
async def authenticate(self, handler, data): async def authenticate(self, handler, data):

View File

@@ -1,35 +1,43 @@
import base64 import base64
from binascii import a2b_hex
from concurrent.futures import ThreadPoolExecutor
import json import json
import os import os
from binascii import a2b_hex
from concurrent.futures import ThreadPoolExecutor
from traitlets.config import SingletonConfigurable, Config from traitlets import Any
from traitlets import ( from traitlets import default
Any, Dict, Integer, List, from traitlets import Dict
default, observe, validate, from traitlets import Integer
) from traitlets import List
from traitlets import observe
from traitlets import validate
from traitlets.config import Config
from traitlets.config import SingletonConfigurable
try: try:
import cryptography import cryptography
from cryptography.fernet import Fernet, MultiFernet, InvalidToken from cryptography.fernet import Fernet, MultiFernet, InvalidToken
except ImportError: except ImportError:
cryptography = None cryptography = None
class InvalidToken(Exception): class InvalidToken(Exception):
pass pass
from .utils import maybe_future from .utils import maybe_future
KEY_ENV = 'JUPYTERHUB_CRYPT_KEY' KEY_ENV = 'JUPYTERHUB_CRYPT_KEY'
class EncryptionUnavailable(Exception): class EncryptionUnavailable(Exception):
pass pass
class CryptographyUnavailable(EncryptionUnavailable): class CryptographyUnavailable(EncryptionUnavailable):
def __str__(self): def __str__(self):
return "cryptography library is required for encryption" return "cryptography library is required for encryption"
class NoEncryptionKeys(EncryptionUnavailable): class NoEncryptionKeys(EncryptionUnavailable):
def __str__(self): def __str__(self):
return "Encryption keys must be specified in %s env" % KEY_ENV return "Encryption keys must be specified in %s env" % KEY_ENV
@@ -70,13 +78,16 @@ def _validate_key(key):
return key return key
class CryptKeeper(SingletonConfigurable): class CryptKeeper(SingletonConfigurable):
"""Encapsulate encryption configuration """Encapsulate encryption configuration
Use via the encryption_config singleton below. Use via the encryption_config singleton below.
""" """
n_threads = Integer(max(os.cpu_count(), 1), config=True, n_threads = Integer(
max(os.cpu_count(), 1),
config=True,
help="The number of threads to allocate for encryption", help="The number of threads to allocate for encryption",
) )
@@ -84,22 +95,27 @@ class CryptKeeper(SingletonConfigurable):
def _config_default(self): def _config_default(self):
# load application config by default # load application config by default
from .app import JupyterHub from .app import JupyterHub
if JupyterHub.initialized(): if JupyterHub.initialized():
return JupyterHub.instance().config return JupyterHub.instance().config
else: else:
return Config() return Config()
executor = Any() executor = Any()
def _executor_default(self): def _executor_default(self):
return ThreadPoolExecutor(self.n_threads) return ThreadPoolExecutor(self.n_threads)
keys = List(config=True) keys = List(config=True)
def _keys_default(self): def _keys_default(self):
if KEY_ENV not in os.environ: if KEY_ENV not in os.environ:
return [] return []
# key can be a ;-separated sequence for key rotation. # key can be a ;-separated sequence for key rotation.
# First item in the list is used for encryption. # First item in the list is used for encryption.
return [ _validate_key(key) for key in os.environ[KEY_ENV].split(';') if key.strip() ] return [
_validate_key(key) for key in os.environ[KEY_ENV].split(';') if key.strip()
]
@validate('keys') @validate('keys')
def _ensure_bytes(self, proposal): def _ensure_bytes(self, proposal):
@@ -107,6 +123,7 @@ class CryptKeeper(SingletonConfigurable):
return [_validate_key(key) for key in proposal.value] return [_validate_key(key) for key in proposal.value]
fernet = Any() fernet = Any()
def _fernet_default(self): def _fernet_default(self):
if cryptography is None or not self.keys: if cryptography is None or not self.keys:
return None return None
@@ -153,6 +170,7 @@ def encrypt(data):
""" """
return CryptKeeper.instance().encrypt(data) return CryptKeeper.instance().encrypt(data)
def decrypt(data): def decrypt(data):
"""decrypt some data with the crypt keeper """decrypt some data with the crypt keeper

View File

@@ -1,15 +1,13 @@
"""Database utilities for JupyterHub""" """Database utilities for JupyterHub"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
# Based on pgcontents.utils.migrate, used under the Apache license. # Based on pgcontents.utils.migrate, used under the Apache license.
from contextlib import contextmanager
from datetime import datetime
import os import os
import shutil import shutil
from subprocess import check_call
import sys import sys
from contextlib import contextmanager
from datetime import datetime
from subprocess import check_call
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from sqlalchemy import create_engine from sqlalchemy import create_engine
@@ -85,9 +83,7 @@ def upgrade(db_url, revision='head'):
The alembic revision to upgrade to. The alembic revision to upgrade to.
""" """
with _temp_alembic_ini(db_url) as alembic_ini: with _temp_alembic_ini(db_url) as alembic_ini:
check_call( check_call(['alembic', '-c', alembic_ini, 'upgrade', revision])
['alembic', '-c', alembic_ini, 'upgrade', revision]
)
def backup_db_file(db_file, log=None): def backup_db_file(db_file, log=None):
@@ -133,30 +129,27 @@ def upgrade_if_needed(db_url, backup=True, log=None):
def shell(args=None): def shell(args=None):
"""Start an IPython shell hooked up to the jupyerhub database""" """Start an IPython shell hooked up to the jupyerhub database"""
from .app import JupyterHub from .app import JupyterHub
hub = JupyterHub() hub = JupyterHub()
hub.load_config_file(hub.config_file) hub.load_config_file(hub.config_file)
db_url = hub.db_url db_url = hub.db_url
db = orm.new_session_factory(db_url, **hub.db_kwargs)() db = orm.new_session_factory(db_url, **hub.db_kwargs)()
ns = { ns = {'db': db, 'db_url': db_url, 'orm': orm}
'db': db,
'db_url': db_url,
'orm': orm,
}
import IPython import IPython
IPython.start_ipython(args, user_ns=ns) IPython.start_ipython(args, user_ns=ns)
def _alembic(args): def _alembic(args):
"""Run an alembic command with a temporary alembic.ini""" """Run an alembic command with a temporary alembic.ini"""
from .app import JupyterHub from .app import JupyterHub
hub = JupyterHub() hub = JupyterHub()
hub.load_config_file(hub.config_file) hub.load_config_file(hub.config_file)
db_url = hub.db_url db_url = hub.db_url
with _temp_alembic_ini(db_url) as alembic_ini: with _temp_alembic_ini(db_url) as alembic_ini:
check_call( check_call(['alembic', '-c', alembic_ini] + args)
['alembic', '-c', alembic_ini] + args
)
def main(args=None): def main(args=None):

View File

@@ -1,8 +1,10 @@
from . import base
from . import login
from . import metrics
from . import pages
from .base import * from .base import *
from .login import * from .login import *
from . import base, pages, login, metrics
default_handlers = [] default_handlers = []
for mod in (base, pages, login, metrics): for mod in (base, pages, login, metrics):
default_handlers.extend(mod.default_handlers) default_handlers.extend(mod.default_handlers)

View File

@@ -1,41 +1,49 @@
"""HTTP Handlers for the hub server""" """HTTP Handlers for the hub server"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import asyncio import asyncio
import copy import copy
from datetime import datetime, timedelta
from http.client import responses
import json import json
import math import math
import random import random
import re import re
import time import time
from urllib.parse import urlparse, urlunparse, parse_qs, urlencode
import uuid import uuid
from datetime import datetime
from datetime import timedelta
from http.client import responses
from urllib.parse import parse_qs
from urllib.parse import urlencode
from urllib.parse import urlparse
from urllib.parse import urlunparse
from jinja2 import TemplateNotFound from jinja2 import TemplateNotFound
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from tornado.log import app_log from tornado import gen
from tornado.httputil import url_concat, HTTPHeaders from tornado import web
from tornado.httputil import HTTPHeaders
from tornado.httputil import url_concat
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from tornado.web import RequestHandler, MissingArgumentError from tornado.log import app_log
from tornado import gen, web from tornado.web import MissingArgumentError
from tornado.web import RequestHandler
from .. import __version__ from .. import __version__
from .. import orm from .. import orm
from ..metrics import PROXY_ADD_DURATION_SECONDS
from ..metrics import ProxyAddStatus
from ..metrics import RUNNING_SERVERS
from ..metrics import SERVER_POLL_DURATION_SECONDS
from ..metrics import SERVER_SPAWN_DURATION_SECONDS
from ..metrics import SERVER_STOP_DURATION_SECONDS
from ..metrics import ServerPollStatus
from ..metrics import ServerSpawnStatus
from ..metrics import ServerStopStatus
from ..objects import Server from ..objects import Server
from ..spawner import LocalProcessSpawner from ..spawner import LocalProcessSpawner
from ..user import User from ..user import User
from ..utils import maybe_future, url_path_join from ..utils import maybe_future
from ..metrics import ( from ..utils import url_path_join
SERVER_SPAWN_DURATION_SECONDS, ServerSpawnStatus,
PROXY_ADD_DURATION_SECONDS, ProxyAddStatus,
SERVER_POLL_DURATION_SECONDS, ServerPollStatus,
RUNNING_SERVERS, SERVER_STOP_DURATION_SECONDS, ServerStopStatus
)
# pattern for the authentication token header # pattern for the authentication token header
auth_header_pat = re.compile(r'^(?:token|bearer)\s+([^\s]+)$', flags=re.IGNORECASE) auth_header_pat = re.compile(r'^(?:token|bearer)\s+([^\s]+)$', flags=re.IGNORECASE)
@@ -52,6 +60,7 @@ reasons = {
# constant, not configurable # constant, not configurable
SESSION_COOKIE_NAME = 'jupyterhub-session-id' SESSION_COOKIE_NAME = 'jupyterhub-session-id'
class BaseHandler(RequestHandler): class BaseHandler(RequestHandler):
"""Base Handler class with access to common methods and properties.""" """Base Handler class with access to common methods and properties."""
@@ -157,8 +166,8 @@ class BaseHandler(RequestHandler):
@property @property
def csp_report_uri(self): def csp_report_uri(self):
return self.settings.get('csp_report_uri', return self.settings.get(
url_path_join(self.hub.base_url, 'security/csp-report') 'csp_report_uri', url_path_join(self.hub.base_url, 'security/csp-report')
) )
@property @property
@@ -167,10 +176,9 @@ class BaseHandler(RequestHandler):
Can be overridden by defining Content-Security-Policy in settings['headers'] Can be overridden by defining Content-Security-Policy in settings['headers']
""" """
return '; '.join([ return '; '.join(
"frame-ancestors 'self'", ["frame-ancestors 'self'", "report-uri " + self.csp_report_uri]
"report-uri " + self.csp_report_uri, )
])
def get_content_type(self): def get_content_type(self):
return 'text/html' return 'text/html'
@@ -190,7 +198,9 @@ class BaseHandler(RequestHandler):
self.set_header(header_name, header_content) self.set_header(header_name, header_content)
if 'Access-Control-Allow-Headers' not in headers: if 'Access-Control-Allow-Headers' not in headers:
self.set_header('Access-Control-Allow-Headers', 'accept, content-type, authorization') self.set_header(
'Access-Control-Allow-Headers', 'accept, content-type, authorization'
)
if 'Content-Security-Policy' not in headers: if 'Content-Security-Policy' not in headers:
self.set_header('Content-Security-Policy', self.content_security_policy) self.set_header('Content-Security-Policy', self.content_security_policy)
self.set_header('Content-Type', self.get_content_type()) self.set_header('Content-Type', self.get_content_type())
@@ -236,8 +246,7 @@ class BaseHandler(RequestHandler):
orm_token = orm.OAuthAccessToken.find(self.db, token) orm_token = orm.OAuthAccessToken.find(self.db, token)
if orm_token is None: if orm_token is None:
return None return None
orm_token.last_activity = \ orm_token.last_activity = orm_token.user.last_activity = datetime.utcnow()
orm_token.user.last_activity = datetime.utcnow()
self.db.commit() self.db.commit()
return self._user_from_orm(orm_token.user) return self._user_from_orm(orm_token.user)
@@ -259,7 +268,11 @@ class BaseHandler(RequestHandler):
if not refresh_age: if not refresh_age:
return user return user
now = time.monotonic() now = time.monotonic()
if not force and user._auth_refreshed and (now - user._auth_refreshed < refresh_age): if (
not force
and user._auth_refreshed
and (now - user._auth_refreshed < refresh_age)
):
# auth up-to-date # auth up-to-date
return user return user
@@ -276,8 +289,7 @@ class BaseHandler(RequestHandler):
if not auth_info: if not auth_info:
self.log.warning( self.log.warning(
"User %s has stale auth info. Login is required to refresh.", "User %s has stale auth info. Login is required to refresh.", user.name
user.name,
) )
return return
@@ -325,10 +337,9 @@ class BaseHandler(RequestHandler):
def _user_for_cookie(self, cookie_name, cookie_value=None): def _user_for_cookie(self, cookie_name, cookie_value=None):
"""Get the User for a given cookie, if there is one""" """Get the User for a given cookie, if there is one"""
cookie_id = self.get_secure_cookie( cookie_id = self.get_secure_cookie(
cookie_name, cookie_name, cookie_value, max_age_days=self.cookie_max_age_days
cookie_value,
max_age_days=self.cookie_max_age_days,
) )
def clear(): def clear():
self.clear_cookie(cookie_name, path=self.hub.base_url) self.clear_cookie(cookie_name, path=self.hub.base_url)
@@ -434,7 +445,11 @@ class BaseHandler(RequestHandler):
# clear hub cookie # clear hub cookie
self.clear_cookie(self.hub.cookie_name, path=self.hub.base_url, **kwargs) self.clear_cookie(self.hub.cookie_name, path=self.hub.base_url, **kwargs)
# clear services cookie # clear services cookie
self.clear_cookie('jupyterhub-services', path=url_path_join(self.base_url, 'services'), **kwargs) self.clear_cookie(
'jupyterhub-services',
path=url_path_join(self.base_url, 'services'),
**kwargs
)
def _set_cookie(self, key, value, encrypted=True, **overrides): def _set_cookie(self, key, value, encrypted=True, **overrides):
"""Setting any cookie should go through here """Setting any cookie should go through here
@@ -444,9 +459,7 @@ class BaseHandler(RequestHandler):
""" """
# tornado <4.2 have a bug that consider secure==True as soon as # tornado <4.2 have a bug that consider secure==True as soon as
# 'secure' kwarg is passed to set_secure_cookie # 'secure' kwarg is passed to set_secure_cookie
kwargs = { kwargs = {'httponly': True}
'httponly': True,
}
if self.request.protocol == 'https': if self.request.protocol == 'https':
kwargs['secure'] = True kwargs['secure'] = True
if self.subdomain_host: if self.subdomain_host:
@@ -463,14 +476,10 @@ class BaseHandler(RequestHandler):
self.log.debug("Setting cookie %s: %s", key, kwargs) self.log.debug("Setting cookie %s: %s", key, kwargs)
set_cookie(key, value, **kwargs) set_cookie(key, value, **kwargs)
def _set_user_cookie(self, user, server): def _set_user_cookie(self, user, server):
self.log.debug("Setting cookie for %s: %s", user.name, server.cookie_name) self.log.debug("Setting cookie for %s: %s", user.name, server.cookie_name)
self._set_cookie( self._set_cookie(
server.cookie_name, server.cookie_name, user.cookie_id, encrypted=True, path=server.base_url
user.cookie_id,
encrypted=True,
path=server.base_url,
) )
def get_session_cookie(self): def get_session_cookie(self):
@@ -494,10 +503,13 @@ class BaseHandler(RequestHandler):
def set_service_cookie(self, user): def set_service_cookie(self, user):
"""set the login cookie for services""" """set the login cookie for services"""
self._set_user_cookie(user, orm.Server( self._set_user_cookie(
user,
orm.Server(
cookie_name='jupyterhub-services', cookie_name='jupyterhub-services',
base_url=url_path_join(self.base_url, 'services') base_url=url_path_join(self.base_url, 'services'),
)) ),
)
def set_hub_cookie(self, user): def set_hub_cookie(self, user):
"""set the login cookie for the Hub""" """set the login cookie for the Hub"""
@@ -508,7 +520,9 @@ class BaseHandler(RequestHandler):
if self.subdomain_host and not self.request.host.startswith(self.domain): if self.subdomain_host and not self.request.host.startswith(self.domain):
self.log.warning( self.log.warning(
"Possibly setting cookie on wrong domain: %s != %s", "Possibly setting cookie on wrong domain: %s != %s",
self.request.host, self.domain) self.request.host,
self.domain,
)
# set single cookie for services # set single cookie for services
if self.db.query(orm.Service).filter(orm.Service.server != None).first(): if self.db.query(orm.Service).filter(orm.Service.server != None).first():
@@ -555,8 +569,10 @@ class BaseHandler(RequestHandler):
# ultimately redirecting to the logged-in user's server. # ultimately redirecting to the logged-in user's server.
without_prefix = next_url[len(self.base_url) :] without_prefix = next_url[len(self.base_url) :]
next_url = url_path_join(self.hub.base_url, without_prefix) next_url = url_path_join(self.hub.base_url, without_prefix)
self.log.warning("Redirecting %s to %s. For sharing public links, use /user-redirect/", self.log.warning(
self.request.uri, next_url, "Redirecting %s to %s. For sharing public links, use /user-redirect/",
self.request.uri,
next_url,
) )
if not next_url: if not next_url:
@@ -627,8 +643,9 @@ class BaseHandler(RequestHandler):
else: else:
self.statsd.incr('login.failure') self.statsd.incr('login.failure')
self.statsd.timing('login.authenticate.failure', auth_timer.ms) self.statsd.timing('login.authenticate.failure', auth_timer.ms)
self.log.warning("Failed login for %s", (data or {}).get('username', 'unknown user')) self.log.warning(
"Failed login for %s", (data or {}).get('username', 'unknown user')
)
# --------------------------------------------------------------- # ---------------------------------------------------------------
# spawning-related # spawning-related
@@ -659,7 +676,9 @@ class BaseHandler(RequestHandler):
if self.authenticator.refresh_pre_spawn: if self.authenticator.refresh_pre_spawn:
auth_user = await self.refresh_auth(user, force=True) auth_user = await self.refresh_auth(user, force=True)
if auth_user is None: if auth_user is None:
raise web.HTTPError(403, "auth has expired for %s, login again", user.name) raise web.HTTPError(
403, "auth has expired for %s, login again", user.name
)
spawn_start_time = time.perf_counter() spawn_start_time = time.perf_counter()
self.extra_error_html = self.spawn_home_error self.extra_error_html = self.spawn_home_error
@@ -681,7 +700,9 @@ class BaseHandler(RequestHandler):
# but for 10k users this takes ~5ms # but for 10k users this takes ~5ms
# and saves us from bookkeeping errors # and saves us from bookkeeping errors
active_counts = self.users.count_active_users() active_counts = self.users.count_active_users()
spawn_pending_count = active_counts['spawn_pending'] + active_counts['proxy_pending'] spawn_pending_count = (
active_counts['spawn_pending'] + active_counts['proxy_pending']
)
active_count = active_counts['active'] active_count = active_counts['active']
concurrent_spawn_limit = self.concurrent_spawn_limit concurrent_spawn_limit = self.concurrent_spawn_limit
@@ -700,18 +721,21 @@ class BaseHandler(RequestHandler):
# round suggestion to nicer human value (nearest 10 seconds or minute) # round suggestion to nicer human value (nearest 10 seconds or minute)
if retry_time <= 90: if retry_time <= 90:
# round human seconds up to nearest 10 # round human seconds up to nearest 10
human_retry_time = "%i0 seconds" % math.ceil(retry_time / 10.) human_retry_time = "%i0 seconds" % math.ceil(retry_time / 10.0)
else: else:
# round number of minutes # round number of minutes
human_retry_time = "%i minutes" % math.round(retry_time / 60.) human_retry_time = "%i minutes" % math.round(retry_time / 60.0)
self.log.warning( self.log.warning(
'%s pending spawns, throttling. Suggested retry in %s seconds.', '%s pending spawns, throttling. Suggested retry in %s seconds.',
spawn_pending_count, retry_time, spawn_pending_count,
retry_time,
) )
err = web.HTTPError( err = web.HTTPError(
429, 429,
"Too many users trying to log in right now. Try again in {}.".format(human_retry_time) "Too many users trying to log in right now. Try again in {}.".format(
human_retry_time
),
) )
# can't call set_header directly here because it gets ignored # can't call set_header directly here because it gets ignored
# when errors are raised # when errors are raised
@@ -720,14 +744,13 @@ class BaseHandler(RequestHandler):
raise err raise err
if active_server_limit and active_count >= active_server_limit: if active_server_limit and active_count >= active_server_limit:
self.log.info( self.log.info('%s servers active, no space available', active_count)
'%s servers active, no space available',
active_count,
)
SERVER_SPAWN_DURATION_SECONDS.labels( SERVER_SPAWN_DURATION_SECONDS.labels(
status=ServerSpawnStatus.too_many_users status=ServerSpawnStatus.too_many_users
).observe(time.perf_counter() - spawn_start_time) ).observe(time.perf_counter() - spawn_start_time)
raise web.HTTPError(429, "Active user limit exceeded. Try again in a few minutes.") raise web.HTTPError(
429, "Active user limit exceeded. Try again in a few minutes."
)
tic = IOLoop.current().time() tic = IOLoop.current().time()
@@ -735,12 +758,16 @@ class BaseHandler(RequestHandler):
spawn_future = user.spawn(server_name, options, handler=self) spawn_future = user.spawn(server_name, options, handler=self)
self.log.debug("%i%s concurrent spawns", self.log.debug(
"%i%s concurrent spawns",
spawn_pending_count, spawn_pending_count,
'/%i' % concurrent_spawn_limit if concurrent_spawn_limit else '') '/%i' % concurrent_spawn_limit if concurrent_spawn_limit else '',
self.log.debug("%i%s active servers", )
self.log.debug(
"%i%s active servers",
active_count, active_count,
'/%i' % active_server_limit if active_server_limit else '') '/%i' % active_server_limit if active_server_limit else '',
)
spawner = user.spawners[server_name] spawner = user.spawners[server_name]
# set spawn_pending now, so there's no gap where _spawn_pending is False # set spawn_pending now, so there's no gap where _spawn_pending is False
@@ -756,7 +783,9 @@ class BaseHandler(RequestHandler):
# wait for spawn Future # wait for spawn Future
await spawn_future await spawn_future
toc = IOLoop.current().time() toc = IOLoop.current().time()
self.log.info("User %s took %.3f seconds to start", user_server_name, toc-tic) self.log.info(
"User %s took %.3f seconds to start", user_server_name, toc - tic
)
self.statsd.timing('spawner.success', (toc - tic) * 1000) self.statsd.timing('spawner.success', (toc - tic) * 1000)
RUNNING_SERVERS.inc() RUNNING_SERVERS.inc()
SERVER_SPAWN_DURATION_SECONDS.labels( SERVER_SPAWN_DURATION_SECONDS.labels(
@@ -767,18 +796,16 @@ class BaseHandler(RequestHandler):
try: try:
await self.proxy.add_user(user, server_name) await self.proxy.add_user(user, server_name)
PROXY_ADD_DURATION_SECONDS.labels( PROXY_ADD_DURATION_SECONDS.labels(status='success').observe(
status='success'
).observe(
time.perf_counter() - proxy_add_start_time time.perf_counter() - proxy_add_start_time
) )
except Exception: except Exception:
self.log.exception("Failed to add %s to proxy!", user_server_name) self.log.exception("Failed to add %s to proxy!", user_server_name)
self.log.error("Stopping %s to avoid inconsistent state", user_server_name) self.log.error(
"Stopping %s to avoid inconsistent state", user_server_name
)
await user.stop() await user.stop()
PROXY_ADD_DURATION_SECONDS.labels( PROXY_ADD_DURATION_SECONDS.labels(status='failure').observe(
status='failure'
).observe(
time.perf_counter() - proxy_add_start_time time.perf_counter() - proxy_add_start_time
) )
else: else:
@@ -818,7 +845,8 @@ class BaseHandler(RequestHandler):
self.log.warning( self.log.warning(
"%i consecutive spawns failed. " "%i consecutive spawns failed. "
"Hub will exit if failure count reaches %i before succeeding", "Hub will exit if failure count reaches %i before succeeding",
failure_count, failure_limit, failure_count,
failure_limit,
) )
if failure_limit and failure_count >= failure_limit: if failure_limit and failure_count >= failure_limit:
self.log.critical( self.log.critical(
@@ -828,6 +856,7 @@ class BaseHandler(RequestHandler):
# mostly propagating errors for the current failures # mostly propagating errors for the current failures
def abort(): def abort():
raise SystemExit(1) raise SystemExit(1)
IOLoop.current().call_later(2, abort) IOLoop.current().call_later(2, abort)
finish_spawn_future.add_done_callback(_track_failure_count) finish_spawn_future.add_done_callback(_track_failure_count)
@@ -842,8 +871,11 @@ class BaseHandler(RequestHandler):
if spawner._spawn_pending and not spawner._waiting_for_response: if spawner._spawn_pending and not spawner._waiting_for_response:
# still in Spawner.start, which is taking a long time # still in Spawner.start, which is taking a long time
# we shouldn't poll while spawn is incomplete. # we shouldn't poll while spawn is incomplete.
self.log.warning("User %s is slow to start (timeout=%s)", self.log.warning(
user_server_name, self.slow_spawn_timeout) "User %s is slow to start (timeout=%s)",
user_server_name,
self.slow_spawn_timeout,
)
return return
# start has finished, but the server hasn't come up # start has finished, but the server hasn't come up
@@ -861,22 +893,34 @@ class BaseHandler(RequestHandler):
status=ServerSpawnStatus.failure status=ServerSpawnStatus.failure
).observe(time.perf_counter() - spawn_start_time) ).observe(time.perf_counter() - spawn_start_time)
raise web.HTTPError(500, "Spawner failed to start [status=%s]. The logs for %s may contain details." % ( raise web.HTTPError(
status, spawner._log_name)) 500,
"Spawner failed to start [status=%s]. The logs for %s may contain details."
% (status, spawner._log_name),
)
if spawner._waiting_for_response: if spawner._waiting_for_response:
# hit timeout waiting for response, but server's running. # hit timeout waiting for response, but server's running.
# Hope that it'll show up soon enough, # Hope that it'll show up soon enough,
# though it's possible that it started at the wrong URL # though it's possible that it started at the wrong URL
self.log.warning("User %s is slow to become responsive (timeout=%s)", self.log.warning(
user_server_name, self.slow_spawn_timeout) "User %s is slow to become responsive (timeout=%s)",
self.log.debug("Expecting server for %s at: %s", user_server_name,
user_server_name, spawner.server.url) self.slow_spawn_timeout,
)
self.log.debug(
"Expecting server for %s at: %s",
user_server_name,
spawner.server.url,
)
if spawner._proxy_pending: if spawner._proxy_pending:
# User.spawn finished, but it hasn't been added to the proxy # User.spawn finished, but it hasn't been added to the proxy
# Could be due to load or a slow proxy # Could be due to load or a slow proxy
self.log.warning("User %s is slow to be added to the proxy (timeout=%s)", self.log.warning(
user_server_name, self.slow_spawn_timeout) "User %s is slow to be added to the proxy (timeout=%s)",
user_server_name,
self.slow_spawn_timeout,
)
async def user_stopped(self, user, server_name): async def user_stopped(self, user, server_name):
"""Callback that fires when the spawner has stopped""" """Callback that fires when the spawner has stopped"""
@@ -888,12 +932,11 @@ class BaseHandler(RequestHandler):
status=ServerPollStatus.from_status(status) status=ServerPollStatus.from_status(status)
).observe(time.perf_counter() - poll_start_time) ).observe(time.perf_counter() - poll_start_time)
if status is None: if status is None:
status = 'unknown' status = 'unknown'
self.log.warning("User %s server stopped, with exit code: %s", self.log.warning(
user.name, status, "User %s server stopped, with exit code: %s", user.name, status
) )
await self.proxy.delete_user(user, server_name) await self.proxy.delete_user(user, server_name)
await user.stop(server_name) await user.stop(server_name)
@@ -920,7 +963,9 @@ class BaseHandler(RequestHandler):
await self.proxy.delete_user(user, server_name) await self.proxy.delete_user(user, server_name)
await user.stop(server_name) await user.stop(server_name)
toc = time.perf_counter() toc = time.perf_counter()
self.log.info("User %s server took %.3f seconds to stop", user.name, toc - tic) self.log.info(
"User %s server took %.3f seconds to stop", user.name, toc - tic
)
self.statsd.timing('spawner.stop', (toc - tic) * 1000) self.statsd.timing('spawner.stop', (toc - tic) * 1000)
RUNNING_SERVERS.dec() RUNNING_SERVERS.dec()
SERVER_STOP_DURATION_SECONDS.labels( SERVER_STOP_DURATION_SECONDS.labels(
@@ -934,14 +979,15 @@ class BaseHandler(RequestHandler):
spawner._stop_future = None spawner._stop_future = None
spawner._stop_pending = False spawner._stop_pending = False
future = spawner._stop_future = asyncio.ensure_future(stop()) future = spawner._stop_future = asyncio.ensure_future(stop())
try: try:
await gen.with_timeout(timedelta(seconds=self.slow_stop_timeout), future) await gen.with_timeout(timedelta(seconds=self.slow_stop_timeout), future)
except gen.TimeoutError: except gen.TimeoutError:
# hit timeout, but stop is still pending # hit timeout, but stop is still pending
self.log.warning("User %s:%s server is slow to stop", user.name, server_name) self.log.warning(
"User %s:%s server is slow to stop", user.name, server_name
)
# return handle on the future for hooking up callbacks # return handle on the future for hooking up callbacks
return future return future
@@ -1051,6 +1097,7 @@ class BaseHandler(RequestHandler):
class Template404(BaseHandler): class Template404(BaseHandler):
"""Render our 404 template""" """Render our 404 template"""
async def prepare(self): async def prepare(self):
await super().prepare() await super().prepare()
raise web.HTTPError(404) raise web.HTTPError(404)
@@ -1061,6 +1108,7 @@ class PrefixRedirectHandler(BaseHandler):
Redirects /foo to /prefix/foo, etc. Redirects /foo to /prefix/foo, etc.
""" """
def get(self): def get(self):
uri = self.request.uri uri = self.request.uri
# Since self.base_url will end with trailing slash. # Since self.base_url will end with trailing slash.
@@ -1076,9 +1124,7 @@ class PrefixRedirectHandler(BaseHandler):
# default / -> /hub/ redirect # default / -> /hub/ redirect
# avoiding extra hop through /hub # avoiding extra hop through /hub
path = '/' path = '/'
self.redirect(url_path_join( self.redirect(url_path_join(self.hub.base_url, path), permanent=False)
self.hub.base_url, path,
), permanent=False)
class UserSpawnHandler(BaseHandler): class UserSpawnHandler(BaseHandler):
@@ -1113,8 +1159,11 @@ class UserSpawnHandler(BaseHandler):
if user is None: if user is None:
# no such user # no such user
raise web.HTTPError(404, "No such user %s" % user_name) raise web.HTTPError(404, "No such user %s" % user_name)
self.log.info("Admin %s requesting spawn on behalf of %s", self.log.info(
current_user.name, user.name) "Admin %s requesting spawn on behalf of %s",
current_user.name,
user.name,
)
admin_spawn = True admin_spawn = True
should_spawn = True should_spawn = True
else: else:
@@ -1122,7 +1171,7 @@ class UserSpawnHandler(BaseHandler):
admin_spawn = False admin_spawn = False
# For non-admins, we should spawn if the user matches # For non-admins, we should spawn if the user matches
# otherwise redirect users to their own server # otherwise redirect users to their own server
should_spawn = (current_user and current_user.name == user_name) should_spawn = current_user and current_user.name == user_name
if "api" in user_path.split("/") and user and not user.active: if "api" in user_path.split("/") and user and not user.active:
# API request for not-running server (e.g. notebook UI left open) # API request for not-running server (e.g. notebook UI left open)
@@ -1142,12 +1191,19 @@ class UserSpawnHandler(BaseHandler):
port = host_info.port port = host_info.port
if not port: if not port:
port = 443 if host_info.scheme == 'https' else 80 port = 443 if host_info.scheme == 'https' else 80
if port != Server.from_url(self.proxy.public_url).connect_port and port == self.hub.connect_port: if (
self.log.warning(""" port != Server.from_url(self.proxy.public_url).connect_port
and port == self.hub.connect_port
):
self.log.warning(
"""
Detected possible direct connection to Hub's private ip: %s, bypassing proxy. Detected possible direct connection to Hub's private ip: %s, bypassing proxy.
This will result in a redirect loop. This will result in a redirect loop.
Make sure to connect to the proxied public URL %s Make sure to connect to the proxied public URL %s
""", self.request.full_url(), self.proxy.public_url) """,
self.request.full_url(),
self.proxy.public_url,
)
# logged in as valid user, check for pending spawn # logged in as valid user, check for pending spawn
if self.allow_named_servers: if self.allow_named_servers:
@@ -1167,19 +1223,31 @@ class UserSpawnHandler(BaseHandler):
# Implicit spawn on /user/:name is not allowed if the user's last spawn failed. # Implicit spawn on /user/:name is not allowed if the user's last spawn failed.
# We should point the user to Home if the most recent spawn failed. # We should point the user to Home if the most recent spawn failed.
exc = spawner._spawn_future.exception() exc = spawner._spawn_future.exception()
self.log.error("Preventing implicit spawn for %s because last spawn failed: %s", self.log.error(
spawner._log_name, exc) "Preventing implicit spawn for %s because last spawn failed: %s",
spawner._log_name,
exc,
)
# raise a copy because each time an Exception object is re-raised, its traceback grows # raise a copy because each time an Exception object is re-raised, its traceback grows
raise copy.copy(exc).with_traceback(exc.__traceback__) raise copy.copy(exc).with_traceback(exc.__traceback__)
# check for pending spawn # check for pending spawn
if spawner.pending == 'spawn' and spawner._spawn_future: if spawner.pending == 'spawn' and spawner._spawn_future:
# wait on the pending spawn # wait on the pending spawn
self.log.debug("Waiting for %s pending %s", spawner._log_name, spawner.pending) self.log.debug(
"Waiting for %s pending %s", spawner._log_name, spawner.pending
)
try: try:
await gen.with_timeout(timedelta(seconds=self.slow_spawn_timeout), spawner._spawn_future) await gen.with_timeout(
timedelta(seconds=self.slow_spawn_timeout),
spawner._spawn_future,
)
except gen.TimeoutError: except gen.TimeoutError:
self.log.info("Pending spawn for %s didn't finish in %.1f seconds", spawner._log_name, self.slow_spawn_timeout) self.log.info(
"Pending spawn for %s didn't finish in %.1f seconds",
spawner._log_name,
self.slow_spawn_timeout,
)
pass pass
# we may have waited above, check pending again: # we may have waited above, check pending again:
@@ -1194,10 +1262,7 @@ class UserSpawnHandler(BaseHandler):
else: else:
page = "spawn_pending.html" page = "spawn_pending.html"
html = self.render_template( html = self.render_template(
page, page, user=user, spawner=spawner, progress_url=spawner._progress_url
user=user,
spawner=spawner,
progress_url=spawner._progress_url,
) )
self.finish(html) self.finish(html)
return return
@@ -1218,8 +1283,11 @@ class UserSpawnHandler(BaseHandler):
if current_user.name != user.name: if current_user.name != user.name:
# spawning on behalf of another user # spawning on behalf of another user
url_parts.append(user.name) url_parts.append(user.name)
self.redirect(url_concat(url_path_join(*url_parts), self.redirect(
{'next': self.request.uri})) url_concat(
url_path_join(*url_parts), {'next': self.request.uri}
)
)
return return
else: else:
await self.spawn_single_user(user, server_name) await self.spawn_single_user(user, server_name)
@@ -1230,9 +1298,7 @@ class UserSpawnHandler(BaseHandler):
# spawn has started, but not finished # spawn has started, but not finished
self.statsd.incr('redirects.user_spawn_pending', 1) self.statsd.incr('redirects.user_spawn_pending', 1)
html = self.render_template( html = self.render_template(
"spawn_pending.html", "spawn_pending.html", user=user, progress_url=spawner._progress_url
user=user,
progress_url=spawner._progress_url,
) )
self.finish(html) self.finish(html)
return return
@@ -1243,14 +1309,16 @@ class UserSpawnHandler(BaseHandler):
try: try:
redirects = int(self.get_argument('redirects', 0)) redirects = int(self.get_argument('redirects', 0))
except ValueError: except ValueError:
self.log.warning("Invalid redirects argument %r", self.get_argument('redirects')) self.log.warning(
"Invalid redirects argument %r", self.get_argument('redirects')
)
redirects = 0 redirects = 0
# check redirect limit to prevent browser-enforced limits. # check redirect limit to prevent browser-enforced limits.
# In case of version mismatch, raise on only two redirects. # In case of version mismatch, raise on only two redirects.
if redirects >= self.settings.get( if redirects >= self.settings.get('user_redirect_limit', 4) or (
'user_redirect_limit', 4 redirects >= 2 and spawner._jupyterhub_version != __version__
) or (redirects >= 2 and spawner._jupyterhub_version != __version__): ):
# We stop if we've been redirected too many times. # We stop if we've been redirected too many times.
msg = "Redirect loop detected." msg = "Redirect loop detected."
if spawner._jupyterhub_version != __version__: if spawner._jupyterhub_version != __version__:
@@ -1259,7 +1327,8 @@ class UserSpawnHandler(BaseHandler):
" Try installing jupyterhub=={hub} in the user environment" " Try installing jupyterhub=={hub} in the user environment"
" if you continue to have problems." " if you continue to have problems."
).format( ).format(
singleuser=spawner._jupyterhub_version or 'unknown (likely < 0.8)', singleuser=spawner._jupyterhub_version
or 'unknown (likely < 0.8)',
hub=__version__, hub=__version__,
) )
raise web.HTTPError(500, msg) raise web.HTTPError(500, msg)
@@ -1297,10 +1366,9 @@ class UserSpawnHandler(BaseHandler):
# not logged in, clear any cookies and reload # not logged in, clear any cookies and reload
self.statsd.incr('redirects.user_to_login', 1) self.statsd.incr('redirects.user_to_login', 1)
self.clear_login_cookie() self.clear_login_cookie()
self.redirect(url_concat( self.redirect(
self.settings['login_url'], url_concat(self.settings['login_url'], {'next': self.request.uri})
{'next': self.request.uri}, )
))
class UserRedirectHandler(BaseHandler): class UserRedirectHandler(BaseHandler):
@@ -1314,6 +1382,7 @@ class UserRedirectHandler(BaseHandler):
.. versionadded:: 0.7 .. versionadded:: 0.7
""" """
@web.authenticated @web.authenticated
def get(self, path): def get(self, path):
user = self.current_user user = self.current_user
@@ -1326,12 +1395,13 @@ class UserRedirectHandler(BaseHandler):
class CSPReportHandler(BaseHandler): class CSPReportHandler(BaseHandler):
'''Accepts a content security policy violation report''' '''Accepts a content security policy violation report'''
@web.authenticated @web.authenticated
def post(self): def post(self):
'''Log a content security policy violation report''' '''Log a content security policy violation report'''
self.log.warning( self.log.warning(
"Content security violation: %s", "Content security violation: %s",
self.request.body.decode('utf8', 'replace') self.request.body.decode('utf8', 'replace'),
) )
# Report it to statsd as well # Report it to statsd as well
self.statsd.incr('csp_report') self.statsd.incr('csp_report')
@@ -1339,11 +1409,13 @@ class CSPReportHandler(BaseHandler):
class AddSlashHandler(BaseHandler): class AddSlashHandler(BaseHandler):
"""Handler for adding trailing slash to URLs that need them""" """Handler for adding trailing slash to URLs that need them"""
def get(self, *args): def get(self, *args):
src = urlparse(self.request.uri) src = urlparse(self.request.uri)
dest = src._replace(path=src.path + '/') dest = src._replace(path=src.path + '/')
self.redirect(urlunparse(dest)) self.redirect(urlunparse(dest))
default_handlers = [ default_handlers = [
(r'', AddSlashHandler), # add trailing / to `/hub` (r'', AddSlashHandler), # add trailing / to `/hub`
(r'/user/(?P<user_name>[^/]+)(?P<user_path>/.*)?', UserSpawnHandler), (r'/user/(?P<user_name>[^/]+)(?P<user_path>/.*)?', UserSpawnHandler),

View File

@@ -1,16 +1,14 @@
"""HTTP Handlers for the hub server""" """HTTP Handlers for the hub server"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import asyncio import asyncio
from tornado import web
from tornado.escape import url_escape from tornado.escape import url_escape
from tornado.httputil import url_concat from tornado.httputil import url_concat
from tornado import web
from .base import BaseHandler
from ..utils import maybe_future from ..utils import maybe_future
from .base import BaseHandler
class LogoutHandler(BaseHandler): class LogoutHandler(BaseHandler):
@@ -52,7 +50,8 @@ class LoginHandler(BaseHandler):
"""Render the login page.""" """Render the login page."""
def _render(self, login_error=None, username=None): def _render(self, login_error=None, username=None):
return self.render_template('login.html', return self.render_template(
'login.html',
next=url_escape(self.get_argument('next', default='')), next=url_escape(self.get_argument('next', default='')),
username=username, username=username,
login_error=login_error, login_error=login_error,
@@ -87,7 +86,9 @@ class LoginHandler(BaseHandler):
self.redirect(self.get_next_url(user)) self.redirect(self.get_next_url(user))
else: else:
if self.get_argument('next', default=False): if self.get_argument('next', default=False):
auto_login_url = url_concat(auto_login_url, {'next': self.get_next_url()}) auto_login_url = url_concat(
auto_login_url, {'next': self.get_next_url()}
)
self.redirect(auto_login_url) self.redirect(auto_login_url)
return return
username = self.get_argument('username', default='') username = self.get_argument('username', default='')
@@ -109,8 +110,7 @@ class LoginHandler(BaseHandler):
self.redirect(self.get_next_url(user)) self.redirect(self.get_next_url(user))
else: else:
html = self._render( html = self._render(
login_error='Invalid username or password', login_error='Invalid username or password', username=data['username']
username=data['username'],
) )
self.finish(html) self.finish(html)
@@ -118,7 +118,4 @@ class LoginHandler(BaseHandler):
# /login renders the login page or the "Login with..." link, # /login renders the login page or the "Login with..." link,
# so it should always be registered. # so it should always be registered.
# /logout clears cookies. # /logout clears cookies.
default_handlers = [ default_handlers = [(r"/login", LoginHandler), (r"/logout", LogoutHandler)]
(r"/login", LoginHandler),
(r"/logout", LogoutHandler),
]

View File

@@ -1,18 +1,21 @@
from prometheus_client import REGISTRY, CONTENT_TYPE_LATEST, generate_latest from prometheus_client import CONTENT_TYPE_LATEST
from prometheus_client import generate_latest
from prometheus_client import REGISTRY
from tornado import gen from tornado import gen
from .base import BaseHandler
from ..utils import metrics_authentication from ..utils import metrics_authentication
from .base import BaseHandler
class MetricsHandler(BaseHandler): class MetricsHandler(BaseHandler):
""" """
Handler to serve Prometheus metrics Handler to serve Prometheus metrics
""" """
@metrics_authentication @metrics_authentication
async def get(self): async def get(self):
self.set_header('Content-Type', CONTENT_TYPE_LATEST) self.set_header('Content-Type', CONTENT_TYPE_LATEST)
self.write(generate_latest(REGISTRY)) self.write(generate_latest(REGISTRY))
default_handlers = [
(r'/metrics$', MetricsHandler) default_handlers = [(r'/metrics$', MetricsHandler)]
]

View File

@@ -1,18 +1,19 @@
"""Basic html-rendering handlers.""" """Basic html-rendering handlers."""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from http.client import responses from http.client import responses
from jinja2 import TemplateNotFound from jinja2 import TemplateNotFound
from tornado import web, gen from tornado import gen
from tornado import web
from tornado.httputil import url_concat from tornado.httputil import url_concat
from .. import orm from .. import orm
from ..utils import admin_only, url_path_join, maybe_future from ..utils import admin_only
from ..utils import maybe_future
from ..utils import url_path_join
from .base import BaseHandler from .base import BaseHandler
@@ -29,6 +30,7 @@ class RootHandler(BaseHandler):
Otherwise, renders login page. Otherwise, renders login page.
""" """
def get(self): def get(self):
user = self.current_user user = self.current_user
if self.default_url: if self.default_url:
@@ -53,8 +55,13 @@ class HomeHandler(BaseHandler):
# send the user to /spawn if they have no active servers, # send the user to /spawn if they have no active servers,
# to establish that this is an explicit spawn request rather # to establish that this is an explicit spawn request rather
# than an implicit one, which can be caused by any link to `/user/:name(/:server_name)` # than an implicit one, which can be caused by any link to `/user/:name(/:server_name)`
url = url_path_join(self.hub.base_url, 'user', user.name) if user.active else url_path_join(self.hub.base_url, 'spawn') url = (
html = self.render_template('home.html', url_path_join(self.hub.base_url, 'user', user.name)
if user.active
else url_path_join(self.hub.base_url, 'spawn')
)
html = self.render_template(
'home.html',
user=user, user=user,
url=url, url=url,
allow_named_servers=self.allow_named_servers, allow_named_servers=self.allow_named_servers,
@@ -74,13 +81,15 @@ class SpawnHandler(BaseHandler):
Only enabled when Spawner.options_form is defined. Only enabled when Spawner.options_form is defined.
""" """
def _render_form(self, for_user, spawner_options_form, message=''): def _render_form(self, for_user, spawner_options_form, message=''):
return self.render_template('spawn.html', return self.render_template(
'spawn.html',
for_user=for_user, for_user=for_user,
spawner_options_form=spawner_options_form, spawner_options_form=spawner_options_form,
error_message=message, error_message=message,
url=self.request.uri, url=self.request.uri,
spawner=for_user.spawner spawner=for_user.spawner,
) )
@web.authenticated @web.authenticated
@@ -92,7 +101,9 @@ class SpawnHandler(BaseHandler):
user = current_user = self.current_user user = current_user = self.current_user
if for_user is not None and for_user != user.name: if for_user is not None and for_user != user.name:
if not user.admin: if not user.admin:
raise web.HTTPError(403, "Only admins can spawn on behalf of other users") raise web.HTTPError(
403, "Only admins can spawn on behalf of other users"
)
user = self.find_user(for_user) user = self.find_user(for_user)
if user is None: if user is None:
@@ -108,7 +119,9 @@ class SpawnHandler(BaseHandler):
if spawner_options_form: if spawner_options_form:
# Add handler to spawner here so you can access query params in form rendering. # Add handler to spawner here so you can access query params in form rendering.
user.spawner.handler = self user.spawner.handler = self
form = self._render_form(for_user=user, spawner_options_form=spawner_options_form) form = self._render_form(
for_user=user, spawner_options_form=spawner_options_form
)
self.finish(form) self.finish(form)
else: else:
# Explicit spawn request: clear _spawn_future # Explicit spawn request: clear _spawn_future
@@ -129,7 +142,9 @@ class SpawnHandler(BaseHandler):
user = current_user = self.current_user user = current_user = self.current_user
if for_user is not None and for_user != user.name: if for_user is not None and for_user != user.name:
if not user.admin: if not user.admin:
raise web.HTTPError(403, "Only admins can spawn on behalf of other users") raise web.HTTPError(
403, "Only admins can spawn on behalf of other users"
)
user = self.find_user(for_user) user = self.find_user(for_user)
if user is None: if user is None:
raise web.HTTPError(404, "No such user: %s" % for_user) raise web.HTTPError(404, "No such user: %s" % for_user)
@@ -151,9 +166,13 @@ class SpawnHandler(BaseHandler):
options = await maybe_future(user.spawner.options_from_form(form_options)) options = await maybe_future(user.spawner.options_from_form(form_options))
await self.spawn_single_user(user, options=options) await self.spawn_single_user(user, options=options)
except Exception as e: except Exception as e:
self.log.error("Failed to spawn single-user server with form", exc_info=True) self.log.error(
"Failed to spawn single-user server with form", exc_info=True
)
spawner_options_form = await user.spawner.get_options_form() spawner_options_form = await user.spawner.get_options_form()
form = self._render_form(for_user=user, spawner_options_form=spawner_options_form, message=str(e)) form = self._render_form(
for_user=user, spawner_options_form=spawner_options_form, message=str(e)
)
self.finish(form) self.finish(form)
return return
if current_user is user: if current_user is user:
@@ -176,9 +195,7 @@ class AdminHandler(BaseHandler):
def get(self): def get(self):
available = {'name', 'admin', 'running', 'last_activity'} available = {'name', 'admin', 'running', 'last_activity'}
default_sort = ['admin', 'name'] default_sort = ['admin', 'name']
mapping = { mapping = {'running': orm.Spawner.server_id}
'running': orm.Spawner.server_id,
}
for name in available: for name in available:
if name not in mapping: if name not in mapping:
mapping[name] = getattr(orm.User, name) mapping[name] = getattr(orm.User, name)
@@ -219,11 +236,13 @@ class AdminHandler(BaseHandler):
users = self.db.query(orm.User).outerjoin(orm.Spawner).order_by(*ordered) users = self.db.query(orm.User).outerjoin(orm.Spawner).order_by(*ordered)
users = [self._user_from_orm(u) for u in users] users = [self._user_from_orm(u) for u in users]
from itertools import chain from itertools import chain
running = [] running = []
for u in users: for u in users:
running.extend(s for s in u.spawners.values() if s.active) running.extend(s for s in u.spawners.values() if s.active)
html = self.render_template('admin.html', html = self.render_template(
'admin.html',
current_user=self.current_user, current_user=self.current_user,
admin_access=self.settings.get('admin_access', False), admin_access=self.settings.get('admin_access', False),
users=users, users=users,
@@ -243,11 +262,9 @@ class TokenPageHandler(BaseHandler):
never = datetime(1900, 1, 1) never = datetime(1900, 1, 1)
user = self.current_user user = self.current_user
def sort_key(token): def sort_key(token):
return ( return (token.last_activity or never, token.created or never)
token.last_activity or never,
token.created or never,
)
now = datetime.utcnow() now = datetime.utcnow()
api_tokens = [] api_tokens = []
@@ -285,13 +302,13 @@ class TokenPageHandler(BaseHandler):
for token in tokens[1:]: for token in tokens[1:]:
if token.created < created: if token.created < created:
created = token.created created = token.created
if ( if last_activity is None or (
last_activity is None or token.last_activity and token.last_activity > last_activity
(token.last_activity and token.last_activity > last_activity)
): ):
last_activity = token.last_activity last_activity = token.last_activity
token = tokens[0] token = tokens[0]
oauth_clients.append({ oauth_clients.append(
{
'client': token.client, 'client': token.client,
'description': token.client.description or token.client.identifier, 'description': token.client.description or token.client.identifier,
'created': created, 'created': created,
@@ -301,21 +318,17 @@ class TokenPageHandler(BaseHandler):
# revoking one oauth token revokes all oauth tokens for that client # revoking one oauth token revokes all oauth tokens for that client
'token_id': tokens[0].api_id, 'token_id': tokens[0].api_id,
'token_count': len(tokens), 'token_count': len(tokens),
}) }
)
# sort oauth clients by last activity, created # sort oauth clients by last activity, created
def sort_key(client): def sort_key(client):
return ( return (client['last_activity'] or never, client['created'] or never)
client['last_activity'] or never,
client['created'] or never,
)
oauth_clients = sorted(oauth_clients, key=sort_key, reverse=True) oauth_clients = sorted(oauth_clients, key=sort_key, reverse=True)
html = self.render_template( html = self.render_template(
'token.html', 'token.html', api_tokens=api_tokens, oauth_clients=oauth_clients
api_tokens=api_tokens,
oauth_clients=oauth_clients,
) )
self.finish(html) self.finish(html)
@@ -331,10 +344,12 @@ class ProxyErrorHandler(BaseHandler):
hub_home = url_path_join(self.hub.base_url, 'home') hub_home = url_path_join(self.hub.base_url, 'home')
message_html = '' message_html = ''
if status_code == 503: if status_code == 503:
message_html = ' '.join([ message_html = ' '.join(
[
"Your server appears to be down.", "Your server appears to be down.",
"Try restarting it <a href='%s'>from the hub</a>" % hub_home "Try restarting it <a href='%s'>from the hub</a>" % hub_home,
]) ]
)
ns = dict( ns = dict(
status_code=status_code, status_code=status_code,
status_message=status_message, status_message=status_message,
@@ -355,6 +370,7 @@ class ProxyErrorHandler(BaseHandler):
class HealthCheckHandler(BaseHandler): class HealthCheckHandler(BaseHandler):
"""Answer to health check""" """Answer to health check"""
def get(self, *args): def get(self, *args):
self.finish() self.finish()

View File

@@ -1,14 +1,16 @@
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import os import os
from tornado.web import StaticFileHandler from tornado.web import StaticFileHandler
class CacheControlStaticFilesHandler(StaticFileHandler): class CacheControlStaticFilesHandler(StaticFileHandler):
"""StaticFileHandler subclass that sets Cache-Control: no-cache without `?v=` """StaticFileHandler subclass that sets Cache-Control: no-cache without `?v=`
rather than relying on default browser cache behavior. rather than relying on default browser cache behavior.
""" """
def compute_etag(self): def compute_etag(self):
return None return None
@@ -16,8 +18,10 @@ class CacheControlStaticFilesHandler(StaticFileHandler):
if "v" not in self.request.arguments: if "v" not in self.request.arguments:
self.add_header("Cache-Control", "no-cache") self.add_header("Cache-Control", "no-cache")
class LogoHandler(StaticFileHandler): class LogoHandler(StaticFileHandler):
"""A singular handler for serving the logo.""" """A singular handler for serving the logo."""
def get(self): def get(self):
return super().get('') return super().get('')
@@ -25,4 +29,3 @@ class LogoHandler(StaticFileHandler):
def get_absolute_path(cls, root, path): def get_absolute_path(cls, root, path):
"""We only serve one file, ignore relative path""" """We only serve one file, ignore relative path"""
return os.path.abspath(root) return os.path.abspath(root)

View File

@@ -1,13 +1,15 @@
"""logging utilities""" """logging utilities"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import json import json
import traceback import traceback
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse
from urllib.parse import urlunparse
from tornado.log import LogFormatter, access_log from tornado.log import access_log
from tornado.web import StaticFileHandler, HTTPError from tornado.log import LogFormatter
from tornado.web import HTTPError
from tornado.web import StaticFileHandler
from .metrics import prometheus_log_method from .metrics import prometheus_log_method
@@ -23,7 +25,11 @@ def coroutine_frames(all_frames):
continue continue
# start out conservative with filename + function matching # start out conservative with filename + function matching
# maybe just filename matching would be sufficient # maybe just filename matching would be sufficient
elif frame[0].endswith('tornado/gen.py') and frame[2] in {'run', 'wrapper', '__init__'}: elif frame[0].endswith('tornado/gen.py') and frame[2] in {
'run',
'wrapper',
'__init__',
}:
continue continue
elif frame[0].endswith('tornado/concurrent.py') and frame[2] == 'result': elif frame[0].endswith('tornado/concurrent.py') and frame[2] == 'result':
continue continue
@@ -51,9 +57,11 @@ def coroutine_traceback(typ, value, tb):
class CoroutineLogFormatter(LogFormatter): class CoroutineLogFormatter(LogFormatter):
"""Log formatter that scrubs coroutine frames""" """Log formatter that scrubs coroutine frames"""
def formatException(self, exc_info): def formatException(self, exc_info):
return ''.join(coroutine_traceback(*exc_info)) return ''.join(coroutine_traceback(*exc_info))
# url params to be scrubbed if seen # url params to be scrubbed if seen
# any url param that *contains* one of these # any url param that *contains* one of these
# will be scrubbed from logs # will be scrubbed from logs
@@ -96,6 +104,7 @@ def _scrub_headers(headers):
# log_request adapted from IPython (BSD) # log_request adapted from IPython (BSD)
def log_request(handler): def log_request(handler):
"""log a bit more information about each request than tornado's default """log a bit more information about each request than tornado's default

View File

@@ -17,13 +17,13 @@ them manually here.
""" """
from enum import Enum from enum import Enum
from prometheus_client import Histogram
from prometheus_client import Gauge from prometheus_client import Gauge
from prometheus_client import Histogram
REQUEST_DURATION_SECONDS = Histogram( REQUEST_DURATION_SECONDS = Histogram(
'request_duration_seconds', 'request_duration_seconds',
'request duration for all HTTP requests', 'request duration for all HTTP requests',
['method', 'handler', 'code'] ['method', 'handler', 'code'],
) )
SERVER_SPAWN_DURATION_SECONDS = Histogram( SERVER_SPAWN_DURATION_SECONDS = Histogram(
@@ -32,32 +32,29 @@ SERVER_SPAWN_DURATION_SECONDS = Histogram(
['status'], ['status'],
# Use custom bucket sizes, since the default bucket ranges # Use custom bucket sizes, since the default bucket ranges
# are meant for quick running processes. Spawns can take a while! # are meant for quick running processes. Spawns can take a while!
buckets=[0.5, 1, 2.5, 5, 10, 15, 30, 60, 120, float("inf")] buckets=[0.5, 1, 2.5, 5, 10, 15, 30, 60, 120, float("inf")],
) )
RUNNING_SERVERS = Gauge( RUNNING_SERVERS = Gauge(
'running_servers', 'running_servers', 'the number of user servers currently running'
'the number of user servers currently running'
) )
RUNNING_SERVERS.set(0) RUNNING_SERVERS.set(0)
TOTAL_USERS = Gauge( TOTAL_USERS = Gauge('total_users', 'toal number of users')
'total_users',
'toal number of users'
)
TOTAL_USERS.set(0) TOTAL_USERS.set(0)
CHECK_ROUTES_DURATION_SECONDS = Histogram( CHECK_ROUTES_DURATION_SECONDS = Histogram(
'check_routes_duration_seconds', 'check_routes_duration_seconds', 'Time taken to validate all routes in proxy'
'Time taken to validate all routes in proxy'
) )
class ServerSpawnStatus(Enum): class ServerSpawnStatus(Enum):
""" """
Possible values for 'status' label of SERVER_SPAWN_DURATION_SECONDS Possible values for 'status' label of SERVER_SPAWN_DURATION_SECONDS
""" """
success = 'success' success = 'success'
failure = 'failure' failure = 'failure'
already_pending = 'already-pending' already_pending = 'already-pending'
@@ -67,27 +64,29 @@ class ServerSpawnStatus(Enum):
def __str__(self): def __str__(self):
return self.value return self.value
for s in ServerSpawnStatus: for s in ServerSpawnStatus:
# Create empty metrics with the given status # Create empty metrics with the given status
SERVER_SPAWN_DURATION_SECONDS.labels(status=s) SERVER_SPAWN_DURATION_SECONDS.labels(status=s)
PROXY_ADD_DURATION_SECONDS = Histogram( PROXY_ADD_DURATION_SECONDS = Histogram(
'proxy_add_duration_seconds', 'proxy_add_duration_seconds', 'duration for adding user routes to proxy', ['status']
'duration for adding user routes to proxy',
['status']
) )
class ProxyAddStatus(Enum): class ProxyAddStatus(Enum):
""" """
Possible values for 'status' label of PROXY_ADD_DURATION_SECONDS Possible values for 'status' label of PROXY_ADD_DURATION_SECONDS
""" """
success = 'success' success = 'success'
failure = 'failure' failure = 'failure'
def __str__(self): def __str__(self):
return self.value return self.value
for s in ProxyAddStatus: for s in ProxyAddStatus:
PROXY_ADD_DURATION_SECONDS.labels(status=s) PROXY_ADD_DURATION_SECONDS.labels(status=s)
@@ -95,13 +94,15 @@ for s in ProxyAddStatus:
SERVER_POLL_DURATION_SECONDS = Histogram( SERVER_POLL_DURATION_SECONDS = Histogram(
'server_poll_duration_seconds', 'server_poll_duration_seconds',
'time taken to poll if server is running', 'time taken to poll if server is running',
['status'] ['status'],
) )
class ServerPollStatus(Enum): class ServerPollStatus(Enum):
""" """
Possible values for 'status' label of SERVER_POLL_DURATION_SECONDS Possible values for 'status' label of SERVER_POLL_DURATION_SECONDS
""" """
running = 'running' running = 'running'
stopped = 'stopped' stopped = 'stopped'
@@ -112,27 +113,28 @@ class ServerPollStatus(Enum):
return cls.running return cls.running
return cls.stopped return cls.stopped
for s in ServerPollStatus: for s in ServerPollStatus:
SERVER_POLL_DURATION_SECONDS.labels(status=s) SERVER_POLL_DURATION_SECONDS.labels(status=s)
SERVER_STOP_DURATION_SECONDS = Histogram( SERVER_STOP_DURATION_SECONDS = Histogram(
'server_stop_seconds', 'server_stop_seconds', 'time taken for server stopping operation', ['status']
'time taken for server stopping operation',
['status'],
) )
class ServerStopStatus(Enum): class ServerStopStatus(Enum):
""" """
Possible values for 'status' label of SERVER_STOP_DURATION_SECONDS Possible values for 'status' label of SERVER_STOP_DURATION_SECONDS
""" """
success = 'success' success = 'success'
failure = 'failure' failure = 'failure'
def __str__(self): def __str__(self):
return self.value return self.value
for s in ServerStopStatus: for s in ServerStopStatus:
SERVER_STOP_DURATION_SECONDS.labels(status=s) SERVER_STOP_DURATION_SECONDS.labels(status=s)
@@ -156,5 +158,5 @@ def prometheus_log_method(handler):
REQUEST_DURATION_SECONDS.labels( REQUEST_DURATION_SECONDS.labels(
method=handler.request.method, method=handler.request.method,
handler='{}.{}'.format(handler.__class__.__module__, type(handler).__name__), handler='{}.{}'.format(handler.__class__.__module__, type(handler).__name__),
code=handler.get_status() code=handler.get_status(),
).observe(handler.request.request_time()) ).observe(handler.request.request_time())

View File

@@ -2,30 +2,29 @@
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
""" """
from datetime import datetime from datetime import datetime
from urllib.parse import urlparse from urllib.parse import urlparse
from oauthlib.oauth2 import RequestValidator, WebApplicationServer from oauthlib.oauth2 import RequestValidator
from oauthlib.oauth2 import WebApplicationServer
from oauthlib.oauth2.rfc6749.grant_types import authorization_code
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import scoped_session
from tornado import web
from tornado.escape import url_escape from tornado.escape import url_escape
from tornado.log import app_log from tornado.log import app_log
from tornado import web
from .. import orm from .. import orm
from ..utils import url_path_join, hash_token, compare_token from ..utils import compare_token
from ..utils import hash_token
from ..utils import url_path_join
# patch absolute-uri check # patch absolute-uri check
# because we want to allow relative uri oauth # because we want to allow relative uri oauth
# for internal services # for internal services
from oauthlib.oauth2.rfc6749.grant_types import authorization_code
authorization_code.is_absolute_uri = lambda uri: True authorization_code.is_absolute_uri = lambda uri: True
class JupyterHubRequestValidator(RequestValidator): class JupyterHubRequestValidator(RequestValidator):
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
super().__init__() super().__init__()
@@ -51,10 +50,7 @@ class JupyterHubRequestValidator(RequestValidator):
client_id = request.client_id client_id = request.client_id
client_secret = request.client_secret client_secret = request.client_secret
oauth_client = ( oauth_client = (
self.db self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
) )
if oauth_client is None: if oauth_client is None:
return False return False
@@ -78,10 +74,7 @@ class JupyterHubRequestValidator(RequestValidator):
- Authorization Code Grant - Authorization Code Grant
""" """
orm_client = ( orm_client = (
self.db self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
) )
if orm_client is None: if orm_client is None:
app_log.warning("No such oauth client %s", client_id) app_log.warning("No such oauth client %s", client_id)
@@ -89,8 +82,9 @@ class JupyterHubRequestValidator(RequestValidator):
request.client = orm_client request.client = orm_client
return True return True
def confirm_redirect_uri(self, client_id, code, redirect_uri, client, def confirm_redirect_uri(
*args, **kwargs): self, client_id, code, redirect_uri, client, *args, **kwargs
):
"""Ensure that the authorization process represented by this authorization """Ensure that the authorization process represented by this authorization
code began with this 'redirect_uri'. code began with this 'redirect_uri'.
If the client specifies a redirect_uri when obtaining code then that If the client specifies a redirect_uri when obtaining code then that
@@ -108,8 +102,10 @@ class JupyterHubRequestValidator(RequestValidator):
""" """
# TODO: record redirect_uri used during oauth # TODO: record redirect_uri used during oauth
# if we ever support multiple destinations # if we ever support multiple destinations
app_log.debug("confirm_redirect_uri: client_id=%s, redirect_uri=%s", app_log.debug(
client_id, redirect_uri, "confirm_redirect_uri: client_id=%s, redirect_uri=%s",
client_id,
redirect_uri,
) )
if redirect_uri == client.redirect_uri: if redirect_uri == client.redirect_uri:
return True return True
@@ -127,10 +123,7 @@ class JupyterHubRequestValidator(RequestValidator):
- Implicit Grant - Implicit Grant
""" """
orm_client = ( orm_client = (
self.db self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
) )
if orm_client is None: if orm_client is None:
raise KeyError(client_id) raise KeyError(client_id)
@@ -159,7 +152,9 @@ class JupyterHubRequestValidator(RequestValidator):
""" """
raise NotImplementedError() raise NotImplementedError()
def is_within_original_scope(self, request_scopes, refresh_token, request, *args, **kwargs): def is_within_original_scope(
self, request_scopes, refresh_token, request, *args, **kwargs
):
"""Check if requested scopes are within a scope of the refresh token. """Check if requested scopes are within a scope of the refresh token.
When access tokens are refreshed the scope of the new token When access tokens are refreshed the scope of the new token
needs to be within the scope of the original token. This is needs to be within the scope of the original token. This is
@@ -227,12 +222,15 @@ class JupyterHubRequestValidator(RequestValidator):
- Authorization Code Grant - Authorization Code Grant
""" """
log_code = code.get('code', 'undefined')[:3] + '...' log_code = code.get('code', 'undefined')[:3] + '...'
app_log.debug("Saving authorization code %s, %s, %s, %s", client_id, log_code, args, kwargs) app_log.debug(
"Saving authorization code %s, %s, %s, %s",
client_id,
log_code,
args,
kwargs,
)
orm_client = ( orm_client = (
self.db self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
) )
if orm_client is None: if orm_client is None:
raise ValueError("No such client: %s" % client_id) raise ValueError("No such client: %s" % client_id)
@@ -330,7 +328,11 @@ class JupyterHubRequestValidator(RequestValidator):
app_log.debug("Saving bearer token %s", log_token) app_log.debug("Saving bearer token %s", log_token)
if request.user is None: if request.user is None:
raise ValueError("No user for access token: %s" % request.user) raise ValueError("No user for access token: %s" % request.user)
client = self.db.query(orm.OAuthClient).filter_by(identifier=request.client.client_id).first() client = (
self.db.query(orm.OAuthClient)
.filter_by(identifier=request.client.client_id)
.first()
)
orm_access_token = orm.OAuthAccessToken( orm_access_token = orm.OAuthAccessToken(
client=client, client=client,
grant_type=orm.GrantType.authorization_code, grant_type=orm.GrantType.authorization_code,
@@ -400,10 +402,7 @@ class JupyterHubRequestValidator(RequestValidator):
""" """
app_log.debug("Validating client id %s", client_id) app_log.debug("Validating client id %s", client_id)
orm_client = ( orm_client = (
self.db self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
) )
if orm_client is None: if orm_client is None:
return False return False
@@ -431,19 +430,13 @@ class JupyterHubRequestValidator(RequestValidator):
Method is used by: Method is used by:
- Authorization Code Grant - Authorization Code Grant
""" """
orm_code = ( orm_code = self.db.query(orm.OAuthCode).filter_by(code=code).first()
self.db
.query(orm.OAuthCode)
.filter_by(code=code)
.first()
)
if orm_code is None: if orm_code is None:
app_log.debug("No such code: %s", code) app_log.debug("No such code: %s", code)
return False return False
if orm_code.client_id != client_id: if orm_code.client_id != client_id:
app_log.debug( app_log.debug(
"OAuth code client id mismatch: %s != %s", "OAuth code client id mismatch: %s != %s", client_id, orm_code.client_id
client_id, orm_code.client_id,
) )
return False return False
request.user = orm_code.user request.user = orm_code.user
@@ -453,7 +446,9 @@ class JupyterHubRequestValidator(RequestValidator):
request.scopes = ['identify'] request.scopes = ['identify']
return True return True
def validate_grant_type(self, client_id, grant_type, client, request, *args, **kwargs): def validate_grant_type(
self, client_id, grant_type, client, request, *args, **kwargs
):
"""Ensure client is authorized to use the grant_type requested. """Ensure client is authorized to use the grant_type requested.
:param client_id: Unicode client identifier :param client_id: Unicode client identifier
:param grant_type: Unicode grant type, i.e. authorization_code, password. :param grant_type: Unicode grant type, i.e. authorization_code, password.
@@ -480,14 +475,13 @@ class JupyterHubRequestValidator(RequestValidator):
- Authorization Code Grant - Authorization Code Grant
- Implicit Grant - Implicit Grant
""" """
app_log.debug("validate_redirect_uri: client_id=%s, redirect_uri=%s", app_log.debug(
client_id, redirect_uri, "validate_redirect_uri: client_id=%s, redirect_uri=%s",
client_id,
redirect_uri,
) )
orm_client = ( orm_client = (
self.db self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
) )
if orm_client is None: if orm_client is None:
app_log.warning("No such oauth client %s", client_id) app_log.warning("No such oauth client %s", client_id)
@@ -495,7 +489,9 @@ class JupyterHubRequestValidator(RequestValidator):
if redirect_uri == orm_client.redirect_uri: if redirect_uri == orm_client.redirect_uri:
return True return True
else: else:
app_log.warning("Redirect uri %s != %s", redirect_uri, orm_client.redirect_uri) app_log.warning(
"Redirect uri %s != %s", redirect_uri, orm_client.redirect_uri
)
return False return False
def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs): def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs):
@@ -514,7 +510,9 @@ class JupyterHubRequestValidator(RequestValidator):
return False return False
raise NotImplementedError('Subclasses must implement this method.') raise NotImplementedError('Subclasses must implement this method.')
def validate_response_type(self, client_id, response_type, client, request, *args, **kwargs): def validate_response_type(
self, client_id, response_type, client, request, *args, **kwargs
):
"""Ensure client is authorized to use the response_type requested. """Ensure client is authorized to use the response_type requested.
:param client_id: Unicode client identifier :param client_id: Unicode client identifier
:param response_type: Unicode response type, i.e. code, token. :param response_type: Unicode response type, i.e. code, token.
@@ -555,10 +553,8 @@ class JupyterHubOAuthServer(WebApplicationServer):
hash its client_secret before putting it in the database. hash its client_secret before putting it in the database.
""" """
# clear existing clients with same ID # clear existing clients with same ID
for orm_client in ( for orm_client in self.db.query(orm.OAuthClient).filter_by(
self.db identifier=client_id
.query(orm.OAuthClient)\
.filter_by(identifier=client_id)
): ):
self.db.delete(orm_client) self.db.delete(orm_client)
self.db.commit() self.db.commit()
@@ -574,12 +570,7 @@ class JupyterHubOAuthServer(WebApplicationServer):
def fetch_by_client_id(self, client_id): def fetch_by_client_id(self, client_id):
"""Find a client by its id""" """Find a client by its id"""
return ( return self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
self.db
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
)
def make_provider(session_factory, url_prefix, login_url): def make_provider(session_factory, url_prefix, login_url):
@@ -588,4 +579,3 @@ def make_provider(session_factory, url_prefix, login_url):
validator = JupyterHubRequestValidator(db) validator = JupyterHubRequestValidator(db)
server = JupyterHubOAuthServer(db, validator) server = JupyterHubOAuthServer(db, validator)
return server return server

View File

@@ -1,22 +1,28 @@
"""Some general objects for use in JupyterHub""" """Some general objects for use in JupyterHub"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import socket import socket
from urllib.parse import urlparse, urlunparse
import warnings import warnings
from urllib.parse import urlparse
from urllib.parse import urlunparse
from traitlets import default
from traitlets import HasTraits
from traitlets import Instance
from traitlets import Integer
from traitlets import observe
from traitlets import Unicode
from traitlets import validate
from traitlets import (
HasTraits, Instance, Integer, Unicode,
default, observe, validate,
)
from .traitlets import URLPrefix
from . import orm from . import orm
from .utils import ( from .traitlets import URLPrefix
url_path_join, can_connect, wait_for_server, from .utils import can_connect
wait_for_http_server, random_port, make_ssl_context, from .utils import make_ssl_context
) from .utils import random_port
from .utils import url_path_join
from .utils import wait_for_http_server
from .utils import wait_for_server
class Server(HasTraits): class Server(HasTraits):
"""An object representing an HTTP endpoint. """An object representing an HTTP endpoint.
@@ -24,6 +30,7 @@ class Server(HasTraits):
*Some* of these reside in the database (user servers), *Some* of these reside in the database (user servers),
but others (Hub, proxy) are in-memory only. but others (Hub, proxy) are in-memory only.
""" """
orm_server = Instance(orm.Server, allow_none=True) orm_server = Instance(orm.Server, allow_none=True)
ip = Unicode() ip = Unicode()
@@ -141,36 +148,31 @@ class Server(HasTraits):
def host(self): def host(self):
if self.connect_url: if self.connect_url:
parsed = urlparse(self.connect_url) parsed = urlparse(self.connect_url)
return "{proto}://{host}".format( return "{proto}://{host}".format(proto=parsed.scheme, host=parsed.netloc)
proto=parsed.scheme,
host=parsed.netloc,
)
return "{proto}://{ip}:{port}".format( return "{proto}://{ip}:{port}".format(
proto=self.proto, proto=self.proto, ip=self._connect_ip, port=self._connect_port
ip=self._connect_ip,
port=self._connect_port,
) )
@property @property
def url(self): def url(self):
if self.connect_url: if self.connect_url:
return self.connect_url return self.connect_url
return "{host}{uri}".format( return "{host}{uri}".format(host=self.host, uri=self.base_url)
host=self.host,
uri=self.base_url,
)
def wait_up(self, timeout=10, http=False, ssl_context=None): def wait_up(self, timeout=10, http=False, ssl_context=None):
"""Wait for this server to come up""" """Wait for this server to come up"""
if http: if http:
ssl_context = ssl_context or make_ssl_context( ssl_context = ssl_context or make_ssl_context(
self.keyfile, self.certfile, cafile=self.cafile) self.keyfile, self.certfile, cafile=self.cafile
)
return wait_for_http_server( return wait_for_http_server(
self.url, timeout=timeout, ssl_context=ssl_context) self.url, timeout=timeout, ssl_context=ssl_context
)
else: else:
return wait_for_server(self._connect_ip, self._connect_port, timeout=timeout) return wait_for_server(
self._connect_ip, self._connect_port, timeout=timeout
)
def is_up(self): def is_up(self):
"""Is the server accepting connections?""" """Is the server accepting connections?"""
@@ -190,11 +192,13 @@ class Hub(Server):
@property @property
def server(self): def server(self):
warnings.warn("Hub.server is deprecated in JupyterHub 0.8. Access attributes on the Hub directly.", warnings.warn(
"Hub.server is deprecated in JupyterHub 0.8. Access attributes on the Hub directly.",
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
return self return self
public_host = Unicode() public_host = Unicode()
routespec = Unicode() routespec = Unicode()
@@ -205,5 +209,7 @@ class Hub(Server):
def __repr__(self): def __repr__(self):
return "<%s %s:%s>" % ( return "<%s %s:%s>" % (
self.__class__.__name__, self.server.ip, self.server.port, self.__class__.__name__,
self.server.ip,
self.server.port,
) )

View File

@@ -1,36 +1,45 @@
"""sqlalchemy ORM tools for the state of the constellation of processes""" """sqlalchemy ORM tools for the state of the constellation of processes"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from datetime import datetime, timedelta
import enum import enum
import json import json
from datetime import datetime
from datetime import timedelta
import alembic.config
import alembic.command import alembic.command
import alembic.config
from alembic.script import ScriptDirectory from alembic.script import ScriptDirectory
from tornado.log import app_log from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy.types import TypeDecorator, Text, LargeBinary from sqlalchemy import create_engine
from sqlalchemy import ( from sqlalchemy import DateTime
create_engine, event, exc, inspect, or_, select, from sqlalchemy import Enum
Column, Integer, ForeignKey, Unicode, Boolean, from sqlalchemy import event
DateTime, Enum, Table, from sqlalchemy import exc
) from sqlalchemy import ForeignKey
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import Table
from sqlalchemy import Unicode
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import ( from sqlalchemy.orm import interfaces
Session, from sqlalchemy.orm import object_session
interfaces, object_session, relationship, sessionmaker, from sqlalchemy.orm import relationship
) from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from sqlalchemy.sql.expression import bindparam from sqlalchemy.sql.expression import bindparam
from sqlalchemy.types import LargeBinary
from sqlalchemy.types import Text
from sqlalchemy.types import TypeDecorator
from tornado.log import app_log
from .utils import ( from .utils import compare_token
random_port, from .utils import hash_token
new_token, hash_token, compare_token, from .utils import new_token
) from .utils import random_port
# top-level variable for easier mocking in tests # top-level variable for easier mocking in tests
utcnow = datetime.utcnow utcnow = datetime.utcnow
@@ -68,6 +77,7 @@ class Server(Base):
connection and cookie info connection and cookie info
""" """
__tablename__ = 'servers' __tablename__ = 'servers'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
@@ -82,7 +92,9 @@ class Server(Base):
# user:group many:many mapping table # user:group many:many mapping table
user_group_map = Table('user_group_map', Base.metadata, user_group_map = Table(
'user_group_map',
Base.metadata,
Column('user_id', ForeignKey('users.id', ondelete='CASCADE'), primary_key=True), Column('user_id', ForeignKey('users.id', ondelete='CASCADE'), primary_key=True),
Column('group_id', ForeignKey('groups.id', ondelete='CASCADE'), primary_key=True), Column('group_id', ForeignKey('groups.id', ondelete='CASCADE'), primary_key=True),
) )
@@ -90,6 +102,7 @@ user_group_map = Table('user_group_map', Base.metadata,
class Group(Base): class Group(Base):
"""User Groups""" """User Groups"""
__tablename__ = 'groups' __tablename__ = 'groups'
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(255), unique=True) name = Column(Unicode(255), unique=True)
@@ -97,7 +110,9 @@ class Group(Base):
def __repr__(self): def __repr__(self):
return "<%s %s (%i users)>" % ( return "<%s %s (%i users)>" % (
self.__class__.__name__, self.name, len(self.users) self.__class__.__name__,
self.name,
len(self.users),
) )
@classmethod @classmethod
@@ -130,15 +145,15 @@ class User(Base):
`servers` is a list that contains a reference for each of the user's single user notebook servers. `servers` is a list that contains a reference for each of the user's single user notebook servers.
The method `server` returns the first entry in the user's `servers` list. The method `server` returns the first entry in the user's `servers` list.
""" """
__tablename__ = 'users' __tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(255), unique=True) name = Column(Unicode(255), unique=True)
_orm_spawners = relationship( _orm_spawners = relationship(
"Spawner", "Spawner", backref="user", cascade="all, delete-orphan"
backref="user",
cascade="all, delete-orphan",
) )
@property @property
def orm_spawners(self): def orm_spawners(self):
return {s.name: s for s in self._orm_spawners} return {s.name: s for s in self._orm_spawners}
@@ -147,20 +162,12 @@ class User(Base):
created = Column(DateTime, default=datetime.utcnow) created = Column(DateTime, default=datetime.utcnow)
last_activity = Column(DateTime, nullable=True) last_activity = Column(DateTime, nullable=True)
api_tokens = relationship( api_tokens = relationship("APIToken", backref="user", cascade="all, delete-orphan")
"APIToken",
backref="user",
cascade="all, delete-orphan",
)
oauth_tokens = relationship( oauth_tokens = relationship(
"OAuthAccessToken", "OAuthAccessToken", backref="user", cascade="all, delete-orphan"
backref="user",
cascade="all, delete-orphan",
) )
oauth_codes = relationship( oauth_codes = relationship(
"OAuthCode", "OAuthCode", backref="user", cascade="all, delete-orphan"
backref="user",
cascade="all, delete-orphan",
) )
cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True) cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True)
# User.state is actually Spawner state # User.state is actually Spawner state
@@ -192,8 +199,10 @@ class User(Base):
""" """
return db.query(cls).filter(cls.name == name).first() return db.query(cls).filter(cls.name == name).first()
class Spawner(Base): class Spawner(Base):
""""State about a Spawner""" """"State about a Spawner"""
__tablename__ = 'spawners' __tablename__ = 'spawners'
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
@@ -214,10 +223,12 @@ class Spawner(Base):
# for which these should all be False # for which these should all be False
active = running = ready = False active = running = ready = False
pending = None pending = None
@property @property
def orm_spawner(self): def orm_spawner(self):
return self return self
class Service(Base): class Service(Base):
"""A service run with JupyterHub """A service run with JupyterHub
@@ -235,6 +246,7 @@ class Service(Base):
- pid: the process id (if managed) - pid: the process id (if managed)
""" """
__tablename__ = 'services' __tablename__ = 'services'
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
@@ -243,9 +255,7 @@ class Service(Base):
admin = Column(Boolean, default=False) admin = Column(Boolean, default=False)
api_tokens = relationship( api_tokens = relationship(
"APIToken", "APIToken", backref="service", cascade="all, delete-orphan"
backref="service",
cascade="all, delete-orphan",
) )
# service-specific interface # service-specific interface
@@ -270,6 +280,7 @@ class Service(Base):
class Hashed(object): class Hashed(object):
"""Mixin for tables with hashed tokens""" """Mixin for tables with hashed tokens"""
prefix_length = 4 prefix_length = 4
algorithm = "sha512" algorithm = "sha512"
rounds = 16384 rounds = 16384
@@ -299,7 +310,9 @@ class Hashed(object):
else: else:
rounds = self.rounds rounds = self.rounds
salt_bytes = self.salt_bytes salt_bytes = self.salt_bytes
self.hashed = hash_token(token, rounds=rounds, salt=salt_bytes, algorithm=self.algorithm) self.hashed = hash_token(
token, rounds=rounds, salt=salt_bytes, algorithm=self.algorithm
)
def match(self, token): def match(self, token):
"""Is this my token?""" """Is this my token?"""
@@ -309,8 +322,9 @@ class Hashed(object):
def check_token(cls, db, token): def check_token(cls, db, token):
"""Check if a token is acceptable""" """Check if a token is acceptable"""
if len(token) < cls.min_length: if len(token) < cls.min_length:
raise ValueError("Tokens must be at least %i characters, got %r" % ( raise ValueError(
cls.min_length, token) "Tokens must be at least %i characters, got %r"
% (cls.min_length, token)
) )
found = cls.find(db, token) found = cls.find(db, token)
if found: if found:
@@ -344,10 +358,13 @@ class Hashed(object):
class APIToken(Hashed, Base): class APIToken(Hashed, Base):
"""An API token""" """An API token"""
__tablename__ = 'api_tokens' __tablename__ = 'api_tokens'
user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True) user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
service_id = Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True) service_id = Column(
Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True
)
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
hashed = Column(Unicode(255), unique=True) hashed = Column(Unicode(255), unique=True)
@@ -375,10 +392,7 @@ class APIToken(Hashed, Base):
kind = 'owner' kind = 'owner'
name = 'unknown' name = 'unknown'
return "<{cls}('{pre}...', {kind}='{name}')>".format( return "<{cls}('{pre}...', {kind}='{name}')>".format(
cls=self.__class__.__name__, cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name
pre=self.prefix,
kind=kind,
name=name,
) )
@classmethod @classmethod
@@ -387,9 +401,7 @@ class APIToken(Hashed, Base):
now = utcnow() now = utcnow()
deleted = False deleted = False
for token in ( for token in (
db.query(cls) db.query(cls).filter(cls.expires_at != None).filter(cls.expires_at < now)
.filter(cls.expires_at != None)
.filter(cls.expires_at < now)
): ):
app_log.debug("Purging expired %s", token) app_log.debug("Purging expired %s", token)
deleted = True deleted = True
@@ -421,8 +433,15 @@ class APIToken(Hashed, Base):
return orm_token return orm_token
@classmethod @classmethod
def new(cls, token=None, user=None, service=None, note='', generated=True, def new(
expires_in=None): cls,
token=None,
user=None,
service=None,
note='',
generated=True,
expires_in=None,
):
"""Generate a new API token for a user or service""" """Generate a new API token for a user or service"""
assert user or service assert user or service
assert not (user and service) assert not (user and service)
@@ -473,7 +492,9 @@ class OAuthAccessToken(Hashed, Base):
def api_id(self): def api_id(self):
return 'o%i' % self.id return 'o%i' % self.id
client_id = Column(Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')) client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
)
grant_type = Column(Enum(GrantType), nullable=False) grant_type = Column(Enum(GrantType), nullable=False)
expires_at = Column(Integer) expires_at = Column(Integer)
refresh_token = Column(Unicode(255)) refresh_token = Column(Unicode(255))
@@ -517,7 +538,9 @@ class OAuthAccessToken(Hashed, Base):
class OAuthCode(Base): class OAuthCode(Base):
__tablename__ = 'oauth_codes' __tablename__ = 'oauth_codes'
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
client_id = Column(Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')) client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
)
code = Column(Unicode(36)) code = Column(Unicode(36))
expires_at = Column(Integer) expires_at = Column(Integer)
redirect_uri = Column(Unicode(1023)) redirect_uri = Column(Unicode(1023))
@@ -539,18 +562,14 @@ class OAuthClient(Base):
return self.identifier return self.identifier
access_tokens = relationship( access_tokens = relationship(
OAuthAccessToken, OAuthAccessToken, backref='client', cascade='all, delete-orphan'
backref='client',
cascade='all, delete-orphan',
)
codes = relationship(
OAuthCode,
backref='client',
cascade='all, delete-orphan',
) )
codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan')
# General database utilities # General database utilities
class DatabaseSchemaMismatch(Exception): class DatabaseSchemaMismatch(Exception):
"""Exception raised when the database schema version does not match """Exception raised when the database schema version does not match
@@ -560,6 +579,7 @@ class DatabaseSchemaMismatch(Exception):
def register_foreign_keys(engine): def register_foreign_keys(engine):
"""register PRAGMA foreign_keys=on on connection""" """register PRAGMA foreign_keys=on on connection"""
@event.listens_for(engine, "connect") @event.listens_for(engine, "connect")
def connect(dbapi_con, con_record): def connect(dbapi_con, con_record):
cursor = dbapi_con.cursor() cursor = dbapi_con.cursor()
@@ -609,6 +629,7 @@ def register_ping_connection(engine):
https://docs.sqlalchemy.org/en/rel_1_1/core/pooling.html#disconnect-handling-pessimistic https://docs.sqlalchemy.org/en/rel_1_1/core/pooling.html#disconnect-handling-pessimistic
""" """
@event.listens_for(engine, "engine_connect") @event.listens_for(engine, "engine_connect")
def ping_connection(connection, branch): def ping_connection(connection, branch):
if branch: if branch:
@@ -633,7 +654,9 @@ def register_ping_connection(engine):
# condition, which is based on inspection of the original exception # condition, which is based on inspection of the original exception
# by the dialect in use. # by the dialect in use.
if err.connection_invalidated: if err.connection_invalidated:
app_log.error("Database connection error, attempting to reconnect: %s", err) app_log.error(
"Database connection error, attempting to reconnect: %s", err
)
# run the same SELECT again - the connection will re-validate # run the same SELECT again - the connection will re-validate
# itself and establish a new connection. The disconnect detection # itself and establish a new connection. The disconnect detection
# here also causes the whole connection pool to be invalidated # here also causes the whole connection pool to be invalidated
@@ -697,29 +720,37 @@ def check_db_revision(engine):
# check database schema version # check database schema version
# it should always be defined at this point # it should always be defined at this point
alembic_revision = engine.execute('SELECT version_num FROM alembic_version').first()[0] alembic_revision = engine.execute(
'SELECT version_num FROM alembic_version'
).first()[0]
if alembic_revision == head: if alembic_revision == head:
app_log.debug("database schema version found: %s", alembic_revision) app_log.debug("database schema version found: %s", alembic_revision)
pass pass
else: else:
raise DatabaseSchemaMismatch("Found database schema version {found} != {head}. " raise DatabaseSchemaMismatch(
"Found database schema version {found} != {head}. "
"Backup your database and run `jupyterhub upgrade-db`" "Backup your database and run `jupyterhub upgrade-db`"
" to upgrade to the latest schema.".format( " to upgrade to the latest schema.".format(
found=alembic_revision, found=alembic_revision, head=head
head=head, )
)) )
def mysql_large_prefix_check(engine): def mysql_large_prefix_check(engine):
"""Check mysql has innodb_large_prefix set""" """Check mysql has innodb_large_prefix set"""
if not str(engine.url).startswith('mysql'): if not str(engine.url).startswith('mysql'):
return False return False
variables = dict(engine.execute( variables = dict(
engine.execute(
'show variables where variable_name like ' 'show variables where variable_name like '
'"innodb_large_prefix" or ' '"innodb_large_prefix" or '
'variable_name like "innodb_file_format";').fetchall()) 'variable_name like "innodb_file_format";'
if (variables['innodb_file_format'] == 'Barracuda' and ).fetchall()
variables['innodb_large_prefix'] == 'ON'): )
if (
variables['innodb_file_format'] == 'Barracuda'
and variables['innodb_large_prefix'] == 'ON'
):
return True return True
else: else:
return False return False
@@ -730,10 +761,9 @@ def add_row_format(base):
t.dialect_kwargs['mysql_ROW_FORMAT'] = 'DYNAMIC' t.dialect_kwargs['mysql_ROW_FORMAT'] = 'DYNAMIC'
def new_session_factory(url="sqlite:///:memory:", def new_session_factory(
reset=False, url="sqlite:///:memory:", reset=False, expire_on_commit=False, **kwargs
expire_on_commit=False, ):
**kwargs):
"""Create a new session at url""" """Create a new session at url"""
if url.startswith('sqlite'): if url.startswith('sqlite'):
kwargs.setdefault('connect_args', {'check_same_thread': False}) kwargs.setdefault('connect_args', {'check_same_thread': False})
@@ -767,7 +797,5 @@ def new_session_factory(url="sqlite:///:memory:",
# SQLAlchemy to expire objects after committing - we don't expect # SQLAlchemy to expire objects after committing - we don't expect
# concurrent runs of the hub talking to the same db. Turning # concurrent runs of the hub talking to the same db. Turning
# this off gives us a major performance boost # this off gives us a major performance boost
session_factory = sessionmaker(bind=engine, session_factory = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
expire_on_commit=expire_on_commit,
)
return session_factory return session_factory

View File

@@ -14,36 +14,38 @@ Route Specification:
'host.tld/path/' for host-based routing or '/path/' for default routing. 'host.tld/path/' for host-based routing or '/path/' for default routing.
- Route paths should be normalized to always start and end with '/' - Route paths should be normalized to always start and end with '/'
""" """
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import asyncio import asyncio
from functools import wraps
import json import json
import os import os
import signal import signal
from subprocess import Popen
import time import time
from urllib.parse import quote, urlparse from functools import wraps
from subprocess import Popen
from urllib.parse import quote
from urllib.parse import urlparse
from tornado import gen from tornado import gen
from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPError from tornado.httpclient import AsyncHTTPClient
from tornado.httpclient import HTTPError
from tornado.httpclient import HTTPRequest
from tornado.ioloop import PeriodicCallback from tornado.ioloop import PeriodicCallback
from traitlets import Any
from traitlets import Bool
from traitlets import ( from traitlets import default
Any, Bool, Instance, Integer, Unicode, from traitlets import Instance
default, observe, from traitlets import Integer
) from traitlets import observe
from jupyterhub.traitlets import Command from traitlets import Unicode
from traitlets.config import LoggingConfigurable from traitlets.config import LoggingConfigurable
from . import utils
from .metrics import CHECK_ROUTES_DURATION_SECONDS from .metrics import CHECK_ROUTES_DURATION_SECONDS
from .objects import Server from .objects import Server
from . import utils from .utils import make_ssl_context
from .utils import url_path_join, make_ssl_context from .utils import url_path_join
from jupyterhub.traitlets import Command
def _one_at_a_time(method): def _one_at_a_time(method):
@@ -53,6 +55,7 @@ def _one_at_a_time(method):
queue them instead of allowing them to be concurrently outstanding. queue them instead of allowing them to be concurrently outstanding.
""" """
method._lock = asyncio.Lock() method._lock = asyncio.Lock()
@wraps(method) @wraps(method)
async def locked_method(*args, **kwargs): async def locked_method(*args, **kwargs):
async with method._lock: async with method._lock:
@@ -86,6 +89,7 @@ class Proxy(LoggingConfigurable):
""" """
db_factory = Any() db_factory = Any()
@property @property
def db(self): def db(self):
return self.db_factory() return self.db_factory()
@@ -97,13 +101,16 @@ class Proxy(LoggingConfigurable):
ssl_cert = Unicode() ssl_cert = Unicode()
host_routing = Bool() host_routing = Bool()
should_start = Bool(True, config=True, should_start = Bool(
True,
config=True,
help="""Should the Hub start the proxy help="""Should the Hub start the proxy
If True, the Hub will start the proxy and stop it. If True, the Hub will start the proxy and stop it.
Set to False if the proxy is managed externally, Set to False if the proxy is managed externally,
such as by systemd, docker, or another service manager. such as by systemd, docker, or another service manager.
""") """,
)
def start(self): def start(self):
"""Start the proxy. """Start the proxy.
@@ -136,9 +143,13 @@ class Proxy(LoggingConfigurable):
# check host routing # check host routing
host_route = not routespec.startswith('/') host_route = not routespec.startswith('/')
if host_route and not self.host_routing: if host_route and not self.host_routing:
raise ValueError("Cannot add host-based route %r, not using host-routing" % routespec) raise ValueError(
"Cannot add host-based route %r, not using host-routing" % routespec
)
if self.host_routing and not host_route: if self.host_routing and not host_route:
raise ValueError("Cannot add route without host %r, using host-routing" % routespec) raise ValueError(
"Cannot add route without host %r, using host-routing" % routespec
)
# add trailing slash # add trailing slash
if not routespec.endswith('/'): if not routespec.endswith('/'):
return routespec + '/' return routespec + '/'
@@ -220,16 +231,19 @@ class Proxy(LoggingConfigurable):
"""Add a service's server to the proxy table.""" """Add a service's server to the proxy table."""
if not service.server: if not service.server:
raise RuntimeError( raise RuntimeError(
"Service %s does not have an http endpoint to add to the proxy.", service.name) "Service %s does not have an http endpoint to add to the proxy.",
service.name,
)
self.log.info("Adding service %s to proxy %s => %s", self.log.info(
service.name, service.proxy_spec, service.server.host, "Adding service %s to proxy %s => %s",
service.name,
service.proxy_spec,
service.server.host,
) )
await self.add_route( await self.add_route(
service.proxy_spec, service.proxy_spec, service.server.host, {'service': service.name}
service.server.host,
{'service': service.name}
) )
async def delete_service(self, service, client=None): async def delete_service(self, service, client=None):
@@ -240,22 +254,23 @@ class Proxy(LoggingConfigurable):
async def add_user(self, user, server_name='', client=None): async def add_user(self, user, server_name='', client=None):
"""Add a user's server to the proxy table.""" """Add a user's server to the proxy table."""
spawner = user.spawners[server_name] spawner = user.spawners[server_name]
self.log.info("Adding user %s to proxy %s => %s", self.log.info(
user.name, spawner.proxy_spec, spawner.server.host, "Adding user %s to proxy %s => %s",
user.name,
spawner.proxy_spec,
spawner.server.host,
) )
if spawner.pending and spawner.pending != 'spawn': if spawner.pending and spawner.pending != 'spawn':
raise RuntimeError( raise RuntimeError(
"%s is pending %s, shouldn't be added to the proxy yet!" % (spawner._log_name, spawner.pending) "%s is pending %s, shouldn't be added to the proxy yet!"
% (spawner._log_name, spawner.pending)
) )
await self.add_route( await self.add_route(
spawner.proxy_spec, spawner.proxy_spec,
spawner.server.host, spawner.server.host,
{ {'user': user.name, 'server_name': server_name},
'user': user.name,
'server_name': server_name,
}
) )
async def delete_user(self, user, server_name=''): async def delete_user(self, user, server_name=''):
@@ -314,7 +329,9 @@ class Proxy(LoggingConfigurable):
else: else:
route = routes[self.app.hub.routespec] route = routes[self.app.hub.routespec]
if route['target'] != hub.host: if route['target'] != hub.host:
self.log.warning("Updating default route %s%s", route['target'], hub.host) self.log.warning(
"Updating default route %s%s", route['target'], hub.host
)
futures.append(self.add_hub_route(hub)) futures.append(self.add_hub_route(hub))
for user in user_dict.values(): for user in user_dict.values():
@@ -324,14 +341,17 @@ class Proxy(LoggingConfigurable):
good_routes.add(spec) good_routes.add(spec)
if spec not in user_routes: if spec not in user_routes:
self.log.warning( self.log.warning(
"Adding missing route for %s (%s)", spec, spawner.server) "Adding missing route for %s (%s)", spec, spawner.server
)
futures.append(self.add_user(user, name)) futures.append(self.add_user(user, name))
else: else:
route = routes[spec] route = routes[spec]
if route['target'] != spawner.server.host: if route['target'] != spawner.server.host:
self.log.warning( self.log.warning(
"Updating route for %s (%s%s)", "Updating route for %s (%s%s)",
spec, route['target'], spawner.server, spec,
route['target'],
spawner.server,
) )
futures.append(self.add_user(user, name)) futures.append(self.add_user(user, name))
elif spawner.pending: elif spawner.pending:
@@ -341,22 +361,26 @@ class Proxy(LoggingConfigurable):
good_routes.add(spawner.proxy_spec) good_routes.add(spawner.proxy_spec)
# check service routes # check service routes
service_routes = {r['data']['service']: r service_routes = {
for r in routes.values() if 'service' in r['data']} r['data']['service']: r for r in routes.values() if 'service' in r['data']
}
for service in service_dict.values(): for service in service_dict.values():
if service.server is None: if service.server is None:
continue continue
good_routes.add(service.proxy_spec) good_routes.add(service.proxy_spec)
if service.name not in service_routes: if service.name not in service_routes:
self.log.warning("Adding missing route for %s (%s)", self.log.warning(
service.name, service.server) "Adding missing route for %s (%s)", service.name, service.server
)
futures.append(self.add_service(service)) futures.append(self.add_service(service))
else: else:
route = service_routes[service.name] route = service_routes[service.name]
if route['target'] != service.server.host: if route['target'] != service.server.host:
self.log.warning( self.log.warning(
"Updating route for %s (%s%s)", "Updating route for %s (%s%s)",
route['routespec'], route['target'], service.server.host, route['routespec'],
route['target'],
service.server.host,
) )
futures.append(self.add_service(service)) futures.append(self.add_service(service))
@@ -424,7 +448,7 @@ class ConfigurableHTTPProxy(Proxy):
help="""The Proxy auth token help="""The Proxy auth token
Loaded from the CONFIGPROXY_AUTH_TOKEN env variable by default. Loaded from the CONFIGPROXY_AUTH_TOKEN env variable by default.
""", """
).tag(config=True) ).tag(config=True)
check_running_interval = Integer(5, config=True) check_running_interval = Integer(5, config=True)
@@ -437,8 +461,8 @@ class ConfigurableHTTPProxy(Proxy):
token = utils.new_token() token = utils.new_token()
return token return token
api_url = Unicode(config=True, api_url = Unicode(
help="""The ip (or hostname) of the proxy's API endpoint""" config=True, help="""The ip (or hostname) of the proxy's API endpoint"""
) )
@default('api_url') @default('api_url')
@@ -448,13 +472,12 @@ class ConfigurableHTTPProxy(Proxy):
if self.app.internal_ssl: if self.app.internal_ssl:
proto = 'https' proto = 'https'
return "{proto}://{url}".format( return "{proto}://{url}".format(proto=proto, url=url)
proto=proto,
url=url,
)
command = Command('configurable-http-proxy', config=True, command = Command(
help="""The command to start the proxy""" 'configurable-http-proxy',
config=True,
help="""The command to start the proxy""",
) )
pid_file = Unicode( pid_file = Unicode(
@@ -463,11 +486,14 @@ class ConfigurableHTTPProxy(Proxy):
help="File in which to write the PID of the proxy process.", help="File in which to write the PID of the proxy process.",
) )
_check_running_callback = Any(help="PeriodicCallback to check if the proxy is running") _check_running_callback = Any(
help="PeriodicCallback to check if the proxy is running"
)
def _check_pid(self, pid): def _check_pid(self, pid):
if os.name == 'nt': if os.name == 'nt':
import psutil import psutil
if not psutil.pid_exists(pid): if not psutil.pid_exists(pid):
raise ProcessLookupError raise ProcessLookupError
else: else:
@@ -558,11 +584,16 @@ class ConfigurableHTTPProxy(Proxy):
env = os.environ.copy() env = os.environ.copy()
env['CONFIGPROXY_AUTH_TOKEN'] = self.auth_token env['CONFIGPROXY_AUTH_TOKEN'] = self.auth_token
cmd = self.command + [ cmd = self.command + [
'--ip', public_server.ip, '--ip',
'--port', str(public_server.port), public_server.ip,
'--api-ip', api_server.ip, '--port',
'--api-port', str(api_server.port), str(public_server.port),
'--error-target', url_path_join(self.hub.url, 'error'), '--api-ip',
api_server.ip,
'--api-port',
str(api_server.port),
'--error-target',
url_path_join(self.hub.url, 'error'),
] ]
if self.app.subdomain_host: if self.app.subdomain_host:
cmd.append('--host-routing') cmd.append('--host-routing')
@@ -595,28 +626,36 @@ class ConfigurableHTTPProxy(Proxy):
cmd.extend(['--client-ssl-request-cert']) cmd.extend(['--client-ssl-request-cert'])
cmd.extend(['--client-ssl-reject-unauthorized']) cmd.extend(['--client-ssl-reject-unauthorized'])
if self.app.statsd_host: if self.app.statsd_host:
cmd.extend([ cmd.extend(
'--statsd-host', self.app.statsd_host, [
'--statsd-port', str(self.app.statsd_port), '--statsd-host',
'--statsd-prefix', self.app.statsd_prefix + '.chp' self.app.statsd_host,
]) '--statsd-port',
str(self.app.statsd_port),
'--statsd-prefix',
self.app.statsd_prefix + '.chp',
]
)
# Warn if SSL is not used # Warn if SSL is not used
if ' --ssl' not in ' '.join(cmd): if ' --ssl' not in ' '.join(cmd):
self.log.warning("Running JupyterHub without SSL." self.log.warning(
" I hope there is SSL termination happening somewhere else...") "Running JupyterHub without SSL."
" I hope there is SSL termination happening somewhere else..."
)
self.log.info("Starting proxy @ %s", public_server.bind_url) self.log.info("Starting proxy @ %s", public_server.bind_url)
self.log.debug("Proxy cmd: %s", cmd) self.log.debug("Proxy cmd: %s", cmd)
shell = os.name == 'nt' shell = os.name == 'nt'
try: try:
self.proxy_process = Popen(cmd, env=env, start_new_session=True, shell=shell) self.proxy_process = Popen(
cmd, env=env, start_new_session=True, shell=shell
)
except FileNotFoundError as e: except FileNotFoundError as e:
self.log.error( self.log.error(
"Failed to find proxy %r\n" "Failed to find proxy %r\n"
"The proxy can be installed with `npm install -g configurable-http-proxy`." "The proxy can be installed with `npm install -g configurable-http-proxy`."
"To install `npm`, install nodejs which includes `npm`." "To install `npm`, install nodejs which includes `npm`."
"If you see an `EACCES` error or permissions error, refer to the `npm` " "If you see an `EACCES` error or permissions error, refer to the `npm` "
"documentation on How To Prevent Permissions Errors." "documentation on How To Prevent Permissions Errors." % self.command
% self.command
) )
raise raise
@@ -625,8 +664,7 @@ class ConfigurableHTTPProxy(Proxy):
def _check_process(): def _check_process():
status = self.proxy_process.poll() status = self.proxy_process.poll()
if status is not None: if status is not None:
e = RuntimeError( e = RuntimeError("Proxy failed to start with exit code %i" % status)
"Proxy failed to start with exit code %i" % status)
raise e from None raise e from None
for server in (public_server, api_server): for server in (public_server, api_server):
@@ -678,8 +716,9 @@ class ConfigurableHTTPProxy(Proxy):
"""Check if the proxy is still running""" """Check if the proxy is still running"""
if self.proxy_process.poll() is None: if self.proxy_process.poll() is None:
return return
self.log.error("Proxy stopped with exit code %r", self.log.error(
'unknown' if self.proxy_process is None else self.proxy_process.poll() "Proxy stopped with exit code %r",
'unknown' if self.proxy_process is None else self.proxy_process.poll(),
) )
self._remove_pid_file() self._remove_pid_file()
await self.start() await self.start()
@@ -724,10 +763,10 @@ class ConfigurableHTTPProxy(Proxy):
if isinstance(body, dict): if isinstance(body, dict):
body = json.dumps(body) body = json.dumps(body)
self.log.debug("Proxy: Fetching %s %s", method, url) self.log.debug("Proxy: Fetching %s %s", method, url)
req = HTTPRequest(url, req = HTTPRequest(
url,
method=method, method=method,
headers={'Authorization': 'token {}'.format( headers={'Authorization': 'token {}'.format(self.auth_token)},
self.auth_token)},
body=body, body=body,
) )
async with self.semaphore: async with self.semaphore:
@@ -739,11 +778,7 @@ class ConfigurableHTTPProxy(Proxy):
body['target'] = target body['target'] = target
body['jupyterhub'] = True body['jupyterhub'] = True
path = self._routespec_to_chp_path(routespec) path = self._routespec_to_chp_path(routespec)
await self.api_request( await self.api_request(path, method='POST', body=body)
path,
method='POST',
body=body,
)
async def delete_route(self, routespec): async def delete_route(self, routespec):
path = self._routespec_to_chp_path(routespec) path = self._routespec_to_chp_path(routespec)
@@ -762,11 +797,7 @@ class ConfigurableHTTPProxy(Proxy):
"""Reformat CHP data format to JupyterHub's proxy API.""" """Reformat CHP data format to JupyterHub's proxy API."""
target = chp_data.pop('target') target = chp_data.pop('target')
chp_data.pop('jupyterhub') chp_data.pop('jupyterhub')
return { return {'routespec': routespec, 'target': target, 'data': chp_data}
'routespec': routespec,
'target': target,
'data': chp_data,
}
async def get_all_routes(self, client=None): async def get_all_routes(self, client=None):
"""Fetch the proxy's routes.""" """Fetch the proxy's routes."""
@@ -779,6 +810,5 @@ class ConfigurableHTTPProxy(Proxy):
# exclude routes not associated with JupyterHub # exclude routes not associated with JupyterHub
self.log.debug("Omitting non-jupyterhub route %r", routespec) self.log.debug("Omitting non-jupyterhub route %r", routespec)
continue continue
all_routes[routespec] = self._reformat_routespec( all_routes[routespec] = self._reformat_routespec(routespec, chp_data)
routespec, chp_data)
return all_routes return all_routes

View File

@@ -9,7 +9,6 @@ model describing the authenticated user.
authenticate with the Hub. authenticate with the Hub.
""" """
import base64 import base64
import json import json
import os import os
@@ -18,22 +17,25 @@ import re
import socket import socket
import string import string
import time import time
from urllib.parse import quote, urlencode
import uuid import uuid
import warnings import warnings
from urllib.parse import quote
from urllib.parse import urlencode
import requests import requests
from tornado.gen import coroutine from tornado.gen import coroutine
from tornado.log import app_log
from tornado.httputil import url_concat from tornado.httputil import url_concat
from tornado.web import HTTPError, RequestHandler from tornado.log import app_log
from tornado.web import HTTPError
from tornado.web import RequestHandler
from traitlets import default
from traitlets import Dict
from traitlets import Instance
from traitlets import Integer
from traitlets import observe
from traitlets import Unicode
from traitlets import validate
from traitlets.config import SingletonConfigurable from traitlets.config import SingletonConfigurable
from traitlets import (
Unicode, Integer, Instance, Dict,
default, observe, validate,
)
from ..utils import url_path_join from ..utils import url_path_join
@@ -63,13 +65,14 @@ class _ExpiringDict(dict):
def __repr__(self): def __repr__(self):
"""include values and timestamps in repr""" """include values and timestamps in repr"""
now = time.monotonic() now = time.monotonic()
return repr({ return repr(
{
key: '{value} (age={age:.0f}s)'.format( key: '{value} (age={age:.0f}s)'.format(
value=repr(value)[:16] + '...', value=repr(value)[:16] + '...', age=now - self.timestamps[key]
age=now-self.timestamps[key],
) )
for key, value in self.values.items() for key, value in self.values.items()
}) }
)
def _check_age(self, key): def _check_age(self, key):
"""Check timestamp for a key""" """Check timestamp for a key"""
@@ -131,24 +134,28 @@ class HubAuth(SingletonConfigurable):
""" """
hub_host = Unicode('', hub_host = Unicode(
'',
help="""The public host of JupyterHub help="""The public host of JupyterHub
Only used if JupyterHub is spreading servers across subdomains. Only used if JupyterHub is spreading servers across subdomains.
""" """,
).tag(config=True) ).tag(config=True)
@default('hub_host') @default('hub_host')
def _default_hub_host(self): def _default_hub_host(self):
return os.getenv('JUPYTERHUB_HOST', '') return os.getenv('JUPYTERHUB_HOST', '')
base_url = Unicode(os.getenv('JUPYTERHUB_SERVICE_PREFIX') or '/', base_url = Unicode(
os.getenv('JUPYTERHUB_SERVICE_PREFIX') or '/',
help="""The base URL prefix of this application help="""The base URL prefix of this application
e.g. /services/service-name/ or /user/name/ e.g. /services/service-name/ or /user/name/
Default: get from JUPYTERHUB_SERVICE_PREFIX Default: get from JUPYTERHUB_SERVICE_PREFIX
""" """,
).tag(config=True) ).tag(config=True)
@validate('base_url') @validate('base_url')
def _add_slash(self, proposal): def _add_slash(self, proposal):
"""Ensure base_url starts and ends with /""" """Ensure base_url starts and ends with /"""
@@ -160,12 +167,14 @@ class HubAuth(SingletonConfigurable):
return value return value
# where is the hub # where is the hub
api_url = Unicode(os.getenv('JUPYTERHUB_API_URL') or 'http://127.0.0.1:8081/hub/api', api_url = Unicode(
os.getenv('JUPYTERHUB_API_URL') or 'http://127.0.0.1:8081/hub/api',
help="""The base API URL of the Hub. help="""The base API URL of the Hub.
Typically `http://hub-ip:hub-port/hub/api` Typically `http://hub-ip:hub-port/hub/api`
""" """,
).tag(config=True) ).tag(config=True)
@default('api_url') @default('api_url')
def _api_url(self): def _api_url(self):
env_url = os.getenv('JUPYTERHUB_API_URL') env_url = os.getenv('JUPYTERHUB_API_URL')
@@ -174,56 +183,64 @@ class HubAuth(SingletonConfigurable):
else: else:
return 'http://127.0.0.1:8081' + url_path_join(self.hub_prefix, 'api') return 'http://127.0.0.1:8081' + url_path_join(self.hub_prefix, 'api')
api_token = Unicode(os.getenv('JUPYTERHUB_API_TOKEN', ''), api_token = Unicode(
os.getenv('JUPYTERHUB_API_TOKEN', ''),
help="""API key for accessing Hub API. help="""API key for accessing Hub API.
Generate with `jupyterhub token [username]` or add to JupyterHub.services config. Generate with `jupyterhub token [username]` or add to JupyterHub.services config.
""" """,
).tag(config=True) ).tag(config=True)
hub_prefix = Unicode('/hub/', hub_prefix = Unicode(
'/hub/',
help="""The URL prefix for the Hub itself. help="""The URL prefix for the Hub itself.
Typically /hub/ Typically /hub/
""" """,
).tag(config=True) ).tag(config=True)
@default('hub_prefix') @default('hub_prefix')
def _default_hub_prefix(self): def _default_hub_prefix(self):
return url_path_join(os.getenv('JUPYTERHUB_BASE_URL') or '/', 'hub') + '/' return url_path_join(os.getenv('JUPYTERHUB_BASE_URL') or '/', 'hub') + '/'
login_url = Unicode('/hub/login', login_url = Unicode(
'/hub/login',
help="""The login URL to use help="""The login URL to use
Typically /hub/login Typically /hub/login
""" """,
).tag(config=True) ).tag(config=True)
@default('login_url') @default('login_url')
def _default_login_url(self): def _default_login_url(self):
return self.hub_host + url_path_join(self.hub_prefix, 'login') return self.hub_host + url_path_join(self.hub_prefix, 'login')
keyfile = Unicode('', keyfile = Unicode(
'',
help="""The ssl key to use for requests help="""The ssl key to use for requests
Use with certfile Use with certfile
""" """,
).tag(config=True) ).tag(config=True)
certfile = Unicode('', certfile = Unicode(
'',
help="""The ssl cert to use for requests help="""The ssl cert to use for requests
Use with keyfile Use with keyfile
""" """,
).tag(config=True) ).tag(config=True)
client_ca = Unicode('', client_ca = Unicode(
'',
help="""The ssl certificate authority to use to verify requests help="""The ssl certificate authority to use to verify requests
Use with keyfile and certfile Use with keyfile and certfile
""" """,
).tag(config=True) ).tag(config=True)
cookie_name = Unicode('jupyterhub-services', cookie_name = Unicode(
help="""The name of the cookie I should be looking for""" 'jupyterhub-services', help="""The name of the cookie I should be looking for"""
).tag(config=True) ).tag(config=True)
cookie_options = Dict( cookie_options = Dict(
@@ -245,21 +262,26 @@ class HubAuth(SingletonConfigurable):
return {} return {}
cookie_cache_max_age = Integer(help="DEPRECATED. Use cache_max_age") cookie_cache_max_age = Integer(help="DEPRECATED. Use cache_max_age")
@observe('cookie_cache_max_age') @observe('cookie_cache_max_age')
def _deprecated_cookie_cache(self, change): def _deprecated_cookie_cache(self, change):
warnings.warn("cookie_cache_max_age is deprecated in JupyterHub 0.8. Use cache_max_age instead.") warnings.warn(
"cookie_cache_max_age is deprecated in JupyterHub 0.8. Use cache_max_age instead."
)
self.cache_max_age = change.new self.cache_max_age = change.new
cache_max_age = Integer(300, cache_max_age = Integer(
300,
help="""The maximum time (in seconds) to cache the Hub's responses for authentication. help="""The maximum time (in seconds) to cache the Hub's responses for authentication.
A larger value reduces load on the Hub and occasional response lag. A larger value reduces load on the Hub and occasional response lag.
A smaller value reduces propagation time of changes on the Hub (rare). A smaller value reduces propagation time of changes on the Hub (rare).
Default: 300 (five minutes) Default: 300 (five minutes)
""" """,
).tag(config=True) ).tag(config=True)
cache = Instance(_ExpiringDict, allow_none=False) cache = Instance(_ExpiringDict, allow_none=False)
@default('cache') @default('cache')
def _default_cache(self): def _default_cache(self):
return _ExpiringDict(self.cache_max_age) return _ExpiringDict(self.cache_max_age)
@@ -311,25 +333,42 @@ class HubAuth(SingletonConfigurable):
except requests.ConnectionError as e: except requests.ConnectionError as e:
app_log.error("Error connecting to %s: %s", self.api_url, e) app_log.error("Error connecting to %s: %s", self.api_url, e)
msg = "Failed to connect to Hub API at %r." % self.api_url msg = "Failed to connect to Hub API at %r." % self.api_url
msg += " Is the Hub accessible at this URL (from host: %s)?" % socket.gethostname() msg += (
" Is the Hub accessible at this URL (from host: %s)?"
% socket.gethostname()
)
if '127.0.0.1' in self.api_url: if '127.0.0.1' in self.api_url:
msg += " Make sure to set c.JupyterHub.hub_ip to an IP accessible to" + \ msg += (
" single-user servers if the servers are not on the same host as the Hub." " Make sure to set c.JupyterHub.hub_ip to an IP accessible to"
+ " single-user servers if the servers are not on the same host as the Hub."
)
raise HTTPError(500, msg) raise HTTPError(500, msg)
data = None data = None
if r.status_code == 404 and allow_404: if r.status_code == 404 and allow_404:
pass pass
elif r.status_code == 403: elif r.status_code == 403:
app_log.error("I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s", r.status_code, r.reason) app_log.error(
"I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s",
r.status_code,
r.reason,
)
app_log.error(r.text) app_log.error(r.text)
raise HTTPError(500, "Permission failure checking authorization, I may need a new token") raise HTTPError(
500, "Permission failure checking authorization, I may need a new token"
)
elif r.status_code >= 500: elif r.status_code >= 500:
app_log.error("Upstream failure verifying auth token: [%i] %s", r.status_code, r.reason) app_log.error(
"Upstream failure verifying auth token: [%i] %s",
r.status_code,
r.reason,
)
app_log.error(r.text) app_log.error(r.text)
raise HTTPError(502, "Failed to check authorization (upstream problem)") raise HTTPError(502, "Failed to check authorization (upstream problem)")
elif r.status_code >= 400: elif r.status_code >= 400:
app_log.warning("Failed to check authorization: [%i] %s", r.status_code, r.reason) app_log.warning(
"Failed to check authorization: [%i] %s", r.status_code, r.reason
)
app_log.warning(r.text) app_log.warning(r.text)
msg = "Failed to check authorization" msg = "Failed to check authorization"
# pass on error_description from oauth failure # pass on error_description from oauth failure
@@ -358,10 +397,12 @@ class HubAuth(SingletonConfigurable):
The 'name' field contains the user's name. The 'name' field contains the user's name.
""" """
return self._check_hub_authorization( return self._check_hub_authorization(
url=url_path_join(self.api_url, url=url_path_join(
self.api_url,
"authorizations/cookie", "authorizations/cookie",
self.cookie_name, self.cookie_name,
quote(encrypted_cookie, safe='')), quote(encrypted_cookie, safe=''),
),
cache_key='cookie:{}:{}'.format(session_id, encrypted_cookie), cache_key='cookie:{}:{}'.format(session_id, encrypted_cookie),
use_cache=use_cache, use_cache=use_cache,
) )
@@ -379,9 +420,9 @@ class HubAuth(SingletonConfigurable):
The 'name' field contains the user's name. The 'name' field contains the user's name.
""" """
return self._check_hub_authorization( return self._check_hub_authorization(
url=url_path_join(self.api_url, url=url_path_join(
"authorizations/token", self.api_url, "authorizations/token", quote(token, safe='')
quote(token, safe='')), ),
cache_key='token:{}:{}'.format(session_id, token), cache_key='token:{}:{}'.format(session_id, token),
use_cache=use_cache, use_cache=use_cache,
) )
@@ -399,7 +440,9 @@ class HubAuth(SingletonConfigurable):
user_token = handler.get_argument('token', '') user_token = handler.get_argument('token', '')
if not user_token: if not user_token:
# get it from Authorization header # get it from Authorization header
m = self.auth_header_pat.match(handler.request.headers.get(self.auth_header_name, '')) m = self.auth_header_pat.match(
handler.request.headers.get(self.auth_header_name, '')
)
if m: if m:
user_token = m.group(1) user_token = m.group(1)
return user_token return user_token
@@ -469,11 +512,14 @@ class HubOAuth(HubAuth):
@default('login_url') @default('login_url')
def _login_url(self): def _login_url(self):
return url_concat(self.oauth_authorization_url, { return url_concat(
self.oauth_authorization_url,
{
'client_id': self.oauth_client_id, 'client_id': self.oauth_client_id,
'redirect_uri': self.oauth_redirect_uri, 'redirect_uri': self.oauth_redirect_uri,
'response_type': 'code', 'response_type': 'code',
}) },
)
@property @property
def cookie_name(self): def cookie_name(self):
@@ -511,6 +557,7 @@ class HubOAuth(HubAuth):
Use JUPYTERHUB_CLIENT_ID by default. Use JUPYTERHUB_CLIENT_ID by default.
""" """
).tag(config=True) ).tag(config=True)
@default('oauth_client_id') @default('oauth_client_id')
def _client_id(self): def _client_id(self):
return os.getenv('JUPYTERHUB_CLIENT_ID', '') return os.getenv('JUPYTERHUB_CLIENT_ID', '')
@@ -527,13 +574,18 @@ class HubOAuth(HubAuth):
Should generally be /base_url/oauth_callback Should generally be /base_url/oauth_callback
""" """
).tag(config=True) ).tag(config=True)
@default('oauth_redirect_uri') @default('oauth_redirect_uri')
def _default_redirect(self): def _default_redirect(self):
return os.getenv('JUPYTERHUB_OAUTH_CALLBACK_URL') or url_path_join(self.base_url, 'oauth_callback') return os.getenv('JUPYTERHUB_OAUTH_CALLBACK_URL') or url_path_join(
self.base_url, 'oauth_callback'
)
oauth_authorization_url = Unicode('/hub/api/oauth2/authorize', oauth_authorization_url = Unicode(
'/hub/api/oauth2/authorize',
help="The URL to redirect to when starting the OAuth process", help="The URL to redirect to when starting the OAuth process",
).tag(config=True) ).tag(config=True)
@default('oauth_authorization_url') @default('oauth_authorization_url')
def _auth_url(self): def _auth_url(self):
return self.hub_host + url_path_join(self.hub_prefix, 'api/oauth2/authorize') return self.hub_host + url_path_join(self.hub_prefix, 'api/oauth2/authorize')
@@ -541,6 +593,7 @@ class HubOAuth(HubAuth):
oauth_token_url = Unicode( oauth_token_url = Unicode(
help="""The URL for requesting an OAuth token from JupyterHub""" help="""The URL for requesting an OAuth token from JupyterHub"""
).tag(config=True) ).tag(config=True)
@default('oauth_token_url') @default('oauth_token_url')
def _token_url(self): def _token_url(self):
return url_path_join(self.api_url, 'oauth2/token') return url_path_join(self.api_url, 'oauth2/token')
@@ -565,11 +618,12 @@ class HubOAuth(HubAuth):
redirect_uri=self.oauth_redirect_uri, redirect_uri=self.oauth_redirect_uri,
) )
token_reply = self._api_request('POST', self.oauth_token_url, token_reply = self._api_request(
'POST',
self.oauth_token_url,
data=urlencode(params).encode('utf8'), data=urlencode(params).encode('utf8'),
headers={ headers={'Content-Type': 'application/x-www-form-urlencoded'},
'Content-Type': 'application/x-www-form-urlencoded' )
})
return token_reply['access_token'] return token_reply['access_token']
@@ -577,9 +631,11 @@ class HubOAuth(HubAuth):
"""Encode a state dict as url-safe base64""" """Encode a state dict as url-safe base64"""
# trim trailing `=` because = is itself not url-safe! # trim trailing `=` because = is itself not url-safe!
json_state = json.dumps(state) json_state = json.dumps(state)
return base64.urlsafe_b64encode( return (
json_state.encode('utf8') base64.urlsafe_b64encode(json_state.encode('utf8'))
).decode('ascii').rstrip('=') .decode('ascii')
.rstrip('=')
)
def _decode_state(self, b64_state): def _decode_state(self, b64_state):
"""Decode a base64 state """Decode a base64 state
@@ -621,7 +677,9 @@ class HubOAuth(HubAuth):
# use a randomized cookie suffix to avoid collisions # use a randomized cookie suffix to avoid collisions
# in case of concurrent logins # in case of concurrent logins
app_log.warning("Detected unused OAuth state cookies") app_log.warning("Detected unused OAuth state cookies")
cookie_suffix = ''.join(random.choice(string.ascii_letters) for i in range(8)) cookie_suffix = ''.join(
random.choice(string.ascii_letters) for i in range(8)
)
cookie_name = '{}-{}'.format(self.state_cookie_name, cookie_suffix) cookie_name = '{}-{}'.format(self.state_cookie_name, cookie_suffix)
extra_state['cookie_name'] = cookie_name extra_state['cookie_name'] = cookie_name
else: else:
@@ -640,11 +698,7 @@ class HubOAuth(HubAuth):
kwargs['secure'] = True kwargs['secure'] = True
# load user cookie overrides # load user cookie overrides
kwargs.update(self.cookie_options) kwargs.update(self.cookie_options)
handler.set_secure_cookie( handler.set_secure_cookie(cookie_name, b64_state, **kwargs)
cookie_name,
b64_state,
**kwargs
)
return b64_state return b64_state
def generate_state(self, next_url=None, **extra_state): def generate_state(self, next_url=None, **extra_state):
@@ -658,10 +712,7 @@ class HubOAuth(HubAuth):
------- -------
state (str): The base64-encoded state string. state (str): The base64-encoded state string.
""" """
state = { state = {'uuid': uuid.uuid4().hex, 'next_url': next_url}
'uuid': uuid.uuid4().hex,
'next_url': next_url,
}
state.update(extra_state) state.update(extra_state)
return self._encode_state(state) return self._encode_state(state)
@@ -681,21 +732,19 @@ class HubOAuth(HubAuth):
def set_cookie(self, handler, access_token): def set_cookie(self, handler, access_token):
"""Set a cookie recording OAuth result""" """Set a cookie recording OAuth result"""
kwargs = { kwargs = {'path': self.base_url, 'httponly': True}
'path': self.base_url,
'httponly': True,
}
if handler.request.protocol == 'https': if handler.request.protocol == 'https':
kwargs['secure'] = True kwargs['secure'] = True
# load user cookie overrides # load user cookie overrides
kwargs.update(self.cookie_options) kwargs.update(self.cookie_options)
app_log.debug("Setting oauth cookie for %s: %s, %s", app_log.debug(
handler.request.remote_ip, self.cookie_name, kwargs) "Setting oauth cookie for %s: %s, %s",
handler.set_secure_cookie( handler.request.remote_ip,
self.cookie_name, self.cookie_name,
access_token, kwargs,
**kwargs
) )
handler.set_secure_cookie(self.cookie_name, access_token, **kwargs)
def clear_cookie(self, handler): def clear_cookie(self, handler):
"""Clear the OAuth cookie""" """Clear the OAuth cookie"""
handler.clear_cookie(self.cookie_name, path=self.base_url) handler.clear_cookie(self.cookie_name, path=self.base_url)
@@ -703,6 +752,7 @@ class HubOAuth(HubAuth):
class UserNotAllowed(Exception): class UserNotAllowed(Exception):
"""Exception raised when a user is identified and not allowed""" """Exception raised when a user is identified and not allowed"""
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
@@ -738,6 +788,7 @@ class HubAuthenticated(object):
... ...
""" """
hub_services = None # set of allowed services hub_services = None # set of allowed services
hub_users = None # set of allowed users hub_users = None # set of allowed users
hub_groups = None # set of allowed groups hub_groups = None # set of allowed groups
@@ -748,9 +799,11 @@ class HubAuthenticated(object):
"""Property indicating that all successfully identified user """Property indicating that all successfully identified user
or service should be allowed. or service should be allowed.
""" """
return (self.hub_services is None return (
self.hub_services is None
and self.hub_users is None and self.hub_users is None
and self.hub_groups is None) and self.hub_groups is None
)
# self.hub_auth must be a HubAuth instance. # self.hub_auth must be a HubAuth instance.
# If nothing specified, use default config, # If nothing specified, use default config,
@@ -758,6 +811,7 @@ class HubAuthenticated(object):
# based on JupyterHub environment variables for services. # based on JupyterHub environment variables for services.
_hub_auth = None _hub_auth = None
hub_auth_class = HubAuth hub_auth_class = HubAuth
@property @property
def hub_auth(self): def hub_auth(self):
if self._hub_auth is None: if self._hub_auth is None:
@@ -794,7 +848,9 @@ class HubAuthenticated(object):
name = model['name'] name = model['name']
kind = model.setdefault('kind', 'user') kind = model.setdefault('kind', 'user')
if self.allow_all: if self.allow_all:
app_log.debug("Allowing Hub %s %s (all Hub users and services allowed)", kind, name) app_log.debug(
"Allowing Hub %s %s (all Hub users and services allowed)", kind, name
)
return model return model
if self.allow_admin and model.get('admin', False): if self.allow_admin and model.get('admin', False):
@@ -816,7 +872,11 @@ class HubAuthenticated(object):
return model return model
elif self.hub_groups and set(model['groups']).intersection(self.hub_groups): elif self.hub_groups and set(model['groups']).intersection(self.hub_groups):
allowed_groups = set(model['groups']).intersection(self.hub_groups) allowed_groups = set(model['groups']).intersection(self.hub_groups)
app_log.debug("Allowing Hub user %s in group(s) %s", name, ','.join(sorted(allowed_groups))) app_log.debug(
"Allowing Hub user %s in group(s) %s",
name,
','.join(sorted(allowed_groups)),
)
# group in whitelist # group in whitelist
return model return model
else: else:
@@ -845,7 +905,10 @@ class HubAuthenticated(object):
# This is not the best, but avoids problems that can be caused # This is not the best, but avoids problems that can be caused
# when get_current_user is allowed to raise. # when get_current_user is allowed to raise.
def raise_on_redirect(*args, **kwargs): def raise_on_redirect(*args, **kwargs):
raise HTTPError(403, "{kind} {name} is not allowed.".format(**user_model)) raise HTTPError(
403, "{kind} {name} is not allowed.".format(**user_model)
)
self.redirect = raise_on_redirect self.redirect = raise_on_redirect
return return
except Exception: except Exception:
@@ -869,6 +932,7 @@ class HubAuthenticated(object):
class HubOAuthenticated(HubAuthenticated): class HubOAuthenticated(HubAuthenticated):
"""Simple subclass of HubAuthenticated using OAuth instead of old shared cookies""" """Simple subclass of HubAuthenticated using OAuth instead of old shared cookies"""
hub_auth_class = HubOAuth hub_auth_class = HubOAuth
@@ -917,5 +981,3 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
app_log.info("Logged-in user %s", user_model) app_log.info("Logged-in user %s", user_model)
self.hub_auth.set_cookie(self, token) self.hub_auth.set_cookie(self, token)
self.redirect(next_url or self.hub_auth.base_url) self.redirect(next_url or self.hub_auth.base_url)

View File

@@ -38,24 +38,26 @@ A hub-managed service with no URL::
} }
""" """
import copy import copy
import os
import pipes import pipes
import shutil import shutil
import os
from subprocess import Popen from subprocess import Popen
from traitlets import ( from traitlets import Any
HasTraits, from traitlets import Bool
Any, Bool, Dict, Unicode, Instance, from traitlets import default
default, from traitlets import Dict
) from traitlets import HasTraits
from traitlets import Instance
from traitlets import Unicode
from traitlets.config import LoggingConfigurable from traitlets.config import LoggingConfigurable
from .. import orm from .. import orm
from ..objects import Server from ..objects import Server
from ..spawner import LocalProcessSpawner
from ..spawner import set_user_setuid
from ..traitlets import Command from ..traitlets import Command
from ..spawner import LocalProcessSpawner, set_user_setuid
from ..utils import url_path_join from ..utils import url_path_join
@@ -81,14 +83,17 @@ class _MockUser(HasTraits):
return '' return ''
return self.server.base_url return self.server.base_url
# We probably shouldn't use a Spawner here, # We probably shouldn't use a Spawner here,
# but there are too many concepts to share. # but there are too many concepts to share.
class _ServiceSpawner(LocalProcessSpawner): class _ServiceSpawner(LocalProcessSpawner):
"""Subclass of LocalProcessSpawner """Subclass of LocalProcessSpawner
Removes notebook-specific-ness from LocalProcessSpawner. Removes notebook-specific-ness from LocalProcessSpawner.
""" """
cwd = Unicode() cwd = Unicode()
cmd = Command(minlen=0) cmd = Command(minlen=0)
@@ -115,7 +120,9 @@ class _ServiceSpawner(LocalProcessSpawner):
self.log.info("Spawning %s", ' '.join(pipes.quote(s) for s in cmd)) self.log.info("Spawning %s", ' '.join(pipes.quote(s) for s in cmd))
try: try:
self.proc = Popen(self.cmd, env=env, self.proc = Popen(
self.cmd,
env=env,
preexec_fn=self.make_preexec_fn(self.user.name), preexec_fn=self.make_preexec_fn(self.user.name),
start_new_session=True, # don't forward signals start_new_session=True, # don't forward signals
cwd=self.cwd or None, cwd=self.cwd or None,
@@ -123,8 +130,10 @@ class _ServiceSpawner(LocalProcessSpawner):
except PermissionError: except PermissionError:
# use which to get abspath # use which to get abspath
script = shutil.which(cmd[0]) or cmd[0] script = shutil.which(cmd[0]) or cmd[0]
self.log.error("Permission denied trying to run %r. Does %s have access to this file?", self.log.error(
script, self.user.name, "Permission denied trying to run %r. Does %s have access to this file?",
script,
self.user.name,
) )
raise raise
@@ -165,9 +174,9 @@ class Service(LoggingConfigurable):
If the service has an http endpoint, it If the service has an http endpoint, it
""" """
).tag(input=True) ).tag(input=True)
admin = Bool(False, admin = Bool(False, help="Does the service need admin-access to the Hub API?").tag(
help="Does the service need admin-access to the Hub API?" input=True
).tag(input=True) )
url = Unicode( url = Unicode(
help="""URL of the service. help="""URL of the service.
@@ -205,22 +214,23 @@ class Service(LoggingConfigurable):
""" """
return 'managed' if self.managed else 'external' return 'managed' if self.managed else 'external'
command = Command(minlen=0, command = Command(minlen=0, help="Command to spawn this service, if managed.").tag(
help="Command to spawn this service, if managed." input=True
).tag(input=True) )
cwd = Unicode( cwd = Unicode(help="""The working directory in which to run the service.""").tag(
help="""The working directory in which to run the service.""" input=True
).tag(input=True) )
environment = Dict( environment = Dict(
help="""Environment variables to pass to the service. help="""Environment variables to pass to the service.
Only used if the Hub is spawning the service. Only used if the Hub is spawning the service.
""" """
).tag(input=True) ).tag(input=True)
user = Unicode("", user = Unicode(
"",
help="""The user to become when launching the service. help="""The user to become when launching the service.
If unspecified, run the service as the same user as the Hub. If unspecified, run the service as the same user as the Hub.
""" """,
).tag(input=True) ).tag(input=True)
domain = Unicode() domain = Unicode()
@@ -245,6 +255,7 @@ class Service(LoggingConfigurable):
Default: `service-<name>` Default: `service-<name>`
""" """
).tag(input=True) ).tag(input=True)
@default('oauth_client_id') @default('oauth_client_id')
def _default_client_id(self): def _default_client_id(self):
return 'service-%s' % self.name return 'service-%s' % self.name
@@ -256,6 +267,7 @@ class Service(LoggingConfigurable):
Default: `/services/:name/oauth_callback` Default: `/services/:name/oauth_callback`
""" """
).tag(input=True) ).tag(input=True)
@default('oauth_redirect_uri') @default('oauth_redirect_uri')
def _default_redirect_uri(self): def _default_redirect_uri(self):
if self.server is None: if self.server is None:
@@ -328,10 +340,7 @@ class Service(LoggingConfigurable):
cwd=self.cwd, cwd=self.cwd,
hub=self.hub, hub=self.hub,
user=_MockUser( user=_MockUser(
name=self.user, name=self.user, service=self, server=self.orm.server, host=self.host
service=self,
server=self.orm.server,
host=self.host,
), ),
internal_ssl=self.app.internal_ssl, internal_ssl=self.app.internal_ssl,
internal_certs_location=self.app.internal_certs_location, internal_certs_location=self.app.internal_certs_location,
@@ -344,7 +353,9 @@ class Service(LoggingConfigurable):
def _proc_stopped(self): def _proc_stopped(self):
"""Called when the service process unexpectedly exits""" """Called when the service process unexpectedly exits"""
self.log.error("Service %s exited with status %i", self.name, self.proc.returncode) self.log.error(
"Service %s exited with status %i", self.name, self.proc.returncode
)
self.start() self.start()
async def stop(self): async def stop(self):
@@ -357,4 +368,4 @@ class Service(LoggingConfigurable):
self.db.delete(self.orm.server) self.db.delete(self.orm.server)
self.db.commit() self.db.commit()
self.spawner.stop_polling() self.spawner.stop_polling()
return (await self.spawner.stop()) return await self.spawner.stop()

View File

@@ -1,23 +1,24 @@
#!/usr/bin/env python #!/usr/bin/env python
"""Extend regular notebook server to be aware of multiuser things.""" """Extend regular notebook server to be aware of multiuser things."""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import asyncio import asyncio
from datetime import datetime, timezone
import json import json
import os import os
import random import random
from datetime import datetime
from datetime import timezone
from textwrap import dedent from textwrap import dedent
from urllib.parse import urlparse from urllib.parse import urlparse
from jinja2 import ChoiceLoader, FunctionLoader from jinja2 import ChoiceLoader
from jinja2 import FunctionLoader
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
from tornado import gen from tornado import gen
from tornado import ioloop from tornado import ioloop
from tornado.web import HTTPError, RequestHandler from tornado.httpclient import AsyncHTTPClient
from tornado.httpclient import HTTPRequest
from tornado.web import HTTPError
from tornado.web import RequestHandler
try: try:
import notebook import notebook
@@ -60,7 +61,9 @@ class HubAuthenticatedHandler(HubOAuthenticated):
@property @property
def allow_admin(self): def allow_admin(self):
return self.settings.get('admin_access', os.getenv('JUPYTERHUB_ADMIN_ACCESS') or False) return self.settings.get(
'admin_access', os.getenv('JUPYTERHUB_ADMIN_ACCESS') or False
)
@property @property
def hub_auth(self): def hub_auth(self):
@@ -79,6 +82,7 @@ class HubAuthenticatedHandler(HubOAuthenticated):
class JupyterHubLoginHandler(LoginHandler): class JupyterHubLoginHandler(LoginHandler):
"""LoginHandler that hooks up Hub authentication""" """LoginHandler that hooks up Hub authentication"""
@staticmethod @staticmethod
def login_available(settings): def login_available(settings):
return True return True
@@ -113,12 +117,14 @@ class JupyterHubLogoutHandler(LogoutHandler):
def get(self): def get(self):
self.settings['hub_auth'].clear_cookie(self) self.settings['hub_auth'].clear_cookie(self)
self.redirect( self.redirect(
self.settings['hub_host'] + self.settings['hub_host']
url_path_join(self.settings['hub_prefix'], 'logout')) + url_path_join(self.settings['hub_prefix'], 'logout')
)
class OAuthCallbackHandler(HubOAuthCallbackHandler, IPythonHandler): class OAuthCallbackHandler(HubOAuthCallbackHandler, IPythonHandler):
"""Mixin IPythonHandler to get the right error pages, etc.""" """Mixin IPythonHandler to get the right error pages, etc."""
@property @property
def hub_auth(self): def hub_auth(self):
return self.settings['hub_auth'] return self.settings['hub_auth']
@@ -126,7 +132,8 @@ class OAuthCallbackHandler(HubOAuthCallbackHandler, IPythonHandler):
# register new hub related command-line aliases # register new hub related command-line aliases
aliases = dict(notebook_aliases) aliases = dict(notebook_aliases)
aliases.update({ aliases.update(
{
'user': 'SingleUserNotebookApp.user', 'user': 'SingleUserNotebookApp.user',
'group': 'SingleUserNotebookApp.group', 'group': 'SingleUserNotebookApp.group',
'cookie-name': 'HubAuth.cookie_name', 'cookie-name': 'HubAuth.cookie_name',
@@ -134,15 +141,17 @@ aliases.update({
'hub-host': 'SingleUserNotebookApp.hub_host', 'hub-host': 'SingleUserNotebookApp.hub_host',
'hub-api-url': 'SingleUserNotebookApp.hub_api_url', 'hub-api-url': 'SingleUserNotebookApp.hub_api_url',
'base-url': 'SingleUserNotebookApp.base_url', 'base-url': 'SingleUserNotebookApp.base_url',
})
flags = dict(notebook_flags)
flags.update({
'disable-user-config': ({
'SingleUserNotebookApp': {
'disable_user_config': True
} }
}, "Disable user-controlled configuration of the notebook server.") )
}) flags = dict(notebook_flags)
flags.update(
{
'disable-user-config': (
{'SingleUserNotebookApp': {'disable_user_config': True}},
"Disable user-controlled configuration of the notebook server.",
)
}
)
page_template = """ page_template = """
{% extends "templates/page.html" %} {% extends "templates/page.html" %}
@@ -209,11 +218,14 @@ def _exclude_home(path_list):
class SingleUserNotebookApp(NotebookApp): class SingleUserNotebookApp(NotebookApp):
"""A Subclass of the regular NotebookApp that is aware of the parent multiuser context.""" """A Subclass of the regular NotebookApp that is aware of the parent multiuser context."""
description = dedent("""
description = dedent(
"""
Single-user server for JupyterHub. Extends the Jupyter Notebook server. Single-user server for JupyterHub. Extends the Jupyter Notebook server.
Meant to be invoked by JupyterHub Spawners, and not directly. Meant to be invoked by JupyterHub Spawners, and not directly.
""") """
)
examples = "" examples = ""
subcommands = {} subcommands = {}
@@ -229,6 +241,7 @@ class SingleUserNotebookApp(NotebookApp):
# ensures that each spawn clears any cookies from previous session, # ensures that each spawn clears any cookies from previous session,
# triggering OAuth again # triggering OAuth again
cookie_secret = Bytes() cookie_secret = Bytes()
def _cookie_secret_default(self): def _cookie_secret_default(self):
return os.urandom(32) return os.urandom(32)
@@ -320,14 +333,17 @@ class SingleUserNotebookApp(NotebookApp):
trust_xheaders = True trust_xheaders = True
login_handler_class = JupyterHubLoginHandler login_handler_class = JupyterHubLoginHandler
logout_handler_class = JupyterHubLogoutHandler logout_handler_class = JupyterHubLogoutHandler
port_retries = 0 # disable port-retries, since the Spawner will tell us what port to use port_retries = (
0
) # disable port-retries, since the Spawner will tell us what port to use
disable_user_config = Bool(False, disable_user_config = Bool(
False,
help="""Disable user configuration of single-user server. help="""Disable user configuration of single-user server.
Prevents user-writable files that normally configure the single-user server Prevents user-writable files that normally configure the single-user server
from being loaded, ensuring admins have full control of configuration. from being loaded, ensuring admins have full control of configuration.
""" """,
).tag(config=True) ).tag(config=True)
@validate('notebook_dir') @validate('notebook_dir')
@@ -394,22 +410,15 @@ class SingleUserNotebookApp(NotebookApp):
# create dynamic default http client, # create dynamic default http client,
# configured with any relevant ssl config # configured with any relevant ssl config
hub_http_client = Any() hub_http_client = Any()
@default('hub_http_client') @default('hub_http_client')
def _default_client(self): def _default_client(self):
ssl_context = make_ssl_context( ssl_context = make_ssl_context(
self.keyfile, self.keyfile, self.certfile, cafile=self.client_ca
self.certfile,
cafile=self.client_ca,
)
AsyncHTTPClient.configure(
None,
defaults={
"ssl_options": ssl_context,
},
) )
AsyncHTTPClient.configure(None, defaults={"ssl_options": ssl_context})
return AsyncHTTPClient() return AsyncHTTPClient()
async def check_hub_version(self): async def check_hub_version(self):
"""Test a connection to my Hub """Test a connection to my Hub
@@ -422,8 +431,12 @@ class SingleUserNotebookApp(NotebookApp):
try: try:
resp = await client.fetch(self.hub_api_url) resp = await client.fetch(self.hub_api_url)
except Exception: except Exception:
self.log.exception("Failed to connect to my Hub at %s (attempt %i/%i). Is it running?", self.log.exception(
self.hub_api_url, i, RETRIES) "Failed to connect to my Hub at %s (attempt %i/%i). Is it running?",
self.hub_api_url,
i,
RETRIES,
)
await gen.sleep(min(2 ** i, 16)) await gen.sleep(min(2 ** i, 16))
else: else:
break break
@@ -434,14 +447,15 @@ class SingleUserNotebookApp(NotebookApp):
_check_version(hub_version, __version__, self.log) _check_version(hub_version, __version__, self.log)
server_name = Unicode() server_name = Unicode()
@default('server_name') @default('server_name')
def _server_name_default(self): def _server_name_default(self):
return os.environ.get('JUPYTERHUB_SERVER_NAME', '') return os.environ.get('JUPYTERHUB_SERVER_NAME', '')
hub_activity_url = Unicode( hub_activity_url = Unicode(
config=True, config=True, help="URL for sending JupyterHub activity updates"
help="URL for sending JupyterHub activity updates",
) )
@default('hub_activity_url') @default('hub_activity_url')
def _default_activity_url(self): def _default_activity_url(self):
return os.environ.get('JUPYTERHUB_ACTIVITY_URL', '') return os.environ.get('JUPYTERHUB_ACTIVITY_URL', '')
@@ -452,8 +466,9 @@ class SingleUserNotebookApp(NotebookApp):
help=""" help="""
Interval (in seconds) on which to update the Hub Interval (in seconds) on which to update the Hub
with our latest activity. with our latest activity.
""" """,
) )
@default('hub_activity_interval') @default('hub_activity_interval')
def _default_activity_interval(self): def _default_activity_interval(self):
env_value = os.environ.get('JUPYTERHUB_ACTIVITY_INTERVAL') env_value = os.environ.get('JUPYTERHUB_ACTIVITY_INTERVAL')
@@ -478,10 +493,7 @@ class SingleUserNotebookApp(NotebookApp):
self.log.warning("last activity is using naïve timestamps") self.log.warning("last activity is using naïve timestamps")
last_activity = last_activity.replace(tzinfo=timezone.utc) last_activity = last_activity.replace(tzinfo=timezone.utc)
if ( if self._last_activity_sent and last_activity < self._last_activity_sent:
self._last_activity_sent
and last_activity < self._last_activity_sent
):
self.log.debug("No activity since %s", self._last_activity_sent) self.log.debug("No activity since %s", self._last_activity_sent)
return return
@@ -496,14 +508,14 @@ class SingleUserNotebookApp(NotebookApp):
"Authorization": "token {}".format(self.hub_auth.api_token), "Authorization": "token {}".format(self.hub_auth.api_token),
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
body=json.dumps({ body=json.dumps(
{
'servers': { 'servers': {
self.server_name: { self.server_name: {'last_activity': last_activity_timestamp}
'last_activity': last_activity_timestamp,
},
}, },
'last_activity': last_activity_timestamp, 'last_activity': last_activity_timestamp,
}) }
),
) )
try: try:
await client.fetch(req) await client.fetch(req)
@@ -526,8 +538,8 @@ class SingleUserNotebookApp(NotebookApp):
if not self.hub_activity_url or not self.hub_activity_interval: if not self.hub_activity_url or not self.hub_activity_interval:
self.log.warning("Activity events disabled") self.log.warning("Activity events disabled")
return return
self.log.info("Updating Hub with activity every %s seconds", self.log.info(
self.hub_activity_interval "Updating Hub with activity every %s seconds", self.hub_activity_interval
) )
while True: while True:
try: try:
@@ -561,7 +573,9 @@ class SingleUserNotebookApp(NotebookApp):
api_token = os.environ['JUPYTERHUB_API_TOKEN'] api_token = os.environ['JUPYTERHUB_API_TOKEN']
if not api_token: if not api_token:
self.exit("JUPYTERHUB_API_TOKEN env is required to run jupyterhub-singleuser. Did you launch it manually?") self.exit(
"JUPYTERHUB_API_TOKEN env is required to run jupyterhub-singleuser. Did you launch it manually?"
)
self.hub_auth = HubOAuth( self.hub_auth = HubOAuth(
parent=self, parent=self,
api_token=api_token, api_token=api_token,
@@ -586,21 +600,23 @@ class SingleUserNotebookApp(NotebookApp):
s['hub_prefix'] = self.hub_prefix s['hub_prefix'] = self.hub_prefix
s['hub_host'] = self.hub_host s['hub_host'] = self.hub_host
s['hub_auth'] = self.hub_auth s['hub_auth'] = self.hub_auth
csp_report_uri = s['csp_report_uri'] = self.hub_host + url_path_join(self.hub_prefix, 'security/csp-report') csp_report_uri = s['csp_report_uri'] = self.hub_host + url_path_join(
self.hub_prefix, 'security/csp-report'
)
headers = s.setdefault('headers', {}) headers = s.setdefault('headers', {})
headers['X-JupyterHub-Version'] = __version__ headers['X-JupyterHub-Version'] = __version__
# set CSP header directly to workaround bugs in jupyter/notebook 5.0 # set CSP header directly to workaround bugs in jupyter/notebook 5.0
headers.setdefault('Content-Security-Policy', ';'.join([ headers.setdefault(
"frame-ancestors 'self'", 'Content-Security-Policy',
"report-uri " + csp_report_uri, ';'.join(["frame-ancestors 'self'", "report-uri " + csp_report_uri]),
])) )
super(SingleUserNotebookApp, self).init_webapp() super(SingleUserNotebookApp, self).init_webapp()
# add OAuth callback # add OAuth callback
self.web_app.add_handlers(r".*$", [( self.web_app.add_handlers(
urlparse(self.hub_auth.oauth_redirect_uri).path, r".*$",
OAuthCallbackHandler [(urlparse(self.hub_auth.oauth_redirect_uri).path, OAuthCallbackHandler)],
)]) )
# apply X-JupyterHub-Version to *all* request handlers (even redirects) # apply X-JupyterHub-Version to *all* request handlers (even redirects)
self.patch_default_headers() self.patch_default_headers()
@@ -610,6 +626,7 @@ class SingleUserNotebookApp(NotebookApp):
if hasattr(RequestHandler, '_orig_set_default_headers'): if hasattr(RequestHandler, '_orig_set_default_headers'):
return return
RequestHandler._orig_set_default_headers = RequestHandler.set_default_headers RequestHandler._orig_set_default_headers = RequestHandler.set_default_headers
def set_jupyterhub_header(self): def set_jupyterhub_header(self):
self._orig_set_default_headers() self._orig_set_default_headers()
self.set_header('X-JupyterHub-Version', __version__) self.set_header('X-JupyterHub-Version', __version__)
@@ -619,13 +636,16 @@ class SingleUserNotebookApp(NotebookApp):
def patch_templates(self): def patch_templates(self):
"""Patch page templates to add Hub-related buttons""" """Patch page templates to add Hub-related buttons"""
self.jinja_template_vars['logo_url'] = self.hub_host + url_path_join(self.hub_prefix, 'logo') self.jinja_template_vars['logo_url'] = self.hub_host + url_path_join(
self.hub_prefix, 'logo'
)
self.jinja_template_vars['hub_host'] = self.hub_host self.jinja_template_vars['hub_host'] = self.hub_host
self.jinja_template_vars['hub_prefix'] = self.hub_prefix self.jinja_template_vars['hub_prefix'] = self.hub_prefix
env = self.web_app.settings['jinja2_env'] env = self.web_app.settings['jinja2_env']
env.globals['hub_control_panel_url'] = \ env.globals['hub_control_panel_url'] = self.hub_host + url_path_join(
self.hub_host + url_path_join(self.hub_prefix, 'home') self.hub_prefix, 'home'
)
# patch jinja env loading to modify page template # patch jinja env loading to modify page template
def get_page(name): def get_page(name):
@@ -633,10 +653,7 @@ class SingleUserNotebookApp(NotebookApp):
return page_template return page_template
orig_loader = env.loader orig_loader = env.loader
env.loader = ChoiceLoader([ env.loader = ChoiceLoader([FunctionLoader(get_page), orig_loader])
FunctionLoader(get_page),
orig_loader,
])
def main(argv=None): def main(argv=None):

View File

@@ -1,10 +1,8 @@
""" """
Contains base Spawner class & default implementation Contains base Spawner class & default implementation
""" """
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import ast import ast
import asyncio import asyncio
import errno import errno
@@ -18,22 +16,35 @@ import warnings
from subprocess import Popen from subprocess import Popen
from tempfile import mkdtemp from tempfile import mkdtemp
# FIXME: remove when we drop Python 3.5 support from async_generator import async_generator
from async_generator import async_generator, yield_ from async_generator import yield_
from sqlalchemy import inspect from sqlalchemy import inspect
from tornado.ioloop import PeriodicCallback from tornado.ioloop import PeriodicCallback
from traitlets import Any
from traitlets import Bool
from traitlets import default
from traitlets import Dict
from traitlets import Float
from traitlets import Instance
from traitlets import Integer
from traitlets import List
from traitlets import observe
from traitlets import Unicode
from traitlets import Union
from traitlets import validate
from traitlets.config import LoggingConfigurable from traitlets.config import LoggingConfigurable
from traitlets import (
Any, Bool, Dict, Instance, Integer, Float, List, Unicode, Union,
default, observe, validate,
)
from .objects import Server from .objects import Server
from .traitlets import Command, ByteSpecification, Callable from .traitlets import ByteSpecification
from .utils import iterate_until, maybe_future, random_port, url_path_join, exponential_backoff from .traitlets import Callable
from .traitlets import Command
from .utils import exponential_backoff
from .utils import iterate_until
from .utils import maybe_future
from .utils import random_port
from .utils import url_path_join
# FIXME: remove when we drop Python 3.5 support
def _quote_safe(s): def _quote_safe(s):
@@ -53,6 +64,7 @@ def _quote_safe(s):
# to avoid getting interpreted by traitlets # to avoid getting interpreted by traitlets
return repr(s) return repr(s)
class Spawner(LoggingConfigurable): class Spawner(LoggingConfigurable):
"""Base class for spawning single-user notebook servers. """Base class for spawning single-user notebook servers.
@@ -146,8 +158,12 @@ class Spawner(LoggingConfigurable):
missing.append(attr) missing.append(attr)
if missing: if missing:
raise NotImplementedError("class `{}` needs to redefine the `start`," raise NotImplementedError(
"`stop` and `poll` methods. `{}` not redefined.".format(cls.__name__, '`, `'.join(missing))) "class `{}` needs to redefine the `start`,"
"`stop` and `poll` methods. `{}` not redefined.".format(
cls.__name__, '`, `'.join(missing)
)
)
proxy_spec = Unicode() proxy_spec = Unicode()
@@ -180,6 +196,7 @@ class Spawner(LoggingConfigurable):
if self.orm_spawner: if self.orm_spawner:
return self.orm_spawner.name return self.orm_spawner.name
return '' return ''
hub = Any() hub = Any()
authenticator = Any() authenticator = Any()
internal_ssl = Bool(False) internal_ssl = Bool(False)
@@ -191,7 +208,8 @@ class Spawner(LoggingConfigurable):
oauth_client_id = Unicode() oauth_client_id = Unicode()
handler = Any() handler = Any()
will_resume = Bool(False, will_resume = Bool(
False,
help="""Whether the Spawner will resume on next start help="""Whether the Spawner will resume on next start
@@ -199,18 +217,20 @@ class Spawner(LoggingConfigurable):
If True, an existing Spawner will resume instead of starting anew If True, an existing Spawner will resume instead of starting anew
(e.g. resuming a Docker container), (e.g. resuming a Docker container),
and API tokens in use when the Spawner stops will not be deleted. and API tokens in use when the Spawner stops will not be deleted.
""" """,
) )
ip = Unicode('', ip = Unicode(
'',
help=""" help="""
The IP address (or hostname) the single-user server should listen on. The IP address (or hostname) the single-user server should listen on.
The JupyterHub proxy implementation should be able to send packets to this interface. The JupyterHub proxy implementation should be able to send packets to this interface.
""" """,
).tag(config=True) ).tag(config=True)
port = Integer(0, port = Integer(
0,
help=""" help="""
The port for single-user servers to listen on. The port for single-user servers to listen on.
@@ -221,7 +241,7 @@ class Spawner(LoggingConfigurable):
e.g. in containers. e.g. in containers.
New in version 0.7. New in version 0.7.
""" """,
).tag(config=True) ).tag(config=True)
consecutive_failure_limit = Integer( consecutive_failure_limit = Integer(
@@ -237,47 +257,48 @@ class Spawner(LoggingConfigurable):
""", """,
).tag(config=True) ).tag(config=True)
start_timeout = Integer(60, start_timeout = Integer(
60,
help=""" help="""
Timeout (in seconds) before giving up on starting of single-user server. Timeout (in seconds) before giving up on starting of single-user server.
This is the timeout for start to return, not the timeout for the server to respond. This is the timeout for start to return, not the timeout for the server to respond.
Callers of spawner.start will assume that startup has failed if it takes longer than this. Callers of spawner.start will assume that startup has failed if it takes longer than this.
start should return when the server process is started and its location is known. start should return when the server process is started and its location is known.
""" """,
).tag(config=True) ).tag(config=True)
http_timeout = Integer(30, http_timeout = Integer(
30,
help=""" help="""
Timeout (in seconds) before giving up on a spawned HTTP server Timeout (in seconds) before giving up on a spawned HTTP server
Once a server has successfully been spawned, this is the amount of time Once a server has successfully been spawned, this is the amount of time
we wait before assuming that the server is unable to accept we wait before assuming that the server is unable to accept
connections. connections.
""" """,
).tag(config=True) ).tag(config=True)
poll_interval = Integer(30, poll_interval = Integer(
30,
help=""" help="""
Interval (in seconds) on which to poll the spawner for single-user server's status. Interval (in seconds) on which to poll the spawner for single-user server's status.
At every poll interval, each spawner's `.poll` method is called, which checks At every poll interval, each spawner's `.poll` method is called, which checks
if the single-user server is still running. If it isn't running, then JupyterHub modifies if the single-user server is still running. If it isn't running, then JupyterHub modifies
its own state accordingly and removes appropriate routes from the configurable proxy. its own state accordingly and removes appropriate routes from the configurable proxy.
""" """,
).tag(config=True) ).tag(config=True)
_callbacks = List() _callbacks = List()
_poll_callback = Any() _poll_callback = Any()
debug = Bool(False, debug = Bool(False, help="Enable debug-logging of the single-user server").tag(
help="Enable debug-logging of the single-user server" config=True
).tag(config=True) )
options_form = Union([ options_form = Union(
Unicode(), [Unicode(), Callable()],
Callable()
],
help=""" help="""
An HTML form for options a user can specify on launching their server. An HTML form for options a user can specify on launching their server.
@@ -303,7 +324,8 @@ class Spawner(LoggingConfigurable):
be called asynchronously if it returns a future, rather than a str. Note that be called asynchronously if it returns a future, rather than a str. Note that
the interface of the spawner class is not deemed stable across versions, the interface of the spawner class is not deemed stable across versions,
so using this functionality might cause your JupyterHub upgrades to break. so using this functionality might cause your JupyterHub upgrades to break.
""").tag(config=True) """,
).tag(config=True)
async def get_options_form(self): async def get_options_form(self):
"""Get the options form """Get the options form
@@ -341,9 +363,11 @@ class Spawner(LoggingConfigurable):
These user options are usually provided by the `options_form` displayed to the user when they start These user options are usually provided by the `options_form` displayed to the user when they start
their server. their server.
""") """
)
env_keep = List([ env_keep = List(
[
'PATH', 'PATH',
'PYTHONPATH', 'PYTHONPATH',
'CONDA_ROOT', 'CONDA_ROOT',
@@ -357,14 +381,16 @@ class Spawner(LoggingConfigurable):
This whitelist is used to ensure that sensitive information in the JupyterHub process's environment This whitelist is used to ensure that sensitive information in the JupyterHub process's environment
(such as `CONFIGPROXY_AUTH_TOKEN`) is not passed to the single-user server's process. (such as `CONFIGPROXY_AUTH_TOKEN`) is not passed to the single-user server's process.
""" """,
).tag(config=True) ).tag(config=True)
env = Dict(help="""Deprecated: use Spawner.get_env or Spawner.environment env = Dict(
help="""Deprecated: use Spawner.get_env or Spawner.environment
- extend Spawner.get_env for adding required env in Spawner subclasses - extend Spawner.get_env for adding required env in Spawner subclasses
- Spawner.environment for config-specified env - Spawner.environment for config-specified env
""") """
)
environment = Dict( environment = Dict(
help=""" help="""
@@ -386,7 +412,8 @@ class Spawner(LoggingConfigurable):
""" """
).tag(config=True) ).tag(config=True)
cmd = Command(['jupyterhub-singleuser'], cmd = Command(
['jupyterhub-singleuser'],
allow_none=True, allow_none=True,
help=""" help="""
The command used for starting the single-user server. The command used for starting the single-user server.
@@ -399,16 +426,17 @@ class Spawner(LoggingConfigurable):
Some spawners allow shell-style expansion here, allowing you to use environment variables. Some spawners allow shell-style expansion here, allowing you to use environment variables.
Most, including the default, do not. Consult the documentation for your spawner to verify! Most, including the default, do not. Consult the documentation for your spawner to verify!
""" """,
).tag(config=True) ).tag(config=True)
args = List(Unicode(), args = List(
Unicode(),
help=""" help="""
Extra arguments to be passed to the single-user server. Extra arguments to be passed to the single-user server.
Some spawners allow shell-style expansion here, allowing you to use environment variables here. Some spawners allow shell-style expansion here, allowing you to use environment variables here.
Most, including the default, do not. Consult the documentation for your spawner to verify! Most, including the default, do not. Consult the documentation for your spawner to verify!
""" """,
).tag(config=True) ).tag(config=True)
notebook_dir = Unicode( notebook_dir = Unicode(
@@ -446,14 +474,16 @@ class Spawner(LoggingConfigurable):
def _deprecate_percent_u(self, proposal): def _deprecate_percent_u(self, proposal):
v = proposal['value'] v = proposal['value']
if '%U' in v: if '%U' in v:
self.log.warning("%%U for username in %s is deprecated in JupyterHub 0.7, use {username}", self.log.warning(
"%%U for username in %s is deprecated in JupyterHub 0.7, use {username}",
proposal['trait'].name, proposal['trait'].name,
) )
v = v.replace('%U', '{username}') v = v.replace('%U', '{username}')
self.log.warning("Converting %r to %r", proposal['value'], v) self.log.warning("Converting %r to %r", proposal['value'], v)
return v return v
disable_user_config = Bool(False, disable_user_config = Bool(
False,
help=""" help="""
Disable per-user configuration of single-user servers. Disable per-user configuration of single-user servers.
@@ -462,10 +492,11 @@ class Spawner(LoggingConfigurable):
Note: a user could circumvent this if the user modifies their Python environment, such as when Note: a user could circumvent this if the user modifies their Python environment, such as when
they have their own conda environments / virtualenvs / containers. they have their own conda environments / virtualenvs / containers.
""" """,
).tag(config=True) ).tag(config=True)
mem_limit = ByteSpecification(None, mem_limit = ByteSpecification(
None,
help=""" help="""
Maximum number of bytes a single-user notebook server is allowed to use. Maximum number of bytes a single-user notebook server is allowed to use.
@@ -484,10 +515,11 @@ class Spawner(LoggingConfigurable):
for the limit to work.** The default spawner, `LocalProcessSpawner`, for the limit to work.** The default spawner, `LocalProcessSpawner`,
does **not** implement this support. A custom spawner **must** add does **not** implement this support. A custom spawner **must** add
support for this setting for it to be enforced. support for this setting for it to be enforced.
""" """,
).tag(config=True) ).tag(config=True)
cpu_limit = Float(None, cpu_limit = Float(
None,
allow_none=True, allow_none=True,
help=""" help="""
Maximum number of cpu-cores a single-user notebook server is allowed to use. Maximum number of cpu-cores a single-user notebook server is allowed to use.
@@ -503,10 +535,11 @@ class Spawner(LoggingConfigurable):
for the limit to work.** The default spawner, `LocalProcessSpawner`, for the limit to work.** The default spawner, `LocalProcessSpawner`,
does **not** implement this support. A custom spawner **must** add does **not** implement this support. A custom spawner **must** add
support for this setting for it to be enforced. support for this setting for it to be enforced.
""" """,
).tag(config=True) ).tag(config=True)
mem_guarantee = ByteSpecification(None, mem_guarantee = ByteSpecification(
None,
help=""" help="""
Minimum number of bytes a single-user notebook server is guaranteed to have available. Minimum number of bytes a single-user notebook server is guaranteed to have available.
@@ -520,10 +553,11 @@ class Spawner(LoggingConfigurable):
for the limit to work.** The default spawner, `LocalProcessSpawner`, for the limit to work.** The default spawner, `LocalProcessSpawner`,
does **not** implement this support. A custom spawner **must** add does **not** implement this support. A custom spawner **must** add
support for this setting for it to be enforced. support for this setting for it to be enforced.
""" """,
).tag(config=True) ).tag(config=True)
cpu_guarantee = Float(None, cpu_guarantee = Float(
None,
allow_none=True, allow_none=True,
help=""" help="""
Minimum number of cpu-cores a single-user notebook server is guaranteed to have available. Minimum number of cpu-cores a single-user notebook server is guaranteed to have available.
@@ -535,7 +569,7 @@ class Spawner(LoggingConfigurable):
for the limit to work.** The default spawner, `LocalProcessSpawner`, for the limit to work.** The default spawner, `LocalProcessSpawner`,
does **not** implement this support. A custom spawner **must** add does **not** implement this support. A custom spawner **must** add
support for this setting for it to be enforced. support for this setting for it to be enforced.
""" """,
).tag(config=True) ).tag(config=True)
pre_spawn_hook = Any( pre_spawn_hook = Any(
@@ -621,7 +655,9 @@ class Spawner(LoggingConfigurable):
""" """
env = {} env = {}
if self.env: if self.env:
warnings.warn("Spawner.env is deprecated, found %s" % self.env, DeprecationWarning) warnings.warn(
"Spawner.env is deprecated, found %s" % self.env, DeprecationWarning
)
env.update(self.env) env.update(self.env)
for key in self.env_keep: for key in self.env_keep:
@@ -648,8 +684,9 @@ class Spawner(LoggingConfigurable):
if self.cookie_options: if self.cookie_options:
env['JUPYTERHUB_COOKIE_OPTIONS'] = json.dumps(self.cookie_options) env['JUPYTERHUB_COOKIE_OPTIONS'] = json.dumps(self.cookie_options)
env['JUPYTERHUB_HOST'] = self.hub.public_host env['JUPYTERHUB_HOST'] = self.hub.public_host
env['JUPYTERHUB_OAUTH_CALLBACK_URL'] = \ env['JUPYTERHUB_OAUTH_CALLBACK_URL'] = url_path_join(
url_path_join(self.user.url, self.name, 'oauth_callback') self.user.url, self.name, 'oauth_callback'
)
# Info previously passed on args # Info previously passed on args
env['JUPYTERHUB_USER'] = self.user.name env['JUPYTERHUB_USER'] = self.user.name
@@ -749,7 +786,7 @@ class Spawner(LoggingConfigurable):
May be set in config if all spawners should have the same value(s), May be set in config if all spawners should have the same value(s),
or set at runtime by Spawner that know their names. or set at runtime by Spawner that know their names.
""" """,
) )
@default('ssl_alt_names') @default('ssl_alt_names')
@@ -793,6 +830,7 @@ class Spawner(LoggingConfigurable):
to the host by either IP or DNS name (note the `default_names` below). to the host by either IP or DNS name (note the `default_names` below).
""" """
from certipy import Certipy from certipy import Certipy
default_names = ["DNS:localhost", "IP:127.0.0.1"] default_names = ["DNS:localhost", "IP:127.0.0.1"]
alt_names = [] alt_names = []
alt_names.extend(self.ssl_alt_names) alt_names.extend(self.ssl_alt_names)
@@ -800,10 +838,7 @@ class Spawner(LoggingConfigurable):
if self.ssl_alt_names_include_local: if self.ssl_alt_names_include_local:
alt_names = default_names + alt_names alt_names = default_names + alt_names
self.log.info("Creating certs for %s: %s", self.log.info("Creating certs for %s: %s", self._log_name, ';'.join(alt_names))
self._log_name,
';'.join(alt_names),
)
common_name = self.user.name or 'service' common_name = self.user.name or 'service'
certipy = Certipy(store_dir=self.internal_certs_location) certipy = Certipy(store_dir=self.internal_certs_location)
@@ -812,7 +847,7 @@ class Spawner(LoggingConfigurable):
'user-' + common_name, 'user-' + common_name,
notebook_component, notebook_component,
alt_names=alt_names, alt_names=alt_names,
overwrite=True overwrite=True,
) )
paths = { paths = {
"keyfile": notebook_key_pair['files']['key'], "keyfile": notebook_key_pair['files']['key'],
@@ -862,7 +897,9 @@ class Spawner(LoggingConfigurable):
if self.port: if self.port:
args.append('--port=%i' % self.port) args.append('--port=%i' % self.port)
elif self.server and self.server.port: elif self.server and self.server.port:
self.log.warning("Setting port from user.server is deprecated as of JupyterHub 0.7.") self.log.warning(
"Setting port from user.server is deprecated as of JupyterHub 0.7."
)
args.append('--port=%i' % self.server.port) args.append('--port=%i' % self.server.port)
if self.notebook_dir: if self.notebook_dir:
@@ -903,13 +940,12 @@ class Spawner(LoggingConfigurable):
This method is always an async generator and will always yield at least one event. This method is always an async generator and will always yield at least one event.
""" """
if not self._spawn_pending: if not self._spawn_pending:
self.log.warning("Spawn not pending, can't generate progress for %s", self._log_name) self.log.warning(
"Spawn not pending, can't generate progress for %s", self._log_name
)
return return
await yield_({ await yield_({"progress": 0, "message": "Server requested"})
"progress": 0,
"message": "Server requested",
})
from async_generator import aclosing from async_generator import aclosing
async with aclosing(self.progress()) as progress: async with aclosing(self.progress()) as progress:
@@ -940,10 +976,7 @@ class Spawner(LoggingConfigurable):
.. versionadded:: 0.9 .. versionadded:: 0.9
""" """
await yield_({ await yield_({"progress": 50, "message": "Spawning server..."})
"progress": 50,
"message": "Spawning server...",
})
async def start(self): async def start(self):
"""Start the single-user server """Start the single-user server
@@ -954,7 +987,9 @@ class Spawner(LoggingConfigurable):
.. versionchanged:: 0.7 .. versionchanged:: 0.7
Return ip, port instead of setting on self.user.server directly. Return ip, port instead of setting on self.user.server directly.
""" """
raise NotImplementedError("Override in subclass. Must be a Tornado gen.coroutine.") raise NotImplementedError(
"Override in subclass. Must be a Tornado gen.coroutine."
)
async def stop(self, now=False): async def stop(self, now=False):
"""Stop the single-user server """Stop the single-user server
@@ -967,7 +1002,9 @@ class Spawner(LoggingConfigurable):
Must be a coroutine. Must be a coroutine.
""" """
raise NotImplementedError("Override in subclass. Must be a Tornado gen.coroutine.") raise NotImplementedError(
"Override in subclass. Must be a Tornado gen.coroutine."
)
async def poll(self): async def poll(self):
"""Check if the single-user process is running """Check if the single-user process is running
@@ -993,7 +1030,9 @@ class Spawner(LoggingConfigurable):
process has not yet completed. process has not yet completed.
""" """
raise NotImplementedError("Override in subclass. Must be a Tornado gen.coroutine.") raise NotImplementedError(
"Override in subclass. Must be a Tornado gen.coroutine."
)
def add_poll_callback(self, callback, *args, **kwargs): def add_poll_callback(self, callback, *args, **kwargs):
"""Add a callback to fire when the single-user server stops""" """Add a callback to fire when the single-user server stops"""
@@ -1023,8 +1062,7 @@ class Spawner(LoggingConfigurable):
self.stop_polling() self.stop_polling()
self._poll_callback = PeriodicCallback( self._poll_callback = PeriodicCallback(
self.poll_and_notify, self.poll_and_notify, 1e3 * self.poll_interval
1e3 * self.poll_interval
) )
self._poll_callback.start() self._poll_callback.start()
@@ -1048,8 +1086,10 @@ class Spawner(LoggingConfigurable):
return status return status
death_interval = Float(0.1) death_interval = Float(0.1)
async def wait_for_death(self, timeout=10): async def wait_for_death(self, timeout=10):
"""Wait for the single-user server to die, up to timeout seconds""" """Wait for the single-user server to die, up to timeout seconds"""
async def _wait_for_death(): async def _wait_for_death():
status = await self.poll() status = await self.poll()
return status is not None return status is not None
@@ -1093,6 +1133,7 @@ def set_user_setuid(username, chdir=True):
""" """
import grp import grp
import pwd import pwd
user = pwd.getpwnam(username) user = pwd.getpwnam(username)
uid = user.pw_uid uid = user.pw_uid
gid = user.pw_gid gid = user.pw_gid
@@ -1132,29 +1173,32 @@ class LocalProcessSpawner(Spawner):
Note: This spawner does not implement CPU / memory guarantees and limits. Note: This spawner does not implement CPU / memory guarantees and limits.
""" """
interrupt_timeout = Integer(10, interrupt_timeout = Integer(
10,
help=""" help="""
Seconds to wait for single-user server process to halt after SIGINT. Seconds to wait for single-user server process to halt after SIGINT.
If the process has not exited cleanly after this many seconds, a SIGTERM is sent. If the process has not exited cleanly after this many seconds, a SIGTERM is sent.
""" """,
).tag(config=True) ).tag(config=True)
term_timeout = Integer(5, term_timeout = Integer(
5,
help=""" help="""
Seconds to wait for single-user server process to halt after SIGTERM. Seconds to wait for single-user server process to halt after SIGTERM.
If the process does not exit cleanly after this many seconds of SIGTERM, a SIGKILL is sent. If the process does not exit cleanly after this many seconds of SIGTERM, a SIGKILL is sent.
""" """,
).tag(config=True) ).tag(config=True)
kill_timeout = Integer(5, kill_timeout = Integer(
5,
help=""" help="""
Seconds to wait for process to halt after SIGKILL before giving up. Seconds to wait for process to halt after SIGKILL before giving up.
If the process does not exit cleanly after this many seconds of SIGKILL, it becomes a zombie If the process does not exit cleanly after this many seconds of SIGKILL, it becomes a zombie
process. The hub process will log a warning and then give up. process. The hub process will log a warning and then give up.
""" """,
).tag(config=True) ).tag(config=True)
popen_kwargs = Dict( popen_kwargs = Dict(
@@ -1168,7 +1212,8 @@ class LocalProcessSpawner(Spawner):
""" """
).tag(config=True) ).tag(config=True)
shell_cmd = Command(minlen=0, shell_cmd = Command(
minlen=0,
help="""Specify a shell command to launch. help="""Specify a shell command to launch.
The single-user command will be appended to this list, The single-user command will be appended to this list,
@@ -1185,20 +1230,23 @@ class LocalProcessSpawner(Spawner):
Using shell_cmd gives users control over PATH, etc., Using shell_cmd gives users control over PATH, etc.,
which could change what the jupyterhub-singleuser launch command does. which could change what the jupyterhub-singleuser launch command does.
Only use this for trusted users. Only use this for trusted users.
""" """,
).tag(config=True) ).tag(config=True)
proc = Instance(Popen, proc = Instance(
Popen,
allow_none=True, allow_none=True,
help=""" help="""
The process representing the single-user server process spawned for current user. The process representing the single-user server process spawned for current user.
Is None if no process has been spawned yet. Is None if no process has been spawned yet.
""") """,
pid = Integer(0, )
pid = Integer(
0,
help=""" help="""
The process id (pid) of the single-user server process spawned for current user. The process id (pid) of the single-user server process spawned for current user.
""" """,
) )
def make_preexec_fn(self, name): def make_preexec_fn(self, name):
@@ -1236,6 +1284,7 @@ class LocalProcessSpawner(Spawner):
def user_env(self, env): def user_env(self, env):
"""Augment environment of spawned process with user specific env variables.""" """Augment environment of spawned process with user specific env variables."""
import pwd import pwd
env['USER'] = self.user.name env['USER'] = self.user.name
home = pwd.getpwnam(self.user.name).pw_dir home = pwd.getpwnam(self.user.name).pw_dir
shell = pwd.getpwnam(self.user.name).pw_shell shell = pwd.getpwnam(self.user.name).pw_shell
@@ -1267,6 +1316,7 @@ class LocalProcessSpawner(Spawner):
and make them readable by the user. and make them readable by the user.
""" """
import pwd import pwd
key = paths['keyfile'] key = paths['keyfile']
cert = paths['certfile'] cert = paths['certfile']
ca = paths['cafile'] ca = paths['cafile']
@@ -1324,8 +1374,10 @@ class LocalProcessSpawner(Spawner):
except PermissionError: except PermissionError:
# use which to get abspath # use which to get abspath
script = shutil.which(cmd[0]) or cmd[0] script = shutil.which(cmd[0]) or cmd[0]
self.log.error("Permission denied trying to run %r. Does %s have access to this file?", self.log.error(
script, self.user.name, "Permission denied trying to run %r. Does %s have access to this file?",
script,
self.user.name,
) )
raise raise
@@ -1445,24 +1497,25 @@ class SimpleLocalProcessSpawner(LocalProcessSpawner):
help=""" help="""
Template to expand to set the user home. Template to expand to set the user home.
{username} is expanded to the jupyterhub username. {username} is expanded to the jupyterhub username.
""" """,
) )
home_dir = Unicode(help="The home directory for the user") home_dir = Unicode(help="The home directory for the user")
@default('home_dir') @default('home_dir')
def _default_home_dir(self): def _default_home_dir(self):
return self.home_dir_template.format( return self.home_dir_template.format(username=self.user.name)
username=self.user.name,
)
def make_preexec_fn(self, name): def make_preexec_fn(self, name):
home = self.home_dir home = self.home_dir
def preexec(): def preexec():
try: try:
os.makedirs(home, 0o755, exist_ok=True) os.makedirs(home, 0o755, exist_ok=True)
os.chdir(home) os.chdir(home)
except Exception as e: except Exception as e:
self.log.exception("Error in preexec for %s", name) self.log.exception("Error in preexec for %s", name)
return preexec return preexec
def user_env(self, env): def user_env(self, env):
@@ -1474,4 +1527,3 @@ class SimpleLocalProcessSpawner(LocalProcessSpawner):
def move_certs(self, paths): def move_certs(self, paths):
"""No-op for installing certs""" """No-op for installing certs"""
return paths return paths

View File

@@ -23,34 +23,33 @@ Fixtures to add functionality or spawning behavior
- `slow_bad_spawn` - `slow_bad_spawn`
""" """
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import asyncio import asyncio
from getpass import getuser
import inspect import inspect
import logging import logging
import os import os
import sys import sys
from getpass import getuser
from subprocess import TimeoutExpired from subprocess import TimeoutExpired
from unittest import mock from unittest import mock
from pytest import fixture, raises from pytest import fixture
from tornado import ioloop, gen from pytest import raises
from tornado import gen
from tornado import ioloop
from tornado.httpclient import HTTPError from tornado.httpclient import HTTPError
from tornado.platform.asyncio import AsyncIOMainLoop from tornado.platform.asyncio import AsyncIOMainLoop
from .. import orm
from .. import crypto
from ..utils import random_port
from . import mocking
from .mocking import MockHub
from .utils import ssl_setup, add_user
from .test_services import mockservice_cmd
import jupyterhub.services.service import jupyterhub.services.service
from . import mocking
from .. import crypto
from .. import orm
from ..utils import random_port
from .mocking import MockHub
from .test_services import mockservice_cmd
from .utils import add_user
from .utils import ssl_setup
# global db session object # global db session object
_db = None _db = None
@@ -78,12 +77,7 @@ def app(request, io_loop, ssl_tmpdir):
ssl_enabled = getattr(request.module, "ssl_enabled", False) ssl_enabled = getattr(request.module, "ssl_enabled", False)
kwargs = dict() kwargs = dict()
if ssl_enabled: if ssl_enabled:
kwargs.update( kwargs.update(dict(internal_ssl=True, internal_certs_location=str(ssl_tmpdir)))
dict(
internal_ssl=True,
internal_certs_location=str(ssl_tmpdir),
)
)
mocked_app = MockHub.instance(**kwargs) mocked_app = MockHub.instance(**kwargs)
@@ -107,9 +101,7 @@ def app(request, io_loop, ssl_tmpdir):
@fixture @fixture
def auth_state_enabled(app): def auth_state_enabled(app):
app.authenticator.auth_state = { app.authenticator.auth_state = {'who': 'cares'}
'who': 'cares',
}
app.authenticator.enable_auth_state = True app.authenticator.enable_auth_state = True
ck = crypto.CryptKeeper.instance() ck = crypto.CryptKeeper.instance()
before_keys = ck.keys before_keys = ck.keys
@@ -128,9 +120,7 @@ def db():
global _db global _db
if _db is None: if _db is None:
_db = orm.new_session_factory('sqlite:///:memory:')() _db = orm.new_session_factory('sqlite:///:memory:')()
user = orm.User( user = orm.User(name=getuser())
name=getuser(),
)
_db.add(user) _db.add(user)
_db.commit() _db.commit()
return _db return _db
@@ -221,12 +211,12 @@ def admin_user(app, username):
yield user yield user
class MockServiceSpawner(jupyterhub.services.service._ServiceSpawner): class MockServiceSpawner(jupyterhub.services.service._ServiceSpawner):
"""mock services for testing. """mock services for testing.
Shorter intervals, etc. Shorter intervals, etc.
""" """
poll_interval = 1 poll_interval = 1
@@ -237,11 +227,7 @@ def _mockservice(request, app, url=False):
global _mock_service_counter global _mock_service_counter
_mock_service_counter += 1 _mock_service_counter += 1
name = 'mock-service-%i' % _mock_service_counter name = 'mock-service-%i' % _mock_service_counter
spec = { spec = {'name': name, 'command': mockservice_cmd, 'admin': True}
'name': name,
'command': mockservice_cmd,
'admin': True,
}
if url: if url:
if app.internal_ssl: if app.internal_ssl:
spec['url'] = 'https://127.0.0.1:%i' % random_port() spec['url'] = 'https://127.0.0.1:%i' % random_port()
@@ -250,22 +236,29 @@ def _mockservice(request, app, url=False):
io_loop = app.io_loop io_loop = app.io_loop
with mock.patch.object(jupyterhub.services.service, '_ServiceSpawner', MockServiceSpawner): with mock.patch.object(
jupyterhub.services.service, '_ServiceSpawner', MockServiceSpawner
):
app.services = [spec] app.services = [spec]
app.init_services() app.init_services()
assert name in app._service_map assert name in app._service_map
service = app._service_map[name] service = app._service_map[name]
@gen.coroutine @gen.coroutine
def start(): def start():
# wait for proxy to be updated before starting the service # wait for proxy to be updated before starting the service
yield app.proxy.add_all_services(app._service_map) yield app.proxy.add_all_services(app._service_map)
service.start() service.start()
io_loop.run_sync(start) io_loop.run_sync(start)
def cleanup(): def cleanup():
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(service.stop()) asyncio.get_event_loop().run_until_complete(service.stop())
app.services[:] = [] app.services[:] = []
app._service_map.clear() app._service_map.clear()
request.addfinalizer(cleanup) request.addfinalizer(cleanup)
# ensure process finishes starting # ensure process finishes starting
with raises(TimeoutExpired): with raises(TimeoutExpired):
@@ -290,47 +283,44 @@ def mockservice_url(request, app):
@fixture @fixture
def admin_access(app): def admin_access(app):
"""Grant admin-access with this fixture""" """Grant admin-access with this fixture"""
with mock.patch.dict(app.tornado_settings, with mock.patch.dict(app.tornado_settings, {'admin_access': True}):
{'admin_access': True}):
yield yield
@fixture @fixture
def no_patience(app): def no_patience(app):
"""Set slow-spawning timeouts to zero""" """Set slow-spawning timeouts to zero"""
with mock.patch.dict(app.tornado_settings, with mock.patch.dict(
{'slow_spawn_timeout': 0.1, app.tornado_settings, {'slow_spawn_timeout': 0.1, 'slow_stop_timeout': 0.1}
'slow_stop_timeout': 0.1}): ):
yield yield
@fixture @fixture
def slow_spawn(app): def slow_spawn(app):
"""Fixture enabling SlowSpawner""" """Fixture enabling SlowSpawner"""
with mock.patch.dict(app.tornado_settings, with mock.patch.dict(app.tornado_settings, {'spawner_class': mocking.SlowSpawner}):
{'spawner_class': mocking.SlowSpawner}):
yield yield
@fixture @fixture
def never_spawn(app): def never_spawn(app):
"""Fixture enabling NeverSpawner""" """Fixture enabling NeverSpawner"""
with mock.patch.dict(app.tornado_settings, with mock.patch.dict(app.tornado_settings, {'spawner_class': mocking.NeverSpawner}):
{'spawner_class': mocking.NeverSpawner}):
yield yield
@fixture @fixture
def bad_spawn(app): def bad_spawn(app):
"""Fixture enabling BadSpawner""" """Fixture enabling BadSpawner"""
with mock.patch.dict(app.tornado_settings, with mock.patch.dict(app.tornado_settings, {'spawner_class': mocking.BadSpawner}):
{'spawner_class': mocking.BadSpawner}):
yield yield
@fixture @fixture
def slow_bad_spawn(app): def slow_bad_spawn(app):
"""Fixture enabling SlowBadSpawner""" """Fixture enabling SlowBadSpawner"""
with mock.patch.dict(app.tornado_settings, with mock.patch.dict(
{'spawner_class': mocking.SlowBadSpawner}): app.tornado_settings, {'spawner_class': mocking.SlowBadSpawner}
):
yield yield

View File

@@ -26,32 +26,36 @@ Other components
- public_url - public_url
""" """
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor
import os import os
import sys import sys
from tempfile import NamedTemporaryFile
import threading import threading
from concurrent.futures import ThreadPoolExecutor
from tempfile import NamedTemporaryFile
from unittest import mock from unittest import mock
from urllib.parse import urlparse from urllib.parse import urlparse
from pamela import PAMError
from tornado import gen from tornado import gen
from tornado.concurrent import Future from tornado.concurrent import Future
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from traitlets import Bool
from traitlets import default
from traitlets import Dict
from traitlets import Bool, Dict, default from .. import orm
from ..app import JupyterHub from ..app import JupyterHub
from ..auth import PAMAuthenticator from ..auth import PAMAuthenticator
from .. import orm
from ..objects import Server from ..objects import Server
from ..spawner import LocalProcessSpawner, SimpleLocalProcessSpawner
from ..singleuser import SingleUserNotebookApp from ..singleuser import SingleUserNotebookApp
from ..utils import random_port, url_path_join from ..spawner import LocalProcessSpawner
from .utils import async_requests, ssl_setup, public_host, public_url from ..spawner import SimpleLocalProcessSpawner
from ..utils import random_port
from pamela import PAMError from ..utils import url_path_join
from .utils import async_requests
from .utils import public_host
from .utils import public_url
from .utils import ssl_setup
def mock_authenticate(username, password, service, encoding): def mock_authenticate(username, password, service, encoding):
@@ -79,6 +83,7 @@ class MockSpawner(SimpleLocalProcessSpawner):
- disables user-switching that we need root permissions to do - disables user-switching that we need root permissions to do
- spawns `jupyterhub.tests.mocksu` instead of a full single-user server - spawns `jupyterhub.tests.mocksu` instead of a full single-user server
""" """
def user_env(self, env): def user_env(self, env):
env = super().user_env(env) env = super().user_env(env)
if self.handler: if self.handler:
@@ -90,6 +95,7 @@ class MockSpawner(SimpleLocalProcessSpawner):
return [sys.executable, '-m', 'jupyterhub.tests.mocksu'] return [sys.executable, '-m', 'jupyterhub.tests.mocksu']
use_this_api_token = None use_this_api_token = None
def start(self): def start(self):
if self.use_this_api_token: if self.use_this_api_token:
self.api_token = self.use_this_api_token self.api_token = self.use_this_api_token
@@ -103,6 +109,7 @@ class SlowSpawner(MockSpawner):
delay = 2 delay = 2
_start_future = None _start_future = None
@gen.coroutine @gen.coroutine
def start(self): def start(self):
(ip, port) = yield super().start() (ip, port) = yield super().start()
@@ -140,6 +147,7 @@ class NeverSpawner(MockSpawner):
class BadSpawner(MockSpawner): class BadSpawner(MockSpawner):
"""Spawner that fails immediately""" """Spawner that fails immediately"""
def start(self): def start(self):
raise RuntimeError("I don't work!") raise RuntimeError("I don't work!")
@@ -154,6 +162,7 @@ class SlowBadSpawner(MockSpawner):
class FormSpawner(MockSpawner): class FormSpawner(MockSpawner):
"""A spawner that has an options form defined""" """A spawner that has an options form defined"""
options_form = "IMAFORM" options_form = "IMAFORM"
def options_from_form(self, form_data): def options_from_form(self, form_data):
@@ -167,6 +176,7 @@ class FormSpawner(MockSpawner):
options['hello'] = form_data['hello_file'][0] options['hello'] = form_data['hello_file'][0]
return options return options
class FalsyCallableFormSpawner(FormSpawner): class FalsyCallableFormSpawner(FormSpawner):
"""A spawner that has a callable options form defined returning a falsy value""" """A spawner that has a callable options form defined returning a falsy value"""
@@ -181,6 +191,7 @@ class MockStructGroup:
self.gr_mem = members self.gr_mem = members
self.gr_gid = gid self.gr_gid = gid
class MockStructPasswd: class MockStructPasswd:
"""Mock pwd.struct_passwd""" """Mock pwd.struct_passwd"""
@@ -193,6 +204,7 @@ class MockPAMAuthenticator(PAMAuthenticator):
auth_state = None auth_state = None
# If true, return admin users marked as admin. # If true, return admin users marked as admin.
return_admin = False return_admin = False
@default('admin_users') @default('admin_users')
def _admin_users_default(self): def _admin_users_default(self):
return {'admin'} return {'admin'}
@@ -203,20 +215,20 @@ class MockPAMAuthenticator(PAMAuthenticator):
@gen.coroutine @gen.coroutine
def authenticate(self, *args, **kwargs): def authenticate(self, *args, **kwargs):
with mock.patch.multiple('pamela', with mock.patch.multiple(
'pamela',
authenticate=mock_authenticate, authenticate=mock_authenticate,
open_session=mock_open_session, open_session=mock_open_session,
close_session=mock_open_session, close_session=mock_open_session,
check_account=mock_check_account, check_account=mock_check_account,
): ):
username = yield super(MockPAMAuthenticator, self).authenticate(*args, **kwargs) username = yield super(MockPAMAuthenticator, self).authenticate(
*args, **kwargs
)
if username is None: if username is None:
return return
elif self.auth_state: elif self.auth_state:
return { return {'name': username, 'auth_state': self.auth_state}
'name': username,
'auth_state': self.auth_state,
}
else: else:
return username return username
@@ -349,11 +361,9 @@ class MockHub(JupyterHub):
external_ca = None external_ca = None
if self.internal_ssl: if self.internal_ssl:
external_ca = self.external_certs['files']['ca'] external_ca = self.external_certs['files']['ca']
r = yield async_requests.post(base_url + 'hub/login', r = yield async_requests.post(
data={ base_url + 'hub/login',
'username': name, data={'username': name, 'password': name},
'password': name,
},
allow_redirects=False, allow_redirects=False,
verify=external_ca, verify=external_ca,
) )
@@ -364,6 +374,7 @@ class MockHub(JupyterHub):
# single-user-server mocking: # single-user-server mocking:
class MockSingleUserServer(SingleUserNotebookApp): class MockSingleUserServer(SingleUserNotebookApp):
"""Mock-out problematic parts of single-user server when run in a thread """Mock-out problematic parts of single-user server when run in a thread
@@ -378,7 +389,9 @@ class MockSingleUserServer(SingleUserNotebookApp):
class StubSingleUserSpawner(MockSpawner): class StubSingleUserSpawner(MockSpawner):
"""Spawner that starts a MockSingleUserServer in a thread.""" """Spawner that starts a MockSingleUserServer in a thread."""
_thread = None _thread = None
@gen.coroutine @gen.coroutine
def start(self): def start(self):
ip = self.ip = '127.0.0.1' ip = self.ip = '127.0.0.1'
@@ -387,6 +400,7 @@ class StubSingleUserSpawner(MockSpawner):
args = self.get_args() args = self.get_args()
evt = threading.Event() evt = threading.Event()
print(args, env) print(args, env)
def _run(): def _run():
asyncio.set_event_loop(asyncio.new_event_loop()) asyncio.set_event_loop(asyncio.new_event_loop())
io_loop = IOLoop() io_loop = IOLoop()
@@ -420,4 +434,3 @@ class StubSingleUserSpawner(MockSpawner):
return None return None
else: else:
return 0 return 0

View File

@@ -12,28 +12,33 @@ Handlers and their purpose include:
- WhoAmIHandler: returns name of user making a request (deprecated cookie login) - WhoAmIHandler: returns name of user making a request (deprecated cookie login)
- OWhoAmIHandler: returns name of user making a request (OAuth login) - OWhoAmIHandler: returns name of user making a request (OAuth login)
""" """
import json import json
import pprint
import os import os
import pprint
import sys import sys
from urllib.parse import urlparse from urllib.parse import urlparse
import requests import requests
from tornado import web, httpserver, ioloop from tornado import httpserver
from tornado import ioloop
from tornado import web
from jupyterhub.services.auth import HubAuthenticated, HubOAuthenticated, HubOAuthCallbackHandler from jupyterhub.services.auth import HubAuthenticated
from jupyterhub.services.auth import HubOAuthCallbackHandler
from jupyterhub.services.auth import HubOAuthenticated
from jupyterhub.utils import make_ssl_context from jupyterhub.utils import make_ssl_context
class EchoHandler(web.RequestHandler): class EchoHandler(web.RequestHandler):
"""Reply to an HTTP request with the path of the request.""" """Reply to an HTTP request with the path of the request."""
def get(self): def get(self):
self.write(self.request.path) self.write(self.request.path)
class EnvHandler(web.RequestHandler): class EnvHandler(web.RequestHandler):
"""Reply to an HTTP request with the service's environment as JSON.""" """Reply to an HTTP request with the service's environment as JSON."""
def get(self): def get(self):
self.set_header('Content-Type', 'application/json') self.set_header('Content-Type', 'application/json')
self.write(json.dumps(dict(os.environ))) self.write(json.dumps(dict(os.environ)))
@@ -41,11 +46,12 @@ class EnvHandler(web.RequestHandler):
class APIHandler(web.RequestHandler): class APIHandler(web.RequestHandler):
"""Relay API requests to the Hub's API using the service's API token.""" """Relay API requests to the Hub's API using the service's API token."""
def get(self, path): def get(self, path):
api_token = os.environ['JUPYTERHUB_API_TOKEN'] api_token = os.environ['JUPYTERHUB_API_TOKEN']
api_url = os.environ['JUPYTERHUB_API_URL'] api_url = os.environ['JUPYTERHUB_API_URL']
r = requests.get(api_url + path, r = requests.get(
headers={'Authorization': 'token %s' % api_token}, api_url + path, headers={'Authorization': 'token %s' % api_token}
) )
r.raise_for_status() r.raise_for_status()
self.set_header('Content-Type', 'application/json') self.set_header('Content-Type', 'application/json')
@@ -57,6 +63,7 @@ class WhoAmIHandler(HubAuthenticated, web.RequestHandler):
Uses "deprecated" cookie login Uses "deprecated" cookie login
""" """
@web.authenticated @web.authenticated
def get(self): def get(self):
self.write(self.get_current_user()) self.write(self.get_current_user())
@@ -67,6 +74,7 @@ class OWhoAmIHandler(HubOAuthenticated, web.RequestHandler):
Uses OAuth login flow Uses OAuth login flow
""" """
@web.authenticated @web.authenticated
def get(self): def get(self):
self.write(self.get_current_user()) self.write(self.get_current_user())
@@ -77,14 +85,17 @@ def main():
if os.getenv('JUPYTERHUB_SERVICE_URL'): if os.getenv('JUPYTERHUB_SERVICE_URL'):
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL']) url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
app = web.Application([ app = web.Application(
[
(r'.*/env', EnvHandler), (r'.*/env', EnvHandler),
(r'.*/api/(.*)', APIHandler), (r'.*/api/(.*)', APIHandler),
(r'.*/whoami/?', WhoAmIHandler), (r'.*/whoami/?', WhoAmIHandler),
(r'.*/owhoami/?', OWhoAmIHandler), (r'.*/owhoami/?', OWhoAmIHandler),
(r'.*/oauth_callback', HubOAuthCallbackHandler), (r'.*/oauth_callback', HubOAuthCallbackHandler),
(r'.*', EchoHandler), (r'.*', EchoHandler),
], cookie_secret=os.urandom(32)) ],
cookie_secret=os.urandom(32),
)
ssl_context = None ssl_context = None
key = os.environ.get('JUPYTERHUB_SSL_KEYFILE') or '' key = os.environ.get('JUPYTERHUB_SSL_KEYFILE') or ''
@@ -92,11 +103,7 @@ def main():
ca = os.environ.get('JUPYTERHUB_SSL_CLIENT_CA') or '' ca = os.environ.get('JUPYTERHUB_SSL_CLIENT_CA') or ''
if key and cert and ca: if key and cert and ca:
ssl_context = make_ssl_context( ssl_context = make_ssl_context(key, cert, cafile=ca, check_hostname=False)
key,
cert,
cafile = ca,
check_hostname = False)
server = httpserver.HTTPServer(app, ssl_options=ssl_context) server = httpserver.HTTPServer(app, ssl_options=ssl_context)
server.listen(url.port, url.hostname) server.listen(url.port, url.hostname)
@@ -108,5 +115,6 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
from tornado.options import parse_command_line from tornado.options import parse_command_line
parse_command_line() parse_command_line()
main() main()

View File

@@ -13,28 +13,32 @@ Handlers and their purpose include:
""" """
import argparse import argparse
import json import json
import sys
import os import os
import sys
from tornado import httpserver
from tornado import ioloop
from tornado import web
from tornado import web, httpserver, ioloop
from .mockservice import EnvHandler
from ..utils import make_ssl_context from ..utils import make_ssl_context
from .mockservice import EnvHandler
class EchoHandler(web.RequestHandler): class EchoHandler(web.RequestHandler):
def get(self): def get(self):
self.write(self.request.path) self.write(self.request.path)
class ArgsHandler(web.RequestHandler): class ArgsHandler(web.RequestHandler):
def get(self): def get(self):
self.write(json.dumps(sys.argv)) self.write(json.dumps(sys.argv))
def main(args): def main(args):
app = web.Application([ app = web.Application(
(r'.*/args', ArgsHandler), [(r'.*/args', ArgsHandler), (r'.*/env', EnvHandler), (r'.*', EchoHandler)]
(r'.*/env', EnvHandler), )
(r'.*', EchoHandler),
])
ssl_context = None ssl_context = None
key = os.environ.get('JUPYTERHUB_SSL_KEYFILE') or '' key = os.environ.get('JUPYTERHUB_SSL_KEYFILE') or ''
@@ -42,12 +46,7 @@ def main(args):
ca = os.environ.get('JUPYTERHUB_SSL_CLIENT_CA') or '' ca = os.environ.get('JUPYTERHUB_SSL_CLIENT_CA') or ''
if key and cert and ca: if key and cert and ca:
ssl_context = make_ssl_context( ssl_context = make_ssl_context(key, cert, cafile=ca, check_hostname=False)
key,
cert,
cafile = ca,
check_hostname = False
)
server = httpserver.HTTPServer(app, ssl_options=ssl_context) server = httpserver.HTTPServer(app, ssl_options=ssl_context)
server.listen(args.port) server.listen(args.port)
@@ -56,6 +55,7 @@ def main(args):
except KeyboardInterrupt: except KeyboardInterrupt:
print('\nInterrupted') print('\nInterrupted')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int) parser.add_argument('--port', type=int)

View File

@@ -4,9 +4,8 @@ Run with old versions of jupyterhub to test upgrade/downgrade
used in test_db.py used in test_db.py
""" """
from datetime import datetime
import os import os
from datetime import datetime
import jupyterhub import jupyterhub
from jupyterhub import orm from jupyterhub import orm
@@ -90,6 +89,7 @@ def populate_db(url):
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
if len(sys.argv) > 1: if len(sys.argv) > 1:
url = sys.argv[1] url = sys.argv[1]
else: else:

View File

@@ -1,30 +1,32 @@
"""Tests for the REST API.""" """Tests for the REST API."""
import asyncio import asyncio
from datetime import datetime, timedelta
from concurrent.futures import Future
import json import json
import re import re
import sys import sys
from unittest import mock
from urllib.parse import urlparse, quote
import uuid import uuid
from async_generator import async_generator, yield_ from concurrent.futures import Future
from datetime import datetime
from datetime import timedelta
from unittest import mock
from urllib.parse import quote
from urllib.parse import urlparse
from async_generator import async_generator
from async_generator import yield_
from pytest import mark from pytest import mark
from tornado import gen from tornado import gen
import jupyterhub import jupyterhub
from .. import orm from .. import orm
from ..utils import url_path_join as ujoin, utcnow from ..utils import url_path_join as ujoin
from .mocking import public_host, public_url from ..utils import utcnow
from .utils import ( from .mocking import public_host
add_user, from .mocking import public_url
api_request, from .utils import add_user
async_requests, from .utils import api_request
auth_header, from .utils import async_requests
find_user, from .utils import auth_header
) from .utils import find_user
# -------------------- # --------------------
@@ -48,12 +50,15 @@ async def test_auth_api(app):
assert reply['name'] == user.name assert reply['name'] == user.name
# check fail # check fail
r = await api_request(app, 'authorizations/token', api_token, r = await api_request(
headers={'Authorization': 'no sir'}, app, 'authorizations/token', api_token, headers={'Authorization': 'no sir'}
) )
assert r.status_code == 403 assert r.status_code == 403
r = await api_request(app, 'authorizations/token', api_token, r = await api_request(
app,
'authorizations/token',
api_token,
headers={'Authorization': 'token: %s' % user.cookie_id}, headers={'Authorization': 'token: %s' % user.cookie_id},
) )
assert r.status_code == 403 assert r.status_code == 403
@@ -67,37 +72,39 @@ async def test_referer_check(app):
user = add_user(app.db, name='admin', admin=True) user = add_user(app.db, name='admin', admin=True)
cookies = await app.login_user('admin') cookies = await app.login_user('admin')
r = await api_request(app, 'users', r = await api_request(
headers={ app, 'users', headers={'Authorization': '', 'Referer': 'null'}, cookies=cookies
'Authorization': '',
'Referer': 'null',
}, cookies=cookies,
) )
assert r.status_code == 403 assert r.status_code == 403
r = await api_request(app, 'users', r = await api_request(
app,
'users',
headers={ headers={
'Authorization': '', 'Authorization': '',
'Referer': 'http://attack.com/csrf/vulnerability', 'Referer': 'http://attack.com/csrf/vulnerability',
}, cookies=cookies, },
cookies=cookies,
) )
assert r.status_code == 403 assert r.status_code == 403
r = await api_request(app, 'users', r = await api_request(
headers={ app,
'Authorization': '', 'users',
'Referer': url, headers={'Authorization': '', 'Referer': url, 'Host': host},
'Host': host, cookies=cookies,
}, cookies=cookies,
) )
assert r.status_code == 200 assert r.status_code == 200
r = await api_request(app, 'users', r = await api_request(
app,
'users',
headers={ headers={
'Authorization': '', 'Authorization': '',
'Referer': ujoin(url, 'foo/bar/baz/bat'), 'Referer': ujoin(url, 'foo/bar/baz/bat'),
'Host': host, 'Host': host,
}, cookies=cookies, },
cookies=cookies,
) )
assert r.status_code == 200 assert r.status_code == 200
@@ -106,6 +113,7 @@ async def test_referer_check(app):
# User API tests # User API tests
# -------------- # --------------
def normalize_timestamp(ts): def normalize_timestamp(ts):
"""Normalize a timestamp """Normalize a timestamp
@@ -128,12 +136,16 @@ def normalize_user(user):
for server in user['servers'].values(): for server in user['servers'].values():
for key in ('started', 'last_activity'): for key in ('started', 'last_activity'):
server[key] = normalize_timestamp(server[key]) server[key] = normalize_timestamp(server[key])
server['progress_url'] = re.sub(r'.*/hub/api', 'PREFIX/hub/api', server['progress_url']) server['progress_url'] = re.sub(
if (isinstance(server['state'], dict) r'.*/hub/api', 'PREFIX/hub/api', server['progress_url']
and isinstance(server['state'].get('pid', None), int)): )
if isinstance(server['state'], dict) and isinstance(
server['state'].get('pid', None), int
):
server['state']['pid'] = 0 server['state']['pid'] = 0
return user return user
def fill_user(model): def fill_user(model):
"""Fill a default user model """Fill a default user model
@@ -153,6 +165,7 @@ def fill_user(model):
TIMESTAMP = normalize_timestamp(datetime.now().isoformat() + 'Z') TIMESTAMP = normalize_timestamp(datetime.now().isoformat() + 'Z')
@mark.user @mark.user
async def test_get_users(app): async def test_get_users(app):
db = app.db db = app.db
@@ -162,20 +175,11 @@ async def test_get_users(app):
users = sorted(r.json(), key=lambda d: d['name']) users = sorted(r.json(), key=lambda d: d['name'])
users = [normalize_user(u) for u in users] users = [normalize_user(u) for u in users]
assert users == [ assert users == [
fill_user({ fill_user({'name': 'admin', 'admin': True}),
'name': 'admin', fill_user({'name': 'user', 'admin': False, 'last_activity': None}),
'admin': True,
}),
fill_user({
'name': 'user',
'admin': False,
'last_activity': None,
}),
] ]
r = await api_request(app, 'users', r = await api_request(app, 'users', headers=auth_header(db, 'user'))
headers=auth_header(db, 'user'),
)
assert r.status_code == 403 assert r.status_code == 403
@@ -202,17 +206,13 @@ async def test_get_self(app):
) )
db.add(oauth_token) db.add(oauth_token)
db.commit() db.commit()
r = await api_request(app, 'user', headers={ r = await api_request(app, 'user', headers={'Authorization': 'token ' + token})
'Authorization': 'token ' + token,
})
r.raise_for_status() r.raise_for_status()
model = r.json() model = r.json()
assert model['name'] == u.name assert model['name'] == u.name
# invalid auth gets 403 # invalid auth gets 403
r = await api_request(app, 'user', headers={ r = await api_request(app, 'user', headers={'Authorization': 'token notvalid'})
'Authorization': 'token notvalid',
})
assert r.status_code == 403 assert r.status_code == 403
@@ -251,8 +251,11 @@ async def test_add_multi_user_bad(app):
@mark.user @mark.user
async def test_add_multi_user_invalid(app): async def test_add_multi_user_invalid(app):
app.authenticator.username_pattern = r'w.*' app.authenticator.username_pattern = r'w.*'
r = await api_request(app, 'users', method='post', r = await api_request(
data=json.dumps({'usernames': ['Willow', 'Andrew', 'Tara']}) app,
'users',
method='post',
data=json.dumps({'usernames': ['Willow', 'Andrew', 'Tara']}),
) )
app.authenticator.username_pattern = '' app.authenticator.username_pattern = ''
assert r.status_code == 400 assert r.status_code == 400
@@ -263,8 +266,8 @@ async def test_add_multi_user_invalid(app):
async def test_add_multi_user(app): async def test_add_multi_user(app):
db = app.db db = app.db
names = ['a', 'b'] names = ['a', 'b']
r = await api_request(app, 'users', method='post', r = await api_request(
data=json.dumps({'usernames': names}), app, 'users', method='post', data=json.dumps({'usernames': names})
) )
assert r.status_code == 201 assert r.status_code == 201
reply = r.json() reply = r.json()
@@ -278,16 +281,16 @@ async def test_add_multi_user(app):
assert not user.admin assert not user.admin
# try to create the same users again # try to create the same users again
r = await api_request(app, 'users', method='post', r = await api_request(
data=json.dumps({'usernames': names}), app, 'users', method='post', data=json.dumps({'usernames': names})
) )
assert r.status_code == 409 assert r.status_code == 409
names = ['a', 'b', 'ab'] names = ['a', 'b', 'ab']
# try to create the same users again # try to create the same users again
r = await api_request(app, 'users', method='post', r = await api_request(
data=json.dumps({'usernames': names}), app, 'users', method='post', data=json.dumps({'usernames': names})
) )
assert r.status_code == 201 assert r.status_code == 201
reply = r.json() reply = r.json()
@@ -299,7 +302,10 @@ async def test_add_multi_user(app):
async def test_add_multi_user_admin(app): async def test_add_multi_user_admin(app):
db = app.db db = app.db
names = ['c', 'd'] names = ['c', 'd']
r = await api_request(app, 'users', method='post', r = await api_request(
app,
'users',
method='post',
data=json.dumps({'usernames': names, 'admin': True}), data=json.dumps({'usernames': names, 'admin': True}),
) )
assert r.status_code == 201 assert r.status_code == 201
@@ -340,8 +346,8 @@ async def test_add_user_duplicate(app):
async def test_add_admin(app): async def test_add_admin(app):
db = app.db db = app.db
name = 'newadmin' name = 'newadmin'
r = await api_request(app, 'users', name, method='post', r = await api_request(
data=json.dumps({'admin': True}), app, 'users', name, method='post', data=json.dumps({'admin': True})
) )
assert r.status_code == 201 assert r.status_code == 201
user = find_user(db, name) user = find_user(db, name)
@@ -369,8 +375,8 @@ async def test_make_admin(app):
assert user.name == name assert user.name == name
assert not user.admin assert not user.admin
r = await api_request(app, 'users', name, method='patch', r = await api_request(
data=json.dumps({'admin': True}) app, 'users', name, method='patch', data=json.dumps({'admin': True})
) )
assert r.status_code == 200 assert r.status_code == 200
user = find_user(db, name) user = find_user(db, name)
@@ -388,8 +394,8 @@ async def test_set_auth_state(app, auth_state_enabled):
assert user is not None assert user is not None
assert user.name == name assert user.name == name
r = await api_request(app, 'users', name, method='patch', r = await api_request(
data=json.dumps({'auth_state': auth_state}) app, 'users', name, method='patch', data=json.dumps({'auth_state': auth_state})
) )
assert r.status_code == 200 assert r.status_code == 200
@@ -409,7 +415,10 @@ async def test_user_set_auth_state(app, auth_state_enabled):
assert user_auth_state is None assert user_auth_state is None
r = await api_request( r = await api_request(
app, 'users', name, method='patch', app,
'users',
name,
method='patch',
data=json.dumps({'auth_state': auth_state}), data=json.dumps({'auth_state': auth_state}),
headers=auth_header(app.db, name), headers=auth_header(app.db, name),
) )
@@ -446,8 +455,7 @@ async def test_user_get_auth_state(app, auth_state_enabled):
assert user.name == name assert user.name == name
await user.save_auth_state(auth_state) await user.save_auth_state(auth_state)
r = await api_request(app, 'users', name, r = await api_request(app, 'users', name, headers=auth_header(app.db, name))
headers=auth_header(app.db, name))
assert r.status_code == 200 assert r.status_code == 200
assert 'auth_state' not in r.json() assert 'auth_state' not in r.json()
@@ -457,13 +465,10 @@ async def test_spawn(app):
db = app.db db = app.db
name = 'wash' name = 'wash'
user = add_user(db, app=app, name=name) user = add_user(db, app=app, name=name)
options = { options = {'s': ['value'], 'i': 5}
's': ['value'],
'i': 5,
}
before_servers = sorted(db.query(orm.Server), key=lambda s: s.url) before_servers = sorted(db.query(orm.Server), key=lambda s: s.url)
r = await api_request(app, 'users', name, 'server', method='post', r = await api_request(
data=json.dumps(options), app, 'users', name, 'server', method='post', data=json.dumps(options)
) )
assert r.status_code == 201 assert r.status_code == 201
assert 'pid' in user.orm_spawners[''].state assert 'pid' in user.orm_spawners[''].state
@@ -520,7 +525,9 @@ async def test_spawn_handler(app):
app_user = app.users[name] app_user = app.users[name]
# spawn via API with ?foo=bar # spawn via API with ?foo=bar
r = await api_request(app, 'users', name, 'server', method='post', params={'foo': 'bar'}) r = await api_request(
app, 'users', name, 'server', method='post', params={'foo': 'bar'}
)
r.raise_for_status() r.raise_for_status()
# verify that request params got passed down # verify that request params got passed down
@@ -640,6 +647,7 @@ def next_event(it):
if line.startswith('data:'): if line.startswith('data:'):
return json.loads(line.split(':', 1)[1]) return json.loads(line.split(':', 1)[1])
@mark.slow @mark.slow
async def test_progress(request, app, no_patience, slow_spawn): async def test_progress(request, app, no_patience, slow_spawn):
db = app.db db = app.db
@@ -655,15 +663,9 @@ async def test_progress(request, app, no_patience, slow_spawn):
ex = async_requests.executor ex = async_requests.executor
line_iter = iter(r.iter_lines(decode_unicode=True)) line_iter = iter(r.iter_lines(decode_unicode=True))
evt = await ex.submit(next_event, line_iter) evt = await ex.submit(next_event, line_iter)
assert evt == { assert evt == {'progress': 0, 'message': 'Server requested'}
'progress': 0,
'message': 'Server requested',
}
evt = await ex.submit(next_event, line_iter) evt = await ex.submit(next_event, line_iter)
assert evt == { assert evt == {'progress': 50, 'message': 'Spawning server...'}
'progress': 50,
'message': 'Spawning server...',
}
evt = await ex.submit(next_event, line_iter) evt = await ex.submit(next_event, line_iter)
url = app_user.url url = app_user.url
assert evt == { assert evt == {
@@ -769,10 +771,7 @@ async def test_progress_bad_slow(request, app, no_patience, slow_bad_spawn):
async def progress_forever(): async def progress_forever():
"""progress function that yields messages forever""" """progress function that yields messages forever"""
for i in range(1, 10): for i in range(1, 10):
await yield_({ await yield_({'progress': i, 'message': 'Stage %s' % i})
'progress': i,
'message': 'Stage %s' % i,
})
# wait a long time before the next event # wait a long time before the next event
await gen.sleep(10) await gen.sleep(10)
@@ -781,7 +780,8 @@ if sys.version_info >= (3, 6):
# additional progress_forever defined as native # additional progress_forever defined as native
# async generator # async generator
# to test for issues with async_generator wrappers # to test for issues with async_generator wrappers
exec(""" exec(
"""
async def progress_forever_native(): async def progress_forever_native():
for i in range(1, 10): for i in range(1, 10):
yield { yield {
@@ -790,7 +790,9 @@ async def progress_forever_native():
} }
# wait a long time before the next event # wait a long time before the next event
await gen.sleep(10) await gen.sleep(10)
""", globals()) """,
globals(),
)
async def test_spawn_progress_cutoff(request, app, no_patience, slow_spawn): async def test_spawn_progress_cutoff(request, app, no_patience, slow_spawn):
@@ -818,18 +820,14 @@ async def test_spawn_progress_cutoff(request, app, no_patience, slow_spawn):
evt = await ex.submit(next_event, line_iter) evt = await ex.submit(next_event, line_iter)
assert evt['progress'] == 0 assert evt['progress'] == 0
evt = await ex.submit(next_event, line_iter) evt = await ex.submit(next_event, line_iter)
assert evt == { assert evt == {'progress': 1, 'message': 'Stage 1'}
'progress': 1,
'message': 'Stage 1',
}
evt = await ex.submit(next_event, line_iter) evt = await ex.submit(next_event, line_iter)
assert evt['progress'] == 100 assert evt['progress'] == 100
async def test_spawn_limit(app, no_patience, slow_spawn, request): async def test_spawn_limit(app, no_patience, slow_spawn, request):
db = app.db db = app.db
p = mock.patch.dict(app.tornado_settings, p = mock.patch.dict(app.tornado_settings, {'concurrent_spawn_limit': 2})
{'concurrent_spawn_limit': 2})
p.start() p.start()
request.addfinalizer(p.stop) request.addfinalizer(p.stop)
@@ -875,11 +873,11 @@ async def test_spawn_limit(app, no_patience, slow_spawn, request):
while any(u.spawner.active for u in users): while any(u.spawner.active for u in users):
await gen.sleep(0.1) await gen.sleep(0.1)
@mark.slow @mark.slow
async def test_active_server_limit(app, request): async def test_active_server_limit(app, request):
db = app.db db = app.db
p = mock.patch.dict(app.tornado_settings, p = mock.patch.dict(app.tornado_settings, {'active_server_limit': 2})
{'active_server_limit': 2})
p.start() p.start()
request.addfinalizer(p.stop) request.addfinalizer(p.stop)
@@ -932,6 +930,7 @@ async def test_active_server_limit(app, request):
assert counts['ready'] == 0 assert counts['ready'] == 0
assert counts['pending'] == 0 assert counts['pending'] == 0
@mark.slow @mark.slow
async def test_start_stop_race(app, no_patience, slow_spawn): async def test_start_stop_race(app, no_patience, slow_spawn):
user = add_user(app.db, app, name='panda') user = add_user(app.db, app, name='panda')
@@ -996,22 +995,18 @@ async def test_cookie(app):
cookie_name = app.hub.cookie_name cookie_name = app.hub.cookie_name
# cookie jar gives '"cookie-value"', we want 'cookie-value' # cookie jar gives '"cookie-value"', we want 'cookie-value'
cookie = cookies[cookie_name][1:-1] cookie = cookies[cookie_name][1:-1]
r = await api_request(app, 'authorizations/cookie', r = await api_request(app, 'authorizations/cookie', cookie_name, "nothintoseehere")
cookie_name, "nothintoseehere",
)
assert r.status_code == 404 assert r.status_code == 404
r = await api_request(app, 'authorizations/cookie', r = await api_request(
cookie_name, quote(cookie, safe=''), app, 'authorizations/cookie', cookie_name, quote(cookie, safe='')
) )
r.raise_for_status() r.raise_for_status()
reply = r.json() reply = r.json()
assert reply['name'] == name assert reply['name'] == name
# deprecated cookie in body: # deprecated cookie in body:
r = await api_request(app, 'authorizations/cookie', r = await api_request(app, 'authorizations/cookie', cookie_name, data=cookie)
cookie_name, data=cookie,
)
r.raise_for_status() r.raise_for_status()
reply = r.json() reply = r.json()
assert reply['name'] == name assert reply['name'] == name
@@ -1035,15 +1030,11 @@ async def test_check_token(app):
assert r.status_code == 404 assert r.status_code == 404
@mark.parametrize("headers, status", [ @mark.parametrize("headers, status", [({}, 200), ({'Authorization': 'token bad'}, 403)])
({}, 200),
({'Authorization': 'token bad'}, 403),
])
async def test_get_new_token_deprecated(app, headers, status): async def test_get_new_token_deprecated(app, headers, status):
# request a new token # request a new token
r = await api_request(app, 'authorizations', 'token', r = await api_request(
method='post', app, 'authorizations', 'token', method='post', headers=headers
headers=headers,
) )
assert r.status_code == status assert r.status_code == status
if status != 200: if status != 200:
@@ -1058,11 +1049,11 @@ async def test_get_new_token_deprecated(app, headers, status):
async def test_token_formdata_deprecated(app): async def test_token_formdata_deprecated(app):
"""Create a token for a user with formdata and no auth header""" """Create a token for a user with formdata and no auth header"""
data = { data = {'username': 'fake', 'password': 'fake'}
'username': 'fake', r = await api_request(
'password': 'fake', app,
} 'authorizations',
r = await api_request(app, 'authorizations', 'token', 'token',
method='post', method='post',
data=json.dumps(data) if data else None, data=json.dumps(data) if data else None,
noauth=True, noauth=True,
@@ -1076,22 +1067,26 @@ async def test_token_formdata_deprecated(app):
assert reply['name'] == data['username'] assert reply['name'] == data['username']
@mark.parametrize("as_user, for_user, status", [ @mark.parametrize(
"as_user, for_user, status",
[
('admin', 'other', 200), ('admin', 'other', 200),
('admin', 'missing', 400), ('admin', 'missing', 400),
('user', 'other', 403), ('user', 'other', 403),
('user', 'user', 200), ('user', 'user', 200),
]) ],
)
async def test_token_as_user_deprecated(app, as_user, for_user, status): async def test_token_as_user_deprecated(app, as_user, for_user, status):
# ensure both users exist # ensure both users exist
u = add_user(app.db, app, name=as_user) u = add_user(app.db, app, name=as_user)
if for_user != 'missing': if for_user != 'missing':
add_user(app.db, app, name=for_user) add_user(app.db, app, name=for_user)
data = {'username': for_user} data = {'username': for_user}
headers = { headers = {'Authorization': 'token %s' % u.new_api_token()}
'Authorization': 'token %s' % u.new_api_token(), r = await api_request(
} app,
r = await api_request(app, 'authorizations', 'token', 'authorizations',
'token',
method='post', method='post',
data=json.dumps(data), data=json.dumps(data),
headers=headers, headers=headers,
@@ -1107,11 +1102,14 @@ async def test_token_as_user_deprecated(app, as_user, for_user, status):
assert reply['name'] == data['username'] assert reply['name'] == data['username']
@mark.parametrize("headers, status, note, expires_in", [ @mark.parametrize(
"headers, status, note, expires_in",
[
({}, 200, 'test note', None), ({}, 200, 'test note', None),
({}, 200, '', 100), ({}, 200, '', 100),
({'Authorization': 'token bad'}, 403, '', None), ({'Authorization': 'token bad'}, 403, '', None),
]) ],
)
async def test_get_new_token(app, headers, status, note, expires_in): async def test_get_new_token(app, headers, status, note, expires_in):
options = {} options = {}
if note: if note:
@@ -1123,10 +1121,8 @@ async def test_get_new_token(app, headers, status, note, expires_in):
else: else:
body = '' body = ''
# request a new token # request a new token
r = await api_request(app, 'users/admin/tokens', r = await api_request(
method='post', app, 'users/admin/tokens', method='post', headers=headers, data=body
headers=headers,
data=body,
) )
assert r.status_code == status assert r.status_code == status
if status != 200: if status != 200:
@@ -1157,30 +1153,34 @@ async def test_get_new_token(app, headers, status, note, expires_in):
assert normalize_token(reply) == initial assert normalize_token(reply) == initial
# delete the token # delete the token
r = await api_request(app, 'users/admin/tokens', token_id, r = await api_request(app, 'users/admin/tokens', token_id, method='delete')
method='delete')
assert r.status_code == 204 assert r.status_code == 204
# verify deletion # verify deletion
r = await api_request(app, 'users/admin/tokens', token_id) r = await api_request(app, 'users/admin/tokens', token_id)
assert r.status_code == 404 assert r.status_code == 404
@mark.parametrize("as_user, for_user, status", [ @mark.parametrize(
"as_user, for_user, status",
[
('admin', 'other', 200), ('admin', 'other', 200),
('admin', 'missing', 404), ('admin', 'missing', 404),
('user', 'other', 403), ('user', 'other', 403),
('user', 'user', 200), ('user', 'user', 200),
]) ],
)
async def test_token_for_user(app, as_user, for_user, status): async def test_token_for_user(app, as_user, for_user, status):
# ensure both users exist # ensure both users exist
u = add_user(app.db, app, name=as_user) u = add_user(app.db, app, name=as_user)
if for_user != 'missing': if for_user != 'missing':
add_user(app.db, app, name=for_user) add_user(app.db, app, name=for_user)
data = {'username': for_user} data = {'username': for_user}
headers = { headers = {'Authorization': 'token %s' % u.new_api_token()}
'Authorization': 'token %s' % u.new_api_token(), r = await api_request(
} app,
r = await api_request(app, 'users', for_user, 'tokens', 'users',
for_user,
'tokens',
method='post', method='post',
data=json.dumps(data), data=json.dumps(data),
headers=headers, headers=headers,
@@ -1191,9 +1191,7 @@ async def test_token_for_user(app, as_user, for_user, status):
return return
assert 'token' in reply assert 'token' in reply
token_id = reply['id'] token_id = reply['id']
r = await api_request(app, 'users', for_user, 'tokens', token_id, r = await api_request(app, 'users', for_user, 'tokens', token_id, headers=headers)
headers=headers,
)
r.raise_for_status() r.raise_for_status()
reply = r.json() reply = r.json()
assert reply['user'] == for_user assert reply['user'] == for_user
@@ -1203,30 +1201,25 @@ async def test_token_for_user(app, as_user, for_user, status):
note = 'Requested via api by user %s' % as_user note = 'Requested via api by user %s' % as_user
assert reply['note'] == note assert reply['note'] == note
# delete the token # delete the token
r = await api_request(app, 'users', for_user, 'tokens', token_id, r = await api_request(
method='delete', app, 'users', for_user, 'tokens', token_id, method='delete', headers=headers
headers=headers,
) )
assert r.status_code == 204 assert r.status_code == 204
r = await api_request(app, 'users', for_user, 'tokens', token_id, r = await api_request(app, 'users', for_user, 'tokens', token_id, headers=headers)
headers=headers,
)
assert r.status_code == 404 assert r.status_code == 404
async def test_token_authenticator_noauth(app): async def test_token_authenticator_noauth(app):
"""Create a token for a user relying on Authenticator.authenticate and no auth header""" """Create a token for a user relying on Authenticator.authenticate and no auth header"""
name = 'user' name = 'user'
data = { data = {'auth': {'username': name, 'password': name}}
'auth': { r = await api_request(
'username': name, app,
'password': name, 'users',
}, name,
} 'tokens',
r = await api_request(app, 'users', name, 'tokens',
method='post', method='post',
data=json.dumps(data) if data else None, data=json.dumps(data) if data else None,
noauth=True, noauth=True,
@@ -1242,17 +1235,14 @@ async def test_token_authenticator_noauth(app):
async def test_token_authenticator_dict_noauth(app): async def test_token_authenticator_dict_noauth(app):
"""Create a token for a user relying on Authenticator.authenticate and no auth header""" """Create a token for a user relying on Authenticator.authenticate and no auth header"""
app.authenticator.auth_state = { app.authenticator.auth_state = {'who': 'cares'}
'who': 'cares',
}
name = 'user' name = 'user'
data = { data = {'auth': {'username': name, 'password': name}}
'auth': { r = await api_request(
'username': name, app,
'password': name, 'users',
}, name,
} 'tokens',
r = await api_request(app, 'users', name, 'tokens',
method='post', method='post',
data=json.dumps(data) if data else None, data=json.dumps(data) if data else None,
noauth=True, noauth=True,
@@ -1266,22 +1256,21 @@ async def test_token_authenticator_dict_noauth(app):
assert reply['name'] == name assert reply['name'] == name
@mark.parametrize("as_user, for_user, status", [ @mark.parametrize(
"as_user, for_user, status",
[
('admin', 'other', 200), ('admin', 'other', 200),
('admin', 'missing', 404), ('admin', 'missing', 404),
('user', 'other', 403), ('user', 'other', 403),
('user', 'user', 200), ('user', 'user', 200),
]) ],
)
async def test_token_list(app, as_user, for_user, status): async def test_token_list(app, as_user, for_user, status):
u = add_user(app.db, app, name=as_user) u = add_user(app.db, app, name=as_user)
if for_user != 'missing': if for_user != 'missing':
for_user_obj = add_user(app.db, app, name=for_user) for_user_obj = add_user(app.db, app, name=for_user)
headers = { headers = {'Authorization': 'token %s' % u.new_api_token()}
'Authorization': 'token %s' % u.new_api_token(), r = await api_request(app, 'users', for_user, 'tokens', headers=headers)
}
r = await api_request(app, 'users', for_user, 'tokens',
headers=headers,
)
assert r.status_code == status assert r.status_code == status
if status != 200: if status != 200:
return return
@@ -1292,8 +1281,8 @@ async def test_token_list(app, as_user, for_user, status):
assert all(token['user'] == for_user for token in reply['oauth_tokens']) assert all(token['user'] == for_user for token in reply['oauth_tokens'])
# validate individual token ids # validate individual token ids
for token in reply['api_tokens'] + reply['oauth_tokens']: for token in reply['api_tokens'] + reply['oauth_tokens']:
r = await api_request(app, 'users', for_user, 'tokens', token['id'], r = await api_request(
headers=headers, app, 'users', for_user, 'tokens', token['id'], headers=headers
) )
r.raise_for_status() r.raise_for_status()
reply = r.json() reply = r.json()
@@ -1320,19 +1309,15 @@ async def test_groups_list(app):
r = await api_request(app, 'groups') r = await api_request(app, 'groups')
r.raise_for_status() r.raise_for_status()
reply = r.json() reply = r.json()
assert reply == [{ assert reply == [{'kind': 'group', 'name': 'alphaflight', 'users': []}]
'kind': 'group',
'name': 'alphaflight',
'users': []
}]
@mark.group @mark.group
async def test_add_multi_group(app): async def test_add_multi_group(app):
db = app.db db = app.db
names = ['group1', 'group2'] names = ['group1', 'group2']
r = await api_request(app, 'groups', method='post', r = await api_request(
data=json.dumps({'groups': names}), app, 'groups', method='post', data=json.dumps({'groups': names})
) )
assert r.status_code == 201 assert r.status_code == 201
reply = r.json() reply = r.json()
@@ -1340,8 +1325,8 @@ async def test_add_multi_group(app):
assert names == r_names assert names == r_names
# try to create the same groups again # try to create the same groups again
r = await api_request(app, 'groups', method='post', r = await api_request(
data=json.dumps({'groups': names}), app, 'groups', method='post', data=json.dumps({'groups': names})
) )
assert r.status_code == 409 assert r.status_code == 409
@@ -1359,11 +1344,7 @@ async def test_group_get(app):
r = await api_request(app, 'groups/alphaflight') r = await api_request(app, 'groups/alphaflight')
r.raise_for_status() r.raise_for_status()
reply = r.json() reply = r.json()
assert reply == { assert reply == {'kind': 'group', 'name': 'alphaflight', 'users': ['sasquatch']}
'kind': 'group',
'name': 'alphaflight',
'users': ['sasquatch']
}
@mark.group @mark.group
@@ -1372,13 +1353,16 @@ async def test_group_create_delete(app):
r = await api_request(app, 'groups/runaways', method='delete') r = await api_request(app, 'groups/runaways', method='delete')
assert r.status_code == 404 assert r.status_code == 404
r = await api_request(app, 'groups/new', method='post', r = await api_request(
data=json.dumps({'users': ['doesntexist']}), app, 'groups/new', method='post', data=json.dumps({'users': ['doesntexist']})
) )
assert r.status_code == 400 assert r.status_code == 400
assert orm.Group.find(db, name='new') is None assert orm.Group.find(db, name='new') is None
r = await api_request(app, 'groups/omegaflight', method='post', r = await api_request(
app,
'groups/omegaflight',
method='post',
data=json.dumps({'users': ['sasquatch']}), data=json.dumps({'users': ['sasquatch']}),
) )
r.raise_for_status() r.raise_for_status()
@@ -1410,10 +1394,15 @@ async def test_group_add_users(app):
assert r.status_code == 400 assert r.status_code == 400
names = ['aurora', 'guardian', 'northstar', 'sasquatch', 'shaman', 'snowbird'] names = ['aurora', 'guardian', 'northstar', 'sasquatch', 'shaman', 'snowbird']
users = [ find_user(db, name=name) or add_user(db, app=app, name=name) for name in names ] users = [
r = await api_request(app, 'groups/alphaflight/users', method='post', data=json.dumps({ find_user(db, name=name) or add_user(db, app=app, name=name) for name in names
'users': names, ]
})) r = await api_request(
app,
'groups/alphaflight/users',
method='post',
data=json.dumps({'users': names}),
)
r.raise_for_status() r.raise_for_status()
for user in users: for user in users:
@@ -1433,9 +1422,12 @@ async def test_group_delete_users(app):
names = ['aurora', 'guardian', 'northstar', 'sasquatch', 'shaman', 'snowbird'] names = ['aurora', 'guardian', 'northstar', 'sasquatch', 'shaman', 'snowbird']
users = [find_user(db, name=name) for name in names] users = [find_user(db, name=name) for name in names]
r = await api_request(app, 'groups/alphaflight/users', method='delete', data=json.dumps({ r = await api_request(
'users': names[:2], app,
})) 'groups/alphaflight/users',
method='delete',
data=json.dumps({'users': names[:2]}),
)
r.raise_for_status() r.raise_for_status()
for user in users[:2]: for user in users[:2]:
@@ -1473,9 +1465,7 @@ async def test_get_services(app, mockservice_url):
} }
} }
r = await api_request(app, 'services', r = await api_request(app, 'services', headers=auth_header(db, 'user'))
headers=auth_header(db, 'user'),
)
assert r.status_code == 403 assert r.status_code == 403
@@ -1498,14 +1488,14 @@ async def test_get_service(app, mockservice_url):
'info': {}, 'info': {},
} }
r = await api_request(app, 'services/%s' % mockservice.name, r = await api_request(
headers={ app,
'Authorization': 'token %s' % mockservice.api_token 'services/%s' % mockservice.name,
} headers={'Authorization': 'token %s' % mockservice.api_token},
) )
r.raise_for_status() r.raise_for_status()
r = await api_request(app, 'services/%s' % mockservice.name, r = await api_request(
headers=auth_header(db, 'user'), app, 'services/%s' % mockservice.name, headers=auth_header(db, 'user')
) )
assert r.status_code == 403 assert r.status_code == 403
@@ -1519,9 +1509,7 @@ async def test_root_api(app):
kwargs["verify"] = app.internal_ssl_ca kwargs["verify"] = app.internal_ssl_ca
r = await async_requests.get(url, **kwargs) r = await async_requests.get(url, **kwargs)
r.raise_for_status() r.raise_for_status()
expected = { expected = {'version': jupyterhub.__version__}
'version': jupyterhub.__version__
}
assert r.json() == expected assert r.json() == expected
@@ -1662,15 +1650,20 @@ def test_shutdown(app):
# which makes gen_test unhappy. So we run the loop ourselves. # which makes gen_test unhappy. So we run the loop ourselves.
async def shutdown(): async def shutdown():
r = await api_request(app, 'shutdown', method='post', r = await api_request(
data=json.dumps({'servers': True, 'proxy': True,}), app,
'shutdown',
method='post',
data=json.dumps({'servers': True, 'proxy': True}),
) )
return r return r
real_stop = loop.stop real_stop = loop.stop
def stop(): def stop():
stop.called = True stop.called = True
loop.call_later(1, real_stop) loop.call_later(1, real_stop)
with mock.patch.object(loop, 'stop', stop): with mock.patch.object(loop, 'stop', stop):
r = loop.run_sync(shutdown, timeout=5) r = loop.run_sync(shutdown, timeout=5)
r.raise_for_status() r.raise_for_status()

View File

@@ -1,25 +1,30 @@
"""Test the JupyterHub entry point""" """Test the JupyterHub entry point"""
import binascii import binascii
import os import os
import re import re
import sys import sys
from subprocess import check_output, Popen, PIPE from subprocess import check_output
from tempfile import NamedTemporaryFile, TemporaryDirectory from subprocess import PIPE
from subprocess import Popen
from tempfile import NamedTemporaryFile
from tempfile import TemporaryDirectory
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from tornado import gen from tornado import gen
from traitlets.config import Config from traitlets.config import Config
from .. import orm
from ..app import COOKIE_SECRET_BYTES
from ..app import JupyterHub
from .mocking import MockHub from .mocking import MockHub
from .test_api import add_user from .test_api import add_user
from .. import orm
from ..app import COOKIE_SECRET_BYTES, JupyterHub
def test_help_all(): def test_help_all():
out = check_output([sys.executable, '-m', 'jupyterhub', '--help-all']).decode('utf8', 'replace') out = check_output([sys.executable, '-m', 'jupyterhub', '--help-all']).decode(
'utf8', 'replace'
)
assert '--ip' in out assert '--ip' in out
assert '--JupyterHub.ip' in out assert '--JupyterHub.ip' in out
@@ -39,9 +44,11 @@ def test_generate_config():
cfg_file = tf.name cfg_file = tf.name
with open(cfg_file, 'w') as f: with open(cfg_file, 'w') as f:
f.write("c.A = 5") f.write("c.A = 5")
p = Popen([sys.executable, '-m', 'jupyterhub', p = Popen(
'--generate-config', '-f', cfg_file], [sys.executable, '-m', 'jupyterhub', '--generate-config', '-f', cfg_file],
stdout=PIPE, stdin=PIPE) stdout=PIPE,
stdin=PIPE,
)
out, _ = p.communicate(b'n') out, _ = p.communicate(b'n')
out = out.decode('utf8', 'replace') out = out.decode('utf8', 'replace')
assert os.path.exists(cfg_file) assert os.path.exists(cfg_file)
@@ -49,9 +56,11 @@ def test_generate_config():
cfg_text = f.read() cfg_text = f.read()
assert cfg_text == 'c.A = 5' assert cfg_text == 'c.A = 5'
p = Popen([sys.executable, '-m', 'jupyterhub', p = Popen(
'--generate-config', '-f', cfg_file], [sys.executable, '-m', 'jupyterhub', '--generate-config', '-f', cfg_file],
stdout=PIPE, stdin=PIPE) stdout=PIPE,
stdin=PIPE,
)
out, _ = p.communicate(b'x\ny') out, _ = p.communicate(b'x\ny')
out = out.decode('utf8', 'replace') out = out.decode('utf8', 'replace')
assert os.path.exists(cfg_file) assert os.path.exists(cfg_file)
@@ -192,9 +201,13 @@ async def test_load_groups(tmpdir, request):
async def test_resume_spawners(tmpdir, request): async def test_resume_spawners(tmpdir, request):
if not os.getenv('JUPYTERHUB_TEST_DB_URL'): if not os.getenv('JUPYTERHUB_TEST_DB_URL'):
p = patch.dict(os.environ, { p = patch.dict(
'JUPYTERHUB_TEST_DB_URL': 'sqlite:///%s' % tmpdir.join('jupyterhub.sqlite'), os.environ,
}) {
'JUPYTERHUB_TEST_DB_URL': 'sqlite:///%s'
% tmpdir.join('jupyterhub.sqlite')
},
)
p.start() p.start()
request.addfinalizer(p.stop) request.addfinalizer(p.stop)
@@ -253,32 +266,18 @@ async def test_resume_spawners(tmpdir, request):
@pytest.mark.parametrize( @pytest.mark.parametrize(
'hub_config, expected', 'hub_config, expected',
[ [
( ({'ip': '0.0.0.0'}, {'bind_url': 'http://0.0.0.0:8000/'}),
{'ip': '0.0.0.0'},
{'bind_url': 'http://0.0.0.0:8000/'},
),
( (
{'port': 123, 'base_url': '/prefix'}, {'port': 123, 'base_url': '/prefix'},
{ {'bind_url': 'http://:123/prefix/', 'base_url': '/prefix/'},
'bind_url': 'http://:123/prefix/',
'base_url': '/prefix/',
},
),
(
{'bind_url': 'http://0.0.0.0:12345/sub'},
{'base_url': '/sub/'},
), ),
({'bind_url': 'http://0.0.0.0:12345/sub'}, {'base_url': '/sub/'}),
( (
# no config, test defaults # no config, test defaults
{}, {},
{ {'base_url': '/', 'bind_url': 'http://:8000', 'ip': '', 'port': 8000},
'base_url': '/',
'bind_url': 'http://:8000',
'ip': '',
'port': 8000,
},
), ),
] ],
) )
def test_url_config(hub_config, expected): def test_url_config(hub_config, expected):
# construct the config object # construct the config object

View File

@@ -1,59 +1,59 @@
"""Tests for PAM authentication""" """Tests for PAM authentication"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import os import os
from unittest import mock from unittest import mock
import pytest import pytest
from requests import HTTPError from requests import HTTPError
from jupyterhub import auth, crypto, orm from .mocking import MockPAMAuthenticator
from .mocking import MockStructGroup
from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd from .mocking import MockStructPasswd
from .utils import add_user from .utils import add_user
from jupyterhub import auth
from jupyterhub import crypto
from jupyterhub import orm
async def test_pam_auth(): async def test_pam_auth():
authenticator = MockPAMAuthenticator() authenticator = MockPAMAuthenticator()
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'match', None, {'username': 'match', 'password': 'match'}
'password': 'match', )
})
assert authorized['name'] == 'match' assert authorized['name'] == 'match'
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'match', None, {'username': 'match', 'password': 'nomatch'}
'password': 'nomatch', )
})
assert authorized is None assert authorized is None
# Account check is on by default for increased security # Account check is on by default for increased security
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'notallowedmatch', None, {'username': 'notallowedmatch', 'password': 'notallowedmatch'}
'password': 'notallowedmatch', )
})
assert authorized is None assert authorized is None
async def test_pam_auth_account_check_disabled(): async def test_pam_auth_account_check_disabled():
authenticator = MockPAMAuthenticator(check_account=False) authenticator = MockPAMAuthenticator(check_account=False)
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'allowedmatch', None, {'username': 'allowedmatch', 'password': 'allowedmatch'}
'password': 'allowedmatch', )
})
assert authorized['name'] == 'allowedmatch' assert authorized['name'] == 'allowedmatch'
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'notallowedmatch', None, {'username': 'notallowedmatch', 'password': 'notallowedmatch'}
'password': 'notallowedmatch', )
})
assert authorized['name'] == 'notallowedmatch' assert authorized['name'] == 'notallowedmatch'
async def test_pam_auth_admin_groups(): async def test_pam_auth_admin_groups():
jh_users = MockStructGroup('jh_users', ['group_admin', 'also_group_admin', 'override_admin', 'non_admin'], 1234) jh_users = MockStructGroup(
'jh_users',
['group_admin', 'also_group_admin', 'override_admin', 'non_admin'],
1234,
)
jh_admins = MockStructGroup('jh_admins', ['group_admin'], 5678) jh_admins = MockStructGroup('jh_admins', ['group_admin'], 5678)
wheel = MockStructGroup('wheel', ['also_group_admin'], 9999) wheel = MockStructGroup('wheel', ['also_group_admin'], 9999)
system_groups = [jh_users, jh_admins, wheel] system_groups = [jh_users, jh_admins, wheel]
@@ -68,7 +68,7 @@ async def test_pam_auth_admin_groups():
'group_admin': [jh_users.gr_gid, jh_admins.gr_gid], 'group_admin': [jh_users.gr_gid, jh_admins.gr_gid],
'also_group_admin': [jh_users.gr_gid, wheel.gr_gid], 'also_group_admin': [jh_users.gr_gid, wheel.gr_gid],
'override_admin': [jh_users.gr_gid], 'override_admin': [jh_users.gr_gid],
'non_admin': [jh_users.gr_gid] 'non_admin': [jh_users.gr_gid],
} }
def getgrnam(name): def getgrnam(name):
@@ -80,76 +80,78 @@ async def test_pam_auth_admin_groups():
def getgrouplist(name, group): def getgrouplist(name, group):
return user_group_map[name] return user_group_map[name]
authenticator = MockPAMAuthenticator(admin_groups={'jh_admins', 'wheel'}, authenticator = MockPAMAuthenticator(
admin_users={'override_admin'}) admin_groups={'jh_admins', 'wheel'}, admin_users={'override_admin'}
)
# Check admin_group applies as expected # Check admin_group applies as expected
with mock.patch.multiple(authenticator, with mock.patch.multiple(
authenticator,
_getgrnam=getgrnam, _getgrnam=getgrnam,
_getpwnam=getpwnam, _getpwnam=getpwnam,
_getgrouplist=getgrouplist): _getgrouplist=getgrouplist,
authorized = await authenticator.get_authenticated_user(None, { ):
'username': 'group_admin', authorized = await authenticator.get_authenticated_user(
'password': 'group_admin' None, {'username': 'group_admin', 'password': 'group_admin'}
}) )
assert authorized['name'] == 'group_admin' assert authorized['name'] == 'group_admin'
assert authorized['admin'] is True assert authorized['admin'] is True
# Check multiple groups work, just in case. # Check multiple groups work, just in case.
with mock.patch.multiple(authenticator, with mock.patch.multiple(
authenticator,
_getgrnam=getgrnam, _getgrnam=getgrnam,
_getpwnam=getpwnam, _getpwnam=getpwnam,
_getgrouplist=getgrouplist): _getgrouplist=getgrouplist,
authorized = await authenticator.get_authenticated_user(None, { ):
'username': 'also_group_admin', authorized = await authenticator.get_authenticated_user(
'password': 'also_group_admin' None, {'username': 'also_group_admin', 'password': 'also_group_admin'}
}) )
assert authorized['name'] == 'also_group_admin' assert authorized['name'] == 'also_group_admin'
assert authorized['admin'] is True assert authorized['admin'] is True
# Check admin_users still applies correctly # Check admin_users still applies correctly
with mock.patch.multiple(authenticator, with mock.patch.multiple(
authenticator,
_getgrnam=getgrnam, _getgrnam=getgrnam,
_getpwnam=getpwnam, _getpwnam=getpwnam,
_getgrouplist=getgrouplist): _getgrouplist=getgrouplist,
authorized = await authenticator.get_authenticated_user(None, { ):
'username': 'override_admin', authorized = await authenticator.get_authenticated_user(
'password': 'override_admin' None, {'username': 'override_admin', 'password': 'override_admin'}
}) )
assert authorized['name'] == 'override_admin' assert authorized['name'] == 'override_admin'
assert authorized['admin'] is True assert authorized['admin'] is True
# Check it doesn't admin everyone # Check it doesn't admin everyone
with mock.patch.multiple(authenticator, with mock.patch.multiple(
authenticator,
_getgrnam=getgrnam, _getgrnam=getgrnam,
_getpwnam=getpwnam, _getpwnam=getpwnam,
_getgrouplist=getgrouplist): _getgrouplist=getgrouplist,
authorized = await authenticator.get_authenticated_user(None, { ):
'username': 'non_admin', authorized = await authenticator.get_authenticated_user(
'password': 'non_admin' None, {'username': 'non_admin', 'password': 'non_admin'}
}) )
assert authorized['name'] == 'non_admin' assert authorized['name'] == 'non_admin'
assert authorized['admin'] is False assert authorized['admin'] is False
async def test_pam_auth_whitelist(): async def test_pam_auth_whitelist():
authenticator = MockPAMAuthenticator(whitelist={'wash', 'kaylee'}) authenticator = MockPAMAuthenticator(whitelist={'wash', 'kaylee'})
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'kaylee', None, {'username': 'kaylee', 'password': 'kaylee'}
'password': 'kaylee', )
})
assert authorized['name'] == 'kaylee' assert authorized['name'] == 'kaylee'
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'wash', None, {'username': 'wash', 'password': 'nomatch'}
'password': 'nomatch', )
})
assert authorized is None assert authorized is None
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'mal', None, {'username': 'mal', 'password': 'mal'}
'password': 'mal', )
})
assert authorized is None assert authorized is None
@@ -160,80 +162,78 @@ async def test_pam_auth_group_whitelist():
authenticator = MockPAMAuthenticator(group_whitelist={'group'}) authenticator = MockPAMAuthenticator(group_whitelist={'group'})
with mock.patch.object(authenticator, '_getgrnam', getgrnam): with mock.patch.object(authenticator, '_getgrnam', getgrnam):
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'kaylee', None, {'username': 'kaylee', 'password': 'kaylee'}
'password': 'kaylee', )
})
assert authorized['name'] == 'kaylee' assert authorized['name'] == 'kaylee'
with mock.patch.object(authenticator, '_getgrnam', getgrnam): with mock.patch.object(authenticator, '_getgrnam', getgrnam):
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'mal', None, {'username': 'mal', 'password': 'mal'}
'password': 'mal', )
})
assert authorized is None assert authorized is None
async def test_pam_auth_blacklist(): async def test_pam_auth_blacklist():
# Null case compared to next case # Null case compared to next case
authenticator = MockPAMAuthenticator() authenticator = MockPAMAuthenticator()
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'wash', None, {'username': 'wash', 'password': 'wash'}
'password': 'wash', )
})
assert authorized['name'] == 'wash' assert authorized['name'] == 'wash'
# Blacklist basics # Blacklist basics
authenticator = MockPAMAuthenticator(blacklist={'wash'}) authenticator = MockPAMAuthenticator(blacklist={'wash'})
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'wash', None, {'username': 'wash', 'password': 'wash'}
'password': 'wash', )
})
assert authorized is None assert authorized is None
# User in both white and blacklists: default deny. Make error someday? # User in both white and blacklists: default deny. Make error someday?
authenticator = MockPAMAuthenticator(blacklist={'wash'}, whitelist={'wash', 'kaylee'}) authenticator = MockPAMAuthenticator(
authorized = await authenticator.get_authenticated_user(None, { blacklist={'wash'}, whitelist={'wash', 'kaylee'}
'username': 'wash', )
'password': 'wash', authorized = await authenticator.get_authenticated_user(
}) None, {'username': 'wash', 'password': 'wash'}
)
assert authorized is None assert authorized is None
# User not in blacklist can log in # User not in blacklist can log in
authenticator = MockPAMAuthenticator(blacklist={'wash'}, whitelist={'wash', 'kaylee'}) authenticator = MockPAMAuthenticator(
authorized = await authenticator.get_authenticated_user(None, { blacklist={'wash'}, whitelist={'wash', 'kaylee'}
'username': 'kaylee', )
'password': 'kaylee', authorized = await authenticator.get_authenticated_user(
}) None, {'username': 'kaylee', 'password': 'kaylee'}
)
assert authorized['name'] == 'kaylee' assert authorized['name'] == 'kaylee'
# User in whitelist, blacklist irrelevent # User in whitelist, blacklist irrelevent
authenticator = MockPAMAuthenticator(blacklist={'mal'}, whitelist={'wash', 'kaylee'}) authenticator = MockPAMAuthenticator(
authorized = await authenticator.get_authenticated_user(None, { blacklist={'mal'}, whitelist={'wash', 'kaylee'}
'username': 'wash', )
'password': 'wash', authorized = await authenticator.get_authenticated_user(
}) None, {'username': 'wash', 'password': 'wash'}
)
assert authorized['name'] == 'wash' assert authorized['name'] == 'wash'
# User in neither list # User in neither list
authenticator = MockPAMAuthenticator(blacklist={'mal'}, whitelist={'wash', 'kaylee'}) authenticator = MockPAMAuthenticator(
authorized = await authenticator.get_authenticated_user(None, { blacklist={'mal'}, whitelist={'wash', 'kaylee'}
'username': 'simon', )
'password': 'simon', authorized = await authenticator.get_authenticated_user(
}) None, {'username': 'simon', 'password': 'simon'}
)
assert authorized is None assert authorized is None
# blacklist == {} # blacklist == {}
authenticator = MockPAMAuthenticator(blacklist=set(), whitelist={'wash', 'kaylee'}) authenticator = MockPAMAuthenticator(blacklist=set(), whitelist={'wash', 'kaylee'})
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'kaylee', None, {'username': 'kaylee', 'password': 'kaylee'}
'password': 'kaylee', )
})
assert authorized['name'] == 'kaylee' assert authorized['name'] == 'kaylee'
async def test_deprecated_signatures(): async def test_deprecated_signatures():
def deprecated_xlist(self, username): def deprecated_xlist(self, username):
return True return True
@@ -244,20 +244,18 @@ async def test_deprecated_signatures():
check_blacklist=deprecated_xlist, check_blacklist=deprecated_xlist,
): ):
deprecated_authenticator = MockPAMAuthenticator() deprecated_authenticator = MockPAMAuthenticator()
authorized = await deprecated_authenticator.get_authenticated_user(None, { authorized = await deprecated_authenticator.get_authenticated_user(
'username': 'test', None, {'username': 'test', 'password': 'test'}
'password': 'test' )
})
assert authorized is not None assert authorized is not None
async def test_pam_auth_no_such_group(): async def test_pam_auth_no_such_group():
authenticator = MockPAMAuthenticator(group_whitelist={'nosuchcrazygroup'}) authenticator = MockPAMAuthenticator(group_whitelist={'nosuchcrazygroup'})
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'kaylee', None, {'username': 'kaylee', 'password': 'kaylee'}
'password': 'kaylee', )
})
assert authorized is None assert authorized is None
@@ -302,6 +300,7 @@ async def test_add_system_user():
authenticator.add_user_cmd = ['echo', '/home/USERNAME'] authenticator.add_user_cmd = ['echo', '/home/USERNAME']
record = {} record = {}
class DummyPopen: class DummyPopen:
def __init__(self, cmd, *args, **kwargs): def __init__(self, cmd, *args, **kwargs):
record['cmd'] = cmd record['cmd'] = cmd
@@ -402,44 +401,38 @@ async def test_auth_state_disabled(app, auth_state_unavailable):
async def test_normalize_names(): async def test_normalize_names():
a = MockPAMAuthenticator() a = MockPAMAuthenticator()
authorized = await a.get_authenticated_user(None, { authorized = await a.get_authenticated_user(
'username': 'ZOE', None, {'username': 'ZOE', 'password': 'ZOE'}
'password': 'ZOE', )
})
assert authorized['name'] == 'zoe' assert authorized['name'] == 'zoe'
authorized = await a.get_authenticated_user(None, { authorized = await a.get_authenticated_user(
'username': 'Glenn', None, {'username': 'Glenn', 'password': 'Glenn'}
'password': 'Glenn', )
})
assert authorized['name'] == 'glenn' assert authorized['name'] == 'glenn'
authorized = await a.get_authenticated_user(None, { authorized = await a.get_authenticated_user(
'username': 'hExi', None, {'username': 'hExi', 'password': 'hExi'}
'password': 'hExi', )
})
assert authorized['name'] == 'hexi' assert authorized['name'] == 'hexi'
authorized = await a.get_authenticated_user(None, { authorized = await a.get_authenticated_user(
'username': 'Test', None, {'username': 'Test', 'password': 'Test'}
'password': 'Test', )
})
assert authorized['name'] == 'test' assert authorized['name'] == 'test'
async def test_username_map(): async def test_username_map():
a = MockPAMAuthenticator(username_map={'wash': 'alpha'}) a = MockPAMAuthenticator(username_map={'wash': 'alpha'})
authorized = await a.get_authenticated_user(None, { authorized = await a.get_authenticated_user(
'username': 'WASH', None, {'username': 'WASH', 'password': 'WASH'}
'password': 'WASH', )
})
assert authorized['name'] == 'alpha' assert authorized['name'] == 'alpha'
authorized = await a.get_authenticated_user(None, { authorized = await a.get_authenticated_user(
'username': 'Inara', None, {'username': 'Inara', 'password': 'Inara'}
'password': 'Inara', )
})
assert authorized['name'] == 'inara' assert authorized['name'] == 'inara'
@@ -463,9 +456,8 @@ def test_post_auth_hook():
a = MockPAMAuthenticator(post_auth_hook=test_auth_hook) a = MockPAMAuthenticator(post_auth_hook=test_auth_hook)
authorized = yield a.get_authenticated_user(None, { authorized = yield a.get_authenticated_user(
'username': 'test_user', None, {'username': 'test_user', 'password': 'test_user'}
'password': 'test_user' )
})
assert authorized['testkey'] == 'testvalue' assert authorized['testkey'] == 'testvalue'

View File

@@ -7,15 +7,16 @@ authentication can expire in a number of ways:
- doesn't need refresh - doesn't need refresh
- needs refresh and cannot be refreshed without new login - needs refresh and cannot be refreshed without new login
""" """
import asyncio import asyncio
from contextlib import contextmanager from contextlib import contextmanager
from unittest import mock from unittest import mock
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs
from urllib.parse import urlparse
import pytest import pytest
from .utils import api_request, get_page from .utils import api_request
from .utils import get_page
async def refresh_expired(authenticator, user): async def refresh_expired(authenticator, user):

View File

@@ -1,24 +1,29 @@
from binascii import b2a_hex, b2a_base64
import os import os
from binascii import b2a_base64
import pytest from binascii import b2a_hex
from unittest.mock import patch from unittest.mock import patch
import pytest
from .. import crypto from .. import crypto
from ..crypto import encrypt, decrypt from ..crypto import decrypt
from ..crypto import encrypt
keys = [('%i' % i).encode('ascii') * 32 for i in range(3)] keys = [('%i' % i).encode('ascii') * 32 for i in range(3)]
hex_keys = [b2a_hex(key).decode('ascii') for key in keys] hex_keys = [b2a_hex(key).decode('ascii') for key in keys]
b64_keys = [b2a_base64(key).decode('ascii').strip() for key in keys] b64_keys = [b2a_base64(key).decode('ascii').strip() for key in keys]
@pytest.mark.parametrize("key_env, keys", [ @pytest.mark.parametrize(
"key_env, keys",
[
(hex_keys[0], [keys[0]]), (hex_keys[0], [keys[0]]),
(';'.join([b64_keys[0], hex_keys[1]]), keys[:2]), (';'.join([b64_keys[0], hex_keys[1]]), keys[:2]),
(';'.join([hex_keys[0], b64_keys[1], '']), keys[:2]), (';'.join([hex_keys[0], b64_keys[1], '']), keys[:2]),
('', []), ('', []),
(';', []), (';', []),
]) ],
)
def test_env_constructor(key_env, keys): def test_env_constructor(key_env, keys):
with patch.dict(os.environ, {crypto.KEY_ENV: key_env}): with patch.dict(os.environ, {crypto.KEY_ENV: key_env}):
ck = crypto.CryptKeeper() ck = crypto.CryptKeeper()
@@ -29,12 +34,15 @@ def test_env_constructor(key_env, keys):
assert ck.fernet is None assert ck.fernet is None
@pytest.mark.parametrize("key", [ @pytest.mark.parametrize(
"key",
[
'a' * 44, # base64, not 32 bytes 'a' * 44, # base64, not 32 bytes
('%44s' % 'notbase64'), # not base64 ('%44s' % 'notbase64'), # not base64
b'x' * 64, # not hex b'x' * 64, # not hex
b'short', # not 32 bytes b'short', # not 32 bytes
]) ],
)
def test_bad_keys(key): def test_bad_keys(key):
ck = crypto.CryptKeeper() ck = crypto.CryptKeeper()
with pytest.raises(ValueError): with pytest.raises(ValueError):
@@ -76,4 +84,3 @@ async def test_missing_keys(crypt_keeper):
with pytest.raises(crypto.NoEncryptionKeys): with pytest.raises(crypto.NoEncryptionKeys):
await decrypt(b'whatever') await decrypt(b'whatever')

View File

@@ -1,14 +1,16 @@
from glob import glob
import os import os
from subprocess import check_call
import sys import sys
import tempfile import tempfile
from glob import glob
from subprocess import check_call
import pytest import pytest
from pytest import raises from pytest import raises
from traitlets.config import Config from traitlets.config import Config
from ..app import NewToken, UpgradeDB, JupyterHub from ..app import JupyterHub
from ..app import NewToken
from ..app import UpgradeDB
here = os.path.abspath(os.path.dirname(__file__)) here = os.path.abspath(os.path.dirname(__file__))
@@ -33,13 +35,7 @@ def generate_old_db(env_dir, hub_version, db_url):
check_call([env_py, populate_db, db_url]) check_call([env_py, populate_db, db_url])
@pytest.mark.parametrize( @pytest.mark.parametrize('hub_version', ['0.7.2', '0.8.1'])
'hub_version',
[
'0.7.2',
'0.8.1',
],
)
async def test_upgrade(tmpdir, hub_version): async def test_upgrade(tmpdir, hub_version):
db_url = os.getenv('JUPYTERHUB_TEST_DB_URL') db_url = os.getenv('JUPYTERHUB_TEST_DB_URL')
if db_url: if db_url:

View File

@@ -1,52 +1,47 @@
"""Tests for dummy authentication""" """Tests for dummy authentication"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import pytest import pytest
from jupyterhub.auth import DummyAuthenticator from jupyterhub.auth import DummyAuthenticator
async def test_dummy_auth_without_global_password(): async def test_dummy_auth_without_global_password():
authenticator = DummyAuthenticator() authenticator = DummyAuthenticator()
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'test_user', None, {'username': 'test_user', 'password': 'test_pass'}
'password': 'test_pass', )
})
assert authorized['name'] == 'test_user' assert authorized['name'] == 'test_user'
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'test_user', None, {'username': 'test_user', 'password': ''}
'password': '', )
})
assert authorized['name'] == 'test_user' assert authorized['name'] == 'test_user'
async def test_dummy_auth_without_username(): async def test_dummy_auth_without_username():
authenticator = DummyAuthenticator() authenticator = DummyAuthenticator()
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': '', None, {'username': '', 'password': 'test_pass'}
'password': 'test_pass', )
})
assert authorized is None assert authorized is None
async def test_dummy_auth_with_global_password(): async def test_dummy_auth_with_global_password():
authenticator = DummyAuthenticator() authenticator = DummyAuthenticator()
authenticator.password = "test_password" authenticator.password = "test_password"
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'test_user', None, {'username': 'test_user', 'password': 'test_password'}
'password': 'test_password', )
})
assert authorized['name'] == 'test_user' assert authorized['name'] == 'test_user'
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'test_user', None, {'username': 'test_user', 'password': 'qwerty'}
'password': 'qwerty', )
})
assert authorized is None assert authorized is None
authorized = await authenticator.get_authenticated_user(None, { authorized = await authenticator.get_authenticated_user(
'username': 'some_other_user', None, {'username': 'some_other_user', 'password': 'test_password'}
'password': 'test_password', )
})
assert authorized['name'] == 'some_other_user' assert authorized['name'] == 'some_other_user'

View File

@@ -1,7 +1,6 @@
"""Tests for the SSL enabled REST API.""" """Tests for the SSL enabled REST API."""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from jupyterhub.tests.test_api import * from jupyterhub.tests.test_api import *
ssl_enabled = True ssl_enabled = True

View File

@@ -1,10 +1,9 @@
"""Test the JupyterHub entry point with internal ssl""" """Test the JupyterHub entry point with internal ssl"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import sys import sys
import jupyterhub.tests.mocking
import jupyterhub.tests.mocking
from jupyterhub.tests.test_app import * from jupyterhub.tests.test_app import *
ssl_enabled = True ssl_enabled = True

View File

@@ -1,19 +1,17 @@
"""Tests for jupyterhub internal_ssl connections""" """Tests for jupyterhub internal_ssl connections"""
import sys
import time import time
from subprocess import check_output from subprocess import check_output
import sys from unittest import mock
from urllib.parse import urlparse from urllib.parse import urlparse
import pytest import pytest
from requests.exceptions import SSLError
from tornado import gen
import jupyterhub import jupyterhub
from tornado import gen
from unittest import mock
from requests.exceptions import SSLError
from .utils import async_requests
from .test_api import add_user from .test_api import add_user
from .utils import async_requests
ssl_enabled = True ssl_enabled = True
@@ -25,8 +23,10 @@ def wait_for_spawner(spawner, timeout=10):
polling at shorter intervals for early termination polling at shorter intervals for early termination
""" """
deadline = time.monotonic() + timeout deadline = time.monotonic() + timeout
def wait(): def wait():
return spawner.server.wait_up(timeout=1, http=True) return spawner.server.wait_up(timeout=1, http=True)
while time.monotonic() < deadline: while time.monotonic() < deadline:
status = yield spawner.poll() status = yield spawner.poll()
assert status is None assert status is None
@@ -59,7 +59,7 @@ async def test_connection_notebook_wrong_certs(app):
"""Connecting to a notebook fails without correct certs""" """Connecting to a notebook fails without correct certs"""
with mock.patch.dict( with mock.patch.dict(
app.config.LocalProcessSpawner, app.config.LocalProcessSpawner,
{'cmd': [sys.executable, '-m', 'jupyterhub.tests.mocksu']} {'cmd': [sys.executable, '-m', 'jupyterhub.tests.mocksu']},
): ):
user = add_user(app.db, app, name='foo') user = add_user(app.db, app, name='foo')
await user.spawn() await user.spawn()

View File

@@ -1,7 +1,6 @@
"""Tests for process spawning with internal_ssl""" """Tests for process spawning with internal_ssl"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from jupyterhub.tests.test_spawner import * from jupyterhub.tests.test_spawner import *
ssl_enabled = True ssl_enabled = True

View File

@@ -5,15 +5,21 @@ from unittest import mock
import pytest import pytest
from ..utils import url_path_join from ..utils import url_path_join
from .test_api import api_request, add_user, fill_user, normalize_user, TIMESTAMP
from .mocking import public_url from .mocking import public_url
from .test_api import add_user
from .test_api import api_request
from .test_api import fill_user
from .test_api import normalize_user
from .test_api import TIMESTAMP
from .utils import async_requests from .utils import async_requests
@pytest.fixture @pytest.fixture
def named_servers(app): def named_servers(app):
with mock.patch.dict(app.tornado_settings, with mock.patch.dict(
{'allow_named_servers': True, 'named_server_limit_per_user': 2}): app.tornado_settings,
{'allow_named_servers': True, 'named_server_limit_per_user': 2},
):
yield yield
@@ -30,7 +36,8 @@ async def test_default_server(app, named_servers):
user_model = normalize_user(r.json()) user_model = normalize_user(r.json())
print(user_model) print(user_model)
assert user_model == fill_user({ assert user_model == fill_user(
{
'name': username, 'name': username,
'auth_state': None, 'auth_state': None,
'server': user.url, 'server': user.url,
@@ -42,12 +49,14 @@ async def test_default_server(app, named_servers):
'url': user.url, 'url': user.url,
'pending': None, 'pending': None,
'ready': True, 'ready': True,
'progress_url': 'PREFIX/hub/api/users/{}/server/progress'.format(username), 'progress_url': 'PREFIX/hub/api/users/{}/server/progress'.format(
username
),
'state': {'pid': 0}, 'state': {'pid': 0},
}
}, },
}, }
}) )
# now stop the server # now stop the server
r = await api_request(app, 'users', username, 'server', method='delete') r = await api_request(app, 'users', username, 'server', method='delete')
@@ -58,11 +67,9 @@ async def test_default_server(app, named_servers):
r.raise_for_status() r.raise_for_status()
user_model = normalize_user(r.json()) user_model = normalize_user(r.json())
assert user_model == fill_user({ assert user_model == fill_user(
'name': username, {'name': username, 'servers': {}, 'auth_state': None}
'servers': {}, )
'auth_state': None,
})
async def test_create_named_server(app, named_servers): async def test_create_named_server(app, named_servers):
@@ -89,7 +96,8 @@ async def test_create_named_server(app, named_servers):
r.raise_for_status() r.raise_for_status()
user_model = normalize_user(r.json()) user_model = normalize_user(r.json())
assert user_model == fill_user({ assert user_model == fill_user(
{
'name': username, 'name': username,
'auth_state': None, 'auth_state': None,
'servers': { 'servers': {
@@ -101,12 +109,14 @@ async def test_create_named_server(app, named_servers):
'pending': None, 'pending': None,
'ready': True, 'ready': True,
'progress_url': 'PREFIX/hub/api/users/{}/servers/{}/progress'.format( 'progress_url': 'PREFIX/hub/api/users/{}/servers/{}/progress'.format(
username, servername), username, servername
),
'state': {'pid': 0}, 'state': {'pid': 0},
} }
for name in [servername] for name in [servername]
}, },
}) }
)
async def test_delete_named_server(app, named_servers): async def test_delete_named_server(app, named_servers):
@@ -119,7 +129,9 @@ async def test_delete_named_server(app, named_servers):
r.raise_for_status() r.raise_for_status()
assert r.status_code == 201 assert r.status_code == 201
r = await api_request(app, 'users', username, 'servers', servername, method='delete') r = await api_request(
app, 'users', username, 'servers', servername, method='delete'
)
r.raise_for_status() r.raise_for_status()
assert r.status_code == 204 assert r.status_code == 204
@@ -127,18 +139,20 @@ async def test_delete_named_server(app, named_servers):
r.raise_for_status() r.raise_for_status()
user_model = normalize_user(r.json()) user_model = normalize_user(r.json())
assert user_model == fill_user({ assert user_model == fill_user(
'name': username, {'name': username, 'auth_state': None, 'servers': {}}
'auth_state': None, )
'servers': {},
})
# wrapper Spawner is gone # wrapper Spawner is gone
assert servername not in user.spawners assert servername not in user.spawners
# low-level record still exists # low-level record still exists
assert servername in user.orm_spawners assert servername in user.orm_spawners
r = await api_request( r = await api_request(
app, 'users', username, 'servers', servername, app,
'users',
username,
'servers',
servername,
method='delete', method='delete',
data=json.dumps({'remove': True}), data=json.dumps({'remove': True}),
) )
@@ -153,7 +167,9 @@ async def test_named_server_disabled(app):
servername = 'okay' servername = 'okay'
r = await api_request(app, 'users', username, 'servers', servername, method='post') r = await api_request(app, 'users', username, 'servers', servername, method='post')
assert r.status_code == 400 assert r.status_code == 400
r = await api_request(app, 'users', username, 'servers', servername, method='delete') r = await api_request(
app, 'users', username, 'servers', servername, method='delete'
)
assert r.status_code == 400 assert r.status_code == 400
@@ -180,7 +196,10 @@ async def test_named_server_limit(app, named_servers):
servername3 = 'bar-3' servername3 = 'bar-3'
r = await api_request(app, 'users', username, 'servers', servername3, method='post') r = await api_request(app, 'users', username, 'servers', servername3, method='post')
assert r.status_code == 400 assert r.status_code == 400
assert r.json() == {"status": 400, "message": "User foo already has the maximum of 2 named servers. One must be deleted before a new server can be created"} assert r.json() == {
"status": 400,
"message": "User foo already has the maximum of 2 named servers. One must be deleted before a new server can be created",
}
# Create default server # Create default server
r = await api_request(app, 'users', username, 'server', method='post') r = await api_request(app, 'users', username, 'server', method='post')
@@ -189,7 +208,11 @@ async def test_named_server_limit(app, named_servers):
# Delete 1st named server # Delete 1st named server
r = await api_request( r = await api_request(
app, 'users', username, 'servers', servername1, app,
'users',
username,
'servers',
servername1,
method='delete', method='delete',
data=json.dumps({'remove': True}), data=json.dumps({'remove': True}),
) )

View File

@@ -1,6 +1,6 @@
"""Tests for basic object-wrappers""" """Tests for basic object-wrappers"""
import socket import socket
import pytest import pytest
from jupyterhub.objects import Server from jupyterhub.objects import Server
@@ -16,7 +16,7 @@ from jupyterhub.objects import Server
'port': 123, 'port': 123,
'host': 'http://abc:123', 'host': 'http://abc:123',
'url': 'http://abc:123/x/', 'url': 'http://abc:123/x/',
} },
), ),
( (
'https://abc', 'https://abc',
@@ -26,9 +26,9 @@ from jupyterhub.objects import Server
'proto': 'https', 'proto': 'https',
'host': 'https://abc:443', 'host': 'https://abc:443',
'url': 'https://abc:443/x/', 'url': 'https://abc:443/x/',
} },
), ),
] ],
) )
def test_bind_url(bind_url, attrs): def test_bind_url(bind_url, attrs):
s = Server(bind_url=bind_url, base_url='/x/') s = Server(bind_url=bind_url, base_url='/x/')
@@ -43,26 +43,28 @@ _hostname = socket.gethostname()
'ip, port, attrs', 'ip, port, attrs',
[ [
( (
'', 123, '',
123,
{ {
'ip': '', 'ip': '',
'port': 123, 'port': 123,
'host': 'http://{}:123'.format(_hostname), 'host': 'http://{}:123'.format(_hostname),
'url': 'http://{}:123/x/'.format(_hostname), 'url': 'http://{}:123/x/'.format(_hostname),
'bind_url': 'http://*:123/x/', 'bind_url': 'http://*:123/x/',
} },
), ),
( (
'127.0.0.1', 999, '127.0.0.1',
999,
{ {
'ip': '127.0.0.1', 'ip': '127.0.0.1',
'port': 999, 'port': 999,
'host': 'http://127.0.0.1:999', 'host': 'http://127.0.0.1:999',
'url': 'http://127.0.0.1:999/x/', 'url': 'http://127.0.0.1:999/x/',
'bind_url': 'http://127.0.0.1:999/x/', 'bind_url': 'http://127.0.0.1:999/x/',
} },
), ),
] ],
) )
def test_ip_port(ip, port, attrs): def test_ip_port(ip, port, attrs):
s = Server(ip=ip, port=port, base_url='/x/') s = Server(ip=ip, port=port, base_url='/x/')

View File

@@ -1,22 +1,21 @@
"""Tests for the ORM bits""" """Tests for the ORM bits"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from datetime import datetime, timedelta
import os import os
import socket import socket
from datetime import datetime
from datetime import timedelta
from unittest import mock from unittest import mock
import pytest import pytest
from tornado import gen from tornado import gen
from .. import orm
from .. import objects
from .. import crypto from .. import crypto
from .. import objects
from .. import orm
from ..emptyclass import EmptyClass
from ..user import User from ..user import User
from .mocking import MockSpawner from .mocking import MockSpawner
from ..emptyclass import EmptyClass
def assert_not_found(db, ORMType, id): def assert_not_found(db, ORMType, id):
@@ -124,7 +123,9 @@ def test_token_expiry(db):
# approximate range # approximate range
assert orm_token.expires_at > now + timedelta(seconds=50) assert orm_token.expires_at > now + timedelta(seconds=50)
assert orm_token.expires_at < now + timedelta(seconds=70) assert orm_token.expires_at < now + timedelta(seconds=70)
the_future = mock.patch('jupyterhub.orm.utcnow', lambda : now + timedelta(seconds=70)) the_future = mock.patch(
'jupyterhub.orm.utcnow', lambda: now + timedelta(seconds=70)
)
with the_future: with the_future:
found = orm.APIToken.find(db, token=token) found = orm.APIToken.find(db, token=token)
assert found is None assert found is None
@@ -215,11 +216,9 @@ async def test_spawn_fails(db):
def start(self): def start(self):
raise RuntimeError("Split the party") raise RuntimeError("Split the party")
user = User(orm_user, { user = User(
'spawner_class': BadSpawner, orm_user, {'spawner_class': BadSpawner, 'config': None, 'statsd': EmptyClass()}
'config': None, )
'statsd': EmptyClass(),
})
with pytest.raises(RuntimeError) as exc: with pytest.raises(RuntimeError) as exc:
await user.spawn() await user.spawn()
@@ -346,9 +345,7 @@ def test_user_delete_cascade(db):
oauth_code = orm.OAuthCode(client=oauth_client, user=user) oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code) db.add(oauth_code)
oauth_token = orm.OAuthAccessToken( oauth_token = orm.OAuthAccessToken(
client=oauth_client, client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
user=user,
grant_type=orm.GrantType.authorization_code,
) )
db.add(oauth_token) db.add(oauth_token)
db.commit() db.commit()
@@ -384,9 +381,7 @@ def test_oauth_client_delete_cascade(db):
oauth_code = orm.OAuthCode(client=oauth_client, user=user) oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code) db.add(oauth_code)
oauth_token = orm.OAuthAccessToken( oauth_token = orm.OAuthAccessToken(
client=oauth_client, client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
user=user,
grant_type=orm.GrantType.authorization_code,
) )
db.add(oauth_token) db.add(oauth_token)
db.commit() db.commit()
@@ -477,6 +472,3 @@ def test_group_delete_cascade(db):
db.delete(user1) db.delete(user1)
db.commit() db.commit()
assert user1 not in group1.users assert user1 not in group1.users

View File

@@ -1,30 +1,27 @@
"""Tests for HTML pages""" """Tests for HTML pages"""
import asyncio import asyncio
import sys import sys
from unittest import mock from unittest import mock
from urllib.parse import urlencode, urlparse from urllib.parse import urlencode
from urllib.parse import urlparse
import pytest
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from tornado import gen from tornado import gen
from tornado.httputil import url_concat from tornado.httputil import url_concat
from ..handlers import BaseHandler
from ..utils import url_path_join as ujoin
from .. import orm from .. import orm
from ..auth import Authenticator from ..auth import Authenticator
from ..handlers import BaseHandler
import pytest from ..utils import url_path_join as ujoin
from .mocking import FalsyCallableFormSpawner
from .mocking import FormSpawner, FalsyCallableFormSpawner from .mocking import FormSpawner
from .utils import ( from .utils import add_user
async_requests, from .utils import api_request
api_request, from .utils import async_requests
add_user, from .utils import get_page
get_page, from .utils import public_host
public_url, from .utils import public_url
public_host,
)
async def test_root_no_auth(app): async def test_root_no_auth(app):
@@ -53,8 +50,7 @@ async def test_root_redirect(app):
async def test_root_default_url_noauth(app): async def test_root_default_url_noauth(app):
with mock.patch.dict(app.tornado_settings, with mock.patch.dict(app.tornado_settings, {'default_url': '/foo/bar'}):
{'default_url': '/foo/bar'}):
r = await get_page('/', app, allow_redirects=False) r = await get_page('/', app, allow_redirects=False)
r.raise_for_status() r.raise_for_status()
url = r.headers.get('Location', '') url = r.headers.get('Location', '')
@@ -65,8 +61,7 @@ async def test_root_default_url_noauth(app):
async def test_root_default_url_auth(app): async def test_root_default_url_auth(app):
name = 'wash' name = 'wash'
cookies = await app.login_user(name) cookies = await app.login_user(name)
with mock.patch.dict(app.tornado_settings, with mock.patch.dict(app.tornado_settings, {'default_url': '/foo/bar'}):
{'default_url': '/foo/bar'}):
r = await get_page('/', app, cookies=cookies, allow_redirects=False) r = await get_page('/', app, cookies=cookies, allow_redirects=False)
r.raise_for_status() r.raise_for_status()
url = r.headers.get('Location', '') url = r.headers.get('Location', '')
@@ -106,12 +101,7 @@ async def test_admin(app):
assert r.url.endswith('/admin') assert r.url.endswith('/admin')
@pytest.mark.parametrize('sort', [ @pytest.mark.parametrize('sort', ['running', 'last_activity', 'admin', 'name'])
'running',
'last_activity',
'admin',
'name',
])
async def test_admin_sort(app, sort): async def test_admin_sort(app, sort):
cookies = await app.login_user('admin') cookies = await app.login_user('admin')
r = await get_page('admin?sort=%s' % sort, app, cookies=cookies) r = await get_page('admin?sort=%s' % sort, app, cookies=cookies)
@@ -146,7 +136,9 @@ async def test_spawn_redirect(app):
assert path == ujoin(app.base_url, '/user/%s/' % name) assert path == ujoin(app.base_url, '/user/%s/' % name)
# stop server to ensure /user/name is handled by the Hub # stop server to ensure /user/name is handled by the Hub
r = await api_request(app, 'users', name, 'server', method='delete', cookies=cookies) r = await api_request(
app, 'users', name, 'server', method='delete', cookies=cookies
)
r.raise_for_status() r.raise_for_status()
# test handing of trailing slash on `/user/name` # test handing of trailing slash on `/user/name`
@@ -208,7 +200,9 @@ async def test_spawn_page(app):
async def test_spawn_page_falsy_callable(app): async def test_spawn_page_falsy_callable(app):
with mock.patch.dict(app.users.settings, {'spawner_class': FalsyCallableFormSpawner}): with mock.patch.dict(
app.users.settings, {'spawner_class': FalsyCallableFormSpawner}
):
cookies = await app.login_user('erik') cookies = await app.login_user('erik')
r = await get_page('spawn', app, cookies=cookies) r = await get_page('spawn', app, cookies=cookies)
assert 'user/erik' in r.url assert 'user/erik' in r.url
@@ -276,22 +270,22 @@ async def test_spawn_form_with_file(app):
u = app.users[orm_u] u = app.users[orm_u]
await u.stop() await u.stop()
r = await async_requests.post(ujoin(base_url, 'spawn'), r = await async_requests.post(
ujoin(base_url, 'spawn'),
cookies=cookies, cookies=cookies,
data={ data={'bounds': ['-1', '1'], 'energy': '511keV'},
'bounds': ['-1', '1'], files={'hello': ('hello.txt', b'hello world\n')},
'energy': '511keV',
},
files={'hello': ('hello.txt', b'hello world\n')}
) )
r.raise_for_status() r.raise_for_status()
assert u.spawner.user_options == { assert u.spawner.user_options == {
'energy': '511keV', 'energy': '511keV',
'bounds': [-1, 1], 'bounds': [-1, 1],
'notspecified': 5, 'notspecified': 5,
'hello': {'filename': 'hello.txt', 'hello': {
'filename': 'hello.txt',
'body': b'hello world\n', 'body': b'hello world\n',
'content_type': 'application/unknown'}, 'content_type': 'application/unknown',
},
} }
@@ -305,9 +299,9 @@ async def test_user_redirect(app):
path = urlparse(r.url).path path = urlparse(r.url).path
assert path == ujoin(app.base_url, '/hub/login') assert path == ujoin(app.base_url, '/hub/login')
query = urlparse(r.url).query query = urlparse(r.url).query
assert query == urlencode({ assert query == urlencode(
'next': ujoin(app.hub.base_url, '/user-redirect/tree/top/') {'next': ujoin(app.hub.base_url, '/user-redirect/tree/top/')}
}) )
r = await get_page('/user-redirect/notebooks/test.ipynb', app, cookies=cookies) r = await get_page('/user-redirect/notebooks/test.ipynb', app, cookies=cookies)
r.raise_for_status() r.raise_for_status()
@@ -339,19 +333,17 @@ async def test_user_redirect_deprecated(app):
path = urlparse(r.url).path path = urlparse(r.url).path
assert path == ujoin(app.base_url, '/hub/login') assert path == ujoin(app.base_url, '/hub/login')
query = urlparse(r.url).query query = urlparse(r.url).query
assert query == urlencode({ assert query == urlencode(
'next': ujoin(app.base_url, '/hub/user/baduser/test.ipynb') {'next': ujoin(app.base_url, '/hub/user/baduser/test.ipynb')}
}) )
async def test_login_fail(app): async def test_login_fail(app):
name = 'wash' name = 'wash'
base_url = public_url(app) base_url = public_url(app)
r = await async_requests.post(base_url + 'hub/login', r = await async_requests.post(
data={ base_url + 'hub/login',
'username': name, data={'username': name, 'password': 'wrong'},
'password': 'wrong',
},
allow_redirects=False, allow_redirects=False,
) )
assert not r.cookies assert not r.cookies
@@ -359,20 +351,17 @@ async def test_login_fail(app):
async def test_login_strip(app): async def test_login_strip(app):
"""Test that login form doesn't strip whitespace from passwords""" """Test that login form doesn't strip whitespace from passwords"""
form_data = { form_data = {'username': 'spiff', 'password': ' space man '}
'username': 'spiff',
'password': ' space man ',
}
base_url = public_url(app) base_url = public_url(app)
called_with = [] called_with = []
@gen.coroutine @gen.coroutine
def mock_authenticate(handler, data): def mock_authenticate(handler, data):
called_with.append(data) called_with.append(data)
with mock.patch.object(app.authenticator, 'authenticate', mock_authenticate): with mock.patch.object(app.authenticator, 'authenticate', mock_authenticate):
await async_requests.post(base_url + 'hub/login', await async_requests.post(
data=form_data, base_url + 'hub/login', data=form_data, allow_redirects=False
allow_redirects=False,
) )
assert called_with == [form_data] assert called_with == [form_data]
@@ -389,12 +378,11 @@ async def test_login_strip(app):
(False, '/user/other', '/hub/user/other'), (False, '/user/other', '/hub/user/other'),
(False, '/absolute', '/absolute'), (False, '/absolute', '/absolute'),
(False, '/has?query#andhash', '/has?query#andhash'), (False, '/has?query#andhash', '/has?query#andhash'),
# next_url outside is not allowed # next_url outside is not allowed
(False, 'https://other.domain', ''), (False, 'https://other.domain', ''),
(False, 'ftp://other.domain', ''), (False, 'ftp://other.domain', ''),
(False, '//other.domain', ''), (False, '//other.domain', ''),
] ],
) )
async def test_login_redirect(app, running, next_url, location): async def test_login_redirect(app, running, next_url, location):
cookies = await app.login_user('river') cookies = await app.login_user('river')
@@ -427,10 +415,11 @@ async def test_auto_login(app, request):
class DummyLoginHandler(BaseHandler): class DummyLoginHandler(BaseHandler):
def get(self): def get(self):
self.write('ok!') self.write('ok!')
base_url = public_url(app) + '/' base_url = public_url(app) + '/'
app.tornado_application.add_handlers(".*$", [ app.tornado_application.add_handlers(
(ujoin(app.hub.base_url, 'dummy'), DummyLoginHandler), ".*$", [(ujoin(app.hub.base_url, 'dummy'), DummyLoginHandler)]
]) )
# no auto_login: end up at /hub/login # no auto_login: end up at /hub/login
r = await async_requests.get(base_url) r = await async_requests.get(base_url)
assert r.url == public_url(app, path='hub/login') assert r.url == public_url(app, path='hub/login')
@@ -438,9 +427,7 @@ async def test_auto_login(app, request):
authenticator = Authenticator(auto_login=True) authenticator = Authenticator(auto_login=True)
authenticator.login_url = lambda base_url: ujoin(base_url, 'dummy') authenticator.login_url = lambda base_url: ujoin(base_url, 'dummy')
with mock.patch.dict(app.tornado_settings, { with mock.patch.dict(app.tornado_settings, {'authenticator': authenticator}):
'authenticator': authenticator,
}):
r = await async_requests.get(base_url) r = await async_requests.get(base_url)
assert r.url == public_url(app, path='hub/dummy') assert r.url == public_url(app, path='hub/dummy')
@@ -449,10 +436,12 @@ async def test_auto_login_logout(app):
name = 'burnham' name = 'burnham'
cookies = await app.login_user(name) cookies = await app.login_user(name)
with mock.patch.dict(app.tornado_settings, { with mock.patch.dict(
'authenticator': Authenticator(auto_login=True), app.tornado_settings, {'authenticator': Authenticator(auto_login=True)}
}): ):
r = await async_requests.get(public_host(app) + app.tornado_settings['logout_url'], cookies=cookies) r = await async_requests.get(
public_host(app) + app.tornado_settings['logout_url'], cookies=cookies
)
r.raise_for_status() r.raise_for_status()
logout_url = public_host(app) + app.tornado_settings['logout_url'] logout_url = public_host(app) + app.tornado_settings['logout_url']
assert r.url == logout_url assert r.url == logout_url
@@ -462,7 +451,9 @@ async def test_auto_login_logout(app):
async def test_logout(app): async def test_logout(app):
name = 'wash' name = 'wash'
cookies = await app.login_user(name) cookies = await app.login_user(name)
r = await async_requests.get(public_host(app) + app.tornado_settings['logout_url'], cookies=cookies) r = await async_requests.get(
public_host(app) + app.tornado_settings['logout_url'], cookies=cookies
)
r.raise_for_status() r.raise_for_status()
login_url = public_host(app) + app.tornado_settings['login_url'] login_url = public_host(app) + app.tornado_settings['login_url']
assert r.url == login_url assert r.url == login_url
@@ -489,12 +480,11 @@ async def test_shutdown_on_logout(app, shutdown_on_logout):
assert spawner.active assert spawner.active
# logout # logout
with mock.patch.dict(app.tornado_settings, { with mock.patch.dict(
'shutdown_on_logout': shutdown_on_logout, app.tornado_settings, {'shutdown_on_logout': shutdown_on_logout}
}): ):
r = await async_requests.get( r = await async_requests.get(
public_host(app) + app.tornado_settings['logout_url'], public_host(app) + app.tornado_settings['logout_url'], cookies=cookies
cookies=cookies,
) )
r.raise_for_status() r.raise_for_status()
@@ -549,7 +539,9 @@ async def test_oauth_token_page(app):
user = app.users[orm.User.find(app.db, name)] user = app.users[orm.User.find(app.db, name)]
client = orm.OAuthClient(identifier='token') client = orm.OAuthClient(identifier='token')
app.db.add(client) app.db.add(client)
oauth_token = orm.OAuthAccessToken(client=client, user=user, grant_type=orm.GrantType.authorization_code) oauth_token = orm.OAuthAccessToken(
client=client, user=user, grant_type=orm.GrantType.authorization_code
)
app.db.add(oauth_token) app.db.add(oauth_token)
app.db.commit() app.db.commit()
r = await get_page('token', app, cookies=cookies) r = await get_page('token', app, cookies=cookies)
@@ -557,23 +549,14 @@ async def test_oauth_token_page(app):
assert r.status_code == 200 assert r.status_code == 200
@pytest.mark.parametrize("error_status", [ @pytest.mark.parametrize("error_status", [503, 404])
503,
404,
])
async def test_proxy_error(app, error_status): async def test_proxy_error(app, error_status):
r = await get_page('/error/%i' % error_status, app) r = await get_page('/error/%i' % error_status, app)
assert r.status_code == 200 assert r.status_code == 200
@pytest.mark.parametrize( @pytest.mark.parametrize(
"announcements", "announcements", ["", "spawn", "spawn,home,login", "login,logout"]
[
"",
"spawn",
"spawn,home,login",
"login,logout",
]
) )
async def test_announcements(app, announcements): async def test_announcements(app, announcements):
"""Test announcements on various pages""" """Test announcements on various pages"""
@@ -618,16 +601,13 @@ async def test_announcements(app, announcements):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"params", "params", ["", "redirect_uri=/noexist", "redirect_uri=ok&client_id=nosuchthing"]
[
"",
"redirect_uri=/noexist",
"redirect_uri=ok&client_id=nosuchthing",
]
) )
async def test_bad_oauth_get(app, params): async def test_bad_oauth_get(app, params):
cookies = await app.login_user("authorizer") cookies = await app.login_user("authorizer")
r = await get_page("hub/api/oauth2/authorize?" + params, app, hub=False, cookies=cookies) r = await get_page(
"hub/api/oauth2/authorize?" + params, app, hub=False, cookies=cookies
)
assert r.status_code == 400 assert r.status_code == 400
@@ -637,11 +617,14 @@ async def test_token_page(app):
r = await get_page("token", app, cookies=cookies) r = await get_page("token", app, cookies=cookies)
r.raise_for_status() r.raise_for_status()
assert urlparse(r.url).path.endswith('/hub/token') assert urlparse(r.url).path.endswith('/hub/token')
def extract_body(r): def extract_body(r):
soup = BeautifulSoup(r.text, "html5lib") soup = BeautifulSoup(r.text, "html5lib")
import re import re
# trim empty lines # trim empty lines
return re.sub(r"(\n\s*)+", "\n", soup.body.find(class_="container").text) return re.sub(r"(\n\s*)+", "\n", soup.body.find(class_="container").text)
body = extract_body(r) body = extract_body(r)
assert "Request new API token" in body, body assert "Request new API token" in body, body
# no tokens yet, no lists # no tokens yet, no lists

View File

@@ -1,20 +1,21 @@
"""Test a proxy being started before the Hub""" """Test a proxy being started before the Hub"""
from contextlib import contextmanager
import json import json
import os import os
from contextlib import contextmanager
from queue import Queue from queue import Queue
from subprocess import Popen from subprocess import Popen
from urllib.parse import urlparse, quote from urllib.parse import quote
from urllib.parse import urlparse
from traitlets.config import Config
import pytest import pytest
from traitlets.config import Config
from .. import orm from .. import orm
from ..utils import url_path_join as ujoin
from ..utils import wait_for_http_server
from .mocking import MockHub from .mocking import MockHub
from .test_api import api_request, add_user from .test_api import add_user
from ..utils import wait_for_http_server, url_path_join as ujoin from .test_api import api_request
@pytest.fixture @pytest.fixture
@@ -52,25 +53,30 @@ async def test_external_proxy(request):
env['CONFIGPROXY_AUTH_TOKEN'] = auth_token env['CONFIGPROXY_AUTH_TOKEN'] = auth_token
cmd = [ cmd = [
'configurable-http-proxy', 'configurable-http-proxy',
'--ip', app.ip, '--ip',
'--port', str(app.port), app.ip,
'--api-ip', proxy_ip, '--port',
'--api-port', str(proxy_port), str(app.port),
'--api-ip',
proxy_ip,
'--api-port',
str(proxy_port),
'--log-level=debug', '--log-level=debug',
] ]
if app.subdomain_host: if app.subdomain_host:
cmd.append('--host-routing') cmd.append('--host-routing')
proxy = Popen(cmd, env=env) proxy = Popen(cmd, env=env)
def _cleanup_proxy(): def _cleanup_proxy():
if proxy.poll() is None: if proxy.poll() is None:
proxy.terminate() proxy.terminate()
proxy.wait(timeout=10) proxy.wait(timeout=10)
request.addfinalizer(_cleanup_proxy) request.addfinalizer(_cleanup_proxy)
def wait_for_proxy(): def wait_for_proxy():
return wait_for_http_server('http://%s:%i' % (proxy_ip, proxy_port)) return wait_for_http_server('http://%s:%i' % (proxy_ip, proxy_port))
await wait_for_proxy() await wait_for_proxy()
await app.initialize([]) await app.initialize([])
@@ -84,8 +90,9 @@ async def test_external_proxy(request):
# add user to the db and start a single user server # add user to the db and start a single user server
name = 'river' name = 'river'
add_user(app.db, app, name=name) add_user(app.db, app, name=name)
r = await api_request(app, 'users', name, 'server', method='post', r = await api_request(
bypass_proxy=True) app, 'users', name, 'server', method='post', bypass_proxy=True
)
r.raise_for_status() r.raise_for_status()
routes = await app.proxy.get_all_routes() routes = await app.proxy.get_all_routes()
@@ -122,12 +129,18 @@ async def test_external_proxy(request):
new_auth_token = 'different!' new_auth_token = 'different!'
env['CONFIGPROXY_AUTH_TOKEN'] = new_auth_token env['CONFIGPROXY_AUTH_TOKEN'] = new_auth_token
proxy_port = 55432 proxy_port = 55432
cmd = ['configurable-http-proxy', cmd = [
'--ip', app.ip, 'configurable-http-proxy',
'--port', str(app.port), '--ip',
'--api-ip', proxy_ip, app.ip,
'--api-port', str(proxy_port), '--port',
'--default-target', 'http://%s:%i' % (app.hub_ip, app.hub_port), str(app.port),
'--api-ip',
proxy_ip,
'--api-port',
str(proxy_port),
'--default-target',
'http://%s:%i' % (app.hub_ip, app.hub_port),
] ]
if app.subdomain_host: if app.subdomain_host:
cmd.append('--host-routing') cmd.append('--host-routing')
@@ -140,10 +153,7 @@ async def test_external_proxy(request):
app, app,
'proxy', 'proxy',
method='patch', method='patch',
data=json.dumps({ data=json.dumps({'api_url': new_api_url, 'auth_token': new_auth_token}),
'api_url': new_api_url,
'auth_token': new_auth_token,
}),
bypass_proxy=True, bypass_proxy=True,
) )
r.raise_for_status() r.raise_for_status()
@@ -156,13 +166,7 @@ async def test_external_proxy(request):
assert sorted(routes.keys()) == [app.hub.routespec, user_spec] assert sorted(routes.keys()) == [app.hub.routespec, user_spec]
@pytest.mark.parametrize("username", [ @pytest.mark.parametrize("username", ['zoe', '50fia', '秀樹', '~TestJH', 'has@'])
'zoe',
'50fia',
'秀樹',
'~TestJH',
'has@',
])
async def test_check_routes(app, username, disable_check_routes): async def test_check_routes(app, username, disable_check_routes):
proxy = app.proxy proxy = app.proxy
test_user = add_user(app.db, app, name=username) test_user = add_user(app.db, app, name=username)
@@ -191,14 +195,17 @@ async def test_check_routes(app, username, disable_check_routes):
assert before == after assert before == after
@pytest.mark.parametrize("routespec", [ @pytest.mark.parametrize(
"routespec",
[
'/has%20space/foo/', '/has%20space/foo/',
'/missing-trailing/slash', '/missing-trailing/slash',
'/has/@/', '/has/@/',
'/has/' + quote('üñîçø∂é'), '/has/' + quote('üñîçø∂é'),
'host.name/path/', 'host.name/path/',
'other.host/path/no/slash', 'other.host/path/no/slash',
]) ],
)
async def test_add_get_delete(app, routespec, disable_check_routes): async def test_add_get_delete(app, routespec, disable_check_routes):
arg = routespec arg = routespec
if not routespec.endswith('/'): if not routespec.endswith('/'):
@@ -207,6 +214,7 @@ async def test_add_get_delete(app, routespec, disable_check_routes):
# host-routes when not host-routing raises an error # host-routes when not host-routing raises an error
# and vice versa # and vice versa
expect_value_error = bool(app.subdomain_host) ^ (not routespec.startswith('/')) expect_value_error = bool(app.subdomain_host) ^ (not routespec.startswith('/'))
@contextmanager @contextmanager
def context(): def context():
if expect_value_error: if expect_value_error:

View File

@@ -1,22 +1,26 @@
"""Tests for services""" """Tests for services"""
import asyncio import asyncio
import os
import sys
import time
from binascii import hexlify from binascii import hexlify
from contextlib import contextmanager from contextlib import contextmanager
import os
from subprocess import Popen from subprocess import Popen
import sys
from threading import Event from threading import Event
import time
from async_generator import asynccontextmanager, async_generator, yield_
import pytest import pytest
import requests import requests
from async_generator import async_generator
from async_generator import asynccontextmanager
from async_generator import yield_
from tornado import gen from tornado import gen
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from ..utils import maybe_future
from ..utils import random_port
from ..utils import url_path_join
from ..utils import wait_for_http_server
from .mocking import public_url from .mocking import public_url
from ..utils import url_path_join, wait_for_http_server, random_port, maybe_future
from .utils import async_requests from .utils import async_requests
mockservice_path = os.path.dirname(os.path.abspath(__file__)) mockservice_path = os.path.dirname(os.path.abspath(__file__))
@@ -80,12 +84,14 @@ async def test_proxy_service(app, mockservice_url):
async def test_external_service(app): async def test_external_service(app):
name = 'external' name = 'external'
async with external_service(app, name=name) as env: async with external_service(app, name=name) as env:
app.services = [{ app.services = [
{
'name': name, 'name': name,
'admin': True, 'admin': True,
'url': env['JUPYTERHUB_SERVICE_URL'], 'url': env['JUPYTERHUB_SERVICE_URL'],
'api_token': env['JUPYTERHUB_API_TOKEN'], 'api_token': env['JUPYTERHUB_API_TOKEN'],
}] }
]
await maybe_future(app.init_services()) await maybe_future(app.init_services())
await app.init_api_tokens() await app.init_api_tokens()
await app.proxy.add_all_services(app._service_map) await app.proxy.add_all_services(app._service_map)

View File

@@ -1,36 +1,43 @@
"""Tests for service authentication""" """Tests for service authentication"""
import asyncio import asyncio
from binascii import hexlify
import copy import copy
from functools import partial
import json import json
import os import os
from queue import Queue
import sys import sys
from binascii import hexlify
from functools import partial
from queue import Queue
from threading import Thread from threading import Thread
from unittest import mock from unittest import mock
from urllib.parse import urlparse from urllib.parse import urlparse
import pytest import pytest
from pytest import raises
import requests import requests
import requests_mock import requests_mock
from pytest import raises
from tornado.ioloop import IOLoop
from tornado.httpserver import HTTPServer from tornado.httpserver import HTTPServer
from tornado.web import RequestHandler, Application, authenticated, HTTPError
from tornado.httputil import url_concat from tornado.httputil import url_concat
from tornado.ioloop import IOLoop
from tornado.web import Application
from tornado.web import authenticated
from tornado.web import HTTPError
from tornado.web import RequestHandler
from .. import orm from .. import orm
from ..services.auth import _ExpiringDict, HubAuth, HubAuthenticated from ..services.auth import _ExpiringDict
from ..services.auth import HubAuth
from ..services.auth import HubAuthenticated
from ..utils import url_path_join from ..utils import url_path_join
from .mocking import public_url, public_host from .mocking import public_host
from .mocking import public_url
from .test_api import add_user from .test_api import add_user
from .utils import async_requests, AsyncSession from .utils import async_requests
from .utils import AsyncSession
# mock for sending monotonic counter way into the future # mock for sending monotonic counter way into the future
monotonic_future = mock.patch('time.monotonic', lambda: sys.maxsize) monotonic_future = mock.patch('time.monotonic', lambda: sys.maxsize)
def test_expiring_dict(): def test_expiring_dict():
cache = _ExpiringDict(max_age=30) cache = _ExpiringDict(max_age=30)
cache['key'] = 'cached value' cache['key'] = 'cached value'
@@ -69,9 +76,7 @@ def test_expiring_dict():
def test_hub_auth(): def test_hub_auth():
auth = HubAuth(cookie_name='foo') auth = HubAuth(cookie_name='foo')
mock_model = { mock_model = {'name': 'onyxia'}
'name': 'onyxia'
}
url = url_path_join(auth.api_url, "authorizations/cookie/foo/bar") url = url_path_join(auth.api_url, "authorizations/cookie/foo/bar")
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.get(url, text=json.dumps(mock_model)) m.get(url, text=json.dumps(mock_model))
@@ -87,9 +92,7 @@ def test_hub_auth():
assert user_model is None assert user_model is None
# invalidate cache with timer # invalidate cache with timer
mock_model = { mock_model = {'name': 'willow'}
'name': 'willow'
}
with monotonic_future, requests_mock.Mocker() as m: with monotonic_future, requests_mock.Mocker() as m:
m.get(url, text=json.dumps(mock_model)) m.get(url, text=json.dumps(mock_model))
user_model = auth.user_for_cookie('bar') user_model = auth.user_for_cookie('bar')
@@ -110,16 +113,14 @@ def test_hub_auth():
def test_hub_authenticated(request): def test_hub_authenticated(request):
auth = HubAuth(cookie_name='jubal') auth = HubAuth(cookie_name='jubal')
mock_model = { mock_model = {'name': 'jubalearly', 'groups': ['lions']}
'name': 'jubalearly',
'groups': ['lions'],
}
cookie_url = url_path_join(auth.api_url, "authorizations/cookie", auth.cookie_name) cookie_url = url_path_join(auth.api_url, "authorizations/cookie", auth.cookie_name)
good_url = url_path_join(cookie_url, "early") good_url = url_path_join(cookie_url, "early")
bad_url = url_path_join(cookie_url, "late") bad_url = url_path_join(cookie_url, "late")
class TestHandler(HubAuthenticated, RequestHandler): class TestHandler(HubAuthenticated, RequestHandler):
hub_auth = auth hub_auth = auth
@authenticated @authenticated
def get(self): def get(self):
self.finish(self.get_current_user()) self.finish(self.get_current_user())
@@ -127,11 +128,10 @@ def test_hub_authenticated(request):
# start hub-authenticated service in a thread: # start hub-authenticated service in a thread:
port = 50505 port = 50505
q = Queue() q = Queue()
def run(): def run():
asyncio.set_event_loop(asyncio.new_event_loop()) asyncio.set_event_loop(asyncio.new_event_loop())
app = Application([ app = Application([('/*', TestHandler)], login_url=auth.login_url)
('/*', TestHandler),
], login_url=auth.login_url)
http_server = HTTPServer(app) http_server = HTTPServer(app)
http_server.listen(port) http_server.listen(port)
@@ -146,6 +146,7 @@ def test_hub_authenticated(request):
loop.add_callback(loop.stop) loop.add_callback(loop.stop)
t.join(timeout=30) t.join(timeout=30)
assert not t.is_alive() assert not t.is_alive()
request.addfinalizer(finish_thread) request.addfinalizer(finish_thread)
# wait for thread to start # wait for thread to start
@@ -153,16 +154,15 @@ def test_hub_authenticated(request):
with requests_mock.Mocker(real_http=True) as m: with requests_mock.Mocker(real_http=True) as m:
# no cookie # no cookie
r = requests.get('http://127.0.0.1:%i' % port, r = requests.get('http://127.0.0.1:%i' % port, allow_redirects=False)
allow_redirects=False,
)
r.raise_for_status() r.raise_for_status()
assert r.status_code == 302 assert r.status_code == 302
assert auth.login_url in r.headers['Location'] assert auth.login_url in r.headers['Location']
# wrong cookie # wrong cookie
m.get(bad_url, status_code=404) m.get(bad_url, status_code=404)
r = requests.get('http://127.0.0.1:%i' % port, r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'late'}, cookies={'jubal': 'late'},
allow_redirects=False, allow_redirects=False,
) )
@@ -176,7 +176,8 @@ def test_hub_authenticated(request):
# upstream 403 # upstream 403
m.get(bad_url, status_code=403) m.get(bad_url, status_code=403)
r = requests.get('http://127.0.0.1:%i' % port, r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'late'}, cookies={'jubal': 'late'},
allow_redirects=False, allow_redirects=False,
) )
@@ -185,7 +186,8 @@ def test_hub_authenticated(request):
m.get(good_url, text=json.dumps(mock_model)) m.get(good_url, text=json.dumps(mock_model))
# no whitelist # no whitelist
r = requests.get('http://127.0.0.1:%i' % port, r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'early'}, cookies={'jubal': 'early'},
allow_redirects=False, allow_redirects=False,
) )
@@ -194,7 +196,8 @@ def test_hub_authenticated(request):
# pass whitelist # pass whitelist
TestHandler.hub_users = {'jubalearly'} TestHandler.hub_users = {'jubalearly'}
r = requests.get('http://127.0.0.1:%i' % port, r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'early'}, cookies={'jubal': 'early'},
allow_redirects=False, allow_redirects=False,
) )
@@ -203,7 +206,8 @@ def test_hub_authenticated(request):
# no pass whitelist # no pass whitelist
TestHandler.hub_users = {'kaylee'} TestHandler.hub_users = {'kaylee'}
r = requests.get('http://127.0.0.1:%i' % port, r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'early'}, cookies={'jubal': 'early'},
allow_redirects=False, allow_redirects=False,
) )
@@ -211,7 +215,8 @@ def test_hub_authenticated(request):
# pass group whitelist # pass group whitelist
TestHandler.hub_groups = {'lions'} TestHandler.hub_groups = {'lions'}
r = requests.get('http://127.0.0.1:%i' % port, r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'early'}, cookies={'jubal': 'early'},
allow_redirects=False, allow_redirects=False,
) )
@@ -220,7 +225,8 @@ def test_hub_authenticated(request):
# no pass group whitelist # no pass group whitelist
TestHandler.hub_groups = {'tigers'} TestHandler.hub_groups = {'tigers'}
r = requests.get('http://127.0.0.1:%i' % port, r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'early'}, cookies={'jubal': 'early'},
allow_redirects=False, allow_redirects=False,
) )
@@ -230,15 +236,14 @@ def test_hub_authenticated(request):
async def test_hubauth_cookie(app, mockservice_url): async def test_hubauth_cookie(app, mockservice_url):
"""Test HubAuthenticated service with user cookies""" """Test HubAuthenticated service with user cookies"""
cookies = await app.login_user('badger') cookies = await app.login_user('badger')
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/', cookies=cookies) r = await async_requests.get(
public_url(app, mockservice_url) + '/whoami/', cookies=cookies
)
r.raise_for_status() r.raise_for_status()
print(r.text) print(r.text)
reply = r.json() reply = r.json()
sub_reply = {key: reply.get(key, 'missing') for key in ['name', 'admin']} sub_reply = {key: reply.get(key, 'missing') for key in ['name', 'admin']}
assert sub_reply == { assert sub_reply == {'name': 'badger', 'admin': False}
'name': 'badger',
'admin': False,
}
async def test_hubauth_token(app, mockservice_url): async def test_hubauth_token(app, mockservice_url):
@@ -248,28 +253,25 @@ async def test_hubauth_token(app, mockservice_url):
app.db.commit() app.db.commit()
# token in Authorization header # token in Authorization header
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/', r = await async_requests.get(
headers={ public_url(app, mockservice_url) + '/whoami/',
'Authorization': 'token %s' % token, headers={'Authorization': 'token %s' % token},
}) )
reply = r.json() reply = r.json()
sub_reply = {key: reply.get(key, 'missing') for key in ['name', 'admin']} sub_reply = {key: reply.get(key, 'missing') for key in ['name', 'admin']}
assert sub_reply == { assert sub_reply == {'name': 'river', 'admin': False}
'name': 'river',
'admin': False,
}
# token in ?token parameter # token in ?token parameter
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/?token=%s' % token) r = await async_requests.get(
public_url(app, mockservice_url) + '/whoami/?token=%s' % token
)
r.raise_for_status() r.raise_for_status()
reply = r.json() reply = r.json()
sub_reply = {key: reply.get(key, 'missing') for key in ['name', 'admin']} sub_reply = {key: reply.get(key, 'missing') for key in ['name', 'admin']}
assert sub_reply == { assert sub_reply == {'name': 'river', 'admin': False}
'name': 'river',
'admin': False,
}
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/?token=no-such-token', r = await async_requests.get(
public_url(app, mockservice_url) + '/whoami/?token=no-such-token',
allow_redirects=False, allow_redirects=False,
) )
assert r.status_code == 302 assert r.status_code == 302
@@ -288,30 +290,25 @@ async def test_hubauth_service_token(app, mockservice_url):
await app.init_api_tokens() await app.init_api_tokens()
# token in Authorization header # token in Authorization header
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/', r = await async_requests.get(
headers={ public_url(app, mockservice_url) + '/whoami/',
'Authorization': 'token %s' % token, headers={'Authorization': 'token %s' % token},
}) )
r.raise_for_status() r.raise_for_status()
reply = r.json() reply = r.json()
assert reply == { assert reply == {'kind': 'service', 'name': name, 'admin': False}
'kind': 'service',
'name': name,
'admin': False,
}
assert not r.cookies assert not r.cookies
# token in ?token parameter # token in ?token parameter
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/?token=%s' % token) r = await async_requests.get(
public_url(app, mockservice_url) + '/whoami/?token=%s' % token
)
r.raise_for_status() r.raise_for_status()
reply = r.json() reply = r.json()
assert reply == { assert reply == {'kind': 'service', 'name': name, 'admin': False}
'kind': 'service',
'name': name,
'admin': False,
}
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/?token=no-such-token', r = await async_requests.get(
public_url(app, mockservice_url) + '/whoami/?token=no-such-token',
allow_redirects=False, allow_redirects=False,
) )
assert r.status_code == 302 assert r.status_code == 302
@@ -351,10 +348,7 @@ async def test_oauth_service(app, mockservice_url):
assert r.status_code == 200 assert r.status_code == 200
reply = r.json() reply = r.json()
sub_reply = {key: reply.get(key, 'missing') for key in ('kind', 'name')} sub_reply = {key: reply.get(key, 'missing') for key in ('kind', 'name')}
assert sub_reply == { assert sub_reply == {'name': 'link', 'kind': 'user'}
'name': 'link',
'kind': 'user',
}
# token-authenticated request to HubOAuth # token-authenticated request to HubOAuth
token = app.users[name].new_api_token() token = app.users[name].new_api_token()
@@ -367,11 +361,7 @@ async def test_oauth_service(app, mockservice_url):
# verify that ?token= requests set a cookie # verify that ?token= requests set a cookie
assert len(r.cookies) != 0 assert len(r.cookies) != 0
# ensure cookie works in future requests # ensure cookie works in future requests
r = await async_requests.get( r = await async_requests.get(url, cookies=r.cookies, allow_redirects=False)
url,
cookies=r.cookies,
allow_redirects=False,
)
r.raise_for_status() r.raise_for_status()
assert r.url == url assert r.url == url
reply = r.json() reply = r.json()
@@ -408,9 +398,7 @@ async def test_oauth_cookie_collision(app, mockservice_url):
# finish oauth 2 # finish oauth 2
# submit the oauth form to complete authorization # submit the oauth form to complete authorization
r = await s.post( r = await s.post(
oauth_2.url, oauth_2.url, data={'scopes': ['identify']}, headers={'Referer': oauth_2.url}
data={'scopes': ['identify']},
headers={'Referer': oauth_2.url},
) )
r.raise_for_status() r.raise_for_status()
assert r.url == url assert r.url == url
@@ -422,9 +410,7 @@ async def test_oauth_cookie_collision(app, mockservice_url):
# finish oauth 1 # finish oauth 1
r = await s.post( r = await s.post(
oauth_1.url, oauth_1.url, data={'scopes': ['identify']}, headers={'Referer': oauth_1.url}
data={'scopes': ['identify']},
headers={'Referer': oauth_1.url},
) )
r.raise_for_status() r.raise_for_status()
assert r.url == url assert r.url == url
@@ -455,11 +441,13 @@ async def test_oauth_logout(app, mockservice_url):
s = AsyncSession() s = AsyncSession()
name = 'propha' name = 'propha'
app_user = add_user(app.db, app=app, name=name) app_user = add_user(app.db, app=app, name=name)
def auth_tokens(): def auth_tokens():
"""Return list of OAuth access tokens for the user""" """Return list of OAuth access tokens for the user"""
return list( return list(
app.db.query(orm.OAuthAccessToken).filter( app.db.query(orm.OAuthAccessToken).filter(
orm.OAuthAccessToken.user_id == app_user.id) orm.OAuthAccessToken.user_id == app_user.id
)
) )
# ensure we start empty # ensure we start empty
@@ -480,14 +468,8 @@ async def test_oauth_logout(app, mockservice_url):
r.raise_for_status() r.raise_for_status()
assert r.status_code == 200 assert r.status_code == 200
reply = r.json() reply = r.json()
sub_reply = { sub_reply = {key: reply.get(key, 'missing') for key in ('kind', 'name')}
key: reply.get(key, 'missing') assert sub_reply == {'name': name, 'kind': 'user'}
for key in ('kind', 'name')
}
assert sub_reply == {
'name': name,
'kind': 'user',
}
# save cookies to verify cache # save cookies to verify cache
saved_cookies = copy.deepcopy(s.cookies) saved_cookies = copy.deepcopy(s.cookies)
session_id = s.cookies['jupyterhub-session-id'] session_id = s.cookies['jupyterhub-session-id']
@@ -522,11 +504,5 @@ async def test_oauth_logout(app, mockservice_url):
r.raise_for_status() r.raise_for_status()
assert r.status_code == 200 assert r.status_code == 200
reply = r.json() reply = r.json()
sub_reply = { sub_reply = {key: reply.get(key, 'missing') for key in ('kind', 'name')}
key: reply.get(key, 'missing') assert sub_reply == {'name': name, 'kind': 'user'}
for key in ('kind', 'name')
}
assert sub_reply == {
'name': name,
'kind': 'user',
}

View File

@@ -1,16 +1,16 @@
"""Tests for jupyterhub.singleuser""" """Tests for jupyterhub.singleuser"""
from subprocess import check_output
import sys import sys
from subprocess import check_output
from urllib.parse import urlparse from urllib.parse import urlparse
import pytest import pytest
import jupyterhub import jupyterhub
from .mocking import StubSingleUserSpawner, public_url
from ..utils import url_path_join from ..utils import url_path_join
from .mocking import public_url
from .utils import async_requests, AsyncSession from .mocking import StubSingleUserSpawner
from .utils import async_requests
from .utils import AsyncSession
async def test_singleuser_auth(app): async def test_singleuser_auth(app):
@@ -47,11 +47,7 @@ async def test_singleuser_auth(app):
r = await s.get(url) r = await s.get(url)
assert urlparse(r.url).path.endswith('/oauth2/authorize') assert urlparse(r.url).path.endswith('/oauth2/authorize')
# submit the oauth form to complete authorization # submit the oauth form to complete authorization
r = await s.post( r = await s.post(r.url, data={'scopes': ['identify']}, headers={'Referer': r.url})
r.url,
data={'scopes': ['identify']},
headers={'Referer': r.url},
)
assert urlparse(r.url).path.rstrip('/').endswith('/user/nandy/tree') assert urlparse(r.url).path.rstrip('/').endswith('/user/nandy/tree')
# user isn't authorized, should raise 403 # user isn't authorized, should raise 403
assert r.status_code == 403 assert r.status_code == 403
@@ -85,11 +81,14 @@ async def test_disable_user_config(app):
def test_help_output(): def test_help_output():
out = check_output([sys.executable, '-m', 'jupyterhub.singleuser', '--help-all']).decode('utf8', 'replace') out = check_output(
[sys.executable, '-m', 'jupyterhub.singleuser', '--help-all']
).decode('utf8', 'replace')
assert 'JupyterHub' in out assert 'JupyterHub' in out
def test_version(): def test_version():
out = check_output([sys.executable, '-m', 'jupyterhub.singleuser', '--version']).decode('utf8', 'replace') out = check_output(
[sys.executable, '-m', 'jupyterhub.singleuser', '--version']
).decode('utf8', 'replace')
assert jupyterhub.__version__ in out assert jupyterhub.__version__ in out

View File

@@ -1,27 +1,28 @@
"""Tests for process spawning""" """Tests for process spawning"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import logging import logging
import os import os
import signal import signal
from subprocess import Popen
import sys import sys
import tempfile import tempfile
import time import time
from subprocess import Popen
from unittest import mock from unittest import mock
from urllib.parse import urlparse from urllib.parse import urlparse
import pytest import pytest
from tornado import gen from tornado import gen
from ..objects import Hub, Server
from .. import orm from .. import orm
from .. import spawner as spawnermod from .. import spawner as spawnermod
from ..spawner import LocalProcessSpawner, Spawner from ..objects import Hub
from ..objects import Server
from ..spawner import LocalProcessSpawner
from ..spawner import Spawner
from ..user import User from ..user import User
from ..utils import new_token, url_path_join from ..utils import new_token
from ..utils import url_path_join
from .mocking import public_url from .mocking import public_url
from .test_api import add_user from .test_api import add_user
from .utils import async_requests from .utils import async_requests
@@ -84,8 +85,10 @@ async def wait_for_spawner(spawner, timeout=10):
polling at shorter intervals for early termination polling at shorter intervals for early termination
""" """
deadline = time.monotonic() + timeout deadline = time.monotonic() + timeout
def wait(): def wait():
return spawner.server.wait_up(timeout=1, http=True) return spawner.server.wait_up(timeout=1, http=True)
while time.monotonic() < deadline: while time.monotonic() < deadline:
status = await spawner.poll() status = await spawner.poll()
assert status is None assert status is None
@@ -187,11 +190,13 @@ def test_setcwd():
os.chdir(cwd) os.chdir(cwd)
chdir = os.chdir chdir = os.chdir
temp_root = os.path.realpath(os.path.abspath(tempfile.gettempdir())) temp_root = os.path.realpath(os.path.abspath(tempfile.gettempdir()))
def raiser(path): def raiser(path):
path = os.path.realpath(os.path.abspath(path)) path = os.path.realpath(os.path.abspath(path))
if not path.startswith(temp_root): if not path.startswith(temp_root):
raise OSError(path) raise OSError(path)
chdir(path) chdir(path)
with mock.patch('os.chdir', raiser): with mock.patch('os.chdir', raiser):
spawnermod._try_setcwd(cwd) spawnermod._try_setcwd(cwd)
assert os.getcwd().startswith(temp_root) assert os.getcwd().startswith(temp_root)
@@ -209,6 +214,7 @@ def test_string_formatting(db):
async def test_popen_kwargs(db): async def test_popen_kwargs(db):
mock_proc = mock.Mock(spec=Popen) mock_proc = mock.Mock(spec=Popen)
def mock_popen(*args, **kwargs): def mock_popen(*args, **kwargs):
mock_proc.args = args mock_proc.args = args
mock_proc.kwargs = kwargs mock_proc.kwargs = kwargs
@@ -226,7 +232,8 @@ async def test_popen_kwargs(db):
async def test_shell_cmd(db, tmpdir, request): async def test_shell_cmd(db, tmpdir, request):
f = tmpdir.join('bashrc') f = tmpdir.join('bashrc')
f.write('export TESTVAR=foo\n') f.write('export TESTVAR=foo\n')
s = new_spawner(db, s = new_spawner(
db,
cmd=[sys.executable, '-m', 'jupyterhub.tests.mocksu'], cmd=[sys.executable, '-m', 'jupyterhub.tests.mocksu'],
shell_cmd=['bash', '--rcfile', str(f), '-i', '-c'], shell_cmd=['bash', '--rcfile', str(f), '-i', '-c'],
) )
@@ -253,6 +260,7 @@ def test_inherit_overwrite():
""" """
if sys.version_info >= (3, 6): if sys.version_info >= (3, 6):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
class S(Spawner): class S(Spawner):
pass pass
@@ -372,19 +380,14 @@ async def test_spawner_delete_server(app):
assert spawner.server is None assert spawner.server is None
@pytest.mark.parametrize( @pytest.mark.parametrize("name", ["has@x", "has~x", "has%x", "has%40x"])
"name",
[
"has@x",
"has~x",
"has%x",
"has%40x",
]
)
async def test_spawner_routing(app, name): async def test_spawner_routing(app, name):
"""Test routing of names with special characters""" """Test routing of names with special characters"""
db = app.db db = app.db
with mock.patch.dict(app.config.LocalProcessSpawner, {'cmd': [sys.executable, '-m', 'jupyterhub.tests.mocksu']}): with mock.patch.dict(
app.config.LocalProcessSpawner,
{'cmd': [sys.executable, '-m', 'jupyterhub.tests.mocksu']},
):
user = add_user(app.db, app, name=name) user = add_user(app.db, app, name=name)
await user.spawn() await user.spawn()
await wait_for_spawner(user.spawner) await wait_for_spawner(user.spawner)

View File

@@ -1,12 +1,16 @@
import pytest import pytest
from traitlets import HasTraits, TraitError from traitlets import HasTraits
from traitlets import TraitError
from jupyterhub.traitlets import URLPrefix, Command, ByteSpecification from jupyterhub.traitlets import ByteSpecification
from jupyterhub.traitlets import Command
from jupyterhub.traitlets import URLPrefix
def test_url_prefix(): def test_url_prefix():
class C(HasTraits): class C(HasTraits):
url = URLPrefix() url = URLPrefix()
c = C() c = C()
c.url = '/a/b/c/' c.url = '/a/b/c/'
assert c.url == '/a/b/c/' assert c.url == '/a/b/c/'
@@ -20,6 +24,7 @@ def test_command():
class C(HasTraits): class C(HasTraits):
cmd = Command('default command') cmd = Command('default command')
cmd2 = Command(['default_cmd']) cmd2 = Command(['default_cmd'])
c = C() c = C()
assert c.cmd == ['default command'] assert c.cmd == ['default command']
assert c.cmd2 == ['default_cmd'] assert c.cmd2 == ['default_cmd']

View File

@@ -1,9 +1,11 @@
"""Tests for utilities""" """Tests for utilities"""
import asyncio import asyncio
import pytest
from async_generator import aclosing, async_generator, yield_ import pytest
from async_generator import aclosing
from async_generator import async_generator
from async_generator import yield_
from ..utils import iterate_until from ..utils import iterate_until
@@ -26,12 +28,15 @@ def schedule_future(io_loop, *, delay, result=None):
return f return f
@pytest.mark.parametrize("deadline, n, delay, expected", [ @pytest.mark.parametrize(
"deadline, n, delay, expected",
[
(0, 3, 1, []), (0, 3, 1, []),
(0, 3, 0, [0, 1, 2]), (0, 3, 0, [0, 1, 2]),
(5, 3, 0.01, [0, 1, 2]), (5, 3, 0.01, [0, 1, 2]),
(0.5, 10, 0.2, [0, 1]), (0.5, 10, 0.2, [0, 1]),
]) ],
)
async def test_iterate_until(io_loop, deadline, n, delay, expected): async def test_iterate_until(io_loop, deadline, n, delay, expected):
f = schedule_future(io_loop, delay=deadline) f = schedule_future(io_loop, delay=deadline)

Some files were not shown because too many files have changed in this diff Show More