run autoformat

apologies to anyone finding this commit via git blame or log

run the autoformatting by

    pre-commit run --all-files
This commit is contained in:
Min RK
2019-02-14 15:08:59 +01:00
parent ca198e0363
commit 5e60582ef3
118 changed files with 3583 additions and 2934 deletions

View File

@@ -1 +0,0 @@

View File

@@ -1,19 +1,16 @@
#!/usr/bin/env python
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
"""
bower-lite
Since Bower's on its way out,
stage frontend dependencies from node_modules into components
"""
import json
import os
from os.path import join
import shutil
from os.path import join
HERE = os.path.abspath(os.path.dirname(__file__))

View File

@@ -1,16 +1,16 @@
-r requirements.txt
mock
# temporary pin of attrs for jsonschema 0.3.0a1
# seems to be a pip bug
attrs>=17.4.0
beautifulsoup4
codecov
coverage
cryptography
html5lib # needed for beautifulsoup
pytest-cov
pytest-asyncio
pytest>=3.3
mock
notebook
pytest-asyncio
pytest-cov
pytest>=3.3
requests-mock
virtualenv
# temporary pin of attrs for jsonschema 0.3.0a1
# seems to be a pip bug
attrs>=17.4.0

View File

@@ -7,5 +7,3 @@ ENV LANG=en_US.UTF-8
USER nobody
CMD ["jupyterhub"]

View File

@@ -18,4 +18,3 @@ Dockerfile.alpine contains base image for jupyterhub. It does not work independ
* Use dummy authenticator for ease of testing. Update following in jupyterhub_config file
- c.JupyterHub.authenticator_class = 'dummyauthenticator.DummyAuthenticator'
- c.DummyAuthenticator.password = "your strong password"

View File

@@ -1,7 +1,7 @@
# ReadTheDocs uses the `environment.yaml` so make sure to update that as well
# if you change this file
-r ../requirements.txt
sphinx>=1.7
alabaster_jupyterhub
recommonmark==0.4.0
sphinx-copybutton
alabaster_jupyterhub
sphinx>=1.7

View File

@@ -13,4 +13,3 @@ Module: :mod:`jupyterhub.app`
-------------------
.. autoconfigurable:: JupyterHub

View File

@@ -30,4 +30,3 @@ Module: :mod:`jupyterhub.auth`
---------------------------
.. autoconfigurable:: DummyAuthenticator

View File

@@ -20,4 +20,3 @@ Module: :mod:`jupyterhub.proxy`
.. autoconfigurable:: ConfigurableHTTPProxy
:members: debug, auth_token, check_running_interval, api_url, command

View File

@@ -14,4 +14,3 @@ Module: :mod:`jupyterhub.services.service`
.. autoconfigurable:: Service
:members: name, admin, url, api_token, managed, kind, command, cwd, environment, user, oauth_client_id, server, prefix, proxy_spec

View File

@@ -38,4 +38,3 @@ Module: :mod:`jupyterhub.services.auth`
--------------------------------
.. autoclass:: HubOAuthCallbackHandler

View File

@@ -19,4 +19,3 @@ Module: :mod:`jupyterhub.spawner`
----------------------------
.. autoconfigurable:: LocalProcessSpawner

View File

@@ -34,4 +34,3 @@ Module: :mod:`jupyterhub.user`
.. attribute:: spawner
The user's :class:`~.Spawner` instance.

View File

@@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
#
import sys
import os
import shlex
import sys
import recommonmark.parser
# For conversion from markdown to html
import recommonmark.parser
# Set paths
sys.path.insert(0, os.path.abspath('.'))
@@ -21,7 +22,7 @@ extensions = [
'sphinx.ext.intersphinx',
'sphinx.ext.napoleon',
'autodoc_traits',
'sphinx_copybutton'
'sphinx_copybutton',
]
templates_path = ['_templates']
@@ -68,6 +69,7 @@ source_suffix = ['.rst', '.md']
# The theme to use for HTML and HTML Help pages.
import alabaster_jupyterhub
html_theme = 'alabaster_jupyterhub'
html_theme_path = [alabaster_jupyterhub.get_html_theme_path()]

View File

@@ -210,4 +210,3 @@ jupyterhub_config.py amendments:
--This is the address on which the proxy will bind. Sets protocol, ip, base_url
c.JupyterHub.bind_url = 'http://127.0.0.1:8000/jhub/'
```

View File

@@ -1,8 +1,9 @@
"""autodoc extension for configurable traits"""
from traitlets import TraitType, Undefined
from sphinx.domains.python import PyClassmember
from sphinx.ext.autodoc import ClassDocumenter, AttributeDocumenter
from sphinx.ext.autodoc import AttributeDocumenter
from sphinx.ext.autodoc import ClassDocumenter
from traitlets import TraitType
from traitlets import Undefined
class ConfigurableDocumenter(ClassDocumenter):

View File

@@ -5,8 +5,10 @@ create a directory for the user before the spawner starts
# pylint: disable=import-error
import os
import shutil
from jupyter_client.localinterfaces import public_ips
def create_dir_hook(spawner):
""" Create directory """
username = spawner.user.name # get the username
@@ -16,6 +18,7 @@ def create_dir_hook(spawner):
# now do whatever you think your user needs
# ...
def clean_dir_hook(spawner):
""" Delete directory """
username = spawner.user.name # get the username
@@ -23,6 +26,7 @@ def clean_dir_hook(spawner):
if os.path.exists(temp_path) and os.path.isdir(temp_path):
shutil.rmtree(temp_path)
# attach the hook functions to the spawner
# pylint: disable=undefined-variable
c.Spawner.pre_spawn_hook = create_dir_hook

View File

@@ -31,11 +31,11 @@ users and servers, you should add this script to the services list
twice, just with different ``name``s, different values, and one with
the ``--cull-users`` option.
"""
from datetime import datetime, timezone
from functools import partial
import json
import os
from datetime import datetime
from datetime import timezone
from functools import partial
try:
from urllib.parse import quote
@@ -85,23 +85,21 @@ def format_td(td):
@coroutine
def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concurrency=10):
def cull_idle(
url, api_token, inactive_limit, cull_users=False, max_age=0, concurrency=10
):
"""Shutdown idle single-user servers
If cull_users, inactive *users* will be deleted as well.
"""
auth_header = {
'Authorization': 'token %s' % api_token,
}
req = HTTPRequest(
url=url + '/users',
headers=auth_header,
)
auth_header = {'Authorization': 'token %s' % api_token}
req = HTTPRequest(url=url + '/users', headers=auth_header)
now = datetime.now(timezone.utc)
client = AsyncHTTPClient()
if concurrency:
semaphore = Semaphore(concurrency)
@coroutine
def fetch(req):
"""client.fetch wrapped in a semaphore to limit concurrency"""
@@ -110,6 +108,7 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
return (yield client.fetch(req))
finally:
yield semaphore.release()
else:
fetch = client.fetch
@@ -129,8 +128,8 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
log_name = '%s/%s' % (user['name'], server_name)
if server.get('pending'):
app_log.warning(
"Not culling server %s with pending %s",
log_name, server['pending'])
"Not culling server %s with pending %s", log_name, server['pending']
)
return False
# jupyterhub < 0.9 defined 'server.url' once the server was ready
@@ -142,8 +141,8 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
if not server.get('ready', bool(server['url'])):
app_log.warning(
"Not culling not-ready not-pending server %s: %s",
log_name, server)
"Not culling not-ready not-pending server %s: %s", log_name, server
)
return False
if server.get('started'):
@@ -163,12 +162,13 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
# for running servers
inactive = age
should_cull = (inactive is not None and
inactive.total_seconds() >= inactive_limit)
should_cull = (
inactive is not None and inactive.total_seconds() >= inactive_limit
)
if should_cull:
app_log.info(
"Culling server %s (inactive for %s)",
log_name, format_td(inactive))
"Culling server %s (inactive for %s)", log_name, format_td(inactive)
)
if max_age and not should_cull:
# only check started if max_age is specified
@@ -177,32 +177,34 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
if age is not None and age.total_seconds() >= max_age:
app_log.info(
"Culling server %s (age: %s, inactive for %s)",
log_name, format_td(age), format_td(inactive))
log_name,
format_td(age),
format_td(inactive),
)
should_cull = True
if not should_cull:
app_log.debug(
"Not culling server %s (age: %s, inactive for %s)",
log_name, format_td(age), format_td(inactive))
log_name,
format_td(age),
format_td(inactive),
)
return False
if server_name:
# culling a named server
delete_url = url + "/users/%s/servers/%s" % (
quote(user['name']), quote(server['name'])
quote(user['name']),
quote(server['name']),
)
else:
delete_url = url + '/users/%s/server' % quote(user['name'])
req = HTTPRequest(
url=delete_url, method='DELETE', headers=auth_header,
)
req = HTTPRequest(url=delete_url, method='DELETE', headers=auth_header)
resp = yield fetch(req)
if resp.code == 202:
app_log.warning(
"Server %s is slow to stop",
log_name,
)
app_log.warning("Server %s is slow to stop", log_name)
# return False to prevent culling user with pending shutdowns
return False
return True
@@ -245,7 +247,9 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
if still_alive:
app_log.debug(
"Not culling user %s with %i servers still alive",
user['name'], still_alive)
user['name'],
still_alive,
)
return False
should_cull = False
@@ -265,12 +269,11 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
# which introduces the 'created' field which is never None
inactive = age
should_cull = (inactive is not None and
inactive.total_seconds() >= inactive_limit)
should_cull = (
inactive is not None and inactive.total_seconds() >= inactive_limit
)
if should_cull:
app_log.info(
"Culling user %s (inactive for %s)",
user['name'], inactive)
app_log.info("Culling user %s (inactive for %s)", user['name'], inactive)
if max_age and not should_cull:
# only check created if max_age is specified
@@ -279,19 +282,23 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
if age is not None and age.total_seconds() >= max_age:
app_log.info(
"Culling user %s (age: %s, inactive for %s)",
user['name'], format_td(age), format_td(inactive))
user['name'],
format_td(age),
format_td(inactive),
)
should_cull = True
if not should_cull:
app_log.debug(
"Not culling user %s (created: %s, last active: %s)",
user['name'], format_td(age), format_td(inactive))
user['name'],
format_td(age),
format_td(inactive),
)
return False
req = HTTPRequest(
url=url + '/users/%s' % user['name'],
method='DELETE',
headers=auth_header,
url=url + '/users/%s' % user['name'], method='DELETE', headers=auth_header
)
yield fetch(req)
return True
@@ -316,20 +323,30 @@ if __name__ == '__main__':
help="The JupyterHub API URL",
)
define('timeout', default=600, help="The idle timeout (in seconds)")
define('cull_every', default=0,
help="The interval (in seconds) for checking for idle servers to cull")
define('max_age', default=0,
help="The maximum age (in seconds) of servers that should be culled even if they are active")
define('cull_users', default=False,
define(
'cull_every',
default=0,
help="The interval (in seconds) for checking for idle servers to cull",
)
define(
'max_age',
default=0,
help="The maximum age (in seconds) of servers that should be culled even if they are active",
)
define(
'cull_users',
default=False,
help="""Cull users in addition to servers.
This is for use in temporary-user cases such as tmpnb.""",
)
define('concurrency', default=10,
define(
'concurrency',
default=10,
help="""Limit the number of concurrent requests made to the Hub.
Deleting a lot of users at the same time can slow down the Hub,
so limit the number of API requests we have outstanding at any given time.
"""
""",
)
parse_command_line()
@@ -343,7 +360,8 @@ if __name__ == '__main__':
app_log.warning(
"Could not load pycurl: %s\n"
"pycurl is recommended if you have a large number of users.",
e)
e,
)
loop = IOLoop.current()
cull = partial(

View File

@@ -4,7 +4,9 @@ import os
# this could come from anywhere
api_token = os.getenv("JUPYTERHUB_API_TOKEN")
if not api_token:
raise ValueError("Make sure to `export JUPYTERHUB_API_TOKEN=$(openssl rand -hex 32)`")
raise ValueError(
"Make sure to `export JUPYTERHUB_API_TOKEN=$(openssl rand -hex 32)`"
)
# tell JupyterHub to register the service as an external oauth client
@@ -14,5 +16,5 @@ c.JupyterHub.services = [
'oauth_client_id': "whoami-oauth-client-test",
'api_token': api_token,
'oauth_redirect_uri': 'http://127.0.0.1:5555/oauth_callback',
},
}
]

View File

@@ -3,18 +3,19 @@
Implements OAuth handshake manually
so all URLs and requests necessary for OAuth with JupyterHub should be in one place
"""
import json
import os
import sys
from urllib.parse import urlencode, urlparse
from urllib.parse import urlencode
from urllib.parse import urlparse
from tornado.auth import OAuth2Mixin
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
from tornado.httputil import url_concat
from tornado.ioloop import IOLoop
from tornado import log
from tornado import web
from tornado.auth import OAuth2Mixin
from tornado.httpclient import AsyncHTTPClient
from tornado.httpclient import HTTPRequest
from tornado.httputil import url_concat
from tornado.ioloop import IOLoop
class JupyterHubLoginHandler(web.RequestHandler):
@@ -32,11 +33,11 @@ class JupyterHubLoginHandler(web.RequestHandler):
code=code,
redirect_uri=self.settings['redirect_uri'],
)
req = HTTPRequest(self.settings['token_url'], method='POST',
req = HTTPRequest(
self.settings['token_url'],
method='POST',
body=urlencode(params).encode('utf8'),
headers={
'Content-Type': 'application/x-www-form-urlencoded',
},
headers={'Content-Type': 'application/x-www-form-urlencoded'},
)
response = await AsyncHTTPClient().fetch(req)
data = json.loads(response.body.decode('utf8', 'replace'))
@@ -55,14 +56,16 @@ class JupyterHubLoginHandler(web.RequestHandler):
# we are the login handler,
# begin oauth process which will come back later with an
# authorization_code
self.redirect(url_concat(
self.redirect(
url_concat(
self.settings['authorize_url'],
dict(
redirect_uri=self.settings['redirect_uri'],
client_id=self.settings['client_id'],
response_type='code',
),
)
)
))
class WhoAmIHandler(web.RequestHandler):
@@ -85,10 +88,7 @@ class WhoAmIHandler(web.RequestHandler):
"""Retrieve the user for a given token, via /hub/api/user"""
req = HTTPRequest(
self.settings['user_url'],
headers={
'Authorization': f'token {token}'
},
self.settings['user_url'], headers={'Authorization': f'token {token}'}
)
response = await AsyncHTTPClient().fetch(req)
return json.loads(response.body.decode('utf8', 'replace'))
@@ -110,23 +110,23 @@ def main():
token_url = hub_api + '/oauth2/token'
user_url = hub_api + '/user'
app = web.Application([
('/oauth_callback', JupyterHubLoginHandler),
('/', WhoAmIHandler),
],
app = web.Application(
[('/oauth_callback', JupyterHubLoginHandler), ('/', WhoAmIHandler)],
login_url='/oauth_callback',
cookie_secret=os.urandom(32),
api_token=os.environ['JUPYTERHUB_API_TOKEN'],
client_id=os.environ['JUPYTERHUB_CLIENT_ID'],
redirect_uri=os.environ['JUPYTERHUB_SERVICE_URL'].rstrip('/') + '/oauth_callback',
redirect_uri=os.environ['JUPYTERHUB_SERVICE_URL'].rstrip('/')
+ '/oauth_callback',
authorize_url=authorize_url,
token_url=token_url,
user_url=user_url,
)
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
log.app_log.info("Running basic whoami service on %s",
os.environ['JUPYTERHUB_SERVICE_URL'])
log.app_log.info(
"Running basic whoami service on %s", os.environ['JUPYTERHUB_SERVICE_URL']
)
app.listen(url.port, url.hostname)
IOLoop.current().start()

View File

@@ -8,10 +8,10 @@ c.Authenticator.whitelist = {'ganymede', 'io', 'rhea'}
# These environment variables are automatically supplied by the linked postgres
# container.
import os;
import os
pg_pass = os.getenv('POSTGRES_ENV_JPY_PSQL_PASSWORD')
pg_host = os.getenv('POSTGRES_PORT_5432_TCP_ADDR')
c.JupyterHub.db_url = 'postgresql://jupyterhub:{}@{}:5432/jupyterhub'.format(
pg_pass,
pg_host,
pg_pass, pg_host
)

View File

@@ -1,11 +1,14 @@
import argparse
import datetime
import json
import os
from tornado import escape
from tornado import gen
from tornado import ioloop
from tornado import web
from jupyterhub.services.auth import HubAuthenticated
from tornado import escape, gen, ioloop, web
class AnnouncementRequestHandler(HubAuthenticated, web.RequestHandler):
@@ -53,19 +56,19 @@ def main():
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--api-prefix", "-a",
parser.add_argument(
"--api-prefix",
"-a",
default=os.environ.get("JUPYTERHUB_SERVICE_PREFIX", "/"),
help="application API prefix")
parser.add_argument("--port", "-p",
default=8888,
help="port for API to listen on",
type=int)
help="application API prefix",
)
parser.add_argument(
"--port", "-p", default=8888, help="port for API to listen on", type=int
)
return parser.parse_args()
def create_application(api_prefix="/",
handler=AnnouncementRequestHandler,
**kwargs):
def create_application(api_prefix="/", handler=AnnouncementRequestHandler, **kwargs):
storage = dict(announcement="", timestamp="", user="")
return web.Application([(api_prefix, handler, dict(storage=storage))])

View File

@@ -22,4 +22,3 @@ In the external example, some extra steps are required to set up supervisor:
3. install `shared-notebook-service` somewhere on your system, and update `/path/to/shared-notebook-service` to the absolute path of this destination
3. copy `shared-notebook.conf` to `/etc/supervisor/conf.d/`
4. `supervisorctl reload`

View File

@@ -1,18 +1,9 @@
# our user list
c.Authenticator.whitelist = [
'minrk',
'ellisonbg',
'willingc',
]
c.Authenticator.whitelist = ['minrk', 'ellisonbg', 'willingc']
# ellisonbg and willingc have access to a shared server:
c.JupyterHub.load_groups = {
'shared': [
'ellisonbg',
'willingc',
]
}
c.JupyterHub.load_groups = {'shared': ['ellisonbg', 'willingc']}
# start the notebook server as a service
c.JupyterHub.services = [

View File

@@ -1,18 +1,9 @@
# our user list
c.Authenticator.whitelist = [
'minrk',
'ellisonbg',
'willingc',
]
c.Authenticator.whitelist = ['minrk', 'ellisonbg', 'willingc']
# ellisonbg and willingc have access to a shared server:
c.JupyterHub.load_groups = {
'shared': [
'ellisonbg',
'willingc',
]
}
c.JupyterHub.load_groups = {'shared': ['ellisonbg', 'willingc']}
service_name = 'shared-notebook'
service_port = 9999
@@ -23,10 +14,6 @@ c.JupyterHub.services = [
{
'name': service_name,
'url': 'http://127.0.0.1:{}'.format(service_port),
'command': [
'jupyterhub-singleuser',
'--group=shared',
'--debug',
],
'command': ['jupyterhub-singleuser', '--group=shared', '--debug'],
}
]

View File

@@ -6,16 +6,12 @@ c.JupyterHub.services = [
'name': 'whoami',
'url': 'http://127.0.0.1:10101',
'command': ['flask', 'run', '--port=10101'],
'environment': {
'FLASK_APP': 'whoami-flask.py',
}
'environment': {'FLASK_APP': 'whoami-flask.py'},
},
{
'name': 'whoami-oauth',
'url': 'http://127.0.0.1:10201',
'command': ['flask', 'run', '--port=10201'],
'environment': {
'FLASK_APP': 'whoami-oauth.py',
}
'environment': {'FLASK_APP': 'whoami-oauth.py'},
},
]

View File

@@ -2,29 +2,29 @@
"""
whoami service authentication with the Hub
"""
from functools import wraps
import json
import os
from functools import wraps
from urllib.parse import quote
from flask import Flask, redirect, request, Response
from flask import Flask
from flask import redirect
from flask import request
from flask import Response
from jupyterhub.services.auth import HubAuth
prefix = os.environ.get('JUPYTERHUB_SERVICE_PREFIX', '/')
auth = HubAuth(
api_token=os.environ['JUPYTERHUB_API_TOKEN'],
cache_max_age=60,
)
auth = HubAuth(api_token=os.environ['JUPYTERHUB_API_TOKEN'], cache_max_age=60)
app = Flask(__name__)
def authenticated(f):
"""Decorator for authenticating with the Hub"""
@wraps(f)
def decorated(*args, **kwargs):
cookie = request.cookies.get(auth.cookie_name)
@@ -40,6 +40,7 @@ def authenticated(f):
else:
# redirect to login url on failed auth
return redirect(auth.login_url + '?next=%s' % quote(request.path))
return decorated
@@ -47,7 +48,5 @@ def authenticated(f):
@authenticated
def whoami(user):
return Response(
json.dumps(user, indent=1, sort_keys=True),
mimetype='application/json',
json.dumps(user, indent=1, sort_keys=True), mimetype='application/json'
)

View File

@@ -2,28 +2,29 @@
"""
whoami service authentication with the Hub
"""
from functools import wraps
import json
import os
from functools import wraps
from flask import Flask, redirect, request, Response, make_response
from flask import Flask
from flask import make_response
from flask import redirect
from flask import request
from flask import Response
from jupyterhub.services.auth import HubOAuth
prefix = os.environ.get('JUPYTERHUB_SERVICE_PREFIX', '/')
auth = HubOAuth(
api_token=os.environ['JUPYTERHUB_API_TOKEN'],
cache_max_age=60,
)
auth = HubOAuth(api_token=os.environ['JUPYTERHUB_API_TOKEN'], cache_max_age=60)
app = Flask(__name__)
def authenticated(f):
"""Decorator for authenticating with the Hub via OAuth"""
@wraps(f)
def decorated(*args, **kwargs):
token = request.cookies.get(auth.cookie_name)
@@ -39,6 +40,7 @@ def authenticated(f):
response = make_response(redirect(auth.login_url + '&state=%s' % state))
response.set_cookie(auth.state_cookie_name, state)
return response
return decorated
@@ -46,10 +48,10 @@ def authenticated(f):
@authenticated
def whoami(user):
return Response(
json.dumps(user, indent=1, sort_keys=True),
mimetype='application/json',
json.dumps(user, indent=1, sort_keys=True), mimetype='application/json'
)
@app.route(prefix + 'oauth_callback')
def oauth_callback():
code = request.args.get('code', None)

View File

@@ -4,18 +4,22 @@ This example service serves `/services/whoami/`,
authenticated with the Hub,
showing the user their own info.
"""
from getpass import getuser
import json
import os
from getpass import getuser
from urllib.parse import urlparse
from tornado.ioloop import IOLoop
from tornado.httpserver import HTTPServer
from tornado.web import RequestHandler, Application, authenticated
from tornado.ioloop import IOLoop
from tornado.web import Application
from tornado.web import authenticated
from tornado.web import RequestHandler
from jupyterhub.services.auth import HubOAuthenticated, HubOAuthCallbackHandler
from jupyterhub.services.auth import HubOAuthCallbackHandler
from jupyterhub.services.auth import HubOAuthenticated
from jupyterhub.utils import url_path_join
class WhoAmIHandler(HubOAuthenticated, RequestHandler):
# hub_users can be a set of users who are allowed to access the service
# `getuser()` here would mean only the user who started the service
@@ -29,12 +33,21 @@ class WhoAmIHandler(HubOAuthenticated, RequestHandler):
self.set_header('content-type', 'application/json')
self.write(json.dumps(user_model, indent=1, sort_keys=True))
def main():
app = Application([
app = Application(
[
(os.environ['JUPYTERHUB_SERVICE_PREFIX'], WhoAmIHandler),
(url_path_join(os.environ['JUPYTERHUB_SERVICE_PREFIX'], 'oauth_callback'), HubOAuthCallbackHandler),
(
url_path_join(
os.environ['JUPYTERHUB_SERVICE_PREFIX'], 'oauth_callback'
),
HubOAuthCallbackHandler,
),
(r'.*', WhoAmIHandler),
], cookie_secret=os.urandom(32))
],
cookie_secret=os.urandom(32),
)
http_server = HTTPServer(app)
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
@@ -43,5 +56,6 @@ def main():
IOLoop.current().start()
if __name__ == '__main__':
main()

View File

@@ -2,14 +2,16 @@
This serves `/services/whoami/`, authenticated with the Hub, showing the user their own info.
"""
from getpass import getuser
import json
import os
from getpass import getuser
from urllib.parse import urlparse
from tornado.ioloop import IOLoop
from tornado.httpserver import HTTPServer
from tornado.web import RequestHandler, Application, authenticated
from tornado.ioloop import IOLoop
from tornado.web import Application
from tornado.web import authenticated
from tornado.web import RequestHandler
from jupyterhub.services.auth import HubAuthenticated
@@ -27,11 +29,14 @@ class WhoAmIHandler(HubAuthenticated, RequestHandler):
self.set_header('content-type', 'application/json')
self.write(json.dumps(user_model, indent=1, sort_keys=True))
def main():
app = Application([
app = Application(
[
(os.environ['JUPYTERHUB_SERVICE_PREFIX'] + '/?', WhoAmIHandler),
(r'.*', WhoAmIHandler),
])
]
)
http_server = HTTPServer(app)
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
@@ -40,5 +45,6 @@ def main():
IOLoop.current().start()
if __name__ == '__main__':
main()

View File

@@ -5,6 +5,7 @@ import shlex
from jupyterhub.spawner import LocalProcessSpawner
class DemoFormSpawner(LocalProcessSpawner):
def _options_form_default(self):
default_env = "YOURNAME=%s\n" % self.user.name
@@ -13,7 +14,9 @@ class DemoFormSpawner(LocalProcessSpawner):
<input name="args" placeholder="e.g. --debug"></input>
<label for="env">Environment variables (one per line)</label>
<textarea name="env">{env}</textarea>
""".format(env=default_env)
""".format(
env=default_env
)
def options_from_form(self, formdata):
options = {}
@@ -43,4 +46,5 @@ class DemoFormSpawner(LocalProcessSpawner):
env.update(self.user_options['env'])
return env
c.JupyterHub.spawner_class = DemoFormSpawner

View File

@@ -1 +1,2 @@
from ._version import version_info, __version__
from ._version import __version__
from ._version import version_info

View File

@@ -1,2 +1,3 @@
from .app import main
main()

View File

@@ -5,6 +5,7 @@ def get_data_files():
"""Walk up until we find share/jupyterhub"""
import sys
from os.path import join, abspath, dirname, exists, split
path = abspath(dirname(__file__))
starting_points = [path]
if not path.startswith(sys.prefix):

View File

@@ -1,5 +1,4 @@
"""JupyterHub version info"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
@@ -23,16 +22,23 @@ __version__ = ".".join(map(str, version_info[:3])) + ".".join(version_info[3:])
def _check_version(hub_version, singleuser_version, log):
"""Compare Hub and single-user server versions"""
if not hub_version:
log.warning("Hub has no version header, which means it is likely < 0.8. Expected %s", __version__)
log.warning(
"Hub has no version header, which means it is likely < 0.8. Expected %s",
__version__,
)
return
if not singleuser_version:
log.warning("Single-user server has no version header, which means it is likely < 0.8. Expected %s", __version__)
log.warning(
"Single-user server has no version header, which means it is likely < 0.8. Expected %s",
__version__,
)
return
# compare minor X.Y versions
if hub_version != singleuser_version:
from distutils.version import LooseVersion as V
hub_major_minor = V(hub_version).version[:2]
singleuser_major_minor = V(singleuser_version).version[:2]
extra = ""
@@ -50,4 +56,6 @@ def _check_version(hub_version, singleuser_version, log):
singleuser_version,
)
else:
log.debug("jupyterhub and jupyterhub-singleuser both on version %s" % hub_version)
log.debug(
"jupyterhub and jupyterhub-singleuser both on version %s" % hub_version
)

View File

@@ -1,9 +1,10 @@
import logging
import sys
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
import logging
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
@@ -14,6 +15,7 @@ config = context.config
if 'jupyterhub' in sys.modules:
from traitlets.config import MultipleInstanceError
from jupyterhub.app import JupyterHub
app = None
if JupyterHub.initialized():
try:
@@ -32,6 +34,7 @@ else:
# add your model's MetaData object here for 'autogenerate' support
from jupyterhub import orm
target_metadata = orm.Base.metadata
# other values from the config, defined by the needs of env.py,
@@ -53,8 +56,7 @@ def run_migrations_offline():
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url, target_metadata=target_metadata, literal_binds=True)
context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
with context.begin_transaction():
context.run_migrations()
@@ -70,17 +72,16 @@ def run_migrations_online():
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix='sqlalchemy.',
poolclass=pool.NullPool)
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata
)
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:

View File

@@ -5,7 +5,6 @@ Revises:
Create Date: 2016-04-11 16:05:34.873288
"""
# revision identifiers, used by Alembic.
revision = '19c0846f6344'
down_revision = None

View File

@@ -5,7 +5,6 @@ Revises: 3ec6993fe20c
Create Date: 2017-12-07 14:43:51.500740
"""
# revision identifiers, used by Alembic.
revision = '1cebaf56856c'
down_revision = '3ec6993fe20c'
@@ -13,6 +12,7 @@ branch_labels = None
depends_on = None
import logging
logger = logging.getLogger('alembic')
from alembic import op

View File

@@ -12,7 +12,6 @@ Revises: af4cbdb2d13c
Create Date: 2017-07-28 16:44:40.413648
"""
# revision identifiers, used by Alembic.
revision = '3ec6993fe20c'
down_revision = 'af4cbdb2d13c'
@@ -44,7 +43,9 @@ def upgrade():
except sa.exc.OperationalError:
# this won't be a problem moving forward, but downgrade will fail
if op.get_context().dialect.name == 'sqlite':
logger.warning("sqlite cannot drop columns. Leaving unused old columns in place.")
logger.warning(
"sqlite cannot drop columns. Leaving unused old columns in place."
)
else:
raise
@@ -54,15 +55,13 @@ def upgrade():
def downgrade():
# drop all the new tables
engine = op.get_bind().engine
for table in ('oauth_clients',
'oauth_codes',
'oauth_access_tokens',
'spawners'):
for table in ('oauth_clients', 'oauth_codes', 'oauth_access_tokens', 'spawners'):
if engine.has_table(table):
op.drop_table(table)
op.drop_column('users', 'encrypted_auth_state')
op.add_column('users', sa.Column('auth_state', JSONDict))
op.add_column('users', sa.Column('_server_id', sa.Integer, sa.ForeignKey('servers.id')))
op.add_column(
'users', sa.Column('_server_id', sa.Integer, sa.ForeignKey('servers.id'))
)

View File

@@ -5,7 +5,6 @@ Revises: 1cebaf56856c
Create Date: 2017-12-19 15:21:09.300513
"""
# revision identifiers, used by Alembic.
revision = '56cc5a70207e'
down_revision = '1cebaf56856c'
@@ -16,22 +15,48 @@ from alembic import op
import sqlalchemy as sa
import logging
logger = logging.getLogger('alembic')
def upgrade():
tables = op.get_bind().engine.table_names()
op.add_column('api_tokens', sa.Column('created', sa.DateTime(), nullable=True))
op.add_column('api_tokens', sa.Column('last_activity', sa.DateTime(), nullable=True))
op.add_column('api_tokens', sa.Column('note', sa.Unicode(length=1023), nullable=True))
op.add_column(
'api_tokens', sa.Column('last_activity', sa.DateTime(), nullable=True)
)
op.add_column(
'api_tokens', sa.Column('note', sa.Unicode(length=1023), nullable=True)
)
if 'oauth_access_tokens' in tables:
op.add_column('oauth_access_tokens', sa.Column('created', sa.DateTime(), nullable=True))
op.add_column('oauth_access_tokens', sa.Column('last_activity', sa.DateTime(), nullable=True))
op.add_column(
'oauth_access_tokens', sa.Column('created', sa.DateTime(), nullable=True)
)
op.add_column(
'oauth_access_tokens',
sa.Column('last_activity', sa.DateTime(), nullable=True),
)
if op.get_context().dialect.name == 'sqlite':
logger.warning("sqlite cannot use ALTER TABLE to create foreign keys. Upgrade will be incomplete.")
logger.warning(
"sqlite cannot use ALTER TABLE to create foreign keys. Upgrade will be incomplete."
)
else:
op.create_foreign_key(None, 'oauth_access_tokens', 'oauth_clients', ['client_id'], ['identifier'], ondelete='CASCADE')
op.create_foreign_key(None, 'oauth_codes', 'oauth_clients', ['client_id'], ['identifier'], ondelete='CASCADE')
op.create_foreign_key(
None,
'oauth_access_tokens',
'oauth_clients',
['client_id'],
['identifier'],
ondelete='CASCADE',
)
op.create_foreign_key(
None,
'oauth_codes',
'oauth_clients',
['client_id'],
['identifier'],
ondelete='CASCADE',
)
def downgrade():

View File

@@ -5,7 +5,6 @@ Revises: d68c98b66cd4
Create Date: 2018-05-07 11:35:58.050542
"""
# revision identifiers, used by Alembic.
revision = '896818069c98'
down_revision = 'd68c98b66cd4'

View File

@@ -5,7 +5,6 @@ Revises: 56cc5a70207e
Create Date: 2018-03-21 14:27:17.466841
"""
# revision identifiers, used by Alembic.
revision = '99a28a4418e1'
down_revision = '56cc5a70207e'
@@ -18,15 +17,18 @@ import sqlalchemy as sa
from datetime import datetime
def upgrade():
op.add_column('users', sa.Column('created', sa.DateTime, nullable=True))
c = op.get_bind()
# fill created date with current time
now = datetime.utcnow()
c.execute("""
c.execute(
"""
UPDATE users
SET created='%s'
""" % (now,)
"""
% (now,)
)
tables = c.engine.table_names()
@@ -34,11 +36,13 @@ def upgrade():
if 'spawners' in tables:
op.add_column('spawners', sa.Column('started', sa.DateTime, nullable=True))
# fill started value with now for running servers
c.execute("""
c.execute(
"""
UPDATE spawners
SET started='%s'
WHERE server_id IS NOT NULL
""" % (now,)
"""
% (now,)
)

View File

@@ -5,7 +5,6 @@ Revises: eeb276e51423
Create Date: 2016-07-28 16:16:38.245348
"""
# revision identifiers, used by Alembic.
revision = 'af4cbdb2d13c'
down_revision = 'eeb276e51423'

View File

@@ -5,7 +5,6 @@ Revises: 99a28a4418e1
Create Date: 2018-04-13 10:50:17.968636
"""
# revision identifiers, used by Alembic.
revision = 'd68c98b66cd4'
down_revision = '99a28a4418e1'
@@ -20,8 +19,7 @@ def upgrade():
tables = op.get_bind().engine.table_names()
if 'oauth_clients' in tables:
op.add_column(
'oauth_clients',
sa.Column('description', sa.Unicode(length=1023))
'oauth_clients', sa.Column('description', sa.Unicode(length=1023))
)

View File

@@ -6,7 +6,6 @@ Revision ID: eeb276e51423
Revises: 19c0846f6344
Create Date: 2016-04-11 16:06:49.239831
"""
# revision identifiers, used by Alembic.
revision = 'eeb276e51423'
down_revision = '19c0846f6344'
@@ -17,6 +16,7 @@ from alembic import op
import sqlalchemy as sa
from jupyterhub.orm import JSONDict
def upgrade():
op.add_column('users', sa.Column('auth_state', JSONDict))

View File

@@ -1,5 +1,10 @@
from . import auth
from . import groups
from . import hub
from . import proxy
from . import services
from . import users
from .base import *
from . import auth, hub, proxy, users, groups, services
default_handlers = []
for mod in (auth, hub, proxy, users, groups, services):

View File

@@ -1,25 +1,23 @@
"""Authorization handlers"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from datetime import datetime
import json
from urllib.parse import (
parse_qsl,
quote,
urlencode,
urlparse,
urlunparse,
)
from datetime import datetime
from urllib.parse import parse_qsl
from urllib.parse import quote
from urllib.parse import urlencode
from urllib.parse import urlparse
from urllib.parse import urlunparse
from oauthlib import oauth2
from tornado import web
from .. import orm
from ..user import User
from ..utils import token_authenticated, compare_token
from .base import BaseHandler, APIHandler
from ..utils import compare_token
from ..utils import token_authenticated
from .base import APIHandler
from .base import BaseHandler
class TokenAPIHandler(APIHandler):
@@ -70,7 +68,9 @@ class TokenAPIHandler(APIHandler):
if data and data.get('username'):
user = self.find_user(data['username'])
if user is not requester and not requester.admin:
raise web.HTTPError(403, "Only admins can request tokens for other users.")
raise web.HTTPError(
403, "Only admins can request tokens for other users."
)
if requester.admin and user is None:
raise web.HTTPError(400, "No such user '%s'" % data['username'])
@@ -82,11 +82,11 @@ class TokenAPIHandler(APIHandler):
note += " by %s %s" % (kind, requester.name)
api_token = user.new_api_token(note=note)
self.write(json.dumps({
'token': api_token,
'warning': warn_msg,
'user': self.user_model(user),
}))
self.write(
json.dumps(
{'token': api_token, 'warning': warn_msg, 'user': self.user_model(user)}
)
)
class CookieAPIHandler(APIHandler):
@@ -94,7 +94,9 @@ class CookieAPIHandler(APIHandler):
def get(self, cookie_name, cookie_value=None):
cookie_name = quote(cookie_name, safe='')
if cookie_value is None:
self.log.warning("Cookie values in request body is deprecated, use `/cookie_name/cookie_value`")
self.log.warning(
"Cookie values in request body is deprecated, use `/cookie_name/cookie_value`"
)
cookie_value = self.request.body
else:
cookie_value = cookie_value.encode('utf8')
@@ -134,7 +136,9 @@ class OAuthHandler:
return uri
# make absolute local redirects full URLs
# to satisfy oauthlib's absolute URI requirement
redirect_uri = self.request.protocol + "://" + self.request.headers['Host'] + redirect_uri
redirect_uri = (
self.request.protocol + "://" + self.request.headers['Host'] + redirect_uri
)
parsed_url = urlparse(uri)
query_list = parse_qsl(parsed_url.query, keep_blank_values=True)
for idx, item in enumerate(query_list):
@@ -142,10 +146,7 @@ class OAuthHandler:
query_list[idx] = ('redirect_uri', redirect_uri)
break
return urlunparse(
urlparse(uri)
._replace(query=urlencode(query_list))
)
return urlunparse(urlparse(uri)._replace(query=urlencode(query_list)))
def add_credentials(self, credentials=None):
"""Add oauth credentials
@@ -164,11 +165,7 @@ class OAuthHandler:
user = self.current_user
# Extra credentials we need in the validator
credentials.update({
'user': user,
'handler': self,
'session_id': session_id,
})
credentials.update({'user': user, 'handler': self, 'session_id': session_id})
return credentials
def send_oauth_response(self, headers, body, status):
@@ -193,7 +190,8 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
def _complete_login(self, uri, headers, scopes, credentials):
try:
headers, body, status = self.oauth_provider.create_authorization_response(
uri, 'POST', '', headers, scopes, credentials)
uri, 'POST', '', headers, scopes, credentials
)
except oauth2.FatalClientError as e:
# TODO: human error page
@@ -213,13 +211,15 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
uri, http_method, body, headers = self.extract_oauth_params()
try:
scopes, credentials = self.oauth_provider.validate_authorization_request(
uri, http_method, body, headers)
uri, http_method, body, headers
)
credentials = self.add_credentials(credentials)
client = self.oauth_provider.fetch_by_client_id(credentials['client_id'])
if client.redirect_uri.startswith(self.current_user.url):
self.log.debug(
"Skipping oauth confirmation for %s accessing %s",
self.current_user, client.description,
self.current_user,
client.description,
)
# access to my own server doesn't require oauth confirmation
# this is the pre-1.0 behavior for all oauth
@@ -228,11 +228,7 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
# Render oauth 'Authorize application...' page
self.write(
self.render_template(
"oauth.html",
scopes=scopes,
oauth_client=client,
)
self.render_template("oauth.html", scopes=scopes, oauth_client=client)
)
# Errors that should be shown to the user on the provider website
@@ -252,7 +248,9 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
if referer != full_url:
# OAuth post must be made to the URL it came from
self.log.error("OAuth POST from %s != %s", referer, full_url)
raise web.HTTPError(403, "Authorization form must be sent from authorization page")
raise web.HTTPError(
403, "Authorization form must be sent from authorization page"
)
# The scopes the user actually authorized, i.e. checkboxes
# that were selected.
@@ -262,7 +260,7 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
try:
headers, body, status = self.oauth_provider.create_authorization_response(
uri, http_method, body, headers, scopes, credentials,
uri, http_method, body, headers, scopes, credentials
)
except oauth2.FatalClientError as e:
raise web.HTTPError(e.status_code, e.description)
@@ -277,7 +275,8 @@ class OAuthTokenHandler(OAuthHandler, APIHandler):
try:
headers, body, status = self.oauth_provider.create_token_response(
uri, http_method, body, headers, credentials)
uri, http_method, body, headers, credentials
)
except oauth2.FatalClientError as e:
raise web.HTTPError(e.status_code, e.description)
else:

View File

@@ -1,10 +1,8 @@
"""Base API handlers"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from datetime import datetime
import json
from datetime import datetime
from http.client import responses
from sqlalchemy.exc import SQLAlchemyError
@@ -12,7 +10,8 @@ from tornado import web
from .. import orm
from ..handlers import BaseHandler
from ..utils import isoformat, url_path_join
from ..utils import isoformat
from ..utils import url_path_join
class APIHandler(BaseHandler):
@@ -55,8 +54,11 @@ class APIHandler(BaseHandler):
host_path = url_path_join(host, self.hub.base_url)
referer_path = referer.split('://', 1)[-1]
if not (referer_path + '/').startswith(host_path):
self.log.warning("Blocking Cross Origin API request. Referer: %s, Host: %s",
referer, host_path)
self.log.warning(
"Blocking Cross Origin API request. Referer: %s, Host: %s",
referer,
host_path,
)
return False
return True
@@ -105,9 +107,13 @@ class APIHandler(BaseHandler):
if exception and isinstance(exception, SQLAlchemyError):
try:
exception_str = str(exception)
self.log.warning("Rolling back session due to database error %s", exception_str)
self.log.warning(
"Rolling back session due to database error %s", exception_str
)
except Exception:
self.log.warning("Rolling back session due to database error %s", type(exception))
self.log.warning(
"Rolling back session due to database error %s", type(exception)
)
self.db.rollback()
self.set_header('Content-Type', 'application/json')
@@ -121,10 +127,9 @@ class APIHandler(BaseHandler):
# Content-Length must be recalculated.
self.clear_header('Content-Length')
self.write(json.dumps({
'status': status_code,
'message': message or status_message,
}))
self.write(
json.dumps({'status': status_code, 'message': message or status_message})
)
def server_model(self, spawner, include_state=False):
"""Get the JSON model for a Spawner"""
@@ -144,21 +149,17 @@ class APIHandler(BaseHandler):
expires_at = None
if isinstance(token, orm.APIToken):
kind = 'api_token'
extra = {
'note': token.note,
}
extra = {'note': token.note}
expires_at = token.expires_at
elif isinstance(token, orm.OAuthAccessToken):
kind = 'oauth'
extra = {
'oauth_client': token.client.description or token.client.client_id,
}
extra = {'oauth_client': token.client.description or token.client.client_id}
if token.expires_at:
expires_at = datetime.fromtimestamp(token.expires_at)
else:
raise TypeError(
"token must be an APIToken or OAuthAccessToken, not %s"
% type(token))
"token must be an APIToken or OAuthAccessToken, not %s" % type(token)
)
if token.user:
owner_key = 'user'
@@ -219,23 +220,11 @@ class APIHandler(BaseHandler):
def service_model(self, service):
"""Get the JSON model for a Service object"""
return {
'kind': 'service',
'name': service.name,
'admin': service.admin,
}
return {'kind': 'service', 'name': service.name, 'admin': service.admin}
_user_model_types = {
'name': str,
'admin': bool,
'groups': list,
'auth_state': dict,
}
_user_model_types = {'name': str, 'admin': bool, 'groups': list, 'auth_state': dict}
_group_model_types = {
'name': str,
'users': list,
}
_group_model_types = {'name': str, 'users': list}
def _check_model(self, model, model_types, name):
"""Check a model provided by a REST API request
@@ -251,24 +240,29 @@ class APIHandler(BaseHandler):
raise web.HTTPError(400, "Invalid JSON keys: %r" % model)
for key, value in model.items():
if not isinstance(value, model_types[key]):
raise web.HTTPError(400, "%s.%s must be %s, not: %r" % (
name, key, model_types[key], type(value)
))
raise web.HTTPError(
400,
"%s.%s must be %s, not: %r"
% (name, key, model_types[key], type(value)),
)
def _check_user_model(self, model):
"""Check a request-provided user model from a REST API"""
self._check_model(model, self._user_model_types, 'user')
for username in model.get('users', []):
if not isinstance(username, str):
raise web.HTTPError(400, ("usernames must be str, not %r", type(username)))
raise web.HTTPError(
400, ("usernames must be str, not %r", type(username))
)
def _check_group_model(self, model):
"""Check a request-provided group model from a REST API"""
self._check_model(model, self._group_model_types, 'group')
for groupname in model.get('groups', []):
if not isinstance(groupname, str):
raise web.HTTPError(400, ("group names must be str, not %r", type(groupname)))
raise web.HTTPError(
400, ("group names must be str, not %r", type(groupname))
)
def options(self, *args, **kwargs):
self.finish()
@@ -279,6 +273,7 @@ class API404(APIHandler):
Ensures JSON 404 errors for malformed URLs
"""
async def prepare(self):
await super().prepare()
raise web.HTTPError(404)

View File

@@ -1,11 +1,10 @@
"""Group handlers"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import json
from tornado import gen, web
from tornado import gen
from tornado import web
from .. import orm
from ..utils import admin_only
@@ -34,6 +33,7 @@ class _GroupAPIHandler(APIHandler):
raise web.HTTPError(404, "No such group: %s", name)
return group
class GroupListAPIHandler(_GroupAPIHandler):
@admin_only
def get(self):
@@ -61,9 +61,7 @@ class GroupListAPIHandler(_GroupAPIHandler):
# check that users exist
users = self._usernames_to_users(usernames)
# create the group
self.log.info("Creating new group %s with %i users",
name, len(users),
)
self.log.info("Creating new group %s with %i users", name, len(users))
self.log.debug("Users: %s", usernames)
group = orm.Group(name=name, users=users)
self.db.add(group)
@@ -99,9 +97,7 @@ class GroupAPIHandler(_GroupAPIHandler):
users = self._usernames_to_users(usernames)
# create the group
self.log.info("Creating new group %s with %i users",
name, len(users),
)
self.log.info("Creating new group %s with %i users", name, len(users))
self.log.debug("Users: %s", usernames)
group = orm.Group(name=name, users=users)
self.db.add(group)
@@ -121,6 +117,7 @@ class GroupAPIHandler(_GroupAPIHandler):
class GroupUsersAPIHandler(_GroupAPIHandler):
"""Modify a group's user list"""
@admin_only
def post(self, name):
"""POST adds users to a group"""

View File

@@ -1,21 +1,18 @@
"""API handlers for administering the Hub itself"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import json
import sys
from tornado import web
from tornado.ioloop import IOLoop
from .._version import __version__
from ..utils import admin_only
from .base import APIHandler
from .._version import __version__
class ShutdownAPIHandler(APIHandler):
@admin_only
def post(self):
"""POST /api/shutdown triggers a clean shutdown
@@ -26,6 +23,7 @@ class ShutdownAPIHandler(APIHandler):
- proxy: specify whether the proxy should be terminated
"""
from ..app import JupyterHub
app = JupyterHub.instance()
data = self.get_json_body()
@@ -33,19 +31,21 @@ class ShutdownAPIHandler(APIHandler):
if 'proxy' in data:
proxy = data['proxy']
if proxy not in {True, False}:
raise web.HTTPError(400, "proxy must be true or false, got %r" % proxy)
raise web.HTTPError(
400, "proxy must be true or false, got %r" % proxy
)
app.cleanup_proxy = proxy
if 'servers' in data:
servers = data['servers']
if servers not in {True, False}:
raise web.HTTPError(400, "servers must be true or false, got %r" % servers)
raise web.HTTPError(
400, "servers must be true or false, got %r" % servers
)
app.cleanup_servers = servers
# finish the request
self.set_status(202)
self.finish(json.dumps({
"message": "Shutting down Hub"
}))
self.finish(json.dumps({"message": "Shutting down Hub"}))
# stop the eventloop, which will trigger cleanup
loop = IOLoop.current()
@@ -53,7 +53,6 @@ class ShutdownAPIHandler(APIHandler):
class RootAPIHandler(APIHandler):
def get(self):
"""GET /api/ returns info about the Hub and its API.
@@ -61,14 +60,11 @@ class RootAPIHandler(APIHandler):
For now, it just returns the version of JupyterHub itself.
"""
data = {
'version': __version__,
}
data = {'version': __version__}
self.finish(json.dumps(data))
class InfoAPIHandler(APIHandler):
@admin_only
def get(self):
"""GET /api/info returns detailed info about the Hub and its API.
@@ -77,10 +73,11 @@ class InfoAPIHandler(APIHandler):
For now, it just returns the version of JupyterHub itself.
"""
def _class_info(typ):
"""info about a class (Spawner or Authenticator)"""
info = {
'class': '{mod}.{name}'.format(mod=typ.__module__, name=typ.__name__),
'class': '{mod}.{name}'.format(mod=typ.__module__, name=typ.__name__)
}
pkg = typ.__module__.split('.')[0]
try:

View File

@@ -1,12 +1,11 @@
"""Proxy handlers"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import json
from urllib.parse import urlparse
from tornado import gen, web
from tornado import gen
from tornado import web
from .. import orm
from ..utils import admin_only
@@ -14,7 +13,6 @@ from .base import APIHandler
class ProxyAPIHandler(APIHandler):
@admin_only
async def get(self):
"""GET /api/proxy fetches the routing table
@@ -58,6 +56,4 @@ class ProxyAPIHandler(APIHandler):
await self.proxy.check_routes(self.users, self.services)
default_handlers = [
(r"/api/proxy", ProxyAPIHandler),
]
default_handlers = [(r"/api/proxy", ProxyAPIHandler)]

View File

@@ -2,10 +2,8 @@
Currently GET-only, no actions can be taken to modify services.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import json
from tornado import web
@@ -14,6 +12,7 @@ from .. import orm
from ..utils import admin_only
from .base import APIHandler
def service_model(service):
"""Produce the model for a service"""
return {
@@ -23,9 +22,10 @@ def service_model(service):
'prefix': service.server.base_url if service.server else '',
'command': service.command,
'pid': service.proc.pid if service.proc else 0,
'info': service.info
'info': service.info,
}
class ServiceListAPIHandler(APIHandler):
@admin_only
def get(self):
@@ -35,6 +35,7 @@ class ServiceListAPIHandler(APIHandler):
def admin_or_self(method):
"""Decorator for restricting access to either the target service or admin"""
def decorated_method(self, name):
current = self.current_user
if current is None:
@@ -49,10 +50,11 @@ def admin_or_self(method):
if name not in self.services:
raise web.HTTPError(404)
return method(self, name)
return decorated_method
class ServiceAPIHandler(APIHandler):
class ServiceAPIHandler(APIHandler):
@admin_or_self
def get(self, name):
service = self.services[name]

View File

@@ -1,11 +1,11 @@
"""User handlers"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
from datetime import datetime, timedelta, timezone
import json
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from async_generator import aclosing
from dateutil.parser import parse as parse_date
@@ -14,7 +14,11 @@ from tornado.iostream import StreamClosedError
from .. import orm
from ..user import User
from ..utils import admin_only, isoformat, iterate_until, maybe_future, url_path_join
from ..utils import admin_only
from ..utils import isoformat
from ..utils import iterate_until
from ..utils import maybe_future
from ..utils import url_path_join
from .base import APIHandler
@@ -89,7 +93,9 @@ class UserListAPIHandler(APIHandler):
except Exception as e:
self.log.error("Failed to create user: %s" % name, exc_info=True)
self.users.delete(user)
raise web.HTTPError(400, "Failed to create user %s: %s" % (name, str(e)))
raise web.HTTPError(
400, "Failed to create user %s: %s" % (name, str(e))
)
else:
created.append(user)
@@ -99,6 +105,7 @@ class UserListAPIHandler(APIHandler):
def admin_or_self(method):
"""Decorator for restricting access to either the target user or admin"""
def m(self, name, *args, **kwargs):
current = self.current_user
if current is None:
@@ -110,15 +117,17 @@ def admin_or_self(method):
if not self.find_user(name):
raise web.HTTPError(404)
return method(self, name, *args, **kwargs)
return m
class UserAPIHandler(APIHandler):
@admin_or_self
async def get(self, name):
user = self.find_user(name)
model = self.user_model(user, include_servers=True, include_state=self.current_user.admin)
model = self.user_model(
user, include_servers=True, include_state=self.current_user.admin
)
# auth state will only be shown if the requester is an admin
# this means users can't see their own auth state unless they
# are admins, Hub admins often are also marked as admins so they
@@ -161,11 +170,16 @@ class UserAPIHandler(APIHandler):
if user.name == self.current_user.name:
raise web.HTTPError(400, "Cannot delete yourself!")
if user.spawner._stop_pending:
raise web.HTTPError(400, "%s's server is in the process of stopping, please wait." % name)
raise web.HTTPError(
400, "%s's server is in the process of stopping, please wait." % name
)
if user.running:
await self.stop_single_user(user)
if user.spawner._stop_pending:
raise web.HTTPError(400, "%s's server is in the process of stopping, please wait." % name)
raise web.HTTPError(
400,
"%s's server is in the process of stopping, please wait." % name,
)
await maybe_future(self.authenticator.delete_user(user))
# remove from registry
@@ -183,7 +197,10 @@ class UserAPIHandler(APIHandler):
if 'name' in data and data['name'] != name:
# check if the new name is already taken inside db
if self.find_user(data['name']):
raise web.HTTPError(400, "User %s already exists, username must be unique" % data['name'])
raise web.HTTPError(
400,
"User %s already exists, username must be unique" % data['name'],
)
for key, value in data.items():
if key == 'auth_state':
await user.save_auth_state(value)
@@ -197,6 +214,7 @@ class UserAPIHandler(APIHandler):
class UserTokenListAPIHandler(APIHandler):
"""API endpoint for listing/creating tokens"""
@admin_or_self
def get(self, name):
"""Get tokens for a given user"""
@@ -207,6 +225,7 @@ class UserTokenListAPIHandler(APIHandler):
now = datetime.utcnow()
api_tokens = []
def sort_key(token):
return token.last_activity or token.created
@@ -228,10 +247,7 @@ class UserTokenListAPIHandler(APIHandler):
self.db.commit()
continue
oauth_tokens.append(self.token_model(token))
self.write(json.dumps({
'api_tokens': api_tokens,
'oauth_tokens': oauth_tokens,
}))
self.write(json.dumps({'api_tokens': api_tokens, 'oauth_tokens': oauth_tokens}))
async def post(self, name):
body = self.get_json_body() or {}
@@ -253,8 +269,9 @@ class UserTokenListAPIHandler(APIHandler):
except Exception as e:
# suppress and log error here in case Authenticator
# isn't prepared to handle auth via this data
self.log.error("Error authenticating request for %s: %s",
self.request.uri, e)
self.log.error(
"Error authenticating request for %s: %s", self.request.uri, e
)
raise web.HTTPError(403)
requester = self.find_user(name)
if requester is None:
@@ -274,9 +291,16 @@ class UserTokenListAPIHandler(APIHandler):
if requester is not user:
note += " by %s %s" % (kind, requester.name)
api_token = user.new_api_token(note=note, expires_in=body.get('expires_in', None))
api_token = user.new_api_token(
note=note, expires_in=body.get('expires_in', None)
)
if requester is not user:
self.log.info("%s %s requested API token for %s", kind.title(), requester.name, user.name)
self.log.info(
"%s %s requested API token for %s",
kind.title(),
requester.name,
user.name,
)
else:
user_kind = 'user' if isinstance(user, User) else 'service'
self.log.info("%s %s requested new API token", user_kind.title(), user.name)
@@ -333,8 +357,7 @@ class UserTokenAPIHandler(APIHandler):
if isinstance(token, orm.OAuthAccessToken):
client_id = token.client_id
tokens = [
token for token in user.oauth_tokens
if token.client_id == client_id
token for token in user.oauth_tokens if token.client_id == client_id
]
else:
tokens = [token]
@@ -354,16 +377,19 @@ class UserServerAPIHandler(APIHandler):
if server_name:
if not self.allow_named_servers:
raise web.HTTPError(400, "Named servers are not enabled.")
if self.named_server_limit_per_user > 0 and server_name not in user.orm_spawners:
if (
self.named_server_limit_per_user > 0
and server_name not in user.orm_spawners
):
named_spawners = list(user.all_spawners(include_default=False))
if self.named_server_limit_per_user <= len(named_spawners):
raise web.HTTPError(
400,
"User {} already has the maximum of {} named servers."
" One must be deleted before a new server can be created".format(
name,
self.named_server_limit_per_user
))
name, self.named_server_limit_per_user
),
)
spawner = user.spawners[server_name]
pending = spawner.pending
if pending == 'spawn':
@@ -396,7 +422,6 @@ class UserServerAPIHandler(APIHandler):
options = self.get_json_body()
remove = (options or {}).get('remove', False)
def _remove_spawner(f=None):
if f and f.exception():
return
@@ -408,7 +433,9 @@ class UserServerAPIHandler(APIHandler):
if not self.allow_named_servers:
raise web.HTTPError(400, "Named servers are not enabled.")
if server_name not in user.orm_spawners:
raise web.HTTPError(404, "%s has no server named '%s'" % (name, server_name))
raise web.HTTPError(
404, "%s has no server named '%s'" % (name, server_name)
)
elif remove:
raise web.HTTPError(400, "Cannot delete the default server")
@@ -423,7 +450,8 @@ class UserServerAPIHandler(APIHandler):
if spawner.pending:
raise web.HTTPError(
400, "%s is pending %s, please wait" % (spawner._log_name, spawner.pending)
400,
"%s is pending %s, please wait" % (spawner._log_name, spawner.pending),
)
stop_future = None
@@ -449,13 +477,16 @@ class UserAdminAccessAPIHandler(APIHandler):
This handler sets the necessary cookie for an admin to login to a single-user server.
"""
@admin_only
def post(self, name):
self.log.warning("Deprecated in JupyterHub 0.8."
" Admin access API is not needed now that we use OAuth.")
self.log.warning(
"Deprecated in JupyterHub 0.8."
" Admin access API is not needed now that we use OAuth."
)
current = self.current_user
self.log.warning("Admin user %s has requested access to %s's server",
current.name, name,
self.log.warning(
"Admin user %s has requested access to %s's server", current.name, name
)
if not self.settings.get('admin_access', False):
raise web.HTTPError(403, "admin access to user servers disabled")
@@ -501,10 +532,7 @@ class SpawnProgressAPIHandler(APIHandler):
except (StreamClosedError, RuntimeError):
return
await asyncio.wait(
[self._finish_future],
timeout=self.keepalive_interval,
)
await asyncio.wait([self._finish_future], timeout=self.keepalive_interval)
@admin_or_self
async def get(self, username, server_name=''):
@@ -535,11 +563,7 @@ class SpawnProgressAPIHandler(APIHandler):
'html_message': 'Server ready at <a href="{0}">{0}</a>'.format(url),
'url': url,
}
failed_event = {
'progress': 100,
'failed': True,
'message': "Spawn failed",
}
failed_event = {'progress': 100, 'failed': True, 'message': "Spawn failed"}
if spawner.ready:
# spawner already ready. Trigger progress-completion immediately
@@ -561,7 +585,9 @@ class SpawnProgressAPIHandler(APIHandler):
raise web.HTTPError(400, "%s is not starting...", spawner._log_name)
# retrieve progress events from the Spawner
async with aclosing(iterate_until(spawn_future, spawner._generate_progress())) as events:
async with aclosing(
iterate_until(spawn_future, spawner._generate_progress())
) as events:
async for event in events:
# don't allow events to sneakily set the 'ready' flag
if 'ready' in event:
@@ -584,7 +610,9 @@ class SpawnProgressAPIHandler(APIHandler):
if f and f.done() and f.exception():
failed_event['message'] = "Spawn failed: %s" % f.exception()
else:
self.log.warning("Server %s didn't start for unknown reason", spawner._log_name)
self.log.warning(
"Server %s didn't start for unknown reason", spawner._log_name
)
await self.send_event(failed_event)
@@ -609,13 +637,12 @@ def _parse_timestamp(timestamp):
400,
"Rejecting activity from more than an hour in the future: {}".format(
isoformat(dt)
)
),
)
return dt
class ActivityAPIHandler(APIHandler):
def _validate_servers(self, user, servers):
"""Validate servers dict argument
@@ -632,10 +659,7 @@ class ActivityAPIHandler(APIHandler):
if server_name not in spawners:
raise web.HTTPError(
400,
"No such server '{}' for user {}".format(
server_name,
user.name,
)
"No such server '{}' for user {}".format(server_name, user.name),
)
# check that each per-server field is a dict
if not isinstance(server_info, dict):
@@ -645,7 +669,9 @@ class ActivityAPIHandler(APIHandler):
raise web.HTTPError(400, msg)
# parse last_activity timestamps
# _parse_timestamp above is responsible for raising errors
server_info['last_activity'] = _parse_timestamp(server_info['last_activity'])
server_info['last_activity'] = _parse_timestamp(
server_info['last_activity']
)
return servers
@admin_or_self
@@ -663,8 +689,7 @@ class ActivityAPIHandler(APIHandler):
servers = body.get('servers')
if not last_activity_timestamp and not servers:
raise web.HTTPError(
400,
"body must contain at least one of `last_activity` or `servers`"
400, "body must contain at least one of `last_activity` or `servers`"
)
if servers:
@@ -677,13 +702,9 @@ class ActivityAPIHandler(APIHandler):
# update user.last_activity if specified
if last_activity_timestamp:
last_activity = _parse_timestamp(last_activity_timestamp)
if (
(not user.last_activity)
or last_activity > user.last_activity
):
self.log.debug("Activity for user %s: %s",
user.name,
isoformat(last_activity),
if (not user.last_activity) or last_activity > user.last_activity:
self.log.debug(
"Activity for user %s: %s", user.name, isoformat(last_activity)
)
user.last_activity = last_activity
else:
@@ -699,11 +720,9 @@ class ActivityAPIHandler(APIHandler):
last_activity = server_info['last_activity']
spawner = user.orm_spawners[server_name]
if (
(not spawner.last_activity)
or last_activity > spawner.last_activity
):
self.log.debug("Activity on server %s/%s: %s",
if (not spawner.last_activity) or last_activity > spawner.last_activity:
self.log.debug(
"Activity on server %s/%s: %s",
user.name,
server_name,
isoformat(last_activity),

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,16 @@
"""Base Authenticator class and the default PAM Authenticator"""
# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.
from concurrent.futures import ThreadPoolExecutor
import inspect
import pipes
import re
import sys
from shutil import which
from subprocess import Popen, PIPE, STDOUT
import warnings
from concurrent.futures import ThreadPoolExecutor
from shutil import which
from subprocess import PIPE
from subprocess import Popen
from subprocess import STDOUT
try:
import pamela
@@ -33,7 +33,9 @@ class Authenticator(LoggingConfigurable):
db = Any()
enable_auth_state = Bool(False, config=True,
enable_auth_state = Bool(
False,
config=True,
help="""Enable persisting auth_state (if available).
auth_state will be encrypted and stored in the Hub's database.
@@ -62,7 +64,7 @@ class Authenticator(LoggingConfigurable):
See :meth:`.refresh_user` for what happens when user auth info is refreshed
(nothing by default).
"""
""",
)
refresh_pre_spawn = Bool(
@@ -78,7 +80,7 @@ class Authenticator(LoggingConfigurable):
If refresh_user cannot refresh the user auth data,
launch will fail until the user logs in again.
"""
""",
)
admin_users = Set(
@@ -131,8 +133,11 @@ class Authenticator(LoggingConfigurable):
sorted_names = sorted(short_names)
single = ''.join(sorted_names)
string_set_typo = "set('%s')" % single
self.log.warning("whitelist contains single-character names: %s; did you mean set([%r]) instead of %s?",
sorted_names[:8], single, string_set_typo,
self.log.warning(
"whitelist contains single-character names: %s; did you mean set([%r]) instead of %s?",
sorted_names[:8],
single,
string_set_typo,
)
custom_html = Unicode(
@@ -199,7 +204,8 @@ class Authenticator(LoggingConfigurable):
"""
).tag(config=True)
delete_invalid_users = Bool(False,
delete_invalid_users = Bool(
False,
help="""Delete any users from the database that do not pass validation
When JupyterHub starts, `.add_user` will be called
@@ -213,10 +219,11 @@ class Authenticator(LoggingConfigurable):
If False (default), invalid users remain in the Hub's database
and a warning will be issued.
This is the default to avoid data loss due to config changes.
"""
""",
)
post_auth_hook = Any(config=True,
post_auth_hook = Any(
config=True,
help="""
An optional hook function that you can implement to do some
bootstrapping work during authentication. For example, loading user account
@@ -248,12 +255,16 @@ class Authenticator(LoggingConfigurable):
c.Authenticator.post_auth_hook = my_hook
"""
""",
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
for method_name in ('check_whitelist', 'check_blacklist', 'check_group_whitelist'):
for method_name in (
'check_whitelist',
'check_blacklist',
'check_group_whitelist',
):
original_method = getattr(self, method_name, None)
if original_method is None:
# no such method (check_group_whitelist is optional)
@@ -273,14 +284,14 @@ class Authenticator(LoggingConfigurable):
Adapting for compatibility.
""".format(
self.__class__.__name__,
method_name,
self.__class__.__name__, method_name
),
DeprecationWarning
DeprecationWarning,
)
def wrapped_method(username, authentication=None, **kwargs):
return original_method(username, **kwargs)
setattr(self, method_name, wrapped_method)
async def run_post_auth_hook(self, handler, authentication):
@@ -299,11 +310,7 @@ class Authenticator(LoggingConfigurable):
"""
if self.post_auth_hook is not None:
authentication = await maybe_future(
self.post_auth_hook(
self,
handler,
authentication,
)
self.post_auth_hook(self, handler, authentication)
)
return authentication
@@ -380,21 +387,25 @@ class Authenticator(LoggingConfigurable):
if 'name' not in authenticated:
raise ValueError("user missing a name: %r" % authenticated)
else:
authenticated = {
'name': authenticated,
}
authenticated = {'name': authenticated}
authenticated.setdefault('auth_state', None)
# Leave the default as None, but reevaluate later post-whitelist
authenticated.setdefault('admin', None)
# normalize the username
authenticated['name'] = username = self.normalize_username(authenticated['name'])
authenticated['name'] = username = self.normalize_username(
authenticated['name']
)
if not self.validate_username(username):
self.log.warning("Disallowing invalid username %r.", username)
return
blacklist_pass = await maybe_future(self.check_blacklist(username, authenticated))
whitelist_pass = await maybe_future(self.check_whitelist(username, authenticated))
blacklist_pass = await maybe_future(
self.check_blacklist(username, authenticated)
)
whitelist_pass = await maybe_future(
self.check_whitelist(username, authenticated)
)
if blacklist_pass:
pass
@@ -404,7 +415,9 @@ class Authenticator(LoggingConfigurable):
if whitelist_pass:
if authenticated['admin'] is None:
authenticated['admin'] = await maybe_future(self.is_admin(handler, authenticated))
authenticated['admin'] = await maybe_future(
self.is_admin(handler, authenticated)
)
authenticated = await self.run_post_auth_hook(handler, authenticated)
@@ -534,7 +547,9 @@ class Authenticator(LoggingConfigurable):
"""
self.whitelist.discard(user.name)
auto_login = Bool(False, config=True,
auto_login = Bool(
False,
config=True,
help="""Automatically begin the login process
rather than starting with a "Login with..." link at `/hub/login`
@@ -544,7 +559,7 @@ class Authenticator(LoggingConfigurable):
registered with `.get_handlers()`.
.. versionadded:: 0.8
"""
""",
)
def login_url(self, base_url):
@@ -592,9 +607,7 @@ class Authenticator(LoggingConfigurable):
list of ``('/url', Handler)`` tuples passed to tornado.
The Hub prefix is added to any URLs.
"""
return [
('/login', LoginHandler),
]
return [('/login', LoginHandler)]
class LocalAuthenticator(Authenticator):
@@ -603,12 +616,13 @@ class LocalAuthenticator(Authenticator):
Checks for local users, and can attempt to create them if they exist.
"""
create_system_users = Bool(False,
create_system_users = Bool(
False,
help="""
If set to True, will attempt to create local system users if they do not exist already.
Supports Linux and BSD variants only.
"""
""",
).tag(config=True)
add_user_cmd = Command(
@@ -699,8 +713,9 @@ class LocalAuthenticator(Authenticator):
raise KeyError(
"User {} does not exist on the system."
" Set LocalAuthenticator.create_system_users=True"
" to automatically create system users from jupyterhub users."
.format(user.name)
" to automatically create system users from jupyterhub users.".format(
user.name
)
)
await maybe_future(super().add_user(user))
@@ -711,6 +726,7 @@ class LocalAuthenticator(Authenticator):
on Windows
"""
import grp
return grp.getgrnam(name)
@staticmethod
@@ -719,6 +735,7 @@ class LocalAuthenticator(Authenticator):
on Windows
"""
import pwd
return pwd.getpwnam(name)
@staticmethod
@@ -727,6 +744,7 @@ class LocalAuthenticator(Authenticator):
on Windows
"""
import os
return os.getgrouplist(name, group)
def system_user_exists(self, user):
@@ -758,23 +776,27 @@ class PAMAuthenticator(LocalAuthenticator):
# run PAM in a thread, since it can be slow
executor = Any()
@default('executor')
def _default_executor(self):
return ThreadPoolExecutor(1)
encoding = Unicode('utf8',
encoding = Unicode(
'utf8',
help="""
The text encoding to use when communicating with PAM
"""
""",
).tag(config=True)
service = Unicode('login',
service = Unicode(
'login',
help="""
The name of the PAM service to use for authentication
"""
""",
).tag(config=True)
open_sessions = Bool(True,
open_sessions = Bool(
True,
help="""
Whether to open a new PAM session when spawners are started.
@@ -784,10 +806,11 @@ class PAMAuthenticator(LocalAuthenticator):
If any errors are encountered when opening/closing PAM sessions,
this is automatically set to False.
"""
""",
).tag(config=True)
check_account = Bool(True,
check_account = Bool(
True,
help="""
Whether to check the user's account status via PAM during authentication.
@@ -797,7 +820,7 @@ class PAMAuthenticator(LocalAuthenticator):
Disabling this can be dangerous as authenticated but unauthorized users may
be granted access and, therefore, arbitrary execution on the system.
"""
""",
).tag(config=True)
admin_groups = Set(
@@ -809,14 +832,15 @@ class PAMAuthenticator(LocalAuthenticator):
"""
).tag(config=True)
pam_normalize_username = Bool(False,
pam_normalize_username = Bool(
False,
help="""
Round-trip the username via PAM lookups to make sure it is unique
PAM can accept multiple usernames that map to the same user,
for example DOMAIN\\username in some cases. To prevent this,
convert username into uid, then back to uid to normalize.
"""
""",
).tag(config=True)
def __init__(self, **kwargs):
@@ -844,12 +868,19 @@ class PAMAuthenticator(LocalAuthenticator):
# (returning None instead of just the username) as this indicates some sort of system failure
admin_group_gids = {self._getgrnam(x).gr_gid for x in self.admin_groups}
user_group_gids = set(self._getgrouplist(username, self._getpwnam(username).pw_gid))
user_group_gids = set(
self._getgrouplist(username, self._getpwnam(username).pw_gid)
)
admin_status = len(admin_group_gids & user_group_gids) != 0
except Exception as e:
if handler is not None:
self.log.error("PAM Admin Group Check failed (%s@%s): %s", username, handler.request.remote_ip, e)
self.log.error(
"PAM Admin Group Check failed (%s@%s): %s",
username,
handler.request.remote_ip,
e,
)
else:
self.log.error("PAM Admin Group Check failed: %s", e)
# re-raise to return a 500 to the user and indicate a problem. We failed, not them.
@@ -865,27 +896,40 @@ class PAMAuthenticator(LocalAuthenticator):
"""
username = data['username']
try:
pamela.authenticate(username, data['password'], service=self.service, encoding=self.encoding)
pamela.authenticate(
username, data['password'], service=self.service, encoding=self.encoding
)
except pamela.PAMError as e:
if handler is not None:
self.log.warning("PAM Authentication failed (%s@%s): %s", username, handler.request.remote_ip, e)
self.log.warning(
"PAM Authentication failed (%s@%s): %s",
username,
handler.request.remote_ip,
e,
)
else:
self.log.warning("PAM Authentication failed: %s", e)
return None
if self.check_account:
try:
pamela.check_account(username, service=self.service, encoding=self.encoding)
pamela.check_account(
username, service=self.service, encoding=self.encoding
)
except pamela.PAMError as e:
if handler is not None:
self.log.warning("PAM Account Check failed (%s@%s): %s", username, handler.request.remote_ip, e)
self.log.warning(
"PAM Account Check failed (%s@%s): %s",
username,
handler.request.remote_ip,
e,
)
else:
self.log.warning("PAM Account Check failed: %s", e)
return None
return username
@run_on_executor
def pre_spawn_start(self, user, spawner):
"""Open PAM session for user if so configured"""
@@ -904,7 +948,9 @@ class PAMAuthenticator(LocalAuthenticator):
if not self.open_sessions:
return
try:
pamela.close_session(user.name, service=self.service, encoding=self.encoding)
pamela.close_session(
user.name, service=self.service, encoding=self.encoding
)
except pamela.PAMError as e:
self.log.warning("Failed to close PAM session for %s: %s", user.name, e)
self.log.warning("Disabling PAM sessions from now on.")
@@ -916,12 +962,14 @@ class PAMAuthenticator(LocalAuthenticator):
PAM can accept multiple usernames as the same user, normalize them."""
if self.pam_normalize_username:
import pwd
uid = pwd.getpwnam(username).pw_uid
username = pwd.getpwuid(uid).pw_name
username = self.username_map.get(username, username)
else:
return super().normalize_username(username)
class DummyAuthenticator(Authenticator):
"""Dummy Authenticator for testing
@@ -938,7 +986,7 @@ class DummyAuthenticator(Authenticator):
Set a global password for all users wanting to log in.
This allows users with any username to log in with the same static password.
"""
""",
)
async def authenticate(self, handler, data):

View File

@@ -1,35 +1,43 @@
import base64
from binascii import a2b_hex
from concurrent.futures import ThreadPoolExecutor
import json
import os
from binascii import a2b_hex
from concurrent.futures import ThreadPoolExecutor
from traitlets.config import SingletonConfigurable, Config
from traitlets import (
Any, Dict, Integer, List,
default, observe, validate,
)
from traitlets import Any
from traitlets import default
from traitlets import Dict
from traitlets import Integer
from traitlets import List
from traitlets import observe
from traitlets import validate
from traitlets.config import Config
from traitlets.config import SingletonConfigurable
try:
import cryptography
from cryptography.fernet import Fernet, MultiFernet, InvalidToken
except ImportError:
cryptography = None
class InvalidToken(Exception):
pass
from .utils import maybe_future
KEY_ENV = 'JUPYTERHUB_CRYPT_KEY'
class EncryptionUnavailable(Exception):
pass
class CryptographyUnavailable(EncryptionUnavailable):
def __str__(self):
return "cryptography library is required for encryption"
class NoEncryptionKeys(EncryptionUnavailable):
def __str__(self):
return "Encryption keys must be specified in %s env" % KEY_ENV
@@ -70,13 +78,16 @@ def _validate_key(key):
return key
class CryptKeeper(SingletonConfigurable):
"""Encapsulate encryption configuration
Use via the encryption_config singleton below.
"""
n_threads = Integer(max(os.cpu_count(), 1), config=True,
n_threads = Integer(
max(os.cpu_count(), 1),
config=True,
help="The number of threads to allocate for encryption",
)
@@ -84,22 +95,27 @@ class CryptKeeper(SingletonConfigurable):
def _config_default(self):
# load application config by default
from .app import JupyterHub
if JupyterHub.initialized():
return JupyterHub.instance().config
else:
return Config()
executor = Any()
def _executor_default(self):
return ThreadPoolExecutor(self.n_threads)
keys = List(config=True)
def _keys_default(self):
if KEY_ENV not in os.environ:
return []
# key can be a ;-separated sequence for key rotation.
# First item in the list is used for encryption.
return [ _validate_key(key) for key in os.environ[KEY_ENV].split(';') if key.strip() ]
return [
_validate_key(key) for key in os.environ[KEY_ENV].split(';') if key.strip()
]
@validate('keys')
def _ensure_bytes(self, proposal):
@@ -107,6 +123,7 @@ class CryptKeeper(SingletonConfigurable):
return [_validate_key(key) for key in proposal.value]
fernet = Any()
def _fernet_default(self):
if cryptography is None or not self.keys:
return None
@@ -153,6 +170,7 @@ def encrypt(data):
"""
return CryptKeeper.instance().encrypt(data)
def decrypt(data):
"""decrypt some data with the crypt keeper

View File

@@ -1,15 +1,13 @@
"""Database utilities for JupyterHub"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
# Based on pgcontents.utils.migrate, used under the Apache license.
from contextlib import contextmanager
from datetime import datetime
import os
import shutil
from subprocess import check_call
import sys
from contextlib import contextmanager
from datetime import datetime
from subprocess import check_call
from tempfile import TemporaryDirectory
from sqlalchemy import create_engine
@@ -85,9 +83,7 @@ def upgrade(db_url, revision='head'):
The alembic revision to upgrade to.
"""
with _temp_alembic_ini(db_url) as alembic_ini:
check_call(
['alembic', '-c', alembic_ini, 'upgrade', revision]
)
check_call(['alembic', '-c', alembic_ini, 'upgrade', revision])
def backup_db_file(db_file, log=None):
@@ -133,30 +129,27 @@ def upgrade_if_needed(db_url, backup=True, log=None):
def shell(args=None):
"""Start an IPython shell hooked up to the jupyerhub database"""
from .app import JupyterHub
hub = JupyterHub()
hub.load_config_file(hub.config_file)
db_url = hub.db_url
db = orm.new_session_factory(db_url, **hub.db_kwargs)()
ns = {
'db': db,
'db_url': db_url,
'orm': orm,
}
ns = {'db': db, 'db_url': db_url, 'orm': orm}
import IPython
IPython.start_ipython(args, user_ns=ns)
def _alembic(args):
"""Run an alembic command with a temporary alembic.ini"""
from .app import JupyterHub
hub = JupyterHub()
hub.load_config_file(hub.config_file)
db_url = hub.db_url
with _temp_alembic_ini(db_url) as alembic_ini:
check_call(
['alembic', '-c', alembic_ini] + args
)
check_call(['alembic', '-c', alembic_ini] + args)
def main(args=None):

View File

@@ -1,8 +1,10 @@
from . import base
from . import login
from . import metrics
from . import pages
from .base import *
from .login import *
from . import base, pages, login, metrics
default_handlers = []
for mod in (base, pages, login, metrics):
default_handlers.extend(mod.default_handlers)

View File

@@ -1,41 +1,49 @@
"""HTTP Handlers for the hub server"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import copy
from datetime import datetime, timedelta
from http.client import responses
import json
import math
import random
import re
import time
from urllib.parse import urlparse, urlunparse, parse_qs, urlencode
import uuid
from datetime import datetime
from datetime import timedelta
from http.client import responses
from urllib.parse import parse_qs
from urllib.parse import urlencode
from urllib.parse import urlparse
from urllib.parse import urlunparse
from jinja2 import TemplateNotFound
from sqlalchemy.exc import SQLAlchemyError
from tornado.log import app_log
from tornado.httputil import url_concat, HTTPHeaders
from tornado import gen
from tornado import web
from tornado.httputil import HTTPHeaders
from tornado.httputil import url_concat
from tornado.ioloop import IOLoop
from tornado.web import RequestHandler, MissingArgumentError
from tornado import gen, web
from tornado.log import app_log
from tornado.web import MissingArgumentError
from tornado.web import RequestHandler
from .. import __version__
from .. import orm
from ..metrics import PROXY_ADD_DURATION_SECONDS
from ..metrics import ProxyAddStatus
from ..metrics import RUNNING_SERVERS
from ..metrics import SERVER_POLL_DURATION_SECONDS
from ..metrics import SERVER_SPAWN_DURATION_SECONDS
from ..metrics import SERVER_STOP_DURATION_SECONDS
from ..metrics import ServerPollStatus
from ..metrics import ServerSpawnStatus
from ..metrics import ServerStopStatus
from ..objects import Server
from ..spawner import LocalProcessSpawner
from ..user import User
from ..utils import maybe_future, url_path_join
from ..metrics import (
SERVER_SPAWN_DURATION_SECONDS, ServerSpawnStatus,
PROXY_ADD_DURATION_SECONDS, ProxyAddStatus,
SERVER_POLL_DURATION_SECONDS, ServerPollStatus,
RUNNING_SERVERS, SERVER_STOP_DURATION_SECONDS, ServerStopStatus
)
from ..utils import maybe_future
from ..utils import url_path_join
# pattern for the authentication token header
auth_header_pat = re.compile(r'^(?:token|bearer)\s+([^\s]+)$', flags=re.IGNORECASE)
@@ -52,6 +60,7 @@ reasons = {
# constant, not configurable
SESSION_COOKIE_NAME = 'jupyterhub-session-id'
class BaseHandler(RequestHandler):
"""Base Handler class with access to common methods and properties."""
@@ -157,8 +166,8 @@ class BaseHandler(RequestHandler):
@property
def csp_report_uri(self):
return self.settings.get('csp_report_uri',
url_path_join(self.hub.base_url, 'security/csp-report')
return self.settings.get(
'csp_report_uri', url_path_join(self.hub.base_url, 'security/csp-report')
)
@property
@@ -167,10 +176,9 @@ class BaseHandler(RequestHandler):
Can be overridden by defining Content-Security-Policy in settings['headers']
"""
return '; '.join([
"frame-ancestors 'self'",
"report-uri " + self.csp_report_uri,
])
return '; '.join(
["frame-ancestors 'self'", "report-uri " + self.csp_report_uri]
)
def get_content_type(self):
return 'text/html'
@@ -190,7 +198,9 @@ class BaseHandler(RequestHandler):
self.set_header(header_name, header_content)
if 'Access-Control-Allow-Headers' not in headers:
self.set_header('Access-Control-Allow-Headers', 'accept, content-type, authorization')
self.set_header(
'Access-Control-Allow-Headers', 'accept, content-type, authorization'
)
if 'Content-Security-Policy' not in headers:
self.set_header('Content-Security-Policy', self.content_security_policy)
self.set_header('Content-Type', self.get_content_type())
@@ -236,8 +246,7 @@ class BaseHandler(RequestHandler):
orm_token = orm.OAuthAccessToken.find(self.db, token)
if orm_token is None:
return None
orm_token.last_activity = \
orm_token.user.last_activity = datetime.utcnow()
orm_token.last_activity = orm_token.user.last_activity = datetime.utcnow()
self.db.commit()
return self._user_from_orm(orm_token.user)
@@ -259,7 +268,11 @@ class BaseHandler(RequestHandler):
if not refresh_age:
return user
now = time.monotonic()
if not force and user._auth_refreshed and (now - user._auth_refreshed < refresh_age):
if (
not force
and user._auth_refreshed
and (now - user._auth_refreshed < refresh_age)
):
# auth up-to-date
return user
@@ -276,8 +289,7 @@ class BaseHandler(RequestHandler):
if not auth_info:
self.log.warning(
"User %s has stale auth info. Login is required to refresh.",
user.name,
"User %s has stale auth info. Login is required to refresh.", user.name
)
return
@@ -325,10 +337,9 @@ class BaseHandler(RequestHandler):
def _user_for_cookie(self, cookie_name, cookie_value=None):
"""Get the User for a given cookie, if there is one"""
cookie_id = self.get_secure_cookie(
cookie_name,
cookie_value,
max_age_days=self.cookie_max_age_days,
cookie_name, cookie_value, max_age_days=self.cookie_max_age_days
)
def clear():
self.clear_cookie(cookie_name, path=self.hub.base_url)
@@ -434,7 +445,11 @@ class BaseHandler(RequestHandler):
# clear hub cookie
self.clear_cookie(self.hub.cookie_name, path=self.hub.base_url, **kwargs)
# clear services cookie
self.clear_cookie('jupyterhub-services', path=url_path_join(self.base_url, 'services'), **kwargs)
self.clear_cookie(
'jupyterhub-services',
path=url_path_join(self.base_url, 'services'),
**kwargs
)
def _set_cookie(self, key, value, encrypted=True, **overrides):
"""Setting any cookie should go through here
@@ -444,9 +459,7 @@ class BaseHandler(RequestHandler):
"""
# tornado <4.2 have a bug that consider secure==True as soon as
# 'secure' kwarg is passed to set_secure_cookie
kwargs = {
'httponly': True,
}
kwargs = {'httponly': True}
if self.request.protocol == 'https':
kwargs['secure'] = True
if self.subdomain_host:
@@ -463,14 +476,10 @@ class BaseHandler(RequestHandler):
self.log.debug("Setting cookie %s: %s", key, kwargs)
set_cookie(key, value, **kwargs)
def _set_user_cookie(self, user, server):
self.log.debug("Setting cookie for %s: %s", user.name, server.cookie_name)
self._set_cookie(
server.cookie_name,
user.cookie_id,
encrypted=True,
path=server.base_url,
server.cookie_name, user.cookie_id, encrypted=True, path=server.base_url
)
def get_session_cookie(self):
@@ -494,10 +503,13 @@ class BaseHandler(RequestHandler):
def set_service_cookie(self, user):
"""set the login cookie for services"""
self._set_user_cookie(user, orm.Server(
self._set_user_cookie(
user,
orm.Server(
cookie_name='jupyterhub-services',
base_url=url_path_join(self.base_url, 'services')
))
base_url=url_path_join(self.base_url, 'services'),
),
)
def set_hub_cookie(self, user):
"""set the login cookie for the Hub"""
@@ -508,7 +520,9 @@ class BaseHandler(RequestHandler):
if self.subdomain_host and not self.request.host.startswith(self.domain):
self.log.warning(
"Possibly setting cookie on wrong domain: %s != %s",
self.request.host, self.domain)
self.request.host,
self.domain,
)
# set single cookie for services
if self.db.query(orm.Service).filter(orm.Service.server != None).first():
@@ -555,8 +569,10 @@ class BaseHandler(RequestHandler):
# ultimately redirecting to the logged-in user's server.
without_prefix = next_url[len(self.base_url) :]
next_url = url_path_join(self.hub.base_url, without_prefix)
self.log.warning("Redirecting %s to %s. For sharing public links, use /user-redirect/",
self.request.uri, next_url,
self.log.warning(
"Redirecting %s to %s. For sharing public links, use /user-redirect/",
self.request.uri,
next_url,
)
if not next_url:
@@ -627,8 +643,9 @@ class BaseHandler(RequestHandler):
else:
self.statsd.incr('login.failure')
self.statsd.timing('login.authenticate.failure', auth_timer.ms)
self.log.warning("Failed login for %s", (data or {}).get('username', 'unknown user'))
self.log.warning(
"Failed login for %s", (data or {}).get('username', 'unknown user')
)
# ---------------------------------------------------------------
# spawning-related
@@ -659,7 +676,9 @@ class BaseHandler(RequestHandler):
if self.authenticator.refresh_pre_spawn:
auth_user = await self.refresh_auth(user, force=True)
if auth_user is None:
raise web.HTTPError(403, "auth has expired for %s, login again", user.name)
raise web.HTTPError(
403, "auth has expired for %s, login again", user.name
)
spawn_start_time = time.perf_counter()
self.extra_error_html = self.spawn_home_error
@@ -681,7 +700,9 @@ class BaseHandler(RequestHandler):
# but for 10k users this takes ~5ms
# and saves us from bookkeeping errors
active_counts = self.users.count_active_users()
spawn_pending_count = active_counts['spawn_pending'] + active_counts['proxy_pending']
spawn_pending_count = (
active_counts['spawn_pending'] + active_counts['proxy_pending']
)
active_count = active_counts['active']
concurrent_spawn_limit = self.concurrent_spawn_limit
@@ -700,18 +721,21 @@ class BaseHandler(RequestHandler):
# round suggestion to nicer human value (nearest 10 seconds or minute)
if retry_time <= 90:
# round human seconds up to nearest 10
human_retry_time = "%i0 seconds" % math.ceil(retry_time / 10.)
human_retry_time = "%i0 seconds" % math.ceil(retry_time / 10.0)
else:
# round number of minutes
human_retry_time = "%i minutes" % math.round(retry_time / 60.)
human_retry_time = "%i minutes" % math.round(retry_time / 60.0)
self.log.warning(
'%s pending spawns, throttling. Suggested retry in %s seconds.',
spawn_pending_count, retry_time,
spawn_pending_count,
retry_time,
)
err = web.HTTPError(
429,
"Too many users trying to log in right now. Try again in {}.".format(human_retry_time)
"Too many users trying to log in right now. Try again in {}.".format(
human_retry_time
),
)
# can't call set_header directly here because it gets ignored
# when errors are raised
@@ -720,14 +744,13 @@ class BaseHandler(RequestHandler):
raise err
if active_server_limit and active_count >= active_server_limit:
self.log.info(
'%s servers active, no space available',
active_count,
)
self.log.info('%s servers active, no space available', active_count)
SERVER_SPAWN_DURATION_SECONDS.labels(
status=ServerSpawnStatus.too_many_users
).observe(time.perf_counter() - spawn_start_time)
raise web.HTTPError(429, "Active user limit exceeded. Try again in a few minutes.")
raise web.HTTPError(
429, "Active user limit exceeded. Try again in a few minutes."
)
tic = IOLoop.current().time()
@@ -735,12 +758,16 @@ class BaseHandler(RequestHandler):
spawn_future = user.spawn(server_name, options, handler=self)
self.log.debug("%i%s concurrent spawns",
self.log.debug(
"%i%s concurrent spawns",
spawn_pending_count,
'/%i' % concurrent_spawn_limit if concurrent_spawn_limit else '')
self.log.debug("%i%s active servers",
'/%i' % concurrent_spawn_limit if concurrent_spawn_limit else '',
)
self.log.debug(
"%i%s active servers",
active_count,
'/%i' % active_server_limit if active_server_limit else '')
'/%i' % active_server_limit if active_server_limit else '',
)
spawner = user.spawners[server_name]
# set spawn_pending now, so there's no gap where _spawn_pending is False
@@ -756,7 +783,9 @@ class BaseHandler(RequestHandler):
# wait for spawn Future
await spawn_future
toc = IOLoop.current().time()
self.log.info("User %s took %.3f seconds to start", user_server_name, toc-tic)
self.log.info(
"User %s took %.3f seconds to start", user_server_name, toc - tic
)
self.statsd.timing('spawner.success', (toc - tic) * 1000)
RUNNING_SERVERS.inc()
SERVER_SPAWN_DURATION_SECONDS.labels(
@@ -767,18 +796,16 @@ class BaseHandler(RequestHandler):
try:
await self.proxy.add_user(user, server_name)
PROXY_ADD_DURATION_SECONDS.labels(
status='success'
).observe(
PROXY_ADD_DURATION_SECONDS.labels(status='success').observe(
time.perf_counter() - proxy_add_start_time
)
except Exception:
self.log.exception("Failed to add %s to proxy!", user_server_name)
self.log.error("Stopping %s to avoid inconsistent state", user_server_name)
self.log.error(
"Stopping %s to avoid inconsistent state", user_server_name
)
await user.stop()
PROXY_ADD_DURATION_SECONDS.labels(
status='failure'
).observe(
PROXY_ADD_DURATION_SECONDS.labels(status='failure').observe(
time.perf_counter() - proxy_add_start_time
)
else:
@@ -818,7 +845,8 @@ class BaseHandler(RequestHandler):
self.log.warning(
"%i consecutive spawns failed. "
"Hub will exit if failure count reaches %i before succeeding",
failure_count, failure_limit,
failure_count,
failure_limit,
)
if failure_limit and failure_count >= failure_limit:
self.log.critical(
@@ -828,6 +856,7 @@ class BaseHandler(RequestHandler):
# mostly propagating errors for the current failures
def abort():
raise SystemExit(1)
IOLoop.current().call_later(2, abort)
finish_spawn_future.add_done_callback(_track_failure_count)
@@ -842,8 +871,11 @@ class BaseHandler(RequestHandler):
if spawner._spawn_pending and not spawner._waiting_for_response:
# still in Spawner.start, which is taking a long time
# we shouldn't poll while spawn is incomplete.
self.log.warning("User %s is slow to start (timeout=%s)",
user_server_name, self.slow_spawn_timeout)
self.log.warning(
"User %s is slow to start (timeout=%s)",
user_server_name,
self.slow_spawn_timeout,
)
return
# start has finished, but the server hasn't come up
@@ -861,22 +893,34 @@ class BaseHandler(RequestHandler):
status=ServerSpawnStatus.failure
).observe(time.perf_counter() - spawn_start_time)
raise web.HTTPError(500, "Spawner failed to start [status=%s]. The logs for %s may contain details." % (
status, spawner._log_name))
raise web.HTTPError(
500,
"Spawner failed to start [status=%s]. The logs for %s may contain details."
% (status, spawner._log_name),
)
if spawner._waiting_for_response:
# hit timeout waiting for response, but server's running.
# Hope that it'll show up soon enough,
# though it's possible that it started at the wrong URL
self.log.warning("User %s is slow to become responsive (timeout=%s)",
user_server_name, self.slow_spawn_timeout)
self.log.debug("Expecting server for %s at: %s",
user_server_name, spawner.server.url)
self.log.warning(
"User %s is slow to become responsive (timeout=%s)",
user_server_name,
self.slow_spawn_timeout,
)
self.log.debug(
"Expecting server for %s at: %s",
user_server_name,
spawner.server.url,
)
if spawner._proxy_pending:
# User.spawn finished, but it hasn't been added to the proxy
# Could be due to load or a slow proxy
self.log.warning("User %s is slow to be added to the proxy (timeout=%s)",
user_server_name, self.slow_spawn_timeout)
self.log.warning(
"User %s is slow to be added to the proxy (timeout=%s)",
user_server_name,
self.slow_spawn_timeout,
)
async def user_stopped(self, user, server_name):
"""Callback that fires when the spawner has stopped"""
@@ -888,12 +932,11 @@ class BaseHandler(RequestHandler):
status=ServerPollStatus.from_status(status)
).observe(time.perf_counter() - poll_start_time)
if status is None:
status = 'unknown'
self.log.warning("User %s server stopped, with exit code: %s",
user.name, status,
self.log.warning(
"User %s server stopped, with exit code: %s", user.name, status
)
await self.proxy.delete_user(user, server_name)
await user.stop(server_name)
@@ -920,7 +963,9 @@ class BaseHandler(RequestHandler):
await self.proxy.delete_user(user, server_name)
await user.stop(server_name)
toc = time.perf_counter()
self.log.info("User %s server took %.3f seconds to stop", user.name, toc - tic)
self.log.info(
"User %s server took %.3f seconds to stop", user.name, toc - tic
)
self.statsd.timing('spawner.stop', (toc - tic) * 1000)
RUNNING_SERVERS.dec()
SERVER_STOP_DURATION_SECONDS.labels(
@@ -934,14 +979,15 @@ class BaseHandler(RequestHandler):
spawner._stop_future = None
spawner._stop_pending = False
future = spawner._stop_future = asyncio.ensure_future(stop())
try:
await gen.with_timeout(timedelta(seconds=self.slow_stop_timeout), future)
except gen.TimeoutError:
# hit timeout, but stop is still pending
self.log.warning("User %s:%s server is slow to stop", user.name, server_name)
self.log.warning(
"User %s:%s server is slow to stop", user.name, server_name
)
# return handle on the future for hooking up callbacks
return future
@@ -1051,6 +1097,7 @@ class BaseHandler(RequestHandler):
class Template404(BaseHandler):
"""Render our 404 template"""
async def prepare(self):
await super().prepare()
raise web.HTTPError(404)
@@ -1061,6 +1108,7 @@ class PrefixRedirectHandler(BaseHandler):
Redirects /foo to /prefix/foo, etc.
"""
def get(self):
uri = self.request.uri
# Since self.base_url will end with trailing slash.
@@ -1076,9 +1124,7 @@ class PrefixRedirectHandler(BaseHandler):
# default / -> /hub/ redirect
# avoiding extra hop through /hub
path = '/'
self.redirect(url_path_join(
self.hub.base_url, path,
), permanent=False)
self.redirect(url_path_join(self.hub.base_url, path), permanent=False)
class UserSpawnHandler(BaseHandler):
@@ -1113,8 +1159,11 @@ class UserSpawnHandler(BaseHandler):
if user is None:
# no such user
raise web.HTTPError(404, "No such user %s" % user_name)
self.log.info("Admin %s requesting spawn on behalf of %s",
current_user.name, user.name)
self.log.info(
"Admin %s requesting spawn on behalf of %s",
current_user.name,
user.name,
)
admin_spawn = True
should_spawn = True
else:
@@ -1122,7 +1171,7 @@ class UserSpawnHandler(BaseHandler):
admin_spawn = False
# For non-admins, we should spawn if the user matches
# otherwise redirect users to their own server
should_spawn = (current_user and current_user.name == user_name)
should_spawn = current_user and current_user.name == user_name
if "api" in user_path.split("/") and user and not user.active:
# API request for not-running server (e.g. notebook UI left open)
@@ -1142,12 +1191,19 @@ class UserSpawnHandler(BaseHandler):
port = host_info.port
if not port:
port = 443 if host_info.scheme == 'https' else 80
if port != Server.from_url(self.proxy.public_url).connect_port and port == self.hub.connect_port:
self.log.warning("""
if (
port != Server.from_url(self.proxy.public_url).connect_port
and port == self.hub.connect_port
):
self.log.warning(
"""
Detected possible direct connection to Hub's private ip: %s, bypassing proxy.
This will result in a redirect loop.
Make sure to connect to the proxied public URL %s
""", self.request.full_url(), self.proxy.public_url)
""",
self.request.full_url(),
self.proxy.public_url,
)
# logged in as valid user, check for pending spawn
if self.allow_named_servers:
@@ -1167,19 +1223,31 @@ class UserSpawnHandler(BaseHandler):
# Implicit spawn on /user/:name is not allowed if the user's last spawn failed.
# We should point the user to Home if the most recent spawn failed.
exc = spawner._spawn_future.exception()
self.log.error("Preventing implicit spawn for %s because last spawn failed: %s",
spawner._log_name, exc)
self.log.error(
"Preventing implicit spawn for %s because last spawn failed: %s",
spawner._log_name,
exc,
)
# raise a copy because each time an Exception object is re-raised, its traceback grows
raise copy.copy(exc).with_traceback(exc.__traceback__)
# check for pending spawn
if spawner.pending == 'spawn' and spawner._spawn_future:
# wait on the pending spawn
self.log.debug("Waiting for %s pending %s", spawner._log_name, spawner.pending)
self.log.debug(
"Waiting for %s pending %s", spawner._log_name, spawner.pending
)
try:
await gen.with_timeout(timedelta(seconds=self.slow_spawn_timeout), spawner._spawn_future)
await gen.with_timeout(
timedelta(seconds=self.slow_spawn_timeout),
spawner._spawn_future,
)
except gen.TimeoutError:
self.log.info("Pending spawn for %s didn't finish in %.1f seconds", spawner._log_name, self.slow_spawn_timeout)
self.log.info(
"Pending spawn for %s didn't finish in %.1f seconds",
spawner._log_name,
self.slow_spawn_timeout,
)
pass
# we may have waited above, check pending again:
@@ -1194,10 +1262,7 @@ class UserSpawnHandler(BaseHandler):
else:
page = "spawn_pending.html"
html = self.render_template(
page,
user=user,
spawner=spawner,
progress_url=spawner._progress_url,
page, user=user, spawner=spawner, progress_url=spawner._progress_url
)
self.finish(html)
return
@@ -1218,8 +1283,11 @@ class UserSpawnHandler(BaseHandler):
if current_user.name != user.name:
# spawning on behalf of another user
url_parts.append(user.name)
self.redirect(url_concat(url_path_join(*url_parts),
{'next': self.request.uri}))
self.redirect(
url_concat(
url_path_join(*url_parts), {'next': self.request.uri}
)
)
return
else:
await self.spawn_single_user(user, server_name)
@@ -1230,9 +1298,7 @@ class UserSpawnHandler(BaseHandler):
# spawn has started, but not finished
self.statsd.incr('redirects.user_spawn_pending', 1)
html = self.render_template(
"spawn_pending.html",
user=user,
progress_url=spawner._progress_url,
"spawn_pending.html", user=user, progress_url=spawner._progress_url
)
self.finish(html)
return
@@ -1243,14 +1309,16 @@ class UserSpawnHandler(BaseHandler):
try:
redirects = int(self.get_argument('redirects', 0))
except ValueError:
self.log.warning("Invalid redirects argument %r", self.get_argument('redirects'))
self.log.warning(
"Invalid redirects argument %r", self.get_argument('redirects')
)
redirects = 0
# check redirect limit to prevent browser-enforced limits.
# In case of version mismatch, raise on only two redirects.
if redirects >= self.settings.get(
'user_redirect_limit', 4
) or (redirects >= 2 and spawner._jupyterhub_version != __version__):
if redirects >= self.settings.get('user_redirect_limit', 4) or (
redirects >= 2 and spawner._jupyterhub_version != __version__
):
# We stop if we've been redirected too many times.
msg = "Redirect loop detected."
if spawner._jupyterhub_version != __version__:
@@ -1259,7 +1327,8 @@ class UserSpawnHandler(BaseHandler):
" Try installing jupyterhub=={hub} in the user environment"
" if you continue to have problems."
).format(
singleuser=spawner._jupyterhub_version or 'unknown (likely < 0.8)',
singleuser=spawner._jupyterhub_version
or 'unknown (likely < 0.8)',
hub=__version__,
)
raise web.HTTPError(500, msg)
@@ -1297,10 +1366,9 @@ class UserSpawnHandler(BaseHandler):
# not logged in, clear any cookies and reload
self.statsd.incr('redirects.user_to_login', 1)
self.clear_login_cookie()
self.redirect(url_concat(
self.settings['login_url'],
{'next': self.request.uri},
))
self.redirect(
url_concat(self.settings['login_url'], {'next': self.request.uri})
)
class UserRedirectHandler(BaseHandler):
@@ -1314,6 +1382,7 @@ class UserRedirectHandler(BaseHandler):
.. versionadded:: 0.7
"""
@web.authenticated
def get(self, path):
user = self.current_user
@@ -1326,12 +1395,13 @@ class UserRedirectHandler(BaseHandler):
class CSPReportHandler(BaseHandler):
'''Accepts a content security policy violation report'''
@web.authenticated
def post(self):
'''Log a content security policy violation report'''
self.log.warning(
"Content security violation: %s",
self.request.body.decode('utf8', 'replace')
self.request.body.decode('utf8', 'replace'),
)
# Report it to statsd as well
self.statsd.incr('csp_report')
@@ -1339,11 +1409,13 @@ class CSPReportHandler(BaseHandler):
class AddSlashHandler(BaseHandler):
"""Handler for adding trailing slash to URLs that need them"""
def get(self, *args):
src = urlparse(self.request.uri)
dest = src._replace(path=src.path + '/')
self.redirect(urlunparse(dest))
default_handlers = [
(r'', AddSlashHandler), # add trailing / to `/hub`
(r'/user/(?P<user_name>[^/]+)(?P<user_path>/.*)?', UserSpawnHandler),

View File

@@ -1,16 +1,14 @@
"""HTTP Handlers for the hub server"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
from tornado import web
from tornado.escape import url_escape
from tornado.httputil import url_concat
from tornado import web
from .base import BaseHandler
from ..utils import maybe_future
from .base import BaseHandler
class LogoutHandler(BaseHandler):
@@ -52,7 +50,8 @@ class LoginHandler(BaseHandler):
"""Render the login page."""
def _render(self, login_error=None, username=None):
return self.render_template('login.html',
return self.render_template(
'login.html',
next=url_escape(self.get_argument('next', default='')),
username=username,
login_error=login_error,
@@ -87,7 +86,9 @@ class LoginHandler(BaseHandler):
self.redirect(self.get_next_url(user))
else:
if self.get_argument('next', default=False):
auto_login_url = url_concat(auto_login_url, {'next': self.get_next_url()})
auto_login_url = url_concat(
auto_login_url, {'next': self.get_next_url()}
)
self.redirect(auto_login_url)
return
username = self.get_argument('username', default='')
@@ -109,8 +110,7 @@ class LoginHandler(BaseHandler):
self.redirect(self.get_next_url(user))
else:
html = self._render(
login_error='Invalid username or password',
username=data['username'],
login_error='Invalid username or password', username=data['username']
)
self.finish(html)
@@ -118,7 +118,4 @@ class LoginHandler(BaseHandler):
# /login renders the login page or the "Login with..." link,
# so it should always be registered.
# /logout clears cookies.
default_handlers = [
(r"/login", LoginHandler),
(r"/logout", LogoutHandler),
]
default_handlers = [(r"/login", LoginHandler), (r"/logout", LogoutHandler)]

View File

@@ -1,18 +1,21 @@
from prometheus_client import REGISTRY, CONTENT_TYPE_LATEST, generate_latest
from prometheus_client import CONTENT_TYPE_LATEST
from prometheus_client import generate_latest
from prometheus_client import REGISTRY
from tornado import gen
from .base import BaseHandler
from ..utils import metrics_authentication
from .base import BaseHandler
class MetricsHandler(BaseHandler):
"""
Handler to serve Prometheus metrics
"""
@metrics_authentication
async def get(self):
self.set_header('Content-Type', CONTENT_TYPE_LATEST)
self.write(generate_latest(REGISTRY))
default_handlers = [
(r'/metrics$', MetricsHandler)
]
default_handlers = [(r'/metrics$', MetricsHandler)]

View File

@@ -1,18 +1,19 @@
"""Basic html-rendering handlers."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from collections import defaultdict
from datetime import datetime
from http.client import responses
from jinja2 import TemplateNotFound
from tornado import web, gen
from tornado import gen
from tornado import web
from tornado.httputil import url_concat
from .. import orm
from ..utils import admin_only, url_path_join, maybe_future
from ..utils import admin_only
from ..utils import maybe_future
from ..utils import url_path_join
from .base import BaseHandler
@@ -29,6 +30,7 @@ class RootHandler(BaseHandler):
Otherwise, renders login page.
"""
def get(self):
user = self.current_user
if self.default_url:
@@ -53,8 +55,13 @@ class HomeHandler(BaseHandler):
# send the user to /spawn if they have no active servers,
# to establish that this is an explicit spawn request rather
# than an implicit one, which can be caused by any link to `/user/:name(/:server_name)`
url = url_path_join(self.hub.base_url, 'user', user.name) if user.active else url_path_join(self.hub.base_url, 'spawn')
html = self.render_template('home.html',
url = (
url_path_join(self.hub.base_url, 'user', user.name)
if user.active
else url_path_join(self.hub.base_url, 'spawn')
)
html = self.render_template(
'home.html',
user=user,
url=url,
allow_named_servers=self.allow_named_servers,
@@ -74,13 +81,15 @@ class SpawnHandler(BaseHandler):
Only enabled when Spawner.options_form is defined.
"""
def _render_form(self, for_user, spawner_options_form, message=''):
return self.render_template('spawn.html',
return self.render_template(
'spawn.html',
for_user=for_user,
spawner_options_form=spawner_options_form,
error_message=message,
url=self.request.uri,
spawner=for_user.spawner
spawner=for_user.spawner,
)
@web.authenticated
@@ -92,7 +101,9 @@ class SpawnHandler(BaseHandler):
user = current_user = self.current_user
if for_user is not None and for_user != user.name:
if not user.admin:
raise web.HTTPError(403, "Only admins can spawn on behalf of other users")
raise web.HTTPError(
403, "Only admins can spawn on behalf of other users"
)
user = self.find_user(for_user)
if user is None:
@@ -108,7 +119,9 @@ class SpawnHandler(BaseHandler):
if spawner_options_form:
# Add handler to spawner here so you can access query params in form rendering.
user.spawner.handler = self
form = self._render_form(for_user=user, spawner_options_form=spawner_options_form)
form = self._render_form(
for_user=user, spawner_options_form=spawner_options_form
)
self.finish(form)
else:
# Explicit spawn request: clear _spawn_future
@@ -129,7 +142,9 @@ class SpawnHandler(BaseHandler):
user = current_user = self.current_user
if for_user is not None and for_user != user.name:
if not user.admin:
raise web.HTTPError(403, "Only admins can spawn on behalf of other users")
raise web.HTTPError(
403, "Only admins can spawn on behalf of other users"
)
user = self.find_user(for_user)
if user is None:
raise web.HTTPError(404, "No such user: %s" % for_user)
@@ -151,9 +166,13 @@ class SpawnHandler(BaseHandler):
options = await maybe_future(user.spawner.options_from_form(form_options))
await self.spawn_single_user(user, options=options)
except Exception as e:
self.log.error("Failed to spawn single-user server with form", exc_info=True)
self.log.error(
"Failed to spawn single-user server with form", exc_info=True
)
spawner_options_form = await user.spawner.get_options_form()
form = self._render_form(for_user=user, spawner_options_form=spawner_options_form, message=str(e))
form = self._render_form(
for_user=user, spawner_options_form=spawner_options_form, message=str(e)
)
self.finish(form)
return
if current_user is user:
@@ -176,9 +195,7 @@ class AdminHandler(BaseHandler):
def get(self):
available = {'name', 'admin', 'running', 'last_activity'}
default_sort = ['admin', 'name']
mapping = {
'running': orm.Spawner.server_id,
}
mapping = {'running': orm.Spawner.server_id}
for name in available:
if name not in mapping:
mapping[name] = getattr(orm.User, name)
@@ -219,11 +236,13 @@ class AdminHandler(BaseHandler):
users = self.db.query(orm.User).outerjoin(orm.Spawner).order_by(*ordered)
users = [self._user_from_orm(u) for u in users]
from itertools import chain
running = []
for u in users:
running.extend(s for s in u.spawners.values() if s.active)
html = self.render_template('admin.html',
html = self.render_template(
'admin.html',
current_user=self.current_user,
admin_access=self.settings.get('admin_access', False),
users=users,
@@ -243,11 +262,9 @@ class TokenPageHandler(BaseHandler):
never = datetime(1900, 1, 1)
user = self.current_user
def sort_key(token):
return (
token.last_activity or never,
token.created or never,
)
return (token.last_activity or never, token.created or never)
now = datetime.utcnow()
api_tokens = []
@@ -285,13 +302,13 @@ class TokenPageHandler(BaseHandler):
for token in tokens[1:]:
if token.created < created:
created = token.created
if (
last_activity is None or
(token.last_activity and token.last_activity > last_activity)
if last_activity is None or (
token.last_activity and token.last_activity > last_activity
):
last_activity = token.last_activity
token = tokens[0]
oauth_clients.append({
oauth_clients.append(
{
'client': token.client,
'description': token.client.description or token.client.identifier,
'created': created,
@@ -301,21 +318,17 @@ class TokenPageHandler(BaseHandler):
# revoking one oauth token revokes all oauth tokens for that client
'token_id': tokens[0].api_id,
'token_count': len(tokens),
})
}
)
# sort oauth clients by last activity, created
def sort_key(client):
return (
client['last_activity'] or never,
client['created'] or never,
)
return (client['last_activity'] or never, client['created'] or never)
oauth_clients = sorted(oauth_clients, key=sort_key, reverse=True)
html = self.render_template(
'token.html',
api_tokens=api_tokens,
oauth_clients=oauth_clients,
'token.html', api_tokens=api_tokens, oauth_clients=oauth_clients
)
self.finish(html)
@@ -331,10 +344,12 @@ class ProxyErrorHandler(BaseHandler):
hub_home = url_path_join(self.hub.base_url, 'home')
message_html = ''
if status_code == 503:
message_html = ' '.join([
message_html = ' '.join(
[
"Your server appears to be down.",
"Try restarting it <a href='%s'>from the hub</a>" % hub_home
])
"Try restarting it <a href='%s'>from the hub</a>" % hub_home,
]
)
ns = dict(
status_code=status_code,
status_message=status_message,
@@ -355,6 +370,7 @@ class ProxyErrorHandler(BaseHandler):
class HealthCheckHandler(BaseHandler):
"""Answer to health check"""
def get(self, *args):
self.finish()

View File

@@ -1,14 +1,16 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
from tornado.web import StaticFileHandler
class CacheControlStaticFilesHandler(StaticFileHandler):
"""StaticFileHandler subclass that sets Cache-Control: no-cache without `?v=`
rather than relying on default browser cache behavior.
"""
def compute_etag(self):
return None
@@ -16,8 +18,10 @@ class CacheControlStaticFilesHandler(StaticFileHandler):
if "v" not in self.request.arguments:
self.add_header("Cache-Control", "no-cache")
class LogoHandler(StaticFileHandler):
"""A singular handler for serving the logo."""
def get(self):
return super().get('')
@@ -25,4 +29,3 @@ class LogoHandler(StaticFileHandler):
def get_absolute_path(cls, root, path):
"""We only serve one file, ignore relative path"""
return os.path.abspath(root)

View File

@@ -1,13 +1,15 @@
"""logging utilities"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import json
import traceback
from urllib.parse import urlparse, urlunparse
from urllib.parse import urlparse
from urllib.parse import urlunparse
from tornado.log import LogFormatter, access_log
from tornado.web import StaticFileHandler, HTTPError
from tornado.log import access_log
from tornado.log import LogFormatter
from tornado.web import HTTPError
from tornado.web import StaticFileHandler
from .metrics import prometheus_log_method
@@ -23,7 +25,11 @@ def coroutine_frames(all_frames):
continue
# start out conservative with filename + function matching
# maybe just filename matching would be sufficient
elif frame[0].endswith('tornado/gen.py') and frame[2] in {'run', 'wrapper', '__init__'}:
elif frame[0].endswith('tornado/gen.py') and frame[2] in {
'run',
'wrapper',
'__init__',
}:
continue
elif frame[0].endswith('tornado/concurrent.py') and frame[2] == 'result':
continue
@@ -51,9 +57,11 @@ def coroutine_traceback(typ, value, tb):
class CoroutineLogFormatter(LogFormatter):
"""Log formatter that scrubs coroutine frames"""
def formatException(self, exc_info):
return ''.join(coroutine_traceback(*exc_info))
# url params to be scrubbed if seen
# any url param that *contains* one of these
# will be scrubbed from logs
@@ -96,6 +104,7 @@ def _scrub_headers(headers):
# log_request adapted from IPython (BSD)
def log_request(handler):
"""log a bit more information about each request than tornado's default

View File

@@ -17,13 +17,13 @@ them manually here.
"""
from enum import Enum
from prometheus_client import Histogram
from prometheus_client import Gauge
from prometheus_client import Histogram
REQUEST_DURATION_SECONDS = Histogram(
'request_duration_seconds',
'request duration for all HTTP requests',
['method', 'handler', 'code']
['method', 'handler', 'code'],
)
SERVER_SPAWN_DURATION_SECONDS = Histogram(
@@ -32,32 +32,29 @@ SERVER_SPAWN_DURATION_SECONDS = Histogram(
['status'],
# Use custom bucket sizes, since the default bucket ranges
# are meant for quick running processes. Spawns can take a while!
buckets=[0.5, 1, 2.5, 5, 10, 15, 30, 60, 120, float("inf")]
buckets=[0.5, 1, 2.5, 5, 10, 15, 30, 60, 120, float("inf")],
)
RUNNING_SERVERS = Gauge(
'running_servers',
'the number of user servers currently running'
'running_servers', 'the number of user servers currently running'
)
RUNNING_SERVERS.set(0)
TOTAL_USERS = Gauge(
'total_users',
'toal number of users'
)
TOTAL_USERS = Gauge('total_users', 'toal number of users')
TOTAL_USERS.set(0)
CHECK_ROUTES_DURATION_SECONDS = Histogram(
'check_routes_duration_seconds',
'Time taken to validate all routes in proxy'
'check_routes_duration_seconds', 'Time taken to validate all routes in proxy'
)
class ServerSpawnStatus(Enum):
"""
Possible values for 'status' label of SERVER_SPAWN_DURATION_SECONDS
"""
success = 'success'
failure = 'failure'
already_pending = 'already-pending'
@@ -67,27 +64,29 @@ class ServerSpawnStatus(Enum):
def __str__(self):
return self.value
for s in ServerSpawnStatus:
# Create empty metrics with the given status
SERVER_SPAWN_DURATION_SECONDS.labels(status=s)
PROXY_ADD_DURATION_SECONDS = Histogram(
'proxy_add_duration_seconds',
'duration for adding user routes to proxy',
['status']
'proxy_add_duration_seconds', 'duration for adding user routes to proxy', ['status']
)
class ProxyAddStatus(Enum):
"""
Possible values for 'status' label of PROXY_ADD_DURATION_SECONDS
"""
success = 'success'
failure = 'failure'
def __str__(self):
return self.value
for s in ProxyAddStatus:
PROXY_ADD_DURATION_SECONDS.labels(status=s)
@@ -95,13 +94,15 @@ for s in ProxyAddStatus:
SERVER_POLL_DURATION_SECONDS = Histogram(
'server_poll_duration_seconds',
'time taken to poll if server is running',
['status']
['status'],
)
class ServerPollStatus(Enum):
"""
Possible values for 'status' label of SERVER_POLL_DURATION_SECONDS
"""
running = 'running'
stopped = 'stopped'
@@ -112,27 +113,28 @@ class ServerPollStatus(Enum):
return cls.running
return cls.stopped
for s in ServerPollStatus:
SERVER_POLL_DURATION_SECONDS.labels(status=s)
SERVER_STOP_DURATION_SECONDS = Histogram(
'server_stop_seconds',
'time taken for server stopping operation',
['status'],
'server_stop_seconds', 'time taken for server stopping operation', ['status']
)
class ServerStopStatus(Enum):
"""
Possible values for 'status' label of SERVER_STOP_DURATION_SECONDS
"""
success = 'success'
failure = 'failure'
def __str__(self):
return self.value
for s in ServerStopStatus:
SERVER_STOP_DURATION_SECONDS.labels(status=s)
@@ -156,5 +158,5 @@ def prometheus_log_method(handler):
REQUEST_DURATION_SECONDS.labels(
method=handler.request.method,
handler='{}.{}'.format(handler.__class__.__module__, type(handler).__name__),
code=handler.get_status()
code=handler.get_status(),
).observe(handler.request.request_time())

View File

@@ -2,30 +2,29 @@
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
"""
from datetime import datetime
from urllib.parse import urlparse
from oauthlib.oauth2 import RequestValidator, WebApplicationServer
from oauthlib.oauth2 import RequestValidator
from oauthlib.oauth2 import WebApplicationServer
from oauthlib.oauth2.rfc6749.grant_types import authorization_code
from sqlalchemy.orm import scoped_session
from tornado import web
from tornado.escape import url_escape
from tornado.log import app_log
from tornado import web
from .. import orm
from ..utils import url_path_join, hash_token, compare_token
from ..utils import compare_token
from ..utils import hash_token
from ..utils import url_path_join
# patch absolute-uri check
# because we want to allow relative uri oauth
# for internal services
from oauthlib.oauth2.rfc6749.grant_types import authorization_code
authorization_code.is_absolute_uri = lambda uri: True
class JupyterHubRequestValidator(RequestValidator):
def __init__(self, db):
self.db = db
super().__init__()
@@ -51,10 +50,7 @@ class JupyterHubRequestValidator(RequestValidator):
client_id = request.client_id
client_secret = request.client_secret
oauth_client = (
self.db
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
)
if oauth_client is None:
return False
@@ -78,10 +74,7 @@ class JupyterHubRequestValidator(RequestValidator):
- Authorization Code Grant
"""
orm_client = (
self.db
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
)
if orm_client is None:
app_log.warning("No such oauth client %s", client_id)
@@ -89,8 +82,9 @@ class JupyterHubRequestValidator(RequestValidator):
request.client = orm_client
return True
def confirm_redirect_uri(self, client_id, code, redirect_uri, client,
*args, **kwargs):
def confirm_redirect_uri(
self, client_id, code, redirect_uri, client, *args, **kwargs
):
"""Ensure that the authorization process represented by this authorization
code began with this 'redirect_uri'.
If the client specifies a redirect_uri when obtaining code then that
@@ -108,8 +102,10 @@ class JupyterHubRequestValidator(RequestValidator):
"""
# TODO: record redirect_uri used during oauth
# if we ever support multiple destinations
app_log.debug("confirm_redirect_uri: client_id=%s, redirect_uri=%s",
client_id, redirect_uri,
app_log.debug(
"confirm_redirect_uri: client_id=%s, redirect_uri=%s",
client_id,
redirect_uri,
)
if redirect_uri == client.redirect_uri:
return True
@@ -127,10 +123,7 @@ class JupyterHubRequestValidator(RequestValidator):
- Implicit Grant
"""
orm_client = (
self.db
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
)
if orm_client is None:
raise KeyError(client_id)
@@ -159,7 +152,9 @@ class JupyterHubRequestValidator(RequestValidator):
"""
raise NotImplementedError()
def is_within_original_scope(self, request_scopes, refresh_token, request, *args, **kwargs):
def is_within_original_scope(
self, request_scopes, refresh_token, request, *args, **kwargs
):
"""Check if requested scopes are within a scope of the refresh token.
When access tokens are refreshed the scope of the new token
needs to be within the scope of the original token. This is
@@ -227,12 +222,15 @@ class JupyterHubRequestValidator(RequestValidator):
- Authorization Code Grant
"""
log_code = code.get('code', 'undefined')[:3] + '...'
app_log.debug("Saving authorization code %s, %s, %s, %s", client_id, log_code, args, kwargs)
app_log.debug(
"Saving authorization code %s, %s, %s, %s",
client_id,
log_code,
args,
kwargs,
)
orm_client = (
self.db
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
)
if orm_client is None:
raise ValueError("No such client: %s" % client_id)
@@ -330,7 +328,11 @@ class JupyterHubRequestValidator(RequestValidator):
app_log.debug("Saving bearer token %s", log_token)
if request.user is None:
raise ValueError("No user for access token: %s" % request.user)
client = self.db.query(orm.OAuthClient).filter_by(identifier=request.client.client_id).first()
client = (
self.db.query(orm.OAuthClient)
.filter_by(identifier=request.client.client_id)
.first()
)
orm_access_token = orm.OAuthAccessToken(
client=client,
grant_type=orm.GrantType.authorization_code,
@@ -400,10 +402,7 @@ class JupyterHubRequestValidator(RequestValidator):
"""
app_log.debug("Validating client id %s", client_id)
orm_client = (
self.db
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
)
if orm_client is None:
return False
@@ -431,19 +430,13 @@ class JupyterHubRequestValidator(RequestValidator):
Method is used by:
- Authorization Code Grant
"""
orm_code = (
self.db
.query(orm.OAuthCode)
.filter_by(code=code)
.first()
)
orm_code = self.db.query(orm.OAuthCode).filter_by(code=code).first()
if orm_code is None:
app_log.debug("No such code: %s", code)
return False
if orm_code.client_id != client_id:
app_log.debug(
"OAuth code client id mismatch: %s != %s",
client_id, orm_code.client_id,
"OAuth code client id mismatch: %s != %s", client_id, orm_code.client_id
)
return False
request.user = orm_code.user
@@ -453,7 +446,9 @@ class JupyterHubRequestValidator(RequestValidator):
request.scopes = ['identify']
return True
def validate_grant_type(self, client_id, grant_type, client, request, *args, **kwargs):
def validate_grant_type(
self, client_id, grant_type, client, request, *args, **kwargs
):
"""Ensure client is authorized to use the grant_type requested.
:param client_id: Unicode client identifier
:param grant_type: Unicode grant type, i.e. authorization_code, password.
@@ -480,14 +475,13 @@ class JupyterHubRequestValidator(RequestValidator):
- Authorization Code Grant
- Implicit Grant
"""
app_log.debug("validate_redirect_uri: client_id=%s, redirect_uri=%s",
client_id, redirect_uri,
app_log.debug(
"validate_redirect_uri: client_id=%s, redirect_uri=%s",
client_id,
redirect_uri,
)
orm_client = (
self.db
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
)
if orm_client is None:
app_log.warning("No such oauth client %s", client_id)
@@ -495,7 +489,9 @@ class JupyterHubRequestValidator(RequestValidator):
if redirect_uri == orm_client.redirect_uri:
return True
else:
app_log.warning("Redirect uri %s != %s", redirect_uri, orm_client.redirect_uri)
app_log.warning(
"Redirect uri %s != %s", redirect_uri, orm_client.redirect_uri
)
return False
def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs):
@@ -514,7 +510,9 @@ class JupyterHubRequestValidator(RequestValidator):
return False
raise NotImplementedError('Subclasses must implement this method.')
def validate_response_type(self, client_id, response_type, client, request, *args, **kwargs):
def validate_response_type(
self, client_id, response_type, client, request, *args, **kwargs
):
"""Ensure client is authorized to use the response_type requested.
:param client_id: Unicode client identifier
:param response_type: Unicode response type, i.e. code, token.
@@ -555,10 +553,8 @@ class JupyterHubOAuthServer(WebApplicationServer):
hash its client_secret before putting it in the database.
"""
# clear existing clients with same ID
for orm_client in (
self.db
.query(orm.OAuthClient)\
.filter_by(identifier=client_id)
for orm_client in self.db.query(orm.OAuthClient).filter_by(
identifier=client_id
):
self.db.delete(orm_client)
self.db.commit()
@@ -574,12 +570,7 @@ class JupyterHubOAuthServer(WebApplicationServer):
def fetch_by_client_id(self, client_id):
"""Find a client by its id"""
return (
self.db
.query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first()
)
return self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
def make_provider(session_factory, url_prefix, login_url):
@@ -588,4 +579,3 @@ def make_provider(session_factory, url_prefix, login_url):
validator = JupyterHubRequestValidator(db)
server = JupyterHubOAuthServer(db, validator)
return server

View File

@@ -1,22 +1,28 @@
"""Some general objects for use in JupyterHub"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import socket
from urllib.parse import urlparse, urlunparse
import warnings
from urllib.parse import urlparse
from urllib.parse import urlunparse
from traitlets import default
from traitlets import HasTraits
from traitlets import Instance
from traitlets import Integer
from traitlets import observe
from traitlets import Unicode
from traitlets import validate
from traitlets import (
HasTraits, Instance, Integer, Unicode,
default, observe, validate,
)
from .traitlets import URLPrefix
from . import orm
from .utils import (
url_path_join, can_connect, wait_for_server,
wait_for_http_server, random_port, make_ssl_context,
)
from .traitlets import URLPrefix
from .utils import can_connect
from .utils import make_ssl_context
from .utils import random_port
from .utils import url_path_join
from .utils import wait_for_http_server
from .utils import wait_for_server
class Server(HasTraits):
"""An object representing an HTTP endpoint.
@@ -24,6 +30,7 @@ class Server(HasTraits):
*Some* of these reside in the database (user servers),
but others (Hub, proxy) are in-memory only.
"""
orm_server = Instance(orm.Server, allow_none=True)
ip = Unicode()
@@ -141,36 +148,31 @@ class Server(HasTraits):
def host(self):
if self.connect_url:
parsed = urlparse(self.connect_url)
return "{proto}://{host}".format(
proto=parsed.scheme,
host=parsed.netloc,
)
return "{proto}://{host}".format(proto=parsed.scheme, host=parsed.netloc)
return "{proto}://{ip}:{port}".format(
proto=self.proto,
ip=self._connect_ip,
port=self._connect_port,
proto=self.proto, ip=self._connect_ip, port=self._connect_port
)
@property
def url(self):
if self.connect_url:
return self.connect_url
return "{host}{uri}".format(
host=self.host,
uri=self.base_url,
)
return "{host}{uri}".format(host=self.host, uri=self.base_url)
def wait_up(self, timeout=10, http=False, ssl_context=None):
"""Wait for this server to come up"""
if http:
ssl_context = ssl_context or make_ssl_context(
self.keyfile, self.certfile, cafile=self.cafile)
self.keyfile, self.certfile, cafile=self.cafile
)
return wait_for_http_server(
self.url, timeout=timeout, ssl_context=ssl_context)
self.url, timeout=timeout, ssl_context=ssl_context
)
else:
return wait_for_server(self._connect_ip, self._connect_port, timeout=timeout)
return wait_for_server(
self._connect_ip, self._connect_port, timeout=timeout
)
def is_up(self):
"""Is the server accepting connections?"""
@@ -190,11 +192,13 @@ class Hub(Server):
@property
def server(self):
warnings.warn("Hub.server is deprecated in JupyterHub 0.8. Access attributes on the Hub directly.",
warnings.warn(
"Hub.server is deprecated in JupyterHub 0.8. Access attributes on the Hub directly.",
DeprecationWarning,
stacklevel=2,
)
return self
public_host = Unicode()
routespec = Unicode()
@@ -205,5 +209,7 @@ class Hub(Server):
def __repr__(self):
return "<%s %s:%s>" % (
self.__class__.__name__, self.server.ip, self.server.port,
self.__class__.__name__,
self.server.ip,
self.server.port,
)

View File

@@ -1,36 +1,45 @@
"""sqlalchemy ORM tools for the state of the constellation of processes"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from datetime import datetime, timedelta
import enum
import json
from datetime import datetime
from datetime import timedelta
import alembic.config
import alembic.command
import alembic.config
from alembic.script import ScriptDirectory
from tornado.log import app_log
from sqlalchemy.types import TypeDecorator, Text, LargeBinary
from sqlalchemy import (
create_engine, event, exc, inspect, or_, select,
Column, Integer, ForeignKey, Unicode, Boolean,
DateTime, Enum, Table,
)
from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import DateTime
from sqlalchemy import Enum
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import Table
from sqlalchemy import Unicode
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import (
Session,
interfaces, object_session, relationship, sessionmaker,
)
from sqlalchemy.orm import interfaces
from sqlalchemy.orm import object_session
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from sqlalchemy.sql.expression import bindparam
from sqlalchemy.types import LargeBinary
from sqlalchemy.types import Text
from sqlalchemy.types import TypeDecorator
from tornado.log import app_log
from .utils import (
random_port,
new_token, hash_token, compare_token,
)
from .utils import compare_token
from .utils import hash_token
from .utils import new_token
from .utils import random_port
# top-level variable for easier mocking in tests
utcnow = datetime.utcnow
@@ -68,6 +77,7 @@ class Server(Base):
connection and cookie info
"""
__tablename__ = 'servers'
id = Column(Integer, primary_key=True)
@@ -82,7 +92,9 @@ class Server(Base):
# user:group many:many mapping table
user_group_map = Table('user_group_map', Base.metadata,
user_group_map = Table(
'user_group_map',
Base.metadata,
Column('user_id', ForeignKey('users.id', ondelete='CASCADE'), primary_key=True),
Column('group_id', ForeignKey('groups.id', ondelete='CASCADE'), primary_key=True),
)
@@ -90,6 +102,7 @@ user_group_map = Table('user_group_map', Base.metadata,
class Group(Base):
"""User Groups"""
__tablename__ = 'groups'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(255), unique=True)
@@ -97,7 +110,9 @@ class Group(Base):
def __repr__(self):
return "<%s %s (%i users)>" % (
self.__class__.__name__, self.name, len(self.users)
self.__class__.__name__,
self.name,
len(self.users),
)
@classmethod
@@ -130,15 +145,15 @@ class User(Base):
`servers` is a list that contains a reference for each of the user's single user notebook servers.
The method `server` returns the first entry in the user's `servers` list.
"""
__tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(255), unique=True)
_orm_spawners = relationship(
"Spawner",
backref="user",
cascade="all, delete-orphan",
"Spawner", backref="user", cascade="all, delete-orphan"
)
@property
def orm_spawners(self):
return {s.name: s for s in self._orm_spawners}
@@ -147,20 +162,12 @@ class User(Base):
created = Column(DateTime, default=datetime.utcnow)
last_activity = Column(DateTime, nullable=True)
api_tokens = relationship(
"APIToken",
backref="user",
cascade="all, delete-orphan",
)
api_tokens = relationship("APIToken", backref="user", cascade="all, delete-orphan")
oauth_tokens = relationship(
"OAuthAccessToken",
backref="user",
cascade="all, delete-orphan",
"OAuthAccessToken", backref="user", cascade="all, delete-orphan"
)
oauth_codes = relationship(
"OAuthCode",
backref="user",
cascade="all, delete-orphan",
"OAuthCode", backref="user", cascade="all, delete-orphan"
)
cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True)
# User.state is actually Spawner state
@@ -192,8 +199,10 @@ class User(Base):
"""
return db.query(cls).filter(cls.name == name).first()
class Spawner(Base):
""""State about a Spawner"""
__tablename__ = 'spawners'
id = Column(Integer, primary_key=True, autoincrement=True)
@@ -214,10 +223,12 @@ class Spawner(Base):
# for which these should all be False
active = running = ready = False
pending = None
@property
def orm_spawner(self):
return self
class Service(Base):
"""A service run with JupyterHub
@@ -235,6 +246,7 @@ class Service(Base):
- pid: the process id (if managed)
"""
__tablename__ = 'services'
id = Column(Integer, primary_key=True, autoincrement=True)
@@ -243,9 +255,7 @@ class Service(Base):
admin = Column(Boolean, default=False)
api_tokens = relationship(
"APIToken",
backref="service",
cascade="all, delete-orphan",
"APIToken", backref="service", cascade="all, delete-orphan"
)
# service-specific interface
@@ -270,6 +280,7 @@ class Service(Base):
class Hashed(object):
"""Mixin for tables with hashed tokens"""
prefix_length = 4
algorithm = "sha512"
rounds = 16384
@@ -299,7 +310,9 @@ class Hashed(object):
else:
rounds = self.rounds
salt_bytes = self.salt_bytes
self.hashed = hash_token(token, rounds=rounds, salt=salt_bytes, algorithm=self.algorithm)
self.hashed = hash_token(
token, rounds=rounds, salt=salt_bytes, algorithm=self.algorithm
)
def match(self, token):
"""Is this my token?"""
@@ -309,8 +322,9 @@ class Hashed(object):
def check_token(cls, db, token):
"""Check if a token is acceptable"""
if len(token) < cls.min_length:
raise ValueError("Tokens must be at least %i characters, got %r" % (
cls.min_length, token)
raise ValueError(
"Tokens must be at least %i characters, got %r"
% (cls.min_length, token)
)
found = cls.find(db, token)
if found:
@@ -344,10 +358,13 @@ class Hashed(object):
class APIToken(Hashed, Base):
"""An API token"""
__tablename__ = 'api_tokens'
user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
service_id = Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True)
service_id = Column(
Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True
)
id = Column(Integer, primary_key=True)
hashed = Column(Unicode(255), unique=True)
@@ -375,10 +392,7 @@ class APIToken(Hashed, Base):
kind = 'owner'
name = 'unknown'
return "<{cls}('{pre}...', {kind}='{name}')>".format(
cls=self.__class__.__name__,
pre=self.prefix,
kind=kind,
name=name,
cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name
)
@classmethod
@@ -387,9 +401,7 @@ class APIToken(Hashed, Base):
now = utcnow()
deleted = False
for token in (
db.query(cls)
.filter(cls.expires_at != None)
.filter(cls.expires_at < now)
db.query(cls).filter(cls.expires_at != None).filter(cls.expires_at < now)
):
app_log.debug("Purging expired %s", token)
deleted = True
@@ -421,8 +433,15 @@ class APIToken(Hashed, Base):
return orm_token
@classmethod
def new(cls, token=None, user=None, service=None, note='', generated=True,
expires_in=None):
def new(
cls,
token=None,
user=None,
service=None,
note='',
generated=True,
expires_in=None,
):
"""Generate a new API token for a user or service"""
assert user or service
assert not (user and service)
@@ -473,7 +492,9 @@ class OAuthAccessToken(Hashed, Base):
def api_id(self):
return 'o%i' % self.id
client_id = Column(Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE'))
client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
)
grant_type = Column(Enum(GrantType), nullable=False)
expires_at = Column(Integer)
refresh_token = Column(Unicode(255))
@@ -517,7 +538,9 @@ class OAuthAccessToken(Hashed, Base):
class OAuthCode(Base):
__tablename__ = 'oauth_codes'
id = Column(Integer, primary_key=True, autoincrement=True)
client_id = Column(Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE'))
client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
)
code = Column(Unicode(36))
expires_at = Column(Integer)
redirect_uri = Column(Unicode(1023))
@@ -539,18 +562,14 @@ class OAuthClient(Base):
return self.identifier
access_tokens = relationship(
OAuthAccessToken,
backref='client',
cascade='all, delete-orphan',
)
codes = relationship(
OAuthCode,
backref='client',
cascade='all, delete-orphan',
OAuthAccessToken, backref='client', cascade='all, delete-orphan'
)
codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan')
# General database utilities
class DatabaseSchemaMismatch(Exception):
"""Exception raised when the database schema version does not match
@@ -560,6 +579,7 @@ class DatabaseSchemaMismatch(Exception):
def register_foreign_keys(engine):
"""register PRAGMA foreign_keys=on on connection"""
@event.listens_for(engine, "connect")
def connect(dbapi_con, con_record):
cursor = dbapi_con.cursor()
@@ -609,6 +629,7 @@ def register_ping_connection(engine):
https://docs.sqlalchemy.org/en/rel_1_1/core/pooling.html#disconnect-handling-pessimistic
"""
@event.listens_for(engine, "engine_connect")
def ping_connection(connection, branch):
if branch:
@@ -633,7 +654,9 @@ def register_ping_connection(engine):
# condition, which is based on inspection of the original exception
# by the dialect in use.
if err.connection_invalidated:
app_log.error("Database connection error, attempting to reconnect: %s", err)
app_log.error(
"Database connection error, attempting to reconnect: %s", err
)
# run the same SELECT again - the connection will re-validate
# itself and establish a new connection. The disconnect detection
# here also causes the whole connection pool to be invalidated
@@ -697,29 +720,37 @@ def check_db_revision(engine):
# check database schema version
# it should always be defined at this point
alembic_revision = engine.execute('SELECT version_num FROM alembic_version').first()[0]
alembic_revision = engine.execute(
'SELECT version_num FROM alembic_version'
).first()[0]
if alembic_revision == head:
app_log.debug("database schema version found: %s", alembic_revision)
pass
else:
raise DatabaseSchemaMismatch("Found database schema version {found} != {head}. "
raise DatabaseSchemaMismatch(
"Found database schema version {found} != {head}. "
"Backup your database and run `jupyterhub upgrade-db`"
" to upgrade to the latest schema.".format(
found=alembic_revision,
head=head,
))
found=alembic_revision, head=head
)
)
def mysql_large_prefix_check(engine):
"""Check mysql has innodb_large_prefix set"""
if not str(engine.url).startswith('mysql'):
return False
variables = dict(engine.execute(
variables = dict(
engine.execute(
'show variables where variable_name like '
'"innodb_large_prefix" or '
'variable_name like "innodb_file_format";').fetchall())
if (variables['innodb_file_format'] == 'Barracuda' and
variables['innodb_large_prefix'] == 'ON'):
'variable_name like "innodb_file_format";'
).fetchall()
)
if (
variables['innodb_file_format'] == 'Barracuda'
and variables['innodb_large_prefix'] == 'ON'
):
return True
else:
return False
@@ -730,10 +761,9 @@ def add_row_format(base):
t.dialect_kwargs['mysql_ROW_FORMAT'] = 'DYNAMIC'
def new_session_factory(url="sqlite:///:memory:",
reset=False,
expire_on_commit=False,
**kwargs):
def new_session_factory(
url="sqlite:///:memory:", reset=False, expire_on_commit=False, **kwargs
):
"""Create a new session at url"""
if url.startswith('sqlite'):
kwargs.setdefault('connect_args', {'check_same_thread': False})
@@ -767,7 +797,5 @@ def new_session_factory(url="sqlite:///:memory:",
# SQLAlchemy to expire objects after committing - we don't expect
# concurrent runs of the hub talking to the same db. Turning
# this off gives us a major performance boost
session_factory = sessionmaker(bind=engine,
expire_on_commit=expire_on_commit,
)
session_factory = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
return session_factory

View File

@@ -14,36 +14,38 @@ Route Specification:
'host.tld/path/' for host-based routing or '/path/' for default routing.
- Route paths should be normalized to always start and end with '/'
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
from functools import wraps
import json
import os
import signal
from subprocess import Popen
import time
from urllib.parse import quote, urlparse
from functools import wraps
from subprocess import Popen
from urllib.parse import quote
from urllib.parse import urlparse
from tornado import gen
from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPError
from tornado.httpclient import AsyncHTTPClient
from tornado.httpclient import HTTPError
from tornado.httpclient import HTTPRequest
from tornado.ioloop import PeriodicCallback
from traitlets import (
Any, Bool, Instance, Integer, Unicode,
default, observe,
)
from jupyterhub.traitlets import Command
from traitlets import Any
from traitlets import Bool
from traitlets import default
from traitlets import Instance
from traitlets import Integer
from traitlets import observe
from traitlets import Unicode
from traitlets.config import LoggingConfigurable
from . import utils
from .metrics import CHECK_ROUTES_DURATION_SECONDS
from .objects import Server
from . import utils
from .utils import url_path_join, make_ssl_context
from .utils import make_ssl_context
from .utils import url_path_join
from jupyterhub.traitlets import Command
def _one_at_a_time(method):
@@ -53,6 +55,7 @@ def _one_at_a_time(method):
queue them instead of allowing them to be concurrently outstanding.
"""
method._lock = asyncio.Lock()
@wraps(method)
async def locked_method(*args, **kwargs):
async with method._lock:
@@ -86,6 +89,7 @@ class Proxy(LoggingConfigurable):
"""
db_factory = Any()
@property
def db(self):
return self.db_factory()
@@ -97,13 +101,16 @@ class Proxy(LoggingConfigurable):
ssl_cert = Unicode()
host_routing = Bool()
should_start = Bool(True, config=True,
should_start = Bool(
True,
config=True,
help="""Should the Hub start the proxy
If True, the Hub will start the proxy and stop it.
Set to False if the proxy is managed externally,
such as by systemd, docker, or another service manager.
""")
""",
)
def start(self):
"""Start the proxy.
@@ -136,9 +143,13 @@ class Proxy(LoggingConfigurable):
# check host routing
host_route = not routespec.startswith('/')
if host_route and not self.host_routing:
raise ValueError("Cannot add host-based route %r, not using host-routing" % routespec)
raise ValueError(
"Cannot add host-based route %r, not using host-routing" % routespec
)
if self.host_routing and not host_route:
raise ValueError("Cannot add route without host %r, using host-routing" % routespec)
raise ValueError(
"Cannot add route without host %r, using host-routing" % routespec
)
# add trailing slash
if not routespec.endswith('/'):
return routespec + '/'
@@ -220,16 +231,19 @@ class Proxy(LoggingConfigurable):
"""Add a service's server to the proxy table."""
if not service.server:
raise RuntimeError(
"Service %s does not have an http endpoint to add to the proxy.", service.name)
"Service %s does not have an http endpoint to add to the proxy.",
service.name,
)
self.log.info("Adding service %s to proxy %s => %s",
service.name, service.proxy_spec, service.server.host,
self.log.info(
"Adding service %s to proxy %s => %s",
service.name,
service.proxy_spec,
service.server.host,
)
await self.add_route(
service.proxy_spec,
service.server.host,
{'service': service.name}
service.proxy_spec, service.server.host, {'service': service.name}
)
async def delete_service(self, service, client=None):
@@ -240,22 +254,23 @@ class Proxy(LoggingConfigurable):
async def add_user(self, user, server_name='', client=None):
"""Add a user's server to the proxy table."""
spawner = user.spawners[server_name]
self.log.info("Adding user %s to proxy %s => %s",
user.name, spawner.proxy_spec, spawner.server.host,
self.log.info(
"Adding user %s to proxy %s => %s",
user.name,
spawner.proxy_spec,
spawner.server.host,
)
if spawner.pending and spawner.pending != 'spawn':
raise RuntimeError(
"%s is pending %s, shouldn't be added to the proxy yet!" % (spawner._log_name, spawner.pending)
"%s is pending %s, shouldn't be added to the proxy yet!"
% (spawner._log_name, spawner.pending)
)
await self.add_route(
spawner.proxy_spec,
spawner.server.host,
{
'user': user.name,
'server_name': server_name,
}
{'user': user.name, 'server_name': server_name},
)
async def delete_user(self, user, server_name=''):
@@ -314,7 +329,9 @@ class Proxy(LoggingConfigurable):
else:
route = routes[self.app.hub.routespec]
if route['target'] != hub.host:
self.log.warning("Updating default route %s%s", route['target'], hub.host)
self.log.warning(
"Updating default route %s%s", route['target'], hub.host
)
futures.append(self.add_hub_route(hub))
for user in user_dict.values():
@@ -324,14 +341,17 @@ class Proxy(LoggingConfigurable):
good_routes.add(spec)
if spec not in user_routes:
self.log.warning(
"Adding missing route for %s (%s)", spec, spawner.server)
"Adding missing route for %s (%s)", spec, spawner.server
)
futures.append(self.add_user(user, name))
else:
route = routes[spec]
if route['target'] != spawner.server.host:
self.log.warning(
"Updating route for %s (%s%s)",
spec, route['target'], spawner.server,
spec,
route['target'],
spawner.server,
)
futures.append(self.add_user(user, name))
elif spawner.pending:
@@ -341,22 +361,26 @@ class Proxy(LoggingConfigurable):
good_routes.add(spawner.proxy_spec)
# check service routes
service_routes = {r['data']['service']: r
for r in routes.values() if 'service' in r['data']}
service_routes = {
r['data']['service']: r for r in routes.values() if 'service' in r['data']
}
for service in service_dict.values():
if service.server is None:
continue
good_routes.add(service.proxy_spec)
if service.name not in service_routes:
self.log.warning("Adding missing route for %s (%s)",
service.name, service.server)
self.log.warning(
"Adding missing route for %s (%s)", service.name, service.server
)
futures.append(self.add_service(service))
else:
route = service_routes[service.name]
if route['target'] != service.server.host:
self.log.warning(
"Updating route for %s (%s%s)",
route['routespec'], route['target'], service.server.host,
route['routespec'],
route['target'],
service.server.host,
)
futures.append(self.add_service(service))
@@ -424,7 +448,7 @@ class ConfigurableHTTPProxy(Proxy):
help="""The Proxy auth token
Loaded from the CONFIGPROXY_AUTH_TOKEN env variable by default.
""",
"""
).tag(config=True)
check_running_interval = Integer(5, config=True)
@@ -437,8 +461,8 @@ class ConfigurableHTTPProxy(Proxy):
token = utils.new_token()
return token
api_url = Unicode(config=True,
help="""The ip (or hostname) of the proxy's API endpoint"""
api_url = Unicode(
config=True, help="""The ip (or hostname) of the proxy's API endpoint"""
)
@default('api_url')
@@ -448,13 +472,12 @@ class ConfigurableHTTPProxy(Proxy):
if self.app.internal_ssl:
proto = 'https'
return "{proto}://{url}".format(
proto=proto,
url=url,
)
return "{proto}://{url}".format(proto=proto, url=url)
command = Command('configurable-http-proxy', config=True,
help="""The command to start the proxy"""
command = Command(
'configurable-http-proxy',
config=True,
help="""The command to start the proxy""",
)
pid_file = Unicode(
@@ -463,11 +486,14 @@ class ConfigurableHTTPProxy(Proxy):
help="File in which to write the PID of the proxy process.",
)
_check_running_callback = Any(help="PeriodicCallback to check if the proxy is running")
_check_running_callback = Any(
help="PeriodicCallback to check if the proxy is running"
)
def _check_pid(self, pid):
if os.name == 'nt':
import psutil
if not psutil.pid_exists(pid):
raise ProcessLookupError
else:
@@ -558,11 +584,16 @@ class ConfigurableHTTPProxy(Proxy):
env = os.environ.copy()
env['CONFIGPROXY_AUTH_TOKEN'] = self.auth_token
cmd = self.command + [
'--ip', public_server.ip,
'--port', str(public_server.port),
'--api-ip', api_server.ip,
'--api-port', str(api_server.port),
'--error-target', url_path_join(self.hub.url, 'error'),
'--ip',
public_server.ip,
'--port',
str(public_server.port),
'--api-ip',
api_server.ip,
'--api-port',
str(api_server.port),
'--error-target',
url_path_join(self.hub.url, 'error'),
]
if self.app.subdomain_host:
cmd.append('--host-routing')
@@ -595,28 +626,36 @@ class ConfigurableHTTPProxy(Proxy):
cmd.extend(['--client-ssl-request-cert'])
cmd.extend(['--client-ssl-reject-unauthorized'])
if self.app.statsd_host:
cmd.extend([
'--statsd-host', self.app.statsd_host,
'--statsd-port', str(self.app.statsd_port),
'--statsd-prefix', self.app.statsd_prefix + '.chp'
])
cmd.extend(
[
'--statsd-host',
self.app.statsd_host,
'--statsd-port',
str(self.app.statsd_port),
'--statsd-prefix',
self.app.statsd_prefix + '.chp',
]
)
# Warn if SSL is not used
if ' --ssl' not in ' '.join(cmd):
self.log.warning("Running JupyterHub without SSL."
" I hope there is SSL termination happening somewhere else...")
self.log.warning(
"Running JupyterHub without SSL."
" I hope there is SSL termination happening somewhere else..."
)
self.log.info("Starting proxy @ %s", public_server.bind_url)
self.log.debug("Proxy cmd: %s", cmd)
shell = os.name == 'nt'
try:
self.proxy_process = Popen(cmd, env=env, start_new_session=True, shell=shell)
self.proxy_process = Popen(
cmd, env=env, start_new_session=True, shell=shell
)
except FileNotFoundError as e:
self.log.error(
"Failed to find proxy %r\n"
"The proxy can be installed with `npm install -g configurable-http-proxy`."
"To install `npm`, install nodejs which includes `npm`."
"If you see an `EACCES` error or permissions error, refer to the `npm` "
"documentation on How To Prevent Permissions Errors."
% self.command
"documentation on How To Prevent Permissions Errors." % self.command
)
raise
@@ -625,8 +664,7 @@ class ConfigurableHTTPProxy(Proxy):
def _check_process():
status = self.proxy_process.poll()
if status is not None:
e = RuntimeError(
"Proxy failed to start with exit code %i" % status)
e = RuntimeError("Proxy failed to start with exit code %i" % status)
raise e from None
for server in (public_server, api_server):
@@ -678,8 +716,9 @@ class ConfigurableHTTPProxy(Proxy):
"""Check if the proxy is still running"""
if self.proxy_process.poll() is None:
return
self.log.error("Proxy stopped with exit code %r",
'unknown' if self.proxy_process is None else self.proxy_process.poll()
self.log.error(
"Proxy stopped with exit code %r",
'unknown' if self.proxy_process is None else self.proxy_process.poll(),
)
self._remove_pid_file()
await self.start()
@@ -724,10 +763,10 @@ class ConfigurableHTTPProxy(Proxy):
if isinstance(body, dict):
body = json.dumps(body)
self.log.debug("Proxy: Fetching %s %s", method, url)
req = HTTPRequest(url,
req = HTTPRequest(
url,
method=method,
headers={'Authorization': 'token {}'.format(
self.auth_token)},
headers={'Authorization': 'token {}'.format(self.auth_token)},
body=body,
)
async with self.semaphore:
@@ -739,11 +778,7 @@ class ConfigurableHTTPProxy(Proxy):
body['target'] = target
body['jupyterhub'] = True
path = self._routespec_to_chp_path(routespec)
await self.api_request(
path,
method='POST',
body=body,
)
await self.api_request(path, method='POST', body=body)
async def delete_route(self, routespec):
path = self._routespec_to_chp_path(routespec)
@@ -762,11 +797,7 @@ class ConfigurableHTTPProxy(Proxy):
"""Reformat CHP data format to JupyterHub's proxy API."""
target = chp_data.pop('target')
chp_data.pop('jupyterhub')
return {
'routespec': routespec,
'target': target,
'data': chp_data,
}
return {'routespec': routespec, 'target': target, 'data': chp_data}
async def get_all_routes(self, client=None):
"""Fetch the proxy's routes."""
@@ -779,6 +810,5 @@ class ConfigurableHTTPProxy(Proxy):
# exclude routes not associated with JupyterHub
self.log.debug("Omitting non-jupyterhub route %r", routespec)
continue
all_routes[routespec] = self._reformat_routespec(
routespec, chp_data)
all_routes[routespec] = self._reformat_routespec(routespec, chp_data)
return all_routes

View File

@@ -9,7 +9,6 @@ model describing the authenticated user.
authenticate with the Hub.
"""
import base64
import json
import os
@@ -18,22 +17,25 @@ import re
import socket
import string
import time
from urllib.parse import quote, urlencode
import uuid
import warnings
from urllib.parse import quote
from urllib.parse import urlencode
import requests
from tornado.gen import coroutine
from tornado.log import app_log
from tornado.httputil import url_concat
from tornado.web import HTTPError, RequestHandler
from tornado.log import app_log
from tornado.web import HTTPError
from tornado.web import RequestHandler
from traitlets import default
from traitlets import Dict
from traitlets import Instance
from traitlets import Integer
from traitlets import observe
from traitlets import Unicode
from traitlets import validate
from traitlets.config import SingletonConfigurable
from traitlets import (
Unicode, Integer, Instance, Dict,
default, observe, validate,
)
from ..utils import url_path_join
@@ -63,13 +65,14 @@ class _ExpiringDict(dict):
def __repr__(self):
"""include values and timestamps in repr"""
now = time.monotonic()
return repr({
return repr(
{
key: '{value} (age={age:.0f}s)'.format(
value=repr(value)[:16] + '...',
age=now-self.timestamps[key],
value=repr(value)[:16] + '...', age=now - self.timestamps[key]
)
for key, value in self.values.items()
})
}
)
def _check_age(self, key):
"""Check timestamp for a key"""
@@ -131,24 +134,28 @@ class HubAuth(SingletonConfigurable):
"""
hub_host = Unicode('',
hub_host = Unicode(
'',
help="""The public host of JupyterHub
Only used if JupyterHub is spreading servers across subdomains.
"""
""",
).tag(config=True)
@default('hub_host')
def _default_hub_host(self):
return os.getenv('JUPYTERHUB_HOST', '')
base_url = Unicode(os.getenv('JUPYTERHUB_SERVICE_PREFIX') or '/',
base_url = Unicode(
os.getenv('JUPYTERHUB_SERVICE_PREFIX') or '/',
help="""The base URL prefix of this application
e.g. /services/service-name/ or /user/name/
Default: get from JUPYTERHUB_SERVICE_PREFIX
"""
""",
).tag(config=True)
@validate('base_url')
def _add_slash(self, proposal):
"""Ensure base_url starts and ends with /"""
@@ -160,12 +167,14 @@ class HubAuth(SingletonConfigurable):
return value
# where is the hub
api_url = Unicode(os.getenv('JUPYTERHUB_API_URL') or 'http://127.0.0.1:8081/hub/api',
api_url = Unicode(
os.getenv('JUPYTERHUB_API_URL') or 'http://127.0.0.1:8081/hub/api',
help="""The base API URL of the Hub.
Typically `http://hub-ip:hub-port/hub/api`
"""
""",
).tag(config=True)
@default('api_url')
def _api_url(self):
env_url = os.getenv('JUPYTERHUB_API_URL')
@@ -174,56 +183,64 @@ class HubAuth(SingletonConfigurable):
else:
return 'http://127.0.0.1:8081' + url_path_join(self.hub_prefix, 'api')
api_token = Unicode(os.getenv('JUPYTERHUB_API_TOKEN', ''),
api_token = Unicode(
os.getenv('JUPYTERHUB_API_TOKEN', ''),
help="""API key for accessing Hub API.
Generate with `jupyterhub token [username]` or add to JupyterHub.services config.
"""
""",
).tag(config=True)
hub_prefix = Unicode('/hub/',
hub_prefix = Unicode(
'/hub/',
help="""The URL prefix for the Hub itself.
Typically /hub/
"""
""",
).tag(config=True)
@default('hub_prefix')
def _default_hub_prefix(self):
return url_path_join(os.getenv('JUPYTERHUB_BASE_URL') or '/', 'hub') + '/'
login_url = Unicode('/hub/login',
login_url = Unicode(
'/hub/login',
help="""The login URL to use
Typically /hub/login
"""
""",
).tag(config=True)
@default('login_url')
def _default_login_url(self):
return self.hub_host + url_path_join(self.hub_prefix, 'login')
keyfile = Unicode('',
keyfile = Unicode(
'',
help="""The ssl key to use for requests
Use with certfile
"""
""",
).tag(config=True)
certfile = Unicode('',
certfile = Unicode(
'',
help="""The ssl cert to use for requests
Use with keyfile
"""
""",
).tag(config=True)
client_ca = Unicode('',
client_ca = Unicode(
'',
help="""The ssl certificate authority to use to verify requests
Use with keyfile and certfile
"""
""",
).tag(config=True)
cookie_name = Unicode('jupyterhub-services',
help="""The name of the cookie I should be looking for"""
cookie_name = Unicode(
'jupyterhub-services', help="""The name of the cookie I should be looking for"""
).tag(config=True)
cookie_options = Dict(
@@ -245,21 +262,26 @@ class HubAuth(SingletonConfigurable):
return {}
cookie_cache_max_age = Integer(help="DEPRECATED. Use cache_max_age")
@observe('cookie_cache_max_age')
def _deprecated_cookie_cache(self, change):
warnings.warn("cookie_cache_max_age is deprecated in JupyterHub 0.8. Use cache_max_age instead.")
warnings.warn(
"cookie_cache_max_age is deprecated in JupyterHub 0.8. Use cache_max_age instead."
)
self.cache_max_age = change.new
cache_max_age = Integer(300,
cache_max_age = Integer(
300,
help="""The maximum time (in seconds) to cache the Hub's responses for authentication.
A larger value reduces load on the Hub and occasional response lag.
A smaller value reduces propagation time of changes on the Hub (rare).
Default: 300 (five minutes)
"""
""",
).tag(config=True)
cache = Instance(_ExpiringDict, allow_none=False)
@default('cache')
def _default_cache(self):
return _ExpiringDict(self.cache_max_age)
@@ -311,25 +333,42 @@ class HubAuth(SingletonConfigurable):
except requests.ConnectionError as e:
app_log.error("Error connecting to %s: %s", self.api_url, e)
msg = "Failed to connect to Hub API at %r." % self.api_url
msg += " Is the Hub accessible at this URL (from host: %s)?" % socket.gethostname()
msg += (
" Is the Hub accessible at this URL (from host: %s)?"
% socket.gethostname()
)
if '127.0.0.1' in self.api_url:
msg += " Make sure to set c.JupyterHub.hub_ip to an IP accessible to" + \
" single-user servers if the servers are not on the same host as the Hub."
msg += (
" Make sure to set c.JupyterHub.hub_ip to an IP accessible to"
+ " single-user servers if the servers are not on the same host as the Hub."
)
raise HTTPError(500, msg)
data = None
if r.status_code == 404 and allow_404:
pass
elif r.status_code == 403:
app_log.error("I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s", r.status_code, r.reason)
app_log.error(
"I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s",
r.status_code,
r.reason,
)
app_log.error(r.text)
raise HTTPError(500, "Permission failure checking authorization, I may need a new token")
raise HTTPError(
500, "Permission failure checking authorization, I may need a new token"
)
elif r.status_code >= 500:
app_log.error("Upstream failure verifying auth token: [%i] %s", r.status_code, r.reason)
app_log.error(
"Upstream failure verifying auth token: [%i] %s",
r.status_code,
r.reason,
)
app_log.error(r.text)
raise HTTPError(502, "Failed to check authorization (upstream problem)")
elif r.status_code >= 400:
app_log.warning("Failed to check authorization: [%i] %s", r.status_code, r.reason)
app_log.warning(
"Failed to check authorization: [%i] %s", r.status_code, r.reason
)
app_log.warning(r.text)
msg = "Failed to check authorization"
# pass on error_description from oauth failure
@@ -358,10 +397,12 @@ class HubAuth(SingletonConfigurable):
The 'name' field contains the user's name.
"""
return self._check_hub_authorization(
url=url_path_join(self.api_url,
url=url_path_join(
self.api_url,
"authorizations/cookie",
self.cookie_name,
quote(encrypted_cookie, safe='')),
quote(encrypted_cookie, safe=''),
),
cache_key='cookie:{}:{}'.format(session_id, encrypted_cookie),
use_cache=use_cache,
)
@@ -379,9 +420,9 @@ class HubAuth(SingletonConfigurable):
The 'name' field contains the user's name.
"""
return self._check_hub_authorization(
url=url_path_join(self.api_url,
"authorizations/token",
quote(token, safe='')),
url=url_path_join(
self.api_url, "authorizations/token", quote(token, safe='')
),
cache_key='token:{}:{}'.format(session_id, token),
use_cache=use_cache,
)
@@ -399,7 +440,9 @@ class HubAuth(SingletonConfigurable):
user_token = handler.get_argument('token', '')
if not user_token:
# get it from Authorization header
m = self.auth_header_pat.match(handler.request.headers.get(self.auth_header_name, ''))
m = self.auth_header_pat.match(
handler.request.headers.get(self.auth_header_name, '')
)
if m:
user_token = m.group(1)
return user_token
@@ -469,11 +512,14 @@ class HubOAuth(HubAuth):
@default('login_url')
def _login_url(self):
return url_concat(self.oauth_authorization_url, {
return url_concat(
self.oauth_authorization_url,
{
'client_id': self.oauth_client_id,
'redirect_uri': self.oauth_redirect_uri,
'response_type': 'code',
})
},
)
@property
def cookie_name(self):
@@ -511,6 +557,7 @@ class HubOAuth(HubAuth):
Use JUPYTERHUB_CLIENT_ID by default.
"""
).tag(config=True)
@default('oauth_client_id')
def _client_id(self):
return os.getenv('JUPYTERHUB_CLIENT_ID', '')
@@ -527,13 +574,18 @@ class HubOAuth(HubAuth):
Should generally be /base_url/oauth_callback
"""
).tag(config=True)
@default('oauth_redirect_uri')
def _default_redirect(self):
return os.getenv('JUPYTERHUB_OAUTH_CALLBACK_URL') or url_path_join(self.base_url, 'oauth_callback')
return os.getenv('JUPYTERHUB_OAUTH_CALLBACK_URL') or url_path_join(
self.base_url, 'oauth_callback'
)
oauth_authorization_url = Unicode('/hub/api/oauth2/authorize',
oauth_authorization_url = Unicode(
'/hub/api/oauth2/authorize',
help="The URL to redirect to when starting the OAuth process",
).tag(config=True)
@default('oauth_authorization_url')
def _auth_url(self):
return self.hub_host + url_path_join(self.hub_prefix, 'api/oauth2/authorize')
@@ -541,6 +593,7 @@ class HubOAuth(HubAuth):
oauth_token_url = Unicode(
help="""The URL for requesting an OAuth token from JupyterHub"""
).tag(config=True)
@default('oauth_token_url')
def _token_url(self):
return url_path_join(self.api_url, 'oauth2/token')
@@ -565,11 +618,12 @@ class HubOAuth(HubAuth):
redirect_uri=self.oauth_redirect_uri,
)
token_reply = self._api_request('POST', self.oauth_token_url,
token_reply = self._api_request(
'POST',
self.oauth_token_url,
data=urlencode(params).encode('utf8'),
headers={
'Content-Type': 'application/x-www-form-urlencoded'
})
headers={'Content-Type': 'application/x-www-form-urlencoded'},
)
return token_reply['access_token']
@@ -577,9 +631,11 @@ class HubOAuth(HubAuth):
"""Encode a state dict as url-safe base64"""
# trim trailing `=` because = is itself not url-safe!
json_state = json.dumps(state)
return base64.urlsafe_b64encode(
json_state.encode('utf8')
).decode('ascii').rstrip('=')
return (
base64.urlsafe_b64encode(json_state.encode('utf8'))
.decode('ascii')
.rstrip('=')
)
def _decode_state(self, b64_state):
"""Decode a base64 state
@@ -621,7 +677,9 @@ class HubOAuth(HubAuth):
# use a randomized cookie suffix to avoid collisions
# in case of concurrent logins
app_log.warning("Detected unused OAuth state cookies")
cookie_suffix = ''.join(random.choice(string.ascii_letters) for i in range(8))
cookie_suffix = ''.join(
random.choice(string.ascii_letters) for i in range(8)
)
cookie_name = '{}-{}'.format(self.state_cookie_name, cookie_suffix)
extra_state['cookie_name'] = cookie_name
else:
@@ -640,11 +698,7 @@ class HubOAuth(HubAuth):
kwargs['secure'] = True
# load user cookie overrides
kwargs.update(self.cookie_options)
handler.set_secure_cookie(
cookie_name,
b64_state,
**kwargs
)
handler.set_secure_cookie(cookie_name, b64_state, **kwargs)
return b64_state
def generate_state(self, next_url=None, **extra_state):
@@ -658,10 +712,7 @@ class HubOAuth(HubAuth):
-------
state (str): The base64-encoded state string.
"""
state = {
'uuid': uuid.uuid4().hex,
'next_url': next_url,
}
state = {'uuid': uuid.uuid4().hex, 'next_url': next_url}
state.update(extra_state)
return self._encode_state(state)
@@ -681,21 +732,19 @@ class HubOAuth(HubAuth):
def set_cookie(self, handler, access_token):
"""Set a cookie recording OAuth result"""
kwargs = {
'path': self.base_url,
'httponly': True,
}
kwargs = {'path': self.base_url, 'httponly': True}
if handler.request.protocol == 'https':
kwargs['secure'] = True
# load user cookie overrides
kwargs.update(self.cookie_options)
app_log.debug("Setting oauth cookie for %s: %s, %s",
handler.request.remote_ip, self.cookie_name, kwargs)
handler.set_secure_cookie(
app_log.debug(
"Setting oauth cookie for %s: %s, %s",
handler.request.remote_ip,
self.cookie_name,
access_token,
**kwargs
kwargs,
)
handler.set_secure_cookie(self.cookie_name, access_token, **kwargs)
def clear_cookie(self, handler):
"""Clear the OAuth cookie"""
handler.clear_cookie(self.cookie_name, path=self.base_url)
@@ -703,6 +752,7 @@ class HubOAuth(HubAuth):
class UserNotAllowed(Exception):
"""Exception raised when a user is identified and not allowed"""
def __init__(self, model):
self.model = model
@@ -738,6 +788,7 @@ class HubAuthenticated(object):
...
"""
hub_services = None # set of allowed services
hub_users = None # set of allowed users
hub_groups = None # set of allowed groups
@@ -748,9 +799,11 @@ class HubAuthenticated(object):
"""Property indicating that all successfully identified user
or service should be allowed.
"""
return (self.hub_services is None
return (
self.hub_services is None
and self.hub_users is None
and self.hub_groups is None)
and self.hub_groups is None
)
# self.hub_auth must be a HubAuth instance.
# If nothing specified, use default config,
@@ -758,6 +811,7 @@ class HubAuthenticated(object):
# based on JupyterHub environment variables for services.
_hub_auth = None
hub_auth_class = HubAuth
@property
def hub_auth(self):
if self._hub_auth is None:
@@ -794,7 +848,9 @@ class HubAuthenticated(object):
name = model['name']
kind = model.setdefault('kind', 'user')
if self.allow_all:
app_log.debug("Allowing Hub %s %s (all Hub users and services allowed)", kind, name)
app_log.debug(
"Allowing Hub %s %s (all Hub users and services allowed)", kind, name
)
return model
if self.allow_admin and model.get('admin', False):
@@ -816,7 +872,11 @@ class HubAuthenticated(object):
return model
elif self.hub_groups and set(model['groups']).intersection(self.hub_groups):
allowed_groups = set(model['groups']).intersection(self.hub_groups)
app_log.debug("Allowing Hub user %s in group(s) %s", name, ','.join(sorted(allowed_groups)))
app_log.debug(
"Allowing Hub user %s in group(s) %s",
name,
','.join(sorted(allowed_groups)),
)
# group in whitelist
return model
else:
@@ -845,7 +905,10 @@ class HubAuthenticated(object):
# This is not the best, but avoids problems that can be caused
# when get_current_user is allowed to raise.
def raise_on_redirect(*args, **kwargs):
raise HTTPError(403, "{kind} {name} is not allowed.".format(**user_model))
raise HTTPError(
403, "{kind} {name} is not allowed.".format(**user_model)
)
self.redirect = raise_on_redirect
return
except Exception:
@@ -869,6 +932,7 @@ class HubAuthenticated(object):
class HubOAuthenticated(HubAuthenticated):
"""Simple subclass of HubAuthenticated using OAuth instead of old shared cookies"""
hub_auth_class = HubOAuth
@@ -917,5 +981,3 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
app_log.info("Logged-in user %s", user_model)
self.hub_auth.set_cookie(self, token)
self.redirect(next_url or self.hub_auth.base_url)

View File

@@ -38,24 +38,26 @@ A hub-managed service with no URL::
}
"""
import copy
import os
import pipes
import shutil
import os
from subprocess import Popen
from traitlets import (
HasTraits,
Any, Bool, Dict, Unicode, Instance,
default,
)
from traitlets import Any
from traitlets import Bool
from traitlets import default
from traitlets import Dict
from traitlets import HasTraits
from traitlets import Instance
from traitlets import Unicode
from traitlets.config import LoggingConfigurable
from .. import orm
from ..objects import Server
from ..spawner import LocalProcessSpawner
from ..spawner import set_user_setuid
from ..traitlets import Command
from ..spawner import LocalProcessSpawner, set_user_setuid
from ..utils import url_path_join
@@ -81,14 +83,17 @@ class _MockUser(HasTraits):
return ''
return self.server.base_url
# We probably shouldn't use a Spawner here,
# but there are too many concepts to share.
class _ServiceSpawner(LocalProcessSpawner):
"""Subclass of LocalProcessSpawner
Removes notebook-specific-ness from LocalProcessSpawner.
"""
cwd = Unicode()
cmd = Command(minlen=0)
@@ -115,7 +120,9 @@ class _ServiceSpawner(LocalProcessSpawner):
self.log.info("Spawning %s", ' '.join(pipes.quote(s) for s in cmd))
try:
self.proc = Popen(self.cmd, env=env,
self.proc = Popen(
self.cmd,
env=env,
preexec_fn=self.make_preexec_fn(self.user.name),
start_new_session=True, # don't forward signals
cwd=self.cwd or None,
@@ -123,8 +130,10 @@ class _ServiceSpawner(LocalProcessSpawner):
except PermissionError:
# use which to get abspath
script = shutil.which(cmd[0]) or cmd[0]
self.log.error("Permission denied trying to run %r. Does %s have access to this file?",
script, self.user.name,
self.log.error(
"Permission denied trying to run %r. Does %s have access to this file?",
script,
self.user.name,
)
raise
@@ -165,9 +174,9 @@ class Service(LoggingConfigurable):
If the service has an http endpoint, it
"""
).tag(input=True)
admin = Bool(False,
help="Does the service need admin-access to the Hub API?"
).tag(input=True)
admin = Bool(False, help="Does the service need admin-access to the Hub API?").tag(
input=True
)
url = Unicode(
help="""URL of the service.
@@ -205,22 +214,23 @@ class Service(LoggingConfigurable):
"""
return 'managed' if self.managed else 'external'
command = Command(minlen=0,
help="Command to spawn this service, if managed."
).tag(input=True)
cwd = Unicode(
help="""The working directory in which to run the service."""
).tag(input=True)
command = Command(minlen=0, help="Command to spawn this service, if managed.").tag(
input=True
)
cwd = Unicode(help="""The working directory in which to run the service.""").tag(
input=True
)
environment = Dict(
help="""Environment variables to pass to the service.
Only used if the Hub is spawning the service.
"""
).tag(input=True)
user = Unicode("",
user = Unicode(
"",
help="""The user to become when launching the service.
If unspecified, run the service as the same user as the Hub.
"""
""",
).tag(input=True)
domain = Unicode()
@@ -245,6 +255,7 @@ class Service(LoggingConfigurable):
Default: `service-<name>`
"""
).tag(input=True)
@default('oauth_client_id')
def _default_client_id(self):
return 'service-%s' % self.name
@@ -256,6 +267,7 @@ class Service(LoggingConfigurable):
Default: `/services/:name/oauth_callback`
"""
).tag(input=True)
@default('oauth_redirect_uri')
def _default_redirect_uri(self):
if self.server is None:
@@ -328,10 +340,7 @@ class Service(LoggingConfigurable):
cwd=self.cwd,
hub=self.hub,
user=_MockUser(
name=self.user,
service=self,
server=self.orm.server,
host=self.host,
name=self.user, service=self, server=self.orm.server, host=self.host
),
internal_ssl=self.app.internal_ssl,
internal_certs_location=self.app.internal_certs_location,
@@ -344,7 +353,9 @@ class Service(LoggingConfigurable):
def _proc_stopped(self):
"""Called when the service process unexpectedly exits"""
self.log.error("Service %s exited with status %i", self.name, self.proc.returncode)
self.log.error(
"Service %s exited with status %i", self.name, self.proc.returncode
)
self.start()
async def stop(self):
@@ -357,4 +368,4 @@ class Service(LoggingConfigurable):
self.db.delete(self.orm.server)
self.db.commit()
self.spawner.stop_polling()
return (await self.spawner.stop())
return await self.spawner.stop()

View File

@@ -1,23 +1,24 @@
#!/usr/bin/env python
"""Extend regular notebook server to be aware of multiuser things."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
from datetime import datetime, timezone
import json
import os
import random
from datetime import datetime
from datetime import timezone
from textwrap import dedent
from urllib.parse import urlparse
from jinja2 import ChoiceLoader, FunctionLoader
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
from jinja2 import ChoiceLoader
from jinja2 import FunctionLoader
from tornado import gen
from tornado import ioloop
from tornado.web import HTTPError, RequestHandler
from tornado.httpclient import AsyncHTTPClient
from tornado.httpclient import HTTPRequest
from tornado.web import HTTPError
from tornado.web import RequestHandler
try:
import notebook
@@ -60,7 +61,9 @@ class HubAuthenticatedHandler(HubOAuthenticated):
@property
def allow_admin(self):
return self.settings.get('admin_access', os.getenv('JUPYTERHUB_ADMIN_ACCESS') or False)
return self.settings.get(
'admin_access', os.getenv('JUPYTERHUB_ADMIN_ACCESS') or False
)
@property
def hub_auth(self):
@@ -79,6 +82,7 @@ class HubAuthenticatedHandler(HubOAuthenticated):
class JupyterHubLoginHandler(LoginHandler):
"""LoginHandler that hooks up Hub authentication"""
@staticmethod
def login_available(settings):
return True
@@ -113,12 +117,14 @@ class JupyterHubLogoutHandler(LogoutHandler):
def get(self):
self.settings['hub_auth'].clear_cookie(self)
self.redirect(
self.settings['hub_host'] +
url_path_join(self.settings['hub_prefix'], 'logout'))
self.settings['hub_host']
+ url_path_join(self.settings['hub_prefix'], 'logout')
)
class OAuthCallbackHandler(HubOAuthCallbackHandler, IPythonHandler):
"""Mixin IPythonHandler to get the right error pages, etc."""
@property
def hub_auth(self):
return self.settings['hub_auth']
@@ -126,7 +132,8 @@ class OAuthCallbackHandler(HubOAuthCallbackHandler, IPythonHandler):
# register new hub related command-line aliases
aliases = dict(notebook_aliases)
aliases.update({
aliases.update(
{
'user': 'SingleUserNotebookApp.user',
'group': 'SingleUserNotebookApp.group',
'cookie-name': 'HubAuth.cookie_name',
@@ -134,15 +141,17 @@ aliases.update({
'hub-host': 'SingleUserNotebookApp.hub_host',
'hub-api-url': 'SingleUserNotebookApp.hub_api_url',
'base-url': 'SingleUserNotebookApp.base_url',
})
flags = dict(notebook_flags)
flags.update({
'disable-user-config': ({
'SingleUserNotebookApp': {
'disable_user_config': True
}
}, "Disable user-controlled configuration of the notebook server.")
})
)
flags = dict(notebook_flags)
flags.update(
{
'disable-user-config': (
{'SingleUserNotebookApp': {'disable_user_config': True}},
"Disable user-controlled configuration of the notebook server.",
)
}
)
page_template = """
{% extends "templates/page.html" %}
@@ -209,11 +218,14 @@ def _exclude_home(path_list):
class SingleUserNotebookApp(NotebookApp):
"""A Subclass of the regular NotebookApp that is aware of the parent multiuser context."""
description = dedent("""
description = dedent(
"""
Single-user server for JupyterHub. Extends the Jupyter Notebook server.
Meant to be invoked by JupyterHub Spawners, and not directly.
""")
"""
)
examples = ""
subcommands = {}
@@ -229,6 +241,7 @@ class SingleUserNotebookApp(NotebookApp):
# ensures that each spawn clears any cookies from previous session,
# triggering OAuth again
cookie_secret = Bytes()
def _cookie_secret_default(self):
return os.urandom(32)
@@ -320,14 +333,17 @@ class SingleUserNotebookApp(NotebookApp):
trust_xheaders = True
login_handler_class = JupyterHubLoginHandler
logout_handler_class = JupyterHubLogoutHandler
port_retries = 0 # disable port-retries, since the Spawner will tell us what port to use
port_retries = (
0
) # disable port-retries, since the Spawner will tell us what port to use
disable_user_config = Bool(False,
disable_user_config = Bool(
False,
help="""Disable user configuration of single-user server.
Prevents user-writable files that normally configure the single-user server
from being loaded, ensuring admins have full control of configuration.
"""
""",
).tag(config=True)
@validate('notebook_dir')
@@ -394,22 +410,15 @@ class SingleUserNotebookApp(NotebookApp):
# create dynamic default http client,
# configured with any relevant ssl config
hub_http_client = Any()
@default('hub_http_client')
def _default_client(self):
ssl_context = make_ssl_context(
self.keyfile,
self.certfile,
cafile=self.client_ca,
)
AsyncHTTPClient.configure(
None,
defaults={
"ssl_options": ssl_context,
},
self.keyfile, self.certfile, cafile=self.client_ca
)
AsyncHTTPClient.configure(None, defaults={"ssl_options": ssl_context})
return AsyncHTTPClient()
async def check_hub_version(self):
"""Test a connection to my Hub
@@ -422,8 +431,12 @@ class SingleUserNotebookApp(NotebookApp):
try:
resp = await client.fetch(self.hub_api_url)
except Exception:
self.log.exception("Failed to connect to my Hub at %s (attempt %i/%i). Is it running?",
self.hub_api_url, i, RETRIES)
self.log.exception(
"Failed to connect to my Hub at %s (attempt %i/%i). Is it running?",
self.hub_api_url,
i,
RETRIES,
)
await gen.sleep(min(2 ** i, 16))
else:
break
@@ -434,14 +447,15 @@ class SingleUserNotebookApp(NotebookApp):
_check_version(hub_version, __version__, self.log)
server_name = Unicode()
@default('server_name')
def _server_name_default(self):
return os.environ.get('JUPYTERHUB_SERVER_NAME', '')
hub_activity_url = Unicode(
config=True,
help="URL for sending JupyterHub activity updates",
config=True, help="URL for sending JupyterHub activity updates"
)
@default('hub_activity_url')
def _default_activity_url(self):
return os.environ.get('JUPYTERHUB_ACTIVITY_URL', '')
@@ -452,8 +466,9 @@ class SingleUserNotebookApp(NotebookApp):
help="""
Interval (in seconds) on which to update the Hub
with our latest activity.
"""
""",
)
@default('hub_activity_interval')
def _default_activity_interval(self):
env_value = os.environ.get('JUPYTERHUB_ACTIVITY_INTERVAL')
@@ -478,10 +493,7 @@ class SingleUserNotebookApp(NotebookApp):
self.log.warning("last activity is using naïve timestamps")
last_activity = last_activity.replace(tzinfo=timezone.utc)
if (
self._last_activity_sent
and last_activity < self._last_activity_sent
):
if self._last_activity_sent and last_activity < self._last_activity_sent:
self.log.debug("No activity since %s", self._last_activity_sent)
return
@@ -496,14 +508,14 @@ class SingleUserNotebookApp(NotebookApp):
"Authorization": "token {}".format(self.hub_auth.api_token),
"Content-Type": "application/json",
},
body=json.dumps({
body=json.dumps(
{
'servers': {
self.server_name: {
'last_activity': last_activity_timestamp,
},
self.server_name: {'last_activity': last_activity_timestamp}
},
'last_activity': last_activity_timestamp,
})
}
),
)
try:
await client.fetch(req)
@@ -526,8 +538,8 @@ class SingleUserNotebookApp(NotebookApp):
if not self.hub_activity_url or not self.hub_activity_interval:
self.log.warning("Activity events disabled")
return
self.log.info("Updating Hub with activity every %s seconds",
self.hub_activity_interval
self.log.info(
"Updating Hub with activity every %s seconds", self.hub_activity_interval
)
while True:
try:
@@ -561,7 +573,9 @@ class SingleUserNotebookApp(NotebookApp):
api_token = os.environ['JUPYTERHUB_API_TOKEN']
if not api_token:
self.exit("JUPYTERHUB_API_TOKEN env is required to run jupyterhub-singleuser. Did you launch it manually?")
self.exit(
"JUPYTERHUB_API_TOKEN env is required to run jupyterhub-singleuser. Did you launch it manually?"
)
self.hub_auth = HubOAuth(
parent=self,
api_token=api_token,
@@ -586,21 +600,23 @@ class SingleUserNotebookApp(NotebookApp):
s['hub_prefix'] = self.hub_prefix
s['hub_host'] = self.hub_host
s['hub_auth'] = self.hub_auth
csp_report_uri = s['csp_report_uri'] = self.hub_host + url_path_join(self.hub_prefix, 'security/csp-report')
csp_report_uri = s['csp_report_uri'] = self.hub_host + url_path_join(
self.hub_prefix, 'security/csp-report'
)
headers = s.setdefault('headers', {})
headers['X-JupyterHub-Version'] = __version__
# set CSP header directly to workaround bugs in jupyter/notebook 5.0
headers.setdefault('Content-Security-Policy', ';'.join([
"frame-ancestors 'self'",
"report-uri " + csp_report_uri,
]))
headers.setdefault(
'Content-Security-Policy',
';'.join(["frame-ancestors 'self'", "report-uri " + csp_report_uri]),
)
super(SingleUserNotebookApp, self).init_webapp()
# add OAuth callback
self.web_app.add_handlers(r".*$", [(
urlparse(self.hub_auth.oauth_redirect_uri).path,
OAuthCallbackHandler
)])
self.web_app.add_handlers(
r".*$",
[(urlparse(self.hub_auth.oauth_redirect_uri).path, OAuthCallbackHandler)],
)
# apply X-JupyterHub-Version to *all* request handlers (even redirects)
self.patch_default_headers()
@@ -610,6 +626,7 @@ class SingleUserNotebookApp(NotebookApp):
if hasattr(RequestHandler, '_orig_set_default_headers'):
return
RequestHandler._orig_set_default_headers = RequestHandler.set_default_headers
def set_jupyterhub_header(self):
self._orig_set_default_headers()
self.set_header('X-JupyterHub-Version', __version__)
@@ -619,13 +636,16 @@ class SingleUserNotebookApp(NotebookApp):
def patch_templates(self):
"""Patch page templates to add Hub-related buttons"""
self.jinja_template_vars['logo_url'] = self.hub_host + url_path_join(self.hub_prefix, 'logo')
self.jinja_template_vars['logo_url'] = self.hub_host + url_path_join(
self.hub_prefix, 'logo'
)
self.jinja_template_vars['hub_host'] = self.hub_host
self.jinja_template_vars['hub_prefix'] = self.hub_prefix
env = self.web_app.settings['jinja2_env']
env.globals['hub_control_panel_url'] = \
self.hub_host + url_path_join(self.hub_prefix, 'home')
env.globals['hub_control_panel_url'] = self.hub_host + url_path_join(
self.hub_prefix, 'home'
)
# patch jinja env loading to modify page template
def get_page(name):
@@ -633,10 +653,7 @@ class SingleUserNotebookApp(NotebookApp):
return page_template
orig_loader = env.loader
env.loader = ChoiceLoader([
FunctionLoader(get_page),
orig_loader,
])
env.loader = ChoiceLoader([FunctionLoader(get_page), orig_loader])
def main(argv=None):

View File

@@ -1,10 +1,8 @@
"""
Contains base Spawner class & default implementation
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import ast
import asyncio
import errno
@@ -18,22 +16,35 @@ import warnings
from subprocess import Popen
from tempfile import mkdtemp
# FIXME: remove when we drop Python 3.5 support
from async_generator import async_generator, yield_
from async_generator import async_generator
from async_generator import yield_
from sqlalchemy import inspect
from tornado.ioloop import PeriodicCallback
from traitlets import Any
from traitlets import Bool
from traitlets import default
from traitlets import Dict
from traitlets import Float
from traitlets import Instance
from traitlets import Integer
from traitlets import List
from traitlets import observe
from traitlets import Unicode
from traitlets import Union
from traitlets import validate
from traitlets.config import LoggingConfigurable
from traitlets import (
Any, Bool, Dict, Instance, Integer, Float, List, Unicode, Union,
default, observe, validate,
)
from .objects import Server
from .traitlets import Command, ByteSpecification, Callable
from .utils import iterate_until, maybe_future, random_port, url_path_join, exponential_backoff
from .traitlets import ByteSpecification
from .traitlets import Callable
from .traitlets import Command
from .utils import exponential_backoff
from .utils import iterate_until
from .utils import maybe_future
from .utils import random_port
from .utils import url_path_join
# FIXME: remove when we drop Python 3.5 support
def _quote_safe(s):
@@ -53,6 +64,7 @@ def _quote_safe(s):
# to avoid getting interpreted by traitlets
return repr(s)
class Spawner(LoggingConfigurable):
"""Base class for spawning single-user notebook servers.
@@ -146,8 +158,12 @@ class Spawner(LoggingConfigurable):
missing.append(attr)
if missing:
raise NotImplementedError("class `{}` needs to redefine the `start`,"
"`stop` and `poll` methods. `{}` not redefined.".format(cls.__name__, '`, `'.join(missing)))
raise NotImplementedError(
"class `{}` needs to redefine the `start`,"
"`stop` and `poll` methods. `{}` not redefined.".format(
cls.__name__, '`, `'.join(missing)
)
)
proxy_spec = Unicode()
@@ -180,6 +196,7 @@ class Spawner(LoggingConfigurable):
if self.orm_spawner:
return self.orm_spawner.name
return ''
hub = Any()
authenticator = Any()
internal_ssl = Bool(False)
@@ -191,7 +208,8 @@ class Spawner(LoggingConfigurable):
oauth_client_id = Unicode()
handler = Any()
will_resume = Bool(False,
will_resume = Bool(
False,
help="""Whether the Spawner will resume on next start
@@ -199,18 +217,20 @@ class Spawner(LoggingConfigurable):
If True, an existing Spawner will resume instead of starting anew
(e.g. resuming a Docker container),
and API tokens in use when the Spawner stops will not be deleted.
"""
""",
)
ip = Unicode('',
ip = Unicode(
'',
help="""
The IP address (or hostname) the single-user server should listen on.
The JupyterHub proxy implementation should be able to send packets to this interface.
"""
""",
).tag(config=True)
port = Integer(0,
port = Integer(
0,
help="""
The port for single-user servers to listen on.
@@ -221,7 +241,7 @@ class Spawner(LoggingConfigurable):
e.g. in containers.
New in version 0.7.
"""
""",
).tag(config=True)
consecutive_failure_limit = Integer(
@@ -237,47 +257,48 @@ class Spawner(LoggingConfigurable):
""",
).tag(config=True)
start_timeout = Integer(60,
start_timeout = Integer(
60,
help="""
Timeout (in seconds) before giving up on starting of single-user server.
This is the timeout for start to return, not the timeout for the server to respond.
Callers of spawner.start will assume that startup has failed if it takes longer than this.
start should return when the server process is started and its location is known.
"""
""",
).tag(config=True)
http_timeout = Integer(30,
http_timeout = Integer(
30,
help="""
Timeout (in seconds) before giving up on a spawned HTTP server
Once a server has successfully been spawned, this is the amount of time
we wait before assuming that the server is unable to accept
connections.
"""
""",
).tag(config=True)
poll_interval = Integer(30,
poll_interval = Integer(
30,
help="""
Interval (in seconds) on which to poll the spawner for single-user server's status.
At every poll interval, each spawner's `.poll` method is called, which checks
if the single-user server is still running. If it isn't running, then JupyterHub modifies
its own state accordingly and removes appropriate routes from the configurable proxy.
"""
""",
).tag(config=True)
_callbacks = List()
_poll_callback = Any()
debug = Bool(False,
help="Enable debug-logging of the single-user server"
).tag(config=True)
debug = Bool(False, help="Enable debug-logging of the single-user server").tag(
config=True
)
options_form = Union([
Unicode(),
Callable()
],
options_form = Union(
[Unicode(), Callable()],
help="""
An HTML form for options a user can specify on launching their server.
@@ -303,7 +324,8 @@ class Spawner(LoggingConfigurable):
be called asynchronously if it returns a future, rather than a str. Note that
the interface of the spawner class is not deemed stable across versions,
so using this functionality might cause your JupyterHub upgrades to break.
""").tag(config=True)
""",
).tag(config=True)
async def get_options_form(self):
"""Get the options form
@@ -341,9 +363,11 @@ class Spawner(LoggingConfigurable):
These user options are usually provided by the `options_form` displayed to the user when they start
their server.
""")
"""
)
env_keep = List([
env_keep = List(
[
'PATH',
'PYTHONPATH',
'CONDA_ROOT',
@@ -357,14 +381,16 @@ class Spawner(LoggingConfigurable):
This whitelist is used to ensure that sensitive information in the JupyterHub process's environment
(such as `CONFIGPROXY_AUTH_TOKEN`) is not passed to the single-user server's process.
"""
""",
).tag(config=True)
env = Dict(help="""Deprecated: use Spawner.get_env or Spawner.environment
env = Dict(
help="""Deprecated: use Spawner.get_env or Spawner.environment
- extend Spawner.get_env for adding required env in Spawner subclasses
- Spawner.environment for config-specified env
""")
"""
)
environment = Dict(
help="""
@@ -386,7 +412,8 @@ class Spawner(LoggingConfigurable):
"""
).tag(config=True)
cmd = Command(['jupyterhub-singleuser'],
cmd = Command(
['jupyterhub-singleuser'],
allow_none=True,
help="""
The command used for starting the single-user server.
@@ -399,16 +426,17 @@ class Spawner(LoggingConfigurable):
Some spawners allow shell-style expansion here, allowing you to use environment variables.
Most, including the default, do not. Consult the documentation for your spawner to verify!
"""
""",
).tag(config=True)
args = List(Unicode(),
args = List(
Unicode(),
help="""
Extra arguments to be passed to the single-user server.
Some spawners allow shell-style expansion here, allowing you to use environment variables here.
Most, including the default, do not. Consult the documentation for your spawner to verify!
"""
""",
).tag(config=True)
notebook_dir = Unicode(
@@ -446,14 +474,16 @@ class Spawner(LoggingConfigurable):
def _deprecate_percent_u(self, proposal):
v = proposal['value']
if '%U' in v:
self.log.warning("%%U for username in %s is deprecated in JupyterHub 0.7, use {username}",
self.log.warning(
"%%U for username in %s is deprecated in JupyterHub 0.7, use {username}",
proposal['trait'].name,
)
v = v.replace('%U', '{username}')
self.log.warning("Converting %r to %r", proposal['value'], v)
return v
disable_user_config = Bool(False,
disable_user_config = Bool(
False,
help="""
Disable per-user configuration of single-user servers.
@@ -462,10 +492,11 @@ class Spawner(LoggingConfigurable):
Note: a user could circumvent this if the user modifies their Python environment, such as when
they have their own conda environments / virtualenvs / containers.
"""
""",
).tag(config=True)
mem_limit = ByteSpecification(None,
mem_limit = ByteSpecification(
None,
help="""
Maximum number of bytes a single-user notebook server is allowed to use.
@@ -484,10 +515,11 @@ class Spawner(LoggingConfigurable):
for the limit to work.** The default spawner, `LocalProcessSpawner`,
does **not** implement this support. A custom spawner **must** add
support for this setting for it to be enforced.
"""
""",
).tag(config=True)
cpu_limit = Float(None,
cpu_limit = Float(
None,
allow_none=True,
help="""
Maximum number of cpu-cores a single-user notebook server is allowed to use.
@@ -503,10 +535,11 @@ class Spawner(LoggingConfigurable):
for the limit to work.** The default spawner, `LocalProcessSpawner`,
does **not** implement this support. A custom spawner **must** add
support for this setting for it to be enforced.
"""
""",
).tag(config=True)
mem_guarantee = ByteSpecification(None,
mem_guarantee = ByteSpecification(
None,
help="""
Minimum number of bytes a single-user notebook server is guaranteed to have available.
@@ -520,10 +553,11 @@ class Spawner(LoggingConfigurable):
for the limit to work.** The default spawner, `LocalProcessSpawner`,
does **not** implement this support. A custom spawner **must** add
support for this setting for it to be enforced.
"""
""",
).tag(config=True)
cpu_guarantee = Float(None,
cpu_guarantee = Float(
None,
allow_none=True,
help="""
Minimum number of cpu-cores a single-user notebook server is guaranteed to have available.
@@ -535,7 +569,7 @@ class Spawner(LoggingConfigurable):
for the limit to work.** The default spawner, `LocalProcessSpawner`,
does **not** implement this support. A custom spawner **must** add
support for this setting for it to be enforced.
"""
""",
).tag(config=True)
pre_spawn_hook = Any(
@@ -621,7 +655,9 @@ class Spawner(LoggingConfigurable):
"""
env = {}
if self.env:
warnings.warn("Spawner.env is deprecated, found %s" % self.env, DeprecationWarning)
warnings.warn(
"Spawner.env is deprecated, found %s" % self.env, DeprecationWarning
)
env.update(self.env)
for key in self.env_keep:
@@ -648,8 +684,9 @@ class Spawner(LoggingConfigurable):
if self.cookie_options:
env['JUPYTERHUB_COOKIE_OPTIONS'] = json.dumps(self.cookie_options)
env['JUPYTERHUB_HOST'] = self.hub.public_host
env['JUPYTERHUB_OAUTH_CALLBACK_URL'] = \
url_path_join(self.user.url, self.name, 'oauth_callback')
env['JUPYTERHUB_OAUTH_CALLBACK_URL'] = url_path_join(
self.user.url, self.name, 'oauth_callback'
)
# Info previously passed on args
env['JUPYTERHUB_USER'] = self.user.name
@@ -749,7 +786,7 @@ class Spawner(LoggingConfigurable):
May be set in config if all spawners should have the same value(s),
or set at runtime by Spawner that know their names.
"""
""",
)
@default('ssl_alt_names')
@@ -793,6 +830,7 @@ class Spawner(LoggingConfigurable):
to the host by either IP or DNS name (note the `default_names` below).
"""
from certipy import Certipy
default_names = ["DNS:localhost", "IP:127.0.0.1"]
alt_names = []
alt_names.extend(self.ssl_alt_names)
@@ -800,10 +838,7 @@ class Spawner(LoggingConfigurable):
if self.ssl_alt_names_include_local:
alt_names = default_names + alt_names
self.log.info("Creating certs for %s: %s",
self._log_name,
';'.join(alt_names),
)
self.log.info("Creating certs for %s: %s", self._log_name, ';'.join(alt_names))
common_name = self.user.name or 'service'
certipy = Certipy(store_dir=self.internal_certs_location)
@@ -812,7 +847,7 @@ class Spawner(LoggingConfigurable):
'user-' + common_name,
notebook_component,
alt_names=alt_names,
overwrite=True
overwrite=True,
)
paths = {
"keyfile": notebook_key_pair['files']['key'],
@@ -862,7 +897,9 @@ class Spawner(LoggingConfigurable):
if self.port:
args.append('--port=%i' % self.port)
elif self.server and self.server.port:
self.log.warning("Setting port from user.server is deprecated as of JupyterHub 0.7.")
self.log.warning(
"Setting port from user.server is deprecated as of JupyterHub 0.7."
)
args.append('--port=%i' % self.server.port)
if self.notebook_dir:
@@ -903,13 +940,12 @@ class Spawner(LoggingConfigurable):
This method is always an async generator and will always yield at least one event.
"""
if not self._spawn_pending:
self.log.warning("Spawn not pending, can't generate progress for %s", self._log_name)
self.log.warning(
"Spawn not pending, can't generate progress for %s", self._log_name
)
return
await yield_({
"progress": 0,
"message": "Server requested",
})
await yield_({"progress": 0, "message": "Server requested"})
from async_generator import aclosing
async with aclosing(self.progress()) as progress:
@@ -940,10 +976,7 @@ class Spawner(LoggingConfigurable):
.. versionadded:: 0.9
"""
await yield_({
"progress": 50,
"message": "Spawning server...",
})
await yield_({"progress": 50, "message": "Spawning server..."})
async def start(self):
"""Start the single-user server
@@ -954,7 +987,9 @@ class Spawner(LoggingConfigurable):
.. versionchanged:: 0.7
Return ip, port instead of setting on self.user.server directly.
"""
raise NotImplementedError("Override in subclass. Must be a Tornado gen.coroutine.")
raise NotImplementedError(
"Override in subclass. Must be a Tornado gen.coroutine."
)
async def stop(self, now=False):
"""Stop the single-user server
@@ -967,7 +1002,9 @@ class Spawner(LoggingConfigurable):
Must be a coroutine.
"""
raise NotImplementedError("Override in subclass. Must be a Tornado gen.coroutine.")
raise NotImplementedError(
"Override in subclass. Must be a Tornado gen.coroutine."
)
async def poll(self):
"""Check if the single-user process is running
@@ -993,7 +1030,9 @@ class Spawner(LoggingConfigurable):
process has not yet completed.
"""
raise NotImplementedError("Override in subclass. Must be a Tornado gen.coroutine.")
raise NotImplementedError(
"Override in subclass. Must be a Tornado gen.coroutine."
)
def add_poll_callback(self, callback, *args, **kwargs):
"""Add a callback to fire when the single-user server stops"""
@@ -1023,8 +1062,7 @@ class Spawner(LoggingConfigurable):
self.stop_polling()
self._poll_callback = PeriodicCallback(
self.poll_and_notify,
1e3 * self.poll_interval
self.poll_and_notify, 1e3 * self.poll_interval
)
self._poll_callback.start()
@@ -1048,8 +1086,10 @@ class Spawner(LoggingConfigurable):
return status
death_interval = Float(0.1)
async def wait_for_death(self, timeout=10):
"""Wait for the single-user server to die, up to timeout seconds"""
async def _wait_for_death():
status = await self.poll()
return status is not None
@@ -1093,6 +1133,7 @@ def set_user_setuid(username, chdir=True):
"""
import grp
import pwd
user = pwd.getpwnam(username)
uid = user.pw_uid
gid = user.pw_gid
@@ -1132,29 +1173,32 @@ class LocalProcessSpawner(Spawner):
Note: This spawner does not implement CPU / memory guarantees and limits.
"""
interrupt_timeout = Integer(10,
interrupt_timeout = Integer(
10,
help="""
Seconds to wait for single-user server process to halt after SIGINT.
If the process has not exited cleanly after this many seconds, a SIGTERM is sent.
"""
""",
).tag(config=True)
term_timeout = Integer(5,
term_timeout = Integer(
5,
help="""
Seconds to wait for single-user server process to halt after SIGTERM.
If the process does not exit cleanly after this many seconds of SIGTERM, a SIGKILL is sent.
"""
""",
).tag(config=True)
kill_timeout = Integer(5,
kill_timeout = Integer(
5,
help="""
Seconds to wait for process to halt after SIGKILL before giving up.
If the process does not exit cleanly after this many seconds of SIGKILL, it becomes a zombie
process. The hub process will log a warning and then give up.
"""
""",
).tag(config=True)
popen_kwargs = Dict(
@@ -1168,7 +1212,8 @@ class LocalProcessSpawner(Spawner):
"""
).tag(config=True)
shell_cmd = Command(minlen=0,
shell_cmd = Command(
minlen=0,
help="""Specify a shell command to launch.
The single-user command will be appended to this list,
@@ -1185,20 +1230,23 @@ class LocalProcessSpawner(Spawner):
Using shell_cmd gives users control over PATH, etc.,
which could change what the jupyterhub-singleuser launch command does.
Only use this for trusted users.
"""
""",
).tag(config=True)
proc = Instance(Popen,
proc = Instance(
Popen,
allow_none=True,
help="""
The process representing the single-user server process spawned for current user.
Is None if no process has been spawned yet.
""")
pid = Integer(0,
""",
)
pid = Integer(
0,
help="""
The process id (pid) of the single-user server process spawned for current user.
"""
""",
)
def make_preexec_fn(self, name):
@@ -1236,6 +1284,7 @@ class LocalProcessSpawner(Spawner):
def user_env(self, env):
"""Augment environment of spawned process with user specific env variables."""
import pwd
env['USER'] = self.user.name
home = pwd.getpwnam(self.user.name).pw_dir
shell = pwd.getpwnam(self.user.name).pw_shell
@@ -1267,6 +1316,7 @@ class LocalProcessSpawner(Spawner):
and make them readable by the user.
"""
import pwd
key = paths['keyfile']
cert = paths['certfile']
ca = paths['cafile']
@@ -1324,8 +1374,10 @@ class LocalProcessSpawner(Spawner):
except PermissionError:
# use which to get abspath
script = shutil.which(cmd[0]) or cmd[0]
self.log.error("Permission denied trying to run %r. Does %s have access to this file?",
script, self.user.name,
self.log.error(
"Permission denied trying to run %r. Does %s have access to this file?",
script,
self.user.name,
)
raise
@@ -1445,24 +1497,25 @@ class SimpleLocalProcessSpawner(LocalProcessSpawner):
help="""
Template to expand to set the user home.
{username} is expanded to the jupyterhub username.
"""
""",
)
home_dir = Unicode(help="The home directory for the user")
@default('home_dir')
def _default_home_dir(self):
return self.home_dir_template.format(
username=self.user.name,
)
return self.home_dir_template.format(username=self.user.name)
def make_preexec_fn(self, name):
home = self.home_dir
def preexec():
try:
os.makedirs(home, 0o755, exist_ok=True)
os.chdir(home)
except Exception as e:
self.log.exception("Error in preexec for %s", name)
return preexec
def user_env(self, env):
@@ -1474,4 +1527,3 @@ class SimpleLocalProcessSpawner(LocalProcessSpawner):
def move_certs(self, paths):
"""No-op for installing certs"""
return paths

View File

@@ -23,34 +23,33 @@ Fixtures to add functionality or spawning behavior
- `slow_bad_spawn`
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
from getpass import getuser
import inspect
import logging
import os
import sys
from getpass import getuser
from subprocess import TimeoutExpired
from unittest import mock
from pytest import fixture, raises
from tornado import ioloop, gen
from pytest import fixture
from pytest import raises
from tornado import gen
from tornado import ioloop
from tornado.httpclient import HTTPError
from tornado.platform.asyncio import AsyncIOMainLoop
from .. import orm
from .. import crypto
from ..utils import random_port
from . import mocking
from .mocking import MockHub
from .utils import ssl_setup, add_user
from .test_services import mockservice_cmd
import jupyterhub.services.service
from . import mocking
from .. import crypto
from .. import orm
from ..utils import random_port
from .mocking import MockHub
from .test_services import mockservice_cmd
from .utils import add_user
from .utils import ssl_setup
# global db session object
_db = None
@@ -78,12 +77,7 @@ def app(request, io_loop, ssl_tmpdir):
ssl_enabled = getattr(request.module, "ssl_enabled", False)
kwargs = dict()
if ssl_enabled:
kwargs.update(
dict(
internal_ssl=True,
internal_certs_location=str(ssl_tmpdir),
)
)
kwargs.update(dict(internal_ssl=True, internal_certs_location=str(ssl_tmpdir)))
mocked_app = MockHub.instance(**kwargs)
@@ -107,9 +101,7 @@ def app(request, io_loop, ssl_tmpdir):
@fixture
def auth_state_enabled(app):
app.authenticator.auth_state = {
'who': 'cares',
}
app.authenticator.auth_state = {'who': 'cares'}
app.authenticator.enable_auth_state = True
ck = crypto.CryptKeeper.instance()
before_keys = ck.keys
@@ -128,9 +120,7 @@ def db():
global _db
if _db is None:
_db = orm.new_session_factory('sqlite:///:memory:')()
user = orm.User(
name=getuser(),
)
user = orm.User(name=getuser())
_db.add(user)
_db.commit()
return _db
@@ -221,12 +211,12 @@ def admin_user(app, username):
yield user
class MockServiceSpawner(jupyterhub.services.service._ServiceSpawner):
"""mock services for testing.
Shorter intervals, etc.
"""
poll_interval = 1
@@ -237,11 +227,7 @@ def _mockservice(request, app, url=False):
global _mock_service_counter
_mock_service_counter += 1
name = 'mock-service-%i' % _mock_service_counter
spec = {
'name': name,
'command': mockservice_cmd,
'admin': True,
}
spec = {'name': name, 'command': mockservice_cmd, 'admin': True}
if url:
if app.internal_ssl:
spec['url'] = 'https://127.0.0.1:%i' % random_port()
@@ -250,22 +236,29 @@ def _mockservice(request, app, url=False):
io_loop = app.io_loop
with mock.patch.object(jupyterhub.services.service, '_ServiceSpawner', MockServiceSpawner):
with mock.patch.object(
jupyterhub.services.service, '_ServiceSpawner', MockServiceSpawner
):
app.services = [spec]
app.init_services()
assert name in app._service_map
service = app._service_map[name]
@gen.coroutine
def start():
# wait for proxy to be updated before starting the service
yield app.proxy.add_all_services(app._service_map)
service.start()
io_loop.run_sync(start)
def cleanup():
import asyncio
asyncio.get_event_loop().run_until_complete(service.stop())
app.services[:] = []
app._service_map.clear()
request.addfinalizer(cleanup)
# ensure process finishes starting
with raises(TimeoutExpired):
@@ -290,47 +283,44 @@ def mockservice_url(request, app):
@fixture
def admin_access(app):
"""Grant admin-access with this fixture"""
with mock.patch.dict(app.tornado_settings,
{'admin_access': True}):
with mock.patch.dict(app.tornado_settings, {'admin_access': True}):
yield
@fixture
def no_patience(app):
"""Set slow-spawning timeouts to zero"""
with mock.patch.dict(app.tornado_settings,
{'slow_spawn_timeout': 0.1,
'slow_stop_timeout': 0.1}):
with mock.patch.dict(
app.tornado_settings, {'slow_spawn_timeout': 0.1, 'slow_stop_timeout': 0.1}
):
yield
@fixture
def slow_spawn(app):
"""Fixture enabling SlowSpawner"""
with mock.patch.dict(app.tornado_settings,
{'spawner_class': mocking.SlowSpawner}):
with mock.patch.dict(app.tornado_settings, {'spawner_class': mocking.SlowSpawner}):
yield
@fixture
def never_spawn(app):
"""Fixture enabling NeverSpawner"""
with mock.patch.dict(app.tornado_settings,
{'spawner_class': mocking.NeverSpawner}):
with mock.patch.dict(app.tornado_settings, {'spawner_class': mocking.NeverSpawner}):
yield
@fixture
def bad_spawn(app):
"""Fixture enabling BadSpawner"""
with mock.patch.dict(app.tornado_settings,
{'spawner_class': mocking.BadSpawner}):
with mock.patch.dict(app.tornado_settings, {'spawner_class': mocking.BadSpawner}):
yield
@fixture
def slow_bad_spawn(app):
"""Fixture enabling SlowBadSpawner"""
with mock.patch.dict(app.tornado_settings,
{'spawner_class': mocking.SlowBadSpawner}):
with mock.patch.dict(
app.tornado_settings, {'spawner_class': mocking.SlowBadSpawner}
):
yield

View File

@@ -26,32 +26,36 @@ Other components
- public_url
"""
import asyncio
from concurrent.futures import ThreadPoolExecutor
import os
import sys
from tempfile import NamedTemporaryFile
import threading
from concurrent.futures import ThreadPoolExecutor
from tempfile import NamedTemporaryFile
from unittest import mock
from urllib.parse import urlparse
from pamela import PAMError
from tornado import gen
from tornado.concurrent import Future
from tornado.ioloop import IOLoop
from traitlets import Bool
from traitlets import default
from traitlets import Dict
from traitlets import Bool, Dict, default
from .. import orm
from ..app import JupyterHub
from ..auth import PAMAuthenticator
from .. import orm
from ..objects import Server
from ..spawner import LocalProcessSpawner, SimpleLocalProcessSpawner
from ..singleuser import SingleUserNotebookApp
from ..utils import random_port, url_path_join
from .utils import async_requests, ssl_setup, public_host, public_url
from pamela import PAMError
from ..spawner import LocalProcessSpawner
from ..spawner import SimpleLocalProcessSpawner
from ..utils import random_port
from ..utils import url_path_join
from .utils import async_requests
from .utils import public_host
from .utils import public_url
from .utils import ssl_setup
def mock_authenticate(username, password, service, encoding):
@@ -79,6 +83,7 @@ class MockSpawner(SimpleLocalProcessSpawner):
- disables user-switching that we need root permissions to do
- spawns `jupyterhub.tests.mocksu` instead of a full single-user server
"""
def user_env(self, env):
env = super().user_env(env)
if self.handler:
@@ -90,6 +95,7 @@ class MockSpawner(SimpleLocalProcessSpawner):
return [sys.executable, '-m', 'jupyterhub.tests.mocksu']
use_this_api_token = None
def start(self):
if self.use_this_api_token:
self.api_token = self.use_this_api_token
@@ -103,6 +109,7 @@ class SlowSpawner(MockSpawner):
delay = 2
_start_future = None
@gen.coroutine
def start(self):
(ip, port) = yield super().start()
@@ -140,6 +147,7 @@ class NeverSpawner(MockSpawner):
class BadSpawner(MockSpawner):
"""Spawner that fails immediately"""
def start(self):
raise RuntimeError("I don't work!")
@@ -154,6 +162,7 @@ class SlowBadSpawner(MockSpawner):
class FormSpawner(MockSpawner):
"""A spawner that has an options form defined"""
options_form = "IMAFORM"
def options_from_form(self, form_data):
@@ -167,6 +176,7 @@ class FormSpawner(MockSpawner):
options['hello'] = form_data['hello_file'][0]
return options
class FalsyCallableFormSpawner(FormSpawner):
"""A spawner that has a callable options form defined returning a falsy value"""
@@ -181,6 +191,7 @@ class MockStructGroup:
self.gr_mem = members
self.gr_gid = gid
class MockStructPasswd:
"""Mock pwd.struct_passwd"""
@@ -193,6 +204,7 @@ class MockPAMAuthenticator(PAMAuthenticator):
auth_state = None
# If true, return admin users marked as admin.
return_admin = False
@default('admin_users')
def _admin_users_default(self):
return {'admin'}
@@ -203,20 +215,20 @@ class MockPAMAuthenticator(PAMAuthenticator):
@gen.coroutine
def authenticate(self, *args, **kwargs):
with mock.patch.multiple('pamela',
with mock.patch.multiple(
'pamela',
authenticate=mock_authenticate,
open_session=mock_open_session,
close_session=mock_open_session,
check_account=mock_check_account,
):
username = yield super(MockPAMAuthenticator, self).authenticate(*args, **kwargs)
username = yield super(MockPAMAuthenticator, self).authenticate(
*args, **kwargs
)
if username is None:
return
elif self.auth_state:
return {
'name': username,
'auth_state': self.auth_state,
}
return {'name': username, 'auth_state': self.auth_state}
else:
return username
@@ -349,11 +361,9 @@ class MockHub(JupyterHub):
external_ca = None
if self.internal_ssl:
external_ca = self.external_certs['files']['ca']
r = yield async_requests.post(base_url + 'hub/login',
data={
'username': name,
'password': name,
},
r = yield async_requests.post(
base_url + 'hub/login',
data={'username': name, 'password': name},
allow_redirects=False,
verify=external_ca,
)
@@ -364,6 +374,7 @@ class MockHub(JupyterHub):
# single-user-server mocking:
class MockSingleUserServer(SingleUserNotebookApp):
"""Mock-out problematic parts of single-user server when run in a thread
@@ -378,7 +389,9 @@ class MockSingleUserServer(SingleUserNotebookApp):
class StubSingleUserSpawner(MockSpawner):
"""Spawner that starts a MockSingleUserServer in a thread."""
_thread = None
@gen.coroutine
def start(self):
ip = self.ip = '127.0.0.1'
@@ -387,6 +400,7 @@ class StubSingleUserSpawner(MockSpawner):
args = self.get_args()
evt = threading.Event()
print(args, env)
def _run():
asyncio.set_event_loop(asyncio.new_event_loop())
io_loop = IOLoop()
@@ -420,4 +434,3 @@ class StubSingleUserSpawner(MockSpawner):
return None
else:
return 0

View File

@@ -12,28 +12,33 @@ Handlers and their purpose include:
- WhoAmIHandler: returns name of user making a request (deprecated cookie login)
- OWhoAmIHandler: returns name of user making a request (OAuth login)
"""
import json
import pprint
import os
import pprint
import sys
from urllib.parse import urlparse
import requests
from tornado import web, httpserver, ioloop
from tornado import httpserver
from tornado import ioloop
from tornado import web
from jupyterhub.services.auth import HubAuthenticated, HubOAuthenticated, HubOAuthCallbackHandler
from jupyterhub.services.auth import HubAuthenticated
from jupyterhub.services.auth import HubOAuthCallbackHandler
from jupyterhub.services.auth import HubOAuthenticated
from jupyterhub.utils import make_ssl_context
class EchoHandler(web.RequestHandler):
"""Reply to an HTTP request with the path of the request."""
def get(self):
self.write(self.request.path)
class EnvHandler(web.RequestHandler):
"""Reply to an HTTP request with the service's environment as JSON."""
def get(self):
self.set_header('Content-Type', 'application/json')
self.write(json.dumps(dict(os.environ)))
@@ -41,11 +46,12 @@ class EnvHandler(web.RequestHandler):
class APIHandler(web.RequestHandler):
"""Relay API requests to the Hub's API using the service's API token."""
def get(self, path):
api_token = os.environ['JUPYTERHUB_API_TOKEN']
api_url = os.environ['JUPYTERHUB_API_URL']
r = requests.get(api_url + path,
headers={'Authorization': 'token %s' % api_token},
r = requests.get(
api_url + path, headers={'Authorization': 'token %s' % api_token}
)
r.raise_for_status()
self.set_header('Content-Type', 'application/json')
@@ -57,6 +63,7 @@ class WhoAmIHandler(HubAuthenticated, web.RequestHandler):
Uses "deprecated" cookie login
"""
@web.authenticated
def get(self):
self.write(self.get_current_user())
@@ -67,6 +74,7 @@ class OWhoAmIHandler(HubOAuthenticated, web.RequestHandler):
Uses OAuth login flow
"""
@web.authenticated
def get(self):
self.write(self.get_current_user())
@@ -77,14 +85,17 @@ def main():
if os.getenv('JUPYTERHUB_SERVICE_URL'):
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
app = web.Application([
app = web.Application(
[
(r'.*/env', EnvHandler),
(r'.*/api/(.*)', APIHandler),
(r'.*/whoami/?', WhoAmIHandler),
(r'.*/owhoami/?', OWhoAmIHandler),
(r'.*/oauth_callback', HubOAuthCallbackHandler),
(r'.*', EchoHandler),
], cookie_secret=os.urandom(32))
],
cookie_secret=os.urandom(32),
)
ssl_context = None
key = os.environ.get('JUPYTERHUB_SSL_KEYFILE') or ''
@@ -92,11 +103,7 @@ def main():
ca = os.environ.get('JUPYTERHUB_SSL_CLIENT_CA') or ''
if key and cert and ca:
ssl_context = make_ssl_context(
key,
cert,
cafile = ca,
check_hostname = False)
ssl_context = make_ssl_context(key, cert, cafile=ca, check_hostname=False)
server = httpserver.HTTPServer(app, ssl_options=ssl_context)
server.listen(url.port, url.hostname)
@@ -108,5 +115,6 @@ def main():
if __name__ == '__main__':
from tornado.options import parse_command_line
parse_command_line()
main()

View File

@@ -13,28 +13,32 @@ Handlers and their purpose include:
"""
import argparse
import json
import sys
import os
import sys
from tornado import httpserver
from tornado import ioloop
from tornado import web
from tornado import web, httpserver, ioloop
from .mockservice import EnvHandler
from ..utils import make_ssl_context
from .mockservice import EnvHandler
class EchoHandler(web.RequestHandler):
def get(self):
self.write(self.request.path)
class ArgsHandler(web.RequestHandler):
def get(self):
self.write(json.dumps(sys.argv))
def main(args):
app = web.Application([
(r'.*/args', ArgsHandler),
(r'.*/env', EnvHandler),
(r'.*', EchoHandler),
])
app = web.Application(
[(r'.*/args', ArgsHandler), (r'.*/env', EnvHandler), (r'.*', EchoHandler)]
)
ssl_context = None
key = os.environ.get('JUPYTERHUB_SSL_KEYFILE') or ''
@@ -42,12 +46,7 @@ def main(args):
ca = os.environ.get('JUPYTERHUB_SSL_CLIENT_CA') or ''
if key and cert and ca:
ssl_context = make_ssl_context(
key,
cert,
cafile = ca,
check_hostname = False
)
ssl_context = make_ssl_context(key, cert, cafile=ca, check_hostname=False)
server = httpserver.HTTPServer(app, ssl_options=ssl_context)
server.listen(args.port)
@@ -56,6 +55,7 @@ def main(args):
except KeyboardInterrupt:
print('\nInterrupted')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int)

View File

@@ -4,9 +4,8 @@ Run with old versions of jupyterhub to test upgrade/downgrade
used in test_db.py
"""
from datetime import datetime
import os
from datetime import datetime
import jupyterhub
from jupyterhub import orm
@@ -90,6 +89,7 @@ def populate_db(url):
if __name__ == '__main__':
import sys
if len(sys.argv) > 1:
url = sys.argv[1]
else:

View File

@@ -1,30 +1,32 @@
"""Tests for the REST API."""
import asyncio
from datetime import datetime, timedelta
from concurrent.futures import Future
import json
import re
import sys
from unittest import mock
from urllib.parse import urlparse, quote
import uuid
from async_generator import async_generator, yield_
from concurrent.futures import Future
from datetime import datetime
from datetime import timedelta
from unittest import mock
from urllib.parse import quote
from urllib.parse import urlparse
from async_generator import async_generator
from async_generator import yield_
from pytest import mark
from tornado import gen
import jupyterhub
from .. import orm
from ..utils import url_path_join as ujoin, utcnow
from .mocking import public_host, public_url
from .utils import (
add_user,
api_request,
async_requests,
auth_header,
find_user,
)
from ..utils import url_path_join as ujoin
from ..utils import utcnow
from .mocking import public_host
from .mocking import public_url
from .utils import add_user
from .utils import api_request
from .utils import async_requests
from .utils import auth_header
from .utils import find_user
# --------------------
@@ -48,12 +50,15 @@ async def test_auth_api(app):
assert reply['name'] == user.name
# check fail
r = await api_request(app, 'authorizations/token', api_token,
headers={'Authorization': 'no sir'},
r = await api_request(
app, 'authorizations/token', api_token, headers={'Authorization': 'no sir'}
)
assert r.status_code == 403
r = await api_request(app, 'authorizations/token', api_token,
r = await api_request(
app,
'authorizations/token',
api_token,
headers={'Authorization': 'token: %s' % user.cookie_id},
)
assert r.status_code == 403
@@ -67,37 +72,39 @@ async def test_referer_check(app):
user = add_user(app.db, name='admin', admin=True)
cookies = await app.login_user('admin')
r = await api_request(app, 'users',
headers={
'Authorization': '',
'Referer': 'null',
}, cookies=cookies,
r = await api_request(
app, 'users', headers={'Authorization': '', 'Referer': 'null'}, cookies=cookies
)
assert r.status_code == 403
r = await api_request(app, 'users',
r = await api_request(
app,
'users',
headers={
'Authorization': '',
'Referer': 'http://attack.com/csrf/vulnerability',
}, cookies=cookies,
},
cookies=cookies,
)
assert r.status_code == 403
r = await api_request(app, 'users',
headers={
'Authorization': '',
'Referer': url,
'Host': host,
}, cookies=cookies,
r = await api_request(
app,
'users',
headers={'Authorization': '', 'Referer': url, 'Host': host},
cookies=cookies,
)
assert r.status_code == 200
r = await api_request(app, 'users',
r = await api_request(
app,
'users',
headers={
'Authorization': '',
'Referer': ujoin(url, 'foo/bar/baz/bat'),
'Host': host,
}, cookies=cookies,
},
cookies=cookies,
)
assert r.status_code == 200
@@ -106,6 +113,7 @@ async def test_referer_check(app):
# User API tests
# --------------
def normalize_timestamp(ts):
"""Normalize a timestamp
@@ -128,12 +136,16 @@ def normalize_user(user):
for server in user['servers'].values():
for key in ('started', 'last_activity'):
server[key] = normalize_timestamp(server[key])
server['progress_url'] = re.sub(r'.*/hub/api', 'PREFIX/hub/api', server['progress_url'])
if (isinstance(server['state'], dict)
and isinstance(server['state'].get('pid', None), int)):
server['progress_url'] = re.sub(
r'.*/hub/api', 'PREFIX/hub/api', server['progress_url']
)
if isinstance(server['state'], dict) and isinstance(
server['state'].get('pid', None), int
):
server['state']['pid'] = 0
return user
def fill_user(model):
"""Fill a default user model
@@ -153,6 +165,7 @@ def fill_user(model):
TIMESTAMP = normalize_timestamp(datetime.now().isoformat() + 'Z')
@mark.user
async def test_get_users(app):
db = app.db
@@ -162,20 +175,11 @@ async def test_get_users(app):
users = sorted(r.json(), key=lambda d: d['name'])
users = [normalize_user(u) for u in users]
assert users == [
fill_user({
'name': 'admin',
'admin': True,
}),
fill_user({
'name': 'user',
'admin': False,
'last_activity': None,
}),
fill_user({'name': 'admin', 'admin': True}),
fill_user({'name': 'user', 'admin': False, 'last_activity': None}),
]
r = await api_request(app, 'users',
headers=auth_header(db, 'user'),
)
r = await api_request(app, 'users', headers=auth_header(db, 'user'))
assert r.status_code == 403
@@ -202,17 +206,13 @@ async def test_get_self(app):
)
db.add(oauth_token)
db.commit()
r = await api_request(app, 'user', headers={
'Authorization': 'token ' + token,
})
r = await api_request(app, 'user', headers={'Authorization': 'token ' + token})
r.raise_for_status()
model = r.json()
assert model['name'] == u.name
# invalid auth gets 403
r = await api_request(app, 'user', headers={
'Authorization': 'token notvalid',
})
r = await api_request(app, 'user', headers={'Authorization': 'token notvalid'})
assert r.status_code == 403
@@ -251,8 +251,11 @@ async def test_add_multi_user_bad(app):
@mark.user
async def test_add_multi_user_invalid(app):
app.authenticator.username_pattern = r'w.*'
r = await api_request(app, 'users', method='post',
data=json.dumps({'usernames': ['Willow', 'Andrew', 'Tara']})
r = await api_request(
app,
'users',
method='post',
data=json.dumps({'usernames': ['Willow', 'Andrew', 'Tara']}),
)
app.authenticator.username_pattern = ''
assert r.status_code == 400
@@ -263,8 +266,8 @@ async def test_add_multi_user_invalid(app):
async def test_add_multi_user(app):
db = app.db
names = ['a', 'b']
r = await api_request(app, 'users', method='post',
data=json.dumps({'usernames': names}),
r = await api_request(
app, 'users', method='post', data=json.dumps({'usernames': names})
)
assert r.status_code == 201
reply = r.json()
@@ -278,16 +281,16 @@ async def test_add_multi_user(app):
assert not user.admin
# try to create the same users again
r = await api_request(app, 'users', method='post',
data=json.dumps({'usernames': names}),
r = await api_request(
app, 'users', method='post', data=json.dumps({'usernames': names})
)
assert r.status_code == 409
names = ['a', 'b', 'ab']
# try to create the same users again
r = await api_request(app, 'users', method='post',
data=json.dumps({'usernames': names}),
r = await api_request(
app, 'users', method='post', data=json.dumps({'usernames': names})
)
assert r.status_code == 201
reply = r.json()
@@ -299,7 +302,10 @@ async def test_add_multi_user(app):
async def test_add_multi_user_admin(app):
db = app.db
names = ['c', 'd']
r = await api_request(app, 'users', method='post',
r = await api_request(
app,
'users',
method='post',
data=json.dumps({'usernames': names, 'admin': True}),
)
assert r.status_code == 201
@@ -340,8 +346,8 @@ async def test_add_user_duplicate(app):
async def test_add_admin(app):
db = app.db
name = 'newadmin'
r = await api_request(app, 'users', name, method='post',
data=json.dumps({'admin': True}),
r = await api_request(
app, 'users', name, method='post', data=json.dumps({'admin': True})
)
assert r.status_code == 201
user = find_user(db, name)
@@ -369,8 +375,8 @@ async def test_make_admin(app):
assert user.name == name
assert not user.admin
r = await api_request(app, 'users', name, method='patch',
data=json.dumps({'admin': True})
r = await api_request(
app, 'users', name, method='patch', data=json.dumps({'admin': True})
)
assert r.status_code == 200
user = find_user(db, name)
@@ -388,8 +394,8 @@ async def test_set_auth_state(app, auth_state_enabled):
assert user is not None
assert user.name == name
r = await api_request(app, 'users', name, method='patch',
data=json.dumps({'auth_state': auth_state})
r = await api_request(
app, 'users', name, method='patch', data=json.dumps({'auth_state': auth_state})
)
assert r.status_code == 200
@@ -409,7 +415,10 @@ async def test_user_set_auth_state(app, auth_state_enabled):
assert user_auth_state is None
r = await api_request(
app, 'users', name, method='patch',
app,
'users',
name,
method='patch',
data=json.dumps({'auth_state': auth_state}),
headers=auth_header(app.db, name),
)
@@ -446,8 +455,7 @@ async def test_user_get_auth_state(app, auth_state_enabled):
assert user.name == name
await user.save_auth_state(auth_state)
r = await api_request(app, 'users', name,
headers=auth_header(app.db, name))
r = await api_request(app, 'users', name, headers=auth_header(app.db, name))
assert r.status_code == 200
assert 'auth_state' not in r.json()
@@ -457,13 +465,10 @@ async def test_spawn(app):
db = app.db
name = 'wash'
user = add_user(db, app=app, name=name)
options = {
's': ['value'],
'i': 5,
}
options = {'s': ['value'], 'i': 5}
before_servers = sorted(db.query(orm.Server), key=lambda s: s.url)
r = await api_request(app, 'users', name, 'server', method='post',
data=json.dumps(options),
r = await api_request(
app, 'users', name, 'server', method='post', data=json.dumps(options)
)
assert r.status_code == 201
assert 'pid' in user.orm_spawners[''].state
@@ -520,7 +525,9 @@ async def test_spawn_handler(app):
app_user = app.users[name]
# spawn via API with ?foo=bar
r = await api_request(app, 'users', name, 'server', method='post', params={'foo': 'bar'})
r = await api_request(
app, 'users', name, 'server', method='post', params={'foo': 'bar'}
)
r.raise_for_status()
# verify that request params got passed down
@@ -640,6 +647,7 @@ def next_event(it):
if line.startswith('data:'):
return json.loads(line.split(':', 1)[1])
@mark.slow
async def test_progress(request, app, no_patience, slow_spawn):
db = app.db
@@ -655,15 +663,9 @@ async def test_progress(request, app, no_patience, slow_spawn):
ex = async_requests.executor
line_iter = iter(r.iter_lines(decode_unicode=True))
evt = await ex.submit(next_event, line_iter)
assert evt == {
'progress': 0,
'message': 'Server requested',
}
assert evt == {'progress': 0, 'message': 'Server requested'}
evt = await ex.submit(next_event, line_iter)
assert evt == {
'progress': 50,
'message': 'Spawning server...',
}
assert evt == {'progress': 50, 'message': 'Spawning server...'}
evt = await ex.submit(next_event, line_iter)
url = app_user.url
assert evt == {
@@ -769,10 +771,7 @@ async def test_progress_bad_slow(request, app, no_patience, slow_bad_spawn):
async def progress_forever():
"""progress function that yields messages forever"""
for i in range(1, 10):
await yield_({
'progress': i,
'message': 'Stage %s' % i,
})
await yield_({'progress': i, 'message': 'Stage %s' % i})
# wait a long time before the next event
await gen.sleep(10)
@@ -781,7 +780,8 @@ if sys.version_info >= (3, 6):
# additional progress_forever defined as native
# async generator
# to test for issues with async_generator wrappers
exec("""
exec(
"""
async def progress_forever_native():
for i in range(1, 10):
yield {
@@ -790,7 +790,9 @@ async def progress_forever_native():
}
# wait a long time before the next event
await gen.sleep(10)
""", globals())
""",
globals(),
)
async def test_spawn_progress_cutoff(request, app, no_patience, slow_spawn):
@@ -818,18 +820,14 @@ async def test_spawn_progress_cutoff(request, app, no_patience, slow_spawn):
evt = await ex.submit(next_event, line_iter)
assert evt['progress'] == 0
evt = await ex.submit(next_event, line_iter)
assert evt == {
'progress': 1,
'message': 'Stage 1',
}
assert evt == {'progress': 1, 'message': 'Stage 1'}
evt = await ex.submit(next_event, line_iter)
assert evt['progress'] == 100
async def test_spawn_limit(app, no_patience, slow_spawn, request):
db = app.db
p = mock.patch.dict(app.tornado_settings,
{'concurrent_spawn_limit': 2})
p = mock.patch.dict(app.tornado_settings, {'concurrent_spawn_limit': 2})
p.start()
request.addfinalizer(p.stop)
@@ -875,11 +873,11 @@ async def test_spawn_limit(app, no_patience, slow_spawn, request):
while any(u.spawner.active for u in users):
await gen.sleep(0.1)
@mark.slow
async def test_active_server_limit(app, request):
db = app.db
p = mock.patch.dict(app.tornado_settings,
{'active_server_limit': 2})
p = mock.patch.dict(app.tornado_settings, {'active_server_limit': 2})
p.start()
request.addfinalizer(p.stop)
@@ -932,6 +930,7 @@ async def test_active_server_limit(app, request):
assert counts['ready'] == 0
assert counts['pending'] == 0
@mark.slow
async def test_start_stop_race(app, no_patience, slow_spawn):
user = add_user(app.db, app, name='panda')
@@ -996,22 +995,18 @@ async def test_cookie(app):
cookie_name = app.hub.cookie_name
# cookie jar gives '"cookie-value"', we want 'cookie-value'
cookie = cookies[cookie_name][1:-1]
r = await api_request(app, 'authorizations/cookie',
cookie_name, "nothintoseehere",
)
r = await api_request(app, 'authorizations/cookie', cookie_name, "nothintoseehere")
assert r.status_code == 404
r = await api_request(app, 'authorizations/cookie',
cookie_name, quote(cookie, safe=''),
r = await api_request(
app, 'authorizations/cookie', cookie_name, quote(cookie, safe='')
)
r.raise_for_status()
reply = r.json()
assert reply['name'] == name
# deprecated cookie in body:
r = await api_request(app, 'authorizations/cookie',
cookie_name, data=cookie,
)
r = await api_request(app, 'authorizations/cookie', cookie_name, data=cookie)
r.raise_for_status()
reply = r.json()
assert reply['name'] == name
@@ -1035,15 +1030,11 @@ async def test_check_token(app):
assert r.status_code == 404
@mark.parametrize("headers, status", [
({}, 200),
({'Authorization': 'token bad'}, 403),
])
@mark.parametrize("headers, status", [({}, 200), ({'Authorization': 'token bad'}, 403)])
async def test_get_new_token_deprecated(app, headers, status):
# request a new token
r = await api_request(app, 'authorizations', 'token',
method='post',
headers=headers,
r = await api_request(
app, 'authorizations', 'token', method='post', headers=headers
)
assert r.status_code == status
if status != 200:
@@ -1058,11 +1049,11 @@ async def test_get_new_token_deprecated(app, headers, status):
async def test_token_formdata_deprecated(app):
"""Create a token for a user with formdata and no auth header"""
data = {
'username': 'fake',
'password': 'fake',
}
r = await api_request(app, 'authorizations', 'token',
data = {'username': 'fake', 'password': 'fake'}
r = await api_request(
app,
'authorizations',
'token',
method='post',
data=json.dumps(data) if data else None,
noauth=True,
@@ -1076,22 +1067,26 @@ async def test_token_formdata_deprecated(app):
assert reply['name'] == data['username']
@mark.parametrize("as_user, for_user, status", [
@mark.parametrize(
"as_user, for_user, status",
[
('admin', 'other', 200),
('admin', 'missing', 400),
('user', 'other', 403),
('user', 'user', 200),
])
],
)
async def test_token_as_user_deprecated(app, as_user, for_user, status):
# ensure both users exist
u = add_user(app.db, app, name=as_user)
if for_user != 'missing':
add_user(app.db, app, name=for_user)
data = {'username': for_user}
headers = {
'Authorization': 'token %s' % u.new_api_token(),
}
r = await api_request(app, 'authorizations', 'token',
headers = {'Authorization': 'token %s' % u.new_api_token()}
r = await api_request(
app,
'authorizations',
'token',
method='post',
data=json.dumps(data),
headers=headers,
@@ -1107,11 +1102,14 @@ async def test_token_as_user_deprecated(app, as_user, for_user, status):
assert reply['name'] == data['username']
@mark.parametrize("headers, status, note, expires_in", [
@mark.parametrize(
"headers, status, note, expires_in",
[
({}, 200, 'test note', None),
({}, 200, '', 100),
({'Authorization': 'token bad'}, 403, '', None),
])
],
)
async def test_get_new_token(app, headers, status, note, expires_in):
options = {}
if note:
@@ -1123,10 +1121,8 @@ async def test_get_new_token(app, headers, status, note, expires_in):
else:
body = ''
# request a new token
r = await api_request(app, 'users/admin/tokens',
method='post',
headers=headers,
data=body,
r = await api_request(
app, 'users/admin/tokens', method='post', headers=headers, data=body
)
assert r.status_code == status
if status != 200:
@@ -1157,30 +1153,34 @@ async def test_get_new_token(app, headers, status, note, expires_in):
assert normalize_token(reply) == initial
# delete the token
r = await api_request(app, 'users/admin/tokens', token_id,
method='delete')
r = await api_request(app, 'users/admin/tokens', token_id, method='delete')
assert r.status_code == 204
# verify deletion
r = await api_request(app, 'users/admin/tokens', token_id)
assert r.status_code == 404
@mark.parametrize("as_user, for_user, status", [
@mark.parametrize(
"as_user, for_user, status",
[
('admin', 'other', 200),
('admin', 'missing', 404),
('user', 'other', 403),
('user', 'user', 200),
])
],
)
async def test_token_for_user(app, as_user, for_user, status):
# ensure both users exist
u = add_user(app.db, app, name=as_user)
if for_user != 'missing':
add_user(app.db, app, name=for_user)
data = {'username': for_user}
headers = {
'Authorization': 'token %s' % u.new_api_token(),
}
r = await api_request(app, 'users', for_user, 'tokens',
headers = {'Authorization': 'token %s' % u.new_api_token()}
r = await api_request(
app,
'users',
for_user,
'tokens',
method='post',
data=json.dumps(data),
headers=headers,
@@ -1191,9 +1191,7 @@ async def test_token_for_user(app, as_user, for_user, status):
return
assert 'token' in reply
token_id = reply['id']
r = await api_request(app, 'users', for_user, 'tokens', token_id,
headers=headers,
)
r = await api_request(app, 'users', for_user, 'tokens', token_id, headers=headers)
r.raise_for_status()
reply = r.json()
assert reply['user'] == for_user
@@ -1203,30 +1201,25 @@ async def test_token_for_user(app, as_user, for_user, status):
note = 'Requested via api by user %s' % as_user
assert reply['note'] == note
# delete the token
r = await api_request(app, 'users', for_user, 'tokens', token_id,
method='delete',
headers=headers,
r = await api_request(
app, 'users', for_user, 'tokens', token_id, method='delete', headers=headers
)
assert r.status_code == 204
r = await api_request(app, 'users', for_user, 'tokens', token_id,
headers=headers,
)
r = await api_request(app, 'users', for_user, 'tokens', token_id, headers=headers)
assert r.status_code == 404
async def test_token_authenticator_noauth(app):
"""Create a token for a user relying on Authenticator.authenticate and no auth header"""
name = 'user'
data = {
'auth': {
'username': name,
'password': name,
},
}
r = await api_request(app, 'users', name, 'tokens',
data = {'auth': {'username': name, 'password': name}}
r = await api_request(
app,
'users',
name,
'tokens',
method='post',
data=json.dumps(data) if data else None,
noauth=True,
@@ -1242,17 +1235,14 @@ async def test_token_authenticator_noauth(app):
async def test_token_authenticator_dict_noauth(app):
"""Create a token for a user relying on Authenticator.authenticate and no auth header"""
app.authenticator.auth_state = {
'who': 'cares',
}
app.authenticator.auth_state = {'who': 'cares'}
name = 'user'
data = {
'auth': {
'username': name,
'password': name,
},
}
r = await api_request(app, 'users', name, 'tokens',
data = {'auth': {'username': name, 'password': name}}
r = await api_request(
app,
'users',
name,
'tokens',
method='post',
data=json.dumps(data) if data else None,
noauth=True,
@@ -1266,22 +1256,21 @@ async def test_token_authenticator_dict_noauth(app):
assert reply['name'] == name
@mark.parametrize("as_user, for_user, status", [
@mark.parametrize(
"as_user, for_user, status",
[
('admin', 'other', 200),
('admin', 'missing', 404),
('user', 'other', 403),
('user', 'user', 200),
])
],
)
async def test_token_list(app, as_user, for_user, status):
u = add_user(app.db, app, name=as_user)
if for_user != 'missing':
for_user_obj = add_user(app.db, app, name=for_user)
headers = {
'Authorization': 'token %s' % u.new_api_token(),
}
r = await api_request(app, 'users', for_user, 'tokens',
headers=headers,
)
headers = {'Authorization': 'token %s' % u.new_api_token()}
r = await api_request(app, 'users', for_user, 'tokens', headers=headers)
assert r.status_code == status
if status != 200:
return
@@ -1292,8 +1281,8 @@ async def test_token_list(app, as_user, for_user, status):
assert all(token['user'] == for_user for token in reply['oauth_tokens'])
# validate individual token ids
for token in reply['api_tokens'] + reply['oauth_tokens']:
r = await api_request(app, 'users', for_user, 'tokens', token['id'],
headers=headers,
r = await api_request(
app, 'users', for_user, 'tokens', token['id'], headers=headers
)
r.raise_for_status()
reply = r.json()
@@ -1320,19 +1309,15 @@ async def test_groups_list(app):
r = await api_request(app, 'groups')
r.raise_for_status()
reply = r.json()
assert reply == [{
'kind': 'group',
'name': 'alphaflight',
'users': []
}]
assert reply == [{'kind': 'group', 'name': 'alphaflight', 'users': []}]
@mark.group
async def test_add_multi_group(app):
db = app.db
names = ['group1', 'group2']
r = await api_request(app, 'groups', method='post',
data=json.dumps({'groups': names}),
r = await api_request(
app, 'groups', method='post', data=json.dumps({'groups': names})
)
assert r.status_code == 201
reply = r.json()
@@ -1340,8 +1325,8 @@ async def test_add_multi_group(app):
assert names == r_names
# try to create the same groups again
r = await api_request(app, 'groups', method='post',
data=json.dumps({'groups': names}),
r = await api_request(
app, 'groups', method='post', data=json.dumps({'groups': names})
)
assert r.status_code == 409
@@ -1359,11 +1344,7 @@ async def test_group_get(app):
r = await api_request(app, 'groups/alphaflight')
r.raise_for_status()
reply = r.json()
assert reply == {
'kind': 'group',
'name': 'alphaflight',
'users': ['sasquatch']
}
assert reply == {'kind': 'group', 'name': 'alphaflight', 'users': ['sasquatch']}
@mark.group
@@ -1372,13 +1353,16 @@ async def test_group_create_delete(app):
r = await api_request(app, 'groups/runaways', method='delete')
assert r.status_code == 404
r = await api_request(app, 'groups/new', method='post',
data=json.dumps({'users': ['doesntexist']}),
r = await api_request(
app, 'groups/new', method='post', data=json.dumps({'users': ['doesntexist']})
)
assert r.status_code == 400
assert orm.Group.find(db, name='new') is None
r = await api_request(app, 'groups/omegaflight', method='post',
r = await api_request(
app,
'groups/omegaflight',
method='post',
data=json.dumps({'users': ['sasquatch']}),
)
r.raise_for_status()
@@ -1410,10 +1394,15 @@ async def test_group_add_users(app):
assert r.status_code == 400
names = ['aurora', 'guardian', 'northstar', 'sasquatch', 'shaman', 'snowbird']
users = [ find_user(db, name=name) or add_user(db, app=app, name=name) for name in names ]
r = await api_request(app, 'groups/alphaflight/users', method='post', data=json.dumps({
'users': names,
}))
users = [
find_user(db, name=name) or add_user(db, app=app, name=name) for name in names
]
r = await api_request(
app,
'groups/alphaflight/users',
method='post',
data=json.dumps({'users': names}),
)
r.raise_for_status()
for user in users:
@@ -1433,9 +1422,12 @@ async def test_group_delete_users(app):
names = ['aurora', 'guardian', 'northstar', 'sasquatch', 'shaman', 'snowbird']
users = [find_user(db, name=name) for name in names]
r = await api_request(app, 'groups/alphaflight/users', method='delete', data=json.dumps({
'users': names[:2],
}))
r = await api_request(
app,
'groups/alphaflight/users',
method='delete',
data=json.dumps({'users': names[:2]}),
)
r.raise_for_status()
for user in users[:2]:
@@ -1473,9 +1465,7 @@ async def test_get_services(app, mockservice_url):
}
}
r = await api_request(app, 'services',
headers=auth_header(db, 'user'),
)
r = await api_request(app, 'services', headers=auth_header(db, 'user'))
assert r.status_code == 403
@@ -1498,14 +1488,14 @@ async def test_get_service(app, mockservice_url):
'info': {},
}
r = await api_request(app, 'services/%s' % mockservice.name,
headers={
'Authorization': 'token %s' % mockservice.api_token
}
r = await api_request(
app,
'services/%s' % mockservice.name,
headers={'Authorization': 'token %s' % mockservice.api_token},
)
r.raise_for_status()
r = await api_request(app, 'services/%s' % mockservice.name,
headers=auth_header(db, 'user'),
r = await api_request(
app, 'services/%s' % mockservice.name, headers=auth_header(db, 'user')
)
assert r.status_code == 403
@@ -1519,9 +1509,7 @@ async def test_root_api(app):
kwargs["verify"] = app.internal_ssl_ca
r = await async_requests.get(url, **kwargs)
r.raise_for_status()
expected = {
'version': jupyterhub.__version__
}
expected = {'version': jupyterhub.__version__}
assert r.json() == expected
@@ -1662,15 +1650,20 @@ def test_shutdown(app):
# which makes gen_test unhappy. So we run the loop ourselves.
async def shutdown():
r = await api_request(app, 'shutdown', method='post',
data=json.dumps({'servers': True, 'proxy': True,}),
r = await api_request(
app,
'shutdown',
method='post',
data=json.dumps({'servers': True, 'proxy': True}),
)
return r
real_stop = loop.stop
def stop():
stop.called = True
loop.call_later(1, real_stop)
with mock.patch.object(loop, 'stop', stop):
r = loop.run_sync(shutdown, timeout=5)
r.raise_for_status()

View File

@@ -1,25 +1,30 @@
"""Test the JupyterHub entry point"""
import binascii
import os
import re
import sys
from subprocess import check_output, Popen, PIPE
from tempfile import NamedTemporaryFile, TemporaryDirectory
from subprocess import check_output
from subprocess import PIPE
from subprocess import Popen
from tempfile import NamedTemporaryFile
from tempfile import TemporaryDirectory
from unittest.mock import patch
import pytest
from tornado import gen
from traitlets.config import Config
from .. import orm
from ..app import COOKIE_SECRET_BYTES
from ..app import JupyterHub
from .mocking import MockHub
from .test_api import add_user
from .. import orm
from ..app import COOKIE_SECRET_BYTES, JupyterHub
def test_help_all():
out = check_output([sys.executable, '-m', 'jupyterhub', '--help-all']).decode('utf8', 'replace')
out = check_output([sys.executable, '-m', 'jupyterhub', '--help-all']).decode(
'utf8', 'replace'
)
assert '--ip' in out
assert '--JupyterHub.ip' in out
@@ -39,9 +44,11 @@ def test_generate_config():
cfg_file = tf.name
with open(cfg_file, 'w') as f:
f.write("c.A = 5")
p = Popen([sys.executable, '-m', 'jupyterhub',
'--generate-config', '-f', cfg_file],
stdout=PIPE, stdin=PIPE)
p = Popen(
[sys.executable, '-m', 'jupyterhub', '--generate-config', '-f', cfg_file],
stdout=PIPE,
stdin=PIPE,
)
out, _ = p.communicate(b'n')
out = out.decode('utf8', 'replace')
assert os.path.exists(cfg_file)
@@ -49,9 +56,11 @@ def test_generate_config():
cfg_text = f.read()
assert cfg_text == 'c.A = 5'
p = Popen([sys.executable, '-m', 'jupyterhub',
'--generate-config', '-f', cfg_file],
stdout=PIPE, stdin=PIPE)
p = Popen(
[sys.executable, '-m', 'jupyterhub', '--generate-config', '-f', cfg_file],
stdout=PIPE,
stdin=PIPE,
)
out, _ = p.communicate(b'x\ny')
out = out.decode('utf8', 'replace')
assert os.path.exists(cfg_file)
@@ -192,9 +201,13 @@ async def test_load_groups(tmpdir, request):
async def test_resume_spawners(tmpdir, request):
if not os.getenv('JUPYTERHUB_TEST_DB_URL'):
p = patch.dict(os.environ, {
'JUPYTERHUB_TEST_DB_URL': 'sqlite:///%s' % tmpdir.join('jupyterhub.sqlite'),
})
p = patch.dict(
os.environ,
{
'JUPYTERHUB_TEST_DB_URL': 'sqlite:///%s'
% tmpdir.join('jupyterhub.sqlite')
},
)
p.start()
request.addfinalizer(p.stop)
@@ -253,32 +266,18 @@ async def test_resume_spawners(tmpdir, request):
@pytest.mark.parametrize(
'hub_config, expected',
[
(
{'ip': '0.0.0.0'},
{'bind_url': 'http://0.0.0.0:8000/'},
),
({'ip': '0.0.0.0'}, {'bind_url': 'http://0.0.0.0:8000/'}),
(
{'port': 123, 'base_url': '/prefix'},
{
'bind_url': 'http://:123/prefix/',
'base_url': '/prefix/',
},
),
(
{'bind_url': 'http://0.0.0.0:12345/sub'},
{'base_url': '/sub/'},
{'bind_url': 'http://:123/prefix/', 'base_url': '/prefix/'},
),
({'bind_url': 'http://0.0.0.0:12345/sub'}, {'base_url': '/sub/'}),
(
# no config, test defaults
{},
{
'base_url': '/',
'bind_url': 'http://:8000',
'ip': '',
'port': 8000,
},
{'base_url': '/', 'bind_url': 'http://:8000', 'ip': '', 'port': 8000},
),
]
],
)
def test_url_config(hub_config, expected):
# construct the config object

View File

@@ -1,59 +1,59 @@
"""Tests for PAM authentication"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
from unittest import mock
import pytest
from requests import HTTPError
from jupyterhub import auth, crypto, orm
from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd
from .mocking import MockPAMAuthenticator
from .mocking import MockStructGroup
from .mocking import MockStructPasswd
from .utils import add_user
from jupyterhub import auth
from jupyterhub import crypto
from jupyterhub import orm
async def test_pam_auth():
authenticator = MockPAMAuthenticator()
authorized = await authenticator.get_authenticated_user(None, {
'username': 'match',
'password': 'match',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'match', 'password': 'match'}
)
assert authorized['name'] == 'match'
authorized = await authenticator.get_authenticated_user(None, {
'username': 'match',
'password': 'nomatch',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'match', 'password': 'nomatch'}
)
assert authorized is None
# Account check is on by default for increased security
authorized = await authenticator.get_authenticated_user(None, {
'username': 'notallowedmatch',
'password': 'notallowedmatch',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'notallowedmatch', 'password': 'notallowedmatch'}
)
assert authorized is None
async def test_pam_auth_account_check_disabled():
authenticator = MockPAMAuthenticator(check_account=False)
authorized = await authenticator.get_authenticated_user(None, {
'username': 'allowedmatch',
'password': 'allowedmatch',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'allowedmatch', 'password': 'allowedmatch'}
)
assert authorized['name'] == 'allowedmatch'
authorized = await authenticator.get_authenticated_user(None, {
'username': 'notallowedmatch',
'password': 'notallowedmatch',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'notallowedmatch', 'password': 'notallowedmatch'}
)
assert authorized['name'] == 'notallowedmatch'
async def test_pam_auth_admin_groups():
jh_users = MockStructGroup('jh_users', ['group_admin', 'also_group_admin', 'override_admin', 'non_admin'], 1234)
jh_users = MockStructGroup(
'jh_users',
['group_admin', 'also_group_admin', 'override_admin', 'non_admin'],
1234,
)
jh_admins = MockStructGroup('jh_admins', ['group_admin'], 5678)
wheel = MockStructGroup('wheel', ['also_group_admin'], 9999)
system_groups = [jh_users, jh_admins, wheel]
@@ -68,7 +68,7 @@ async def test_pam_auth_admin_groups():
'group_admin': [jh_users.gr_gid, jh_admins.gr_gid],
'also_group_admin': [jh_users.gr_gid, wheel.gr_gid],
'override_admin': [jh_users.gr_gid],
'non_admin': [jh_users.gr_gid]
'non_admin': [jh_users.gr_gid],
}
def getgrnam(name):
@@ -80,76 +80,78 @@ async def test_pam_auth_admin_groups():
def getgrouplist(name, group):
return user_group_map[name]
authenticator = MockPAMAuthenticator(admin_groups={'jh_admins', 'wheel'},
admin_users={'override_admin'})
authenticator = MockPAMAuthenticator(
admin_groups={'jh_admins', 'wheel'}, admin_users={'override_admin'}
)
# Check admin_group applies as expected
with mock.patch.multiple(authenticator,
with mock.patch.multiple(
authenticator,
_getgrnam=getgrnam,
_getpwnam=getpwnam,
_getgrouplist=getgrouplist):
authorized = await authenticator.get_authenticated_user(None, {
'username': 'group_admin',
'password': 'group_admin'
})
_getgrouplist=getgrouplist,
):
authorized = await authenticator.get_authenticated_user(
None, {'username': 'group_admin', 'password': 'group_admin'}
)
assert authorized['name'] == 'group_admin'
assert authorized['admin'] is True
# Check multiple groups work, just in case.
with mock.patch.multiple(authenticator,
with mock.patch.multiple(
authenticator,
_getgrnam=getgrnam,
_getpwnam=getpwnam,
_getgrouplist=getgrouplist):
authorized = await authenticator.get_authenticated_user(None, {
'username': 'also_group_admin',
'password': 'also_group_admin'
})
_getgrouplist=getgrouplist,
):
authorized = await authenticator.get_authenticated_user(
None, {'username': 'also_group_admin', 'password': 'also_group_admin'}
)
assert authorized['name'] == 'also_group_admin'
assert authorized['admin'] is True
# Check admin_users still applies correctly
with mock.patch.multiple(authenticator,
with mock.patch.multiple(
authenticator,
_getgrnam=getgrnam,
_getpwnam=getpwnam,
_getgrouplist=getgrouplist):
authorized = await authenticator.get_authenticated_user(None, {
'username': 'override_admin',
'password': 'override_admin'
})
_getgrouplist=getgrouplist,
):
authorized = await authenticator.get_authenticated_user(
None, {'username': 'override_admin', 'password': 'override_admin'}
)
assert authorized['name'] == 'override_admin'
assert authorized['admin'] is True
# Check it doesn't admin everyone
with mock.patch.multiple(authenticator,
with mock.patch.multiple(
authenticator,
_getgrnam=getgrnam,
_getpwnam=getpwnam,
_getgrouplist=getgrouplist):
authorized = await authenticator.get_authenticated_user(None, {
'username': 'non_admin',
'password': 'non_admin'
})
_getgrouplist=getgrouplist,
):
authorized = await authenticator.get_authenticated_user(
None, {'username': 'non_admin', 'password': 'non_admin'}
)
assert authorized['name'] == 'non_admin'
assert authorized['admin'] is False
async def test_pam_auth_whitelist():
authenticator = MockPAMAuthenticator(whitelist={'wash', 'kaylee'})
authorized = await authenticator.get_authenticated_user(None, {
'username': 'kaylee',
'password': 'kaylee',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'kaylee', 'password': 'kaylee'}
)
assert authorized['name'] == 'kaylee'
authorized = await authenticator.get_authenticated_user(None, {
'username': 'wash',
'password': 'nomatch',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'wash', 'password': 'nomatch'}
)
assert authorized is None
authorized = await authenticator.get_authenticated_user(None, {
'username': 'mal',
'password': 'mal',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'mal', 'password': 'mal'}
)
assert authorized is None
@@ -160,80 +162,78 @@ async def test_pam_auth_group_whitelist():
authenticator = MockPAMAuthenticator(group_whitelist={'group'})
with mock.patch.object(authenticator, '_getgrnam', getgrnam):
authorized = await authenticator.get_authenticated_user(None, {
'username': 'kaylee',
'password': 'kaylee',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'kaylee', 'password': 'kaylee'}
)
assert authorized['name'] == 'kaylee'
with mock.patch.object(authenticator, '_getgrnam', getgrnam):
authorized = await authenticator.get_authenticated_user(None, {
'username': 'mal',
'password': 'mal',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'mal', 'password': 'mal'}
)
assert authorized is None
async def test_pam_auth_blacklist():
# Null case compared to next case
authenticator = MockPAMAuthenticator()
authorized = await authenticator.get_authenticated_user(None, {
'username': 'wash',
'password': 'wash',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'wash', 'password': 'wash'}
)
assert authorized['name'] == 'wash'
# Blacklist basics
authenticator = MockPAMAuthenticator(blacklist={'wash'})
authorized = await authenticator.get_authenticated_user(None, {
'username': 'wash',
'password': 'wash',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'wash', 'password': 'wash'}
)
assert authorized is None
# User in both white and blacklists: default deny. Make error someday?
authenticator = MockPAMAuthenticator(blacklist={'wash'}, whitelist={'wash', 'kaylee'})
authorized = await authenticator.get_authenticated_user(None, {
'username': 'wash',
'password': 'wash',
})
authenticator = MockPAMAuthenticator(
blacklist={'wash'}, whitelist={'wash', 'kaylee'}
)
authorized = await authenticator.get_authenticated_user(
None, {'username': 'wash', 'password': 'wash'}
)
assert authorized is None
# User not in blacklist can log in
authenticator = MockPAMAuthenticator(blacklist={'wash'}, whitelist={'wash', 'kaylee'})
authorized = await authenticator.get_authenticated_user(None, {
'username': 'kaylee',
'password': 'kaylee',
})
authenticator = MockPAMAuthenticator(
blacklist={'wash'}, whitelist={'wash', 'kaylee'}
)
authorized = await authenticator.get_authenticated_user(
None, {'username': 'kaylee', 'password': 'kaylee'}
)
assert authorized['name'] == 'kaylee'
# User in whitelist, blacklist irrelevent
authenticator = MockPAMAuthenticator(blacklist={'mal'}, whitelist={'wash', 'kaylee'})
authorized = await authenticator.get_authenticated_user(None, {
'username': 'wash',
'password': 'wash',
})
authenticator = MockPAMAuthenticator(
blacklist={'mal'}, whitelist={'wash', 'kaylee'}
)
authorized = await authenticator.get_authenticated_user(
None, {'username': 'wash', 'password': 'wash'}
)
assert authorized['name'] == 'wash'
# User in neither list
authenticator = MockPAMAuthenticator(blacklist={'mal'}, whitelist={'wash', 'kaylee'})
authorized = await authenticator.get_authenticated_user(None, {
'username': 'simon',
'password': 'simon',
})
authenticator = MockPAMAuthenticator(
blacklist={'mal'}, whitelist={'wash', 'kaylee'}
)
authorized = await authenticator.get_authenticated_user(
None, {'username': 'simon', 'password': 'simon'}
)
assert authorized is None
# blacklist == {}
authenticator = MockPAMAuthenticator(blacklist=set(), whitelist={'wash', 'kaylee'})
authorized = await authenticator.get_authenticated_user(None, {
'username': 'kaylee',
'password': 'kaylee',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'kaylee', 'password': 'kaylee'}
)
assert authorized['name'] == 'kaylee'
async def test_deprecated_signatures():
def deprecated_xlist(self, username):
return True
@@ -244,20 +244,18 @@ async def test_deprecated_signatures():
check_blacklist=deprecated_xlist,
):
deprecated_authenticator = MockPAMAuthenticator()
authorized = await deprecated_authenticator.get_authenticated_user(None, {
'username': 'test',
'password': 'test'
})
authorized = await deprecated_authenticator.get_authenticated_user(
None, {'username': 'test', 'password': 'test'}
)
assert authorized is not None
async def test_pam_auth_no_such_group():
authenticator = MockPAMAuthenticator(group_whitelist={'nosuchcrazygroup'})
authorized = await authenticator.get_authenticated_user(None, {
'username': 'kaylee',
'password': 'kaylee',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'kaylee', 'password': 'kaylee'}
)
assert authorized is None
@@ -302,6 +300,7 @@ async def test_add_system_user():
authenticator.add_user_cmd = ['echo', '/home/USERNAME']
record = {}
class DummyPopen:
def __init__(self, cmd, *args, **kwargs):
record['cmd'] = cmd
@@ -402,44 +401,38 @@ async def test_auth_state_disabled(app, auth_state_unavailable):
async def test_normalize_names():
a = MockPAMAuthenticator()
authorized = await a.get_authenticated_user(None, {
'username': 'ZOE',
'password': 'ZOE',
})
authorized = await a.get_authenticated_user(
None, {'username': 'ZOE', 'password': 'ZOE'}
)
assert authorized['name'] == 'zoe'
authorized = await a.get_authenticated_user(None, {
'username': 'Glenn',
'password': 'Glenn',
})
authorized = await a.get_authenticated_user(
None, {'username': 'Glenn', 'password': 'Glenn'}
)
assert authorized['name'] == 'glenn'
authorized = await a.get_authenticated_user(None, {
'username': 'hExi',
'password': 'hExi',
})
authorized = await a.get_authenticated_user(
None, {'username': 'hExi', 'password': 'hExi'}
)
assert authorized['name'] == 'hexi'
authorized = await a.get_authenticated_user(None, {
'username': 'Test',
'password': 'Test',
})
authorized = await a.get_authenticated_user(
None, {'username': 'Test', 'password': 'Test'}
)
assert authorized['name'] == 'test'
async def test_username_map():
a = MockPAMAuthenticator(username_map={'wash': 'alpha'})
authorized = await a.get_authenticated_user(None, {
'username': 'WASH',
'password': 'WASH',
})
authorized = await a.get_authenticated_user(
None, {'username': 'WASH', 'password': 'WASH'}
)
assert authorized['name'] == 'alpha'
authorized = await a.get_authenticated_user(None, {
'username': 'Inara',
'password': 'Inara',
})
authorized = await a.get_authenticated_user(
None, {'username': 'Inara', 'password': 'Inara'}
)
assert authorized['name'] == 'inara'
@@ -463,9 +456,8 @@ def test_post_auth_hook():
a = MockPAMAuthenticator(post_auth_hook=test_auth_hook)
authorized = yield a.get_authenticated_user(None, {
'username': 'test_user',
'password': 'test_user'
})
authorized = yield a.get_authenticated_user(
None, {'username': 'test_user', 'password': 'test_user'}
)
assert authorized['testkey'] == 'testvalue'

View File

@@ -7,15 +7,16 @@ authentication can expire in a number of ways:
- doesn't need refresh
- needs refresh and cannot be refreshed without new login
"""
import asyncio
from contextlib import contextmanager
from unittest import mock
from urllib.parse import parse_qs, urlparse
from urllib.parse import parse_qs
from urllib.parse import urlparse
import pytest
from .utils import api_request, get_page
from .utils import api_request
from .utils import get_page
async def refresh_expired(authenticator, user):

View File

@@ -1,24 +1,29 @@
from binascii import b2a_hex, b2a_base64
import os
import pytest
from binascii import b2a_base64
from binascii import b2a_hex
from unittest.mock import patch
import pytest
from .. import crypto
from ..crypto import encrypt, decrypt
from ..crypto import decrypt
from ..crypto import encrypt
keys = [('%i' % i).encode('ascii') * 32 for i in range(3)]
hex_keys = [b2a_hex(key).decode('ascii') for key in keys]
b64_keys = [b2a_base64(key).decode('ascii').strip() for key in keys]
@pytest.mark.parametrize("key_env, keys", [
@pytest.mark.parametrize(
"key_env, keys",
[
(hex_keys[0], [keys[0]]),
(';'.join([b64_keys[0], hex_keys[1]]), keys[:2]),
(';'.join([hex_keys[0], b64_keys[1], '']), keys[:2]),
('', []),
(';', []),
])
],
)
def test_env_constructor(key_env, keys):
with patch.dict(os.environ, {crypto.KEY_ENV: key_env}):
ck = crypto.CryptKeeper()
@@ -29,12 +34,15 @@ def test_env_constructor(key_env, keys):
assert ck.fernet is None
@pytest.mark.parametrize("key", [
@pytest.mark.parametrize(
"key",
[
'a' * 44, # base64, not 32 bytes
('%44s' % 'notbase64'), # not base64
b'x' * 64, # not hex
b'short', # not 32 bytes
])
],
)
def test_bad_keys(key):
ck = crypto.CryptKeeper()
with pytest.raises(ValueError):
@@ -76,4 +84,3 @@ async def test_missing_keys(crypt_keeper):
with pytest.raises(crypto.NoEncryptionKeys):
await decrypt(b'whatever')

View File

@@ -1,14 +1,16 @@
from glob import glob
import os
from subprocess import check_call
import sys
import tempfile
from glob import glob
from subprocess import check_call
import pytest
from pytest import raises
from traitlets.config import Config
from ..app import NewToken, UpgradeDB, JupyterHub
from ..app import JupyterHub
from ..app import NewToken
from ..app import UpgradeDB
here = os.path.abspath(os.path.dirname(__file__))
@@ -33,13 +35,7 @@ def generate_old_db(env_dir, hub_version, db_url):
check_call([env_py, populate_db, db_url])
@pytest.mark.parametrize(
'hub_version',
[
'0.7.2',
'0.8.1',
],
)
@pytest.mark.parametrize('hub_version', ['0.7.2', '0.8.1'])
async def test_upgrade(tmpdir, hub_version):
db_url = os.getenv('JUPYTERHUB_TEST_DB_URL')
if db_url:

View File

@@ -1,52 +1,47 @@
"""Tests for dummy authentication"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import pytest
from jupyterhub.auth import DummyAuthenticator
async def test_dummy_auth_without_global_password():
authenticator = DummyAuthenticator()
authorized = await authenticator.get_authenticated_user(None, {
'username': 'test_user',
'password': 'test_pass',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'test_user', 'password': 'test_pass'}
)
assert authorized['name'] == 'test_user'
authorized = await authenticator.get_authenticated_user(None, {
'username': 'test_user',
'password': '',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'test_user', 'password': ''}
)
assert authorized['name'] == 'test_user'
async def test_dummy_auth_without_username():
authenticator = DummyAuthenticator()
authorized = await authenticator.get_authenticated_user(None, {
'username': '',
'password': 'test_pass',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': '', 'password': 'test_pass'}
)
assert authorized is None
async def test_dummy_auth_with_global_password():
authenticator = DummyAuthenticator()
authenticator.password = "test_password"
authorized = await authenticator.get_authenticated_user(None, {
'username': 'test_user',
'password': 'test_password',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'test_user', 'password': 'test_password'}
)
assert authorized['name'] == 'test_user'
authorized = await authenticator.get_authenticated_user(None, {
'username': 'test_user',
'password': 'qwerty',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'test_user', 'password': 'qwerty'}
)
assert authorized is None
authorized = await authenticator.get_authenticated_user(None, {
'username': 'some_other_user',
'password': 'test_password',
})
authorized = await authenticator.get_authenticated_user(
None, {'username': 'some_other_user', 'password': 'test_password'}
)
assert authorized['name'] == 'some_other_user'

View File

@@ -1,7 +1,6 @@
"""Tests for the SSL enabled REST API."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from jupyterhub.tests.test_api import *
ssl_enabled = True

View File

@@ -1,10 +1,9 @@
"""Test the JupyterHub entry point with internal ssl"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import sys
import jupyterhub.tests.mocking
import jupyterhub.tests.mocking
from jupyterhub.tests.test_app import *
ssl_enabled = True

View File

@@ -1,19 +1,17 @@
"""Tests for jupyterhub internal_ssl connections"""
import sys
import time
from subprocess import check_output
import sys
from unittest import mock
from urllib.parse import urlparse
import pytest
from requests.exceptions import SSLError
from tornado import gen
import jupyterhub
from tornado import gen
from unittest import mock
from requests.exceptions import SSLError
from .utils import async_requests
from .test_api import add_user
from .utils import async_requests
ssl_enabled = True
@@ -25,8 +23,10 @@ def wait_for_spawner(spawner, timeout=10):
polling at shorter intervals for early termination
"""
deadline = time.monotonic() + timeout
def wait():
return spawner.server.wait_up(timeout=1, http=True)
while time.monotonic() < deadline:
status = yield spawner.poll()
assert status is None
@@ -59,7 +59,7 @@ async def test_connection_notebook_wrong_certs(app):
"""Connecting to a notebook fails without correct certs"""
with mock.patch.dict(
app.config.LocalProcessSpawner,
{'cmd': [sys.executable, '-m', 'jupyterhub.tests.mocksu']}
{'cmd': [sys.executable, '-m', 'jupyterhub.tests.mocksu']},
):
user = add_user(app.db, app, name='foo')
await user.spawn()

View File

@@ -1,7 +1,6 @@
"""Tests for process spawning with internal_ssl"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from jupyterhub.tests.test_spawner import *
ssl_enabled = True

View File

@@ -5,15 +5,21 @@ from unittest import mock
import pytest
from ..utils import url_path_join
from .test_api import api_request, add_user, fill_user, normalize_user, TIMESTAMP
from .mocking import public_url
from .test_api import add_user
from .test_api import api_request
from .test_api import fill_user
from .test_api import normalize_user
from .test_api import TIMESTAMP
from .utils import async_requests
@pytest.fixture
def named_servers(app):
with mock.patch.dict(app.tornado_settings,
{'allow_named_servers': True, 'named_server_limit_per_user': 2}):
with mock.patch.dict(
app.tornado_settings,
{'allow_named_servers': True, 'named_server_limit_per_user': 2},
):
yield
@@ -30,7 +36,8 @@ async def test_default_server(app, named_servers):
user_model = normalize_user(r.json())
print(user_model)
assert user_model == fill_user({
assert user_model == fill_user(
{
'name': username,
'auth_state': None,
'server': user.url,
@@ -42,12 +49,14 @@ async def test_default_server(app, named_servers):
'url': user.url,
'pending': None,
'ready': True,
'progress_url': 'PREFIX/hub/api/users/{}/server/progress'.format(username),
'progress_url': 'PREFIX/hub/api/users/{}/server/progress'.format(
username
),
'state': {'pid': 0},
}
},
},
})
}
)
# now stop the server
r = await api_request(app, 'users', username, 'server', method='delete')
@@ -58,11 +67,9 @@ async def test_default_server(app, named_servers):
r.raise_for_status()
user_model = normalize_user(r.json())
assert user_model == fill_user({
'name': username,
'servers': {},
'auth_state': None,
})
assert user_model == fill_user(
{'name': username, 'servers': {}, 'auth_state': None}
)
async def test_create_named_server(app, named_servers):
@@ -89,7 +96,8 @@ async def test_create_named_server(app, named_servers):
r.raise_for_status()
user_model = normalize_user(r.json())
assert user_model == fill_user({
assert user_model == fill_user(
{
'name': username,
'auth_state': None,
'servers': {
@@ -101,12 +109,14 @@ async def test_create_named_server(app, named_servers):
'pending': None,
'ready': True,
'progress_url': 'PREFIX/hub/api/users/{}/servers/{}/progress'.format(
username, servername),
username, servername
),
'state': {'pid': 0},
}
for name in [servername]
},
})
}
)
async def test_delete_named_server(app, named_servers):
@@ -119,7 +129,9 @@ async def test_delete_named_server(app, named_servers):
r.raise_for_status()
assert r.status_code == 201
r = await api_request(app, 'users', username, 'servers', servername, method='delete')
r = await api_request(
app, 'users', username, 'servers', servername, method='delete'
)
r.raise_for_status()
assert r.status_code == 204
@@ -127,18 +139,20 @@ async def test_delete_named_server(app, named_servers):
r.raise_for_status()
user_model = normalize_user(r.json())
assert user_model == fill_user({
'name': username,
'auth_state': None,
'servers': {},
})
assert user_model == fill_user(
{'name': username, 'auth_state': None, 'servers': {}}
)
# wrapper Spawner is gone
assert servername not in user.spawners
# low-level record still exists
assert servername in user.orm_spawners
r = await api_request(
app, 'users', username, 'servers', servername,
app,
'users',
username,
'servers',
servername,
method='delete',
data=json.dumps({'remove': True}),
)
@@ -153,7 +167,9 @@ async def test_named_server_disabled(app):
servername = 'okay'
r = await api_request(app, 'users', username, 'servers', servername, method='post')
assert r.status_code == 400
r = await api_request(app, 'users', username, 'servers', servername, method='delete')
r = await api_request(
app, 'users', username, 'servers', servername, method='delete'
)
assert r.status_code == 400
@@ -180,7 +196,10 @@ async def test_named_server_limit(app, named_servers):
servername3 = 'bar-3'
r = await api_request(app, 'users', username, 'servers', servername3, method='post')
assert r.status_code == 400
assert r.json() == {"status": 400, "message": "User foo already has the maximum of 2 named servers. One must be deleted before a new server can be created"}
assert r.json() == {
"status": 400,
"message": "User foo already has the maximum of 2 named servers. One must be deleted before a new server can be created",
}
# Create default server
r = await api_request(app, 'users', username, 'server', method='post')
@@ -189,7 +208,11 @@ async def test_named_server_limit(app, named_servers):
# Delete 1st named server
r = await api_request(
app, 'users', username, 'servers', servername1,
app,
'users',
username,
'servers',
servername1,
method='delete',
data=json.dumps({'remove': True}),
)

View File

@@ -1,6 +1,6 @@
"""Tests for basic object-wrappers"""
import socket
import pytest
from jupyterhub.objects import Server
@@ -16,7 +16,7 @@ from jupyterhub.objects import Server
'port': 123,
'host': 'http://abc:123',
'url': 'http://abc:123/x/',
}
},
),
(
'https://abc',
@@ -26,9 +26,9 @@ from jupyterhub.objects import Server
'proto': 'https',
'host': 'https://abc:443',
'url': 'https://abc:443/x/',
}
},
),
]
],
)
def test_bind_url(bind_url, attrs):
s = Server(bind_url=bind_url, base_url='/x/')
@@ -43,26 +43,28 @@ _hostname = socket.gethostname()
'ip, port, attrs',
[
(
'', 123,
'',
123,
{
'ip': '',
'port': 123,
'host': 'http://{}:123'.format(_hostname),
'url': 'http://{}:123/x/'.format(_hostname),
'bind_url': 'http://*:123/x/',
}
},
),
(
'127.0.0.1', 999,
'127.0.0.1',
999,
{
'ip': '127.0.0.1',
'port': 999,
'host': 'http://127.0.0.1:999',
'url': 'http://127.0.0.1:999/x/',
'bind_url': 'http://127.0.0.1:999/x/',
}
},
),
]
],
)
def test_ip_port(ip, port, attrs):
s = Server(ip=ip, port=port, base_url='/x/')

View File

@@ -1,22 +1,21 @@
"""Tests for the ORM bits"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from datetime import datetime, timedelta
import os
import socket
from datetime import datetime
from datetime import timedelta
from unittest import mock
import pytest
from tornado import gen
from .. import orm
from .. import objects
from .. import crypto
from .. import objects
from .. import orm
from ..emptyclass import EmptyClass
from ..user import User
from .mocking import MockSpawner
from ..emptyclass import EmptyClass
def assert_not_found(db, ORMType, id):
@@ -124,7 +123,9 @@ def test_token_expiry(db):
# approximate range
assert orm_token.expires_at > now + timedelta(seconds=50)
assert orm_token.expires_at < now + timedelta(seconds=70)
the_future = mock.patch('jupyterhub.orm.utcnow', lambda : now + timedelta(seconds=70))
the_future = mock.patch(
'jupyterhub.orm.utcnow', lambda: now + timedelta(seconds=70)
)
with the_future:
found = orm.APIToken.find(db, token=token)
assert found is None
@@ -215,11 +216,9 @@ async def test_spawn_fails(db):
def start(self):
raise RuntimeError("Split the party")
user = User(orm_user, {
'spawner_class': BadSpawner,
'config': None,
'statsd': EmptyClass(),
})
user = User(
orm_user, {'spawner_class': BadSpawner, 'config': None, 'statsd': EmptyClass()}
)
with pytest.raises(RuntimeError) as exc:
await user.spawn()
@@ -346,9 +345,7 @@ def test_user_delete_cascade(db):
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code)
oauth_token = orm.OAuthAccessToken(
client=oauth_client,
user=user,
grant_type=orm.GrantType.authorization_code,
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
)
db.add(oauth_token)
db.commit()
@@ -384,9 +381,7 @@ def test_oauth_client_delete_cascade(db):
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code)
oauth_token = orm.OAuthAccessToken(
client=oauth_client,
user=user,
grant_type=orm.GrantType.authorization_code,
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
)
db.add(oauth_token)
db.commit()
@@ -477,6 +472,3 @@ def test_group_delete_cascade(db):
db.delete(user1)
db.commit()
assert user1 not in group1.users

View File

@@ -1,30 +1,27 @@
"""Tests for HTML pages"""
import asyncio
import sys
from unittest import mock
from urllib.parse import urlencode, urlparse
from urllib.parse import urlencode
from urllib.parse import urlparse
import pytest
from bs4 import BeautifulSoup
from tornado import gen
from tornado.httputil import url_concat
from ..handlers import BaseHandler
from ..utils import url_path_join as ujoin
from .. import orm
from ..auth import Authenticator
import pytest
from .mocking import FormSpawner, FalsyCallableFormSpawner
from .utils import (
async_requests,
api_request,
add_user,
get_page,
public_url,
public_host,
)
from ..handlers import BaseHandler
from ..utils import url_path_join as ujoin
from .mocking import FalsyCallableFormSpawner
from .mocking import FormSpawner
from .utils import add_user
from .utils import api_request
from .utils import async_requests
from .utils import get_page
from .utils import public_host
from .utils import public_url
async def test_root_no_auth(app):
@@ -53,8 +50,7 @@ async def test_root_redirect(app):
async def test_root_default_url_noauth(app):
with mock.patch.dict(app.tornado_settings,
{'default_url': '/foo/bar'}):
with mock.patch.dict(app.tornado_settings, {'default_url': '/foo/bar'}):
r = await get_page('/', app, allow_redirects=False)
r.raise_for_status()
url = r.headers.get('Location', '')
@@ -65,8 +61,7 @@ async def test_root_default_url_noauth(app):
async def test_root_default_url_auth(app):
name = 'wash'
cookies = await app.login_user(name)
with mock.patch.dict(app.tornado_settings,
{'default_url': '/foo/bar'}):
with mock.patch.dict(app.tornado_settings, {'default_url': '/foo/bar'}):
r = await get_page('/', app, cookies=cookies, allow_redirects=False)
r.raise_for_status()
url = r.headers.get('Location', '')
@@ -106,12 +101,7 @@ async def test_admin(app):
assert r.url.endswith('/admin')
@pytest.mark.parametrize('sort', [
'running',
'last_activity',
'admin',
'name',
])
@pytest.mark.parametrize('sort', ['running', 'last_activity', 'admin', 'name'])
async def test_admin_sort(app, sort):
cookies = await app.login_user('admin')
r = await get_page('admin?sort=%s' % sort, app, cookies=cookies)
@@ -146,7 +136,9 @@ async def test_spawn_redirect(app):
assert path == ujoin(app.base_url, '/user/%s/' % name)
# stop server to ensure /user/name is handled by the Hub
r = await api_request(app, 'users', name, 'server', method='delete', cookies=cookies)
r = await api_request(
app, 'users', name, 'server', method='delete', cookies=cookies
)
r.raise_for_status()
# test handing of trailing slash on `/user/name`
@@ -208,7 +200,9 @@ async def test_spawn_page(app):
async def test_spawn_page_falsy_callable(app):
with mock.patch.dict(app.users.settings, {'spawner_class': FalsyCallableFormSpawner}):
with mock.patch.dict(
app.users.settings, {'spawner_class': FalsyCallableFormSpawner}
):
cookies = await app.login_user('erik')
r = await get_page('spawn', app, cookies=cookies)
assert 'user/erik' in r.url
@@ -276,22 +270,22 @@ async def test_spawn_form_with_file(app):
u = app.users[orm_u]
await u.stop()
r = await async_requests.post(ujoin(base_url, 'spawn'),
r = await async_requests.post(
ujoin(base_url, 'spawn'),
cookies=cookies,
data={
'bounds': ['-1', '1'],
'energy': '511keV',
},
files={'hello': ('hello.txt', b'hello world\n')}
data={'bounds': ['-1', '1'], 'energy': '511keV'},
files={'hello': ('hello.txt', b'hello world\n')},
)
r.raise_for_status()
assert u.spawner.user_options == {
'energy': '511keV',
'bounds': [-1, 1],
'notspecified': 5,
'hello': {'filename': 'hello.txt',
'hello': {
'filename': 'hello.txt',
'body': b'hello world\n',
'content_type': 'application/unknown'},
'content_type': 'application/unknown',
},
}
@@ -305,9 +299,9 @@ async def test_user_redirect(app):
path = urlparse(r.url).path
assert path == ujoin(app.base_url, '/hub/login')
query = urlparse(r.url).query
assert query == urlencode({
'next': ujoin(app.hub.base_url, '/user-redirect/tree/top/')
})
assert query == urlencode(
{'next': ujoin(app.hub.base_url, '/user-redirect/tree/top/')}
)
r = await get_page('/user-redirect/notebooks/test.ipynb', app, cookies=cookies)
r.raise_for_status()
@@ -339,19 +333,17 @@ async def test_user_redirect_deprecated(app):
path = urlparse(r.url).path
assert path == ujoin(app.base_url, '/hub/login')
query = urlparse(r.url).query
assert query == urlencode({
'next': ujoin(app.base_url, '/hub/user/baduser/test.ipynb')
})
assert query == urlencode(
{'next': ujoin(app.base_url, '/hub/user/baduser/test.ipynb')}
)
async def test_login_fail(app):
name = 'wash'
base_url = public_url(app)
r = await async_requests.post(base_url + 'hub/login',
data={
'username': name,
'password': 'wrong',
},
r = await async_requests.post(
base_url + 'hub/login',
data={'username': name, 'password': 'wrong'},
allow_redirects=False,
)
assert not r.cookies
@@ -359,20 +351,17 @@ async def test_login_fail(app):
async def test_login_strip(app):
"""Test that login form doesn't strip whitespace from passwords"""
form_data = {
'username': 'spiff',
'password': ' space man ',
}
form_data = {'username': 'spiff', 'password': ' space man '}
base_url = public_url(app)
called_with = []
@gen.coroutine
def mock_authenticate(handler, data):
called_with.append(data)
with mock.patch.object(app.authenticator, 'authenticate', mock_authenticate):
await async_requests.post(base_url + 'hub/login',
data=form_data,
allow_redirects=False,
await async_requests.post(
base_url + 'hub/login', data=form_data, allow_redirects=False
)
assert called_with == [form_data]
@@ -389,12 +378,11 @@ async def test_login_strip(app):
(False, '/user/other', '/hub/user/other'),
(False, '/absolute', '/absolute'),
(False, '/has?query#andhash', '/has?query#andhash'),
# next_url outside is not allowed
(False, 'https://other.domain', ''),
(False, 'ftp://other.domain', ''),
(False, '//other.domain', ''),
]
],
)
async def test_login_redirect(app, running, next_url, location):
cookies = await app.login_user('river')
@@ -427,10 +415,11 @@ async def test_auto_login(app, request):
class DummyLoginHandler(BaseHandler):
def get(self):
self.write('ok!')
base_url = public_url(app) + '/'
app.tornado_application.add_handlers(".*$", [
(ujoin(app.hub.base_url, 'dummy'), DummyLoginHandler),
])
app.tornado_application.add_handlers(
".*$", [(ujoin(app.hub.base_url, 'dummy'), DummyLoginHandler)]
)
# no auto_login: end up at /hub/login
r = await async_requests.get(base_url)
assert r.url == public_url(app, path='hub/login')
@@ -438,9 +427,7 @@ async def test_auto_login(app, request):
authenticator = Authenticator(auto_login=True)
authenticator.login_url = lambda base_url: ujoin(base_url, 'dummy')
with mock.patch.dict(app.tornado_settings, {
'authenticator': authenticator,
}):
with mock.patch.dict(app.tornado_settings, {'authenticator': authenticator}):
r = await async_requests.get(base_url)
assert r.url == public_url(app, path='hub/dummy')
@@ -449,10 +436,12 @@ async def test_auto_login_logout(app):
name = 'burnham'
cookies = await app.login_user(name)
with mock.patch.dict(app.tornado_settings, {
'authenticator': Authenticator(auto_login=True),
}):
r = await async_requests.get(public_host(app) + app.tornado_settings['logout_url'], cookies=cookies)
with mock.patch.dict(
app.tornado_settings, {'authenticator': Authenticator(auto_login=True)}
):
r = await async_requests.get(
public_host(app) + app.tornado_settings['logout_url'], cookies=cookies
)
r.raise_for_status()
logout_url = public_host(app) + app.tornado_settings['logout_url']
assert r.url == logout_url
@@ -462,7 +451,9 @@ async def test_auto_login_logout(app):
async def test_logout(app):
name = 'wash'
cookies = await app.login_user(name)
r = await async_requests.get(public_host(app) + app.tornado_settings['logout_url'], cookies=cookies)
r = await async_requests.get(
public_host(app) + app.tornado_settings['logout_url'], cookies=cookies
)
r.raise_for_status()
login_url = public_host(app) + app.tornado_settings['login_url']
assert r.url == login_url
@@ -489,12 +480,11 @@ async def test_shutdown_on_logout(app, shutdown_on_logout):
assert spawner.active
# logout
with mock.patch.dict(app.tornado_settings, {
'shutdown_on_logout': shutdown_on_logout,
}):
with mock.patch.dict(
app.tornado_settings, {'shutdown_on_logout': shutdown_on_logout}
):
r = await async_requests.get(
public_host(app) + app.tornado_settings['logout_url'],
cookies=cookies,
public_host(app) + app.tornado_settings['logout_url'], cookies=cookies
)
r.raise_for_status()
@@ -549,7 +539,9 @@ async def test_oauth_token_page(app):
user = app.users[orm.User.find(app.db, name)]
client = orm.OAuthClient(identifier='token')
app.db.add(client)
oauth_token = orm.OAuthAccessToken(client=client, user=user, grant_type=orm.GrantType.authorization_code)
oauth_token = orm.OAuthAccessToken(
client=client, user=user, grant_type=orm.GrantType.authorization_code
)
app.db.add(oauth_token)
app.db.commit()
r = await get_page('token', app, cookies=cookies)
@@ -557,23 +549,14 @@ async def test_oauth_token_page(app):
assert r.status_code == 200
@pytest.mark.parametrize("error_status", [
503,
404,
])
@pytest.mark.parametrize("error_status", [503, 404])
async def test_proxy_error(app, error_status):
r = await get_page('/error/%i' % error_status, app)
assert r.status_code == 200
@pytest.mark.parametrize(
"announcements",
[
"",
"spawn",
"spawn,home,login",
"login,logout",
]
"announcements", ["", "spawn", "spawn,home,login", "login,logout"]
)
async def test_announcements(app, announcements):
"""Test announcements on various pages"""
@@ -618,16 +601,13 @@ async def test_announcements(app, announcements):
@pytest.mark.parametrize(
"params",
[
"",
"redirect_uri=/noexist",
"redirect_uri=ok&client_id=nosuchthing",
]
"params", ["", "redirect_uri=/noexist", "redirect_uri=ok&client_id=nosuchthing"]
)
async def test_bad_oauth_get(app, params):
cookies = await app.login_user("authorizer")
r = await get_page("hub/api/oauth2/authorize?" + params, app, hub=False, cookies=cookies)
r = await get_page(
"hub/api/oauth2/authorize?" + params, app, hub=False, cookies=cookies
)
assert r.status_code == 400
@@ -637,11 +617,14 @@ async def test_token_page(app):
r = await get_page("token", app, cookies=cookies)
r.raise_for_status()
assert urlparse(r.url).path.endswith('/hub/token')
def extract_body(r):
soup = BeautifulSoup(r.text, "html5lib")
import re
# trim empty lines
return re.sub(r"(\n\s*)+", "\n", soup.body.find(class_="container").text)
body = extract_body(r)
assert "Request new API token" in body, body
# no tokens yet, no lists

View File

@@ -1,20 +1,21 @@
"""Test a proxy being started before the Hub"""
from contextlib import contextmanager
import json
import os
from contextlib import contextmanager
from queue import Queue
from subprocess import Popen
from urllib.parse import urlparse, quote
from traitlets.config import Config
from urllib.parse import quote
from urllib.parse import urlparse
import pytest
from traitlets.config import Config
from .. import orm
from ..utils import url_path_join as ujoin
from ..utils import wait_for_http_server
from .mocking import MockHub
from .test_api import api_request, add_user
from ..utils import wait_for_http_server, url_path_join as ujoin
from .test_api import add_user
from .test_api import api_request
@pytest.fixture
@@ -52,25 +53,30 @@ async def test_external_proxy(request):
env['CONFIGPROXY_AUTH_TOKEN'] = auth_token
cmd = [
'configurable-http-proxy',
'--ip', app.ip,
'--port', str(app.port),
'--api-ip', proxy_ip,
'--api-port', str(proxy_port),
'--ip',
app.ip,
'--port',
str(app.port),
'--api-ip',
proxy_ip,
'--api-port',
str(proxy_port),
'--log-level=debug',
]
if app.subdomain_host:
cmd.append('--host-routing')
proxy = Popen(cmd, env=env)
def _cleanup_proxy():
if proxy.poll() is None:
proxy.terminate()
proxy.wait(timeout=10)
request.addfinalizer(_cleanup_proxy)
def wait_for_proxy():
return wait_for_http_server('http://%s:%i' % (proxy_ip, proxy_port))
await wait_for_proxy()
await app.initialize([])
@@ -84,8 +90,9 @@ async def test_external_proxy(request):
# add user to the db and start a single user server
name = 'river'
add_user(app.db, app, name=name)
r = await api_request(app, 'users', name, 'server', method='post',
bypass_proxy=True)
r = await api_request(
app, 'users', name, 'server', method='post', bypass_proxy=True
)
r.raise_for_status()
routes = await app.proxy.get_all_routes()
@@ -122,12 +129,18 @@ async def test_external_proxy(request):
new_auth_token = 'different!'
env['CONFIGPROXY_AUTH_TOKEN'] = new_auth_token
proxy_port = 55432
cmd = ['configurable-http-proxy',
'--ip', app.ip,
'--port', str(app.port),
'--api-ip', proxy_ip,
'--api-port', str(proxy_port),
'--default-target', 'http://%s:%i' % (app.hub_ip, app.hub_port),
cmd = [
'configurable-http-proxy',
'--ip',
app.ip,
'--port',
str(app.port),
'--api-ip',
proxy_ip,
'--api-port',
str(proxy_port),
'--default-target',
'http://%s:%i' % (app.hub_ip, app.hub_port),
]
if app.subdomain_host:
cmd.append('--host-routing')
@@ -140,10 +153,7 @@ async def test_external_proxy(request):
app,
'proxy',
method='patch',
data=json.dumps({
'api_url': new_api_url,
'auth_token': new_auth_token,
}),
data=json.dumps({'api_url': new_api_url, 'auth_token': new_auth_token}),
bypass_proxy=True,
)
r.raise_for_status()
@@ -156,13 +166,7 @@ async def test_external_proxy(request):
assert sorted(routes.keys()) == [app.hub.routespec, user_spec]
@pytest.mark.parametrize("username", [
'zoe',
'50fia',
'秀樹',
'~TestJH',
'has@',
])
@pytest.mark.parametrize("username", ['zoe', '50fia', '秀樹', '~TestJH', 'has@'])
async def test_check_routes(app, username, disable_check_routes):
proxy = app.proxy
test_user = add_user(app.db, app, name=username)
@@ -191,14 +195,17 @@ async def test_check_routes(app, username, disable_check_routes):
assert before == after
@pytest.mark.parametrize("routespec", [
@pytest.mark.parametrize(
"routespec",
[
'/has%20space/foo/',
'/missing-trailing/slash',
'/has/@/',
'/has/' + quote('üñîçø∂é'),
'host.name/path/',
'other.host/path/no/slash',
])
],
)
async def test_add_get_delete(app, routespec, disable_check_routes):
arg = routespec
if not routespec.endswith('/'):
@@ -207,6 +214,7 @@ async def test_add_get_delete(app, routespec, disable_check_routes):
# host-routes when not host-routing raises an error
# and vice versa
expect_value_error = bool(app.subdomain_host) ^ (not routespec.startswith('/'))
@contextmanager
def context():
if expect_value_error:

View File

@@ -1,22 +1,26 @@
"""Tests for services"""
import asyncio
import os
import sys
import time
from binascii import hexlify
from contextlib import contextmanager
import os
from subprocess import Popen
import sys
from threading import Event
import time
from async_generator import asynccontextmanager, async_generator, yield_
import pytest
import requests
from async_generator import async_generator
from async_generator import asynccontextmanager
from async_generator import yield_
from tornado import gen
from tornado.ioloop import IOLoop
from ..utils import maybe_future
from ..utils import random_port
from ..utils import url_path_join
from ..utils import wait_for_http_server
from .mocking import public_url
from ..utils import url_path_join, wait_for_http_server, random_port, maybe_future
from .utils import async_requests
mockservice_path = os.path.dirname(os.path.abspath(__file__))
@@ -80,12 +84,14 @@ async def test_proxy_service(app, mockservice_url):
async def test_external_service(app):
name = 'external'
async with external_service(app, name=name) as env:
app.services = [{
app.services = [
{
'name': name,
'admin': True,
'url': env['JUPYTERHUB_SERVICE_URL'],
'api_token': env['JUPYTERHUB_API_TOKEN'],
}]
}
]
await maybe_future(app.init_services())
await app.init_api_tokens()
await app.proxy.add_all_services(app._service_map)

View File

@@ -1,36 +1,43 @@
"""Tests for service authentication"""
import asyncio
from binascii import hexlify
import copy
from functools import partial
import json
import os
from queue import Queue
import sys
from binascii import hexlify
from functools import partial
from queue import Queue
from threading import Thread
from unittest import mock
from urllib.parse import urlparse
import pytest
from pytest import raises
import requests
import requests_mock
from tornado.ioloop import IOLoop
from pytest import raises
from tornado.httpserver import HTTPServer
from tornado.web import RequestHandler, Application, authenticated, HTTPError
from tornado.httputil import url_concat
from tornado.ioloop import IOLoop
from tornado.web import Application
from tornado.web import authenticated
from tornado.web import HTTPError
from tornado.web import RequestHandler
from .. import orm
from ..services.auth import _ExpiringDict, HubAuth, HubAuthenticated
from ..services.auth import _ExpiringDict
from ..services.auth import HubAuth
from ..services.auth import HubAuthenticated
from ..utils import url_path_join
from .mocking import public_url, public_host
from .mocking import public_host
from .mocking import public_url
from .test_api import add_user
from .utils import async_requests, AsyncSession
from .utils import async_requests
from .utils import AsyncSession
# mock for sending monotonic counter way into the future
monotonic_future = mock.patch('time.monotonic', lambda: sys.maxsize)
def test_expiring_dict():
cache = _ExpiringDict(max_age=30)
cache['key'] = 'cached value'
@@ -69,9 +76,7 @@ def test_expiring_dict():
def test_hub_auth():
auth = HubAuth(cookie_name='foo')
mock_model = {
'name': 'onyxia'
}
mock_model = {'name': 'onyxia'}
url = url_path_join(auth.api_url, "authorizations/cookie/foo/bar")
with requests_mock.Mocker() as m:
m.get(url, text=json.dumps(mock_model))
@@ -87,9 +92,7 @@ def test_hub_auth():
assert user_model is None
# invalidate cache with timer
mock_model = {
'name': 'willow'
}
mock_model = {'name': 'willow'}
with monotonic_future, requests_mock.Mocker() as m:
m.get(url, text=json.dumps(mock_model))
user_model = auth.user_for_cookie('bar')
@@ -110,16 +113,14 @@ def test_hub_auth():
def test_hub_authenticated(request):
auth = HubAuth(cookie_name='jubal')
mock_model = {
'name': 'jubalearly',
'groups': ['lions'],
}
mock_model = {'name': 'jubalearly', 'groups': ['lions']}
cookie_url = url_path_join(auth.api_url, "authorizations/cookie", auth.cookie_name)
good_url = url_path_join(cookie_url, "early")
bad_url = url_path_join(cookie_url, "late")
class TestHandler(HubAuthenticated, RequestHandler):
hub_auth = auth
@authenticated
def get(self):
self.finish(self.get_current_user())
@@ -127,11 +128,10 @@ def test_hub_authenticated(request):
# start hub-authenticated service in a thread:
port = 50505
q = Queue()
def run():
asyncio.set_event_loop(asyncio.new_event_loop())
app = Application([
('/*', TestHandler),
], login_url=auth.login_url)
app = Application([('/*', TestHandler)], login_url=auth.login_url)
http_server = HTTPServer(app)
http_server.listen(port)
@@ -146,6 +146,7 @@ def test_hub_authenticated(request):
loop.add_callback(loop.stop)
t.join(timeout=30)
assert not t.is_alive()
request.addfinalizer(finish_thread)
# wait for thread to start
@@ -153,16 +154,15 @@ def test_hub_authenticated(request):
with requests_mock.Mocker(real_http=True) as m:
# no cookie
r = requests.get('http://127.0.0.1:%i' % port,
allow_redirects=False,
)
r = requests.get('http://127.0.0.1:%i' % port, allow_redirects=False)
r.raise_for_status()
assert r.status_code == 302
assert auth.login_url in r.headers['Location']
# wrong cookie
m.get(bad_url, status_code=404)
r = requests.get('http://127.0.0.1:%i' % port,
r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'late'},
allow_redirects=False,
)
@@ -176,7 +176,8 @@ def test_hub_authenticated(request):
# upstream 403
m.get(bad_url, status_code=403)
r = requests.get('http://127.0.0.1:%i' % port,
r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'late'},
allow_redirects=False,
)
@@ -185,7 +186,8 @@ def test_hub_authenticated(request):
m.get(good_url, text=json.dumps(mock_model))
# no whitelist
r = requests.get('http://127.0.0.1:%i' % port,
r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'early'},
allow_redirects=False,
)
@@ -194,7 +196,8 @@ def test_hub_authenticated(request):
# pass whitelist
TestHandler.hub_users = {'jubalearly'}
r = requests.get('http://127.0.0.1:%i' % port,
r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'early'},
allow_redirects=False,
)
@@ -203,7 +206,8 @@ def test_hub_authenticated(request):
# no pass whitelist
TestHandler.hub_users = {'kaylee'}
r = requests.get('http://127.0.0.1:%i' % port,
r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'early'},
allow_redirects=False,
)
@@ -211,7 +215,8 @@ def test_hub_authenticated(request):
# pass group whitelist
TestHandler.hub_groups = {'lions'}
r = requests.get('http://127.0.0.1:%i' % port,
r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'early'},
allow_redirects=False,
)
@@ -220,7 +225,8 @@ def test_hub_authenticated(request):
# no pass group whitelist
TestHandler.hub_groups = {'tigers'}
r = requests.get('http://127.0.0.1:%i' % port,
r = requests.get(
'http://127.0.0.1:%i' % port,
cookies={'jubal': 'early'},
allow_redirects=False,
)
@@ -230,15 +236,14 @@ def test_hub_authenticated(request):
async def test_hubauth_cookie(app, mockservice_url):
"""Test HubAuthenticated service with user cookies"""
cookies = await app.login_user('badger')
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/', cookies=cookies)
r = await async_requests.get(
public_url(app, mockservice_url) + '/whoami/', cookies=cookies
)
r.raise_for_status()
print(r.text)
reply = r.json()
sub_reply = {key: reply.get(key, 'missing') for key in ['name', 'admin']}
assert sub_reply == {
'name': 'badger',
'admin': False,
}
assert sub_reply == {'name': 'badger', 'admin': False}
async def test_hubauth_token(app, mockservice_url):
@@ -248,28 +253,25 @@ async def test_hubauth_token(app, mockservice_url):
app.db.commit()
# token in Authorization header
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/',
headers={
'Authorization': 'token %s' % token,
})
r = await async_requests.get(
public_url(app, mockservice_url) + '/whoami/',
headers={'Authorization': 'token %s' % token},
)
reply = r.json()
sub_reply = {key: reply.get(key, 'missing') for key in ['name', 'admin']}
assert sub_reply == {
'name': 'river',
'admin': False,
}
assert sub_reply == {'name': 'river', 'admin': False}
# token in ?token parameter
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/?token=%s' % token)
r = await async_requests.get(
public_url(app, mockservice_url) + '/whoami/?token=%s' % token
)
r.raise_for_status()
reply = r.json()
sub_reply = {key: reply.get(key, 'missing') for key in ['name', 'admin']}
assert sub_reply == {
'name': 'river',
'admin': False,
}
assert sub_reply == {'name': 'river', 'admin': False}
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/?token=no-such-token',
r = await async_requests.get(
public_url(app, mockservice_url) + '/whoami/?token=no-such-token',
allow_redirects=False,
)
assert r.status_code == 302
@@ -288,30 +290,25 @@ async def test_hubauth_service_token(app, mockservice_url):
await app.init_api_tokens()
# token in Authorization header
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/',
headers={
'Authorization': 'token %s' % token,
})
r = await async_requests.get(
public_url(app, mockservice_url) + '/whoami/',
headers={'Authorization': 'token %s' % token},
)
r.raise_for_status()
reply = r.json()
assert reply == {
'kind': 'service',
'name': name,
'admin': False,
}
assert reply == {'kind': 'service', 'name': name, 'admin': False}
assert not r.cookies
# token in ?token parameter
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/?token=%s' % token)
r = await async_requests.get(
public_url(app, mockservice_url) + '/whoami/?token=%s' % token
)
r.raise_for_status()
reply = r.json()
assert reply == {
'kind': 'service',
'name': name,
'admin': False,
}
assert reply == {'kind': 'service', 'name': name, 'admin': False}
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/?token=no-such-token',
r = await async_requests.get(
public_url(app, mockservice_url) + '/whoami/?token=no-such-token',
allow_redirects=False,
)
assert r.status_code == 302
@@ -351,10 +348,7 @@ async def test_oauth_service(app, mockservice_url):
assert r.status_code == 200
reply = r.json()
sub_reply = {key: reply.get(key, 'missing') for key in ('kind', 'name')}
assert sub_reply == {
'name': 'link',
'kind': 'user',
}
assert sub_reply == {'name': 'link', 'kind': 'user'}
# token-authenticated request to HubOAuth
token = app.users[name].new_api_token()
@@ -367,11 +361,7 @@ async def test_oauth_service(app, mockservice_url):
# verify that ?token= requests set a cookie
assert len(r.cookies) != 0
# ensure cookie works in future requests
r = await async_requests.get(
url,
cookies=r.cookies,
allow_redirects=False,
)
r = await async_requests.get(url, cookies=r.cookies, allow_redirects=False)
r.raise_for_status()
assert r.url == url
reply = r.json()
@@ -408,9 +398,7 @@ async def test_oauth_cookie_collision(app, mockservice_url):
# finish oauth 2
# submit the oauth form to complete authorization
r = await s.post(
oauth_2.url,
data={'scopes': ['identify']},
headers={'Referer': oauth_2.url},
oauth_2.url, data={'scopes': ['identify']}, headers={'Referer': oauth_2.url}
)
r.raise_for_status()
assert r.url == url
@@ -422,9 +410,7 @@ async def test_oauth_cookie_collision(app, mockservice_url):
# finish oauth 1
r = await s.post(
oauth_1.url,
data={'scopes': ['identify']},
headers={'Referer': oauth_1.url},
oauth_1.url, data={'scopes': ['identify']}, headers={'Referer': oauth_1.url}
)
r.raise_for_status()
assert r.url == url
@@ -455,11 +441,13 @@ async def test_oauth_logout(app, mockservice_url):
s = AsyncSession()
name = 'propha'
app_user = add_user(app.db, app=app, name=name)
def auth_tokens():
"""Return list of OAuth access tokens for the user"""
return list(
app.db.query(orm.OAuthAccessToken).filter(
orm.OAuthAccessToken.user_id == app_user.id)
orm.OAuthAccessToken.user_id == app_user.id
)
)
# ensure we start empty
@@ -480,14 +468,8 @@ async def test_oauth_logout(app, mockservice_url):
r.raise_for_status()
assert r.status_code == 200
reply = r.json()
sub_reply = {
key: reply.get(key, 'missing')
for key in ('kind', 'name')
}
assert sub_reply == {
'name': name,
'kind': 'user',
}
sub_reply = {key: reply.get(key, 'missing') for key in ('kind', 'name')}
assert sub_reply == {'name': name, 'kind': 'user'}
# save cookies to verify cache
saved_cookies = copy.deepcopy(s.cookies)
session_id = s.cookies['jupyterhub-session-id']
@@ -522,11 +504,5 @@ async def test_oauth_logout(app, mockservice_url):
r.raise_for_status()
assert r.status_code == 200
reply = r.json()
sub_reply = {
key: reply.get(key, 'missing')
for key in ('kind', 'name')
}
assert sub_reply == {
'name': name,
'kind': 'user',
}
sub_reply = {key: reply.get(key, 'missing') for key in ('kind', 'name')}
assert sub_reply == {'name': name, 'kind': 'user'}

View File

@@ -1,16 +1,16 @@
"""Tests for jupyterhub.singleuser"""
from subprocess import check_output
import sys
from subprocess import check_output
from urllib.parse import urlparse
import pytest
import jupyterhub
from .mocking import StubSingleUserSpawner, public_url
from ..utils import url_path_join
from .utils import async_requests, AsyncSession
from .mocking import public_url
from .mocking import StubSingleUserSpawner
from .utils import async_requests
from .utils import AsyncSession
async def test_singleuser_auth(app):
@@ -47,11 +47,7 @@ async def test_singleuser_auth(app):
r = await s.get(url)
assert urlparse(r.url).path.endswith('/oauth2/authorize')
# submit the oauth form to complete authorization
r = await s.post(
r.url,
data={'scopes': ['identify']},
headers={'Referer': r.url},
)
r = await s.post(r.url, data={'scopes': ['identify']}, headers={'Referer': r.url})
assert urlparse(r.url).path.rstrip('/').endswith('/user/nandy/tree')
# user isn't authorized, should raise 403
assert r.status_code == 403
@@ -85,11 +81,14 @@ async def test_disable_user_config(app):
def test_help_output():
out = check_output([sys.executable, '-m', 'jupyterhub.singleuser', '--help-all']).decode('utf8', 'replace')
out = check_output(
[sys.executable, '-m', 'jupyterhub.singleuser', '--help-all']
).decode('utf8', 'replace')
assert 'JupyterHub' in out
def test_version():
out = check_output([sys.executable, '-m', 'jupyterhub.singleuser', '--version']).decode('utf8', 'replace')
out = check_output(
[sys.executable, '-m', 'jupyterhub.singleuser', '--version']
).decode('utf8', 'replace')
assert jupyterhub.__version__ in out

View File

@@ -1,27 +1,28 @@
"""Tests for process spawning"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import logging
import os
import signal
from subprocess import Popen
import sys
import tempfile
import time
from subprocess import Popen
from unittest import mock
from urllib.parse import urlparse
import pytest
from tornado import gen
from ..objects import Hub, Server
from .. import orm
from .. import spawner as spawnermod
from ..spawner import LocalProcessSpawner, Spawner
from ..objects import Hub
from ..objects import Server
from ..spawner import LocalProcessSpawner
from ..spawner import Spawner
from ..user import User
from ..utils import new_token, url_path_join
from ..utils import new_token
from ..utils import url_path_join
from .mocking import public_url
from .test_api import add_user
from .utils import async_requests
@@ -84,8 +85,10 @@ async def wait_for_spawner(spawner, timeout=10):
polling at shorter intervals for early termination
"""
deadline = time.monotonic() + timeout
def wait():
return spawner.server.wait_up(timeout=1, http=True)
while time.monotonic() < deadline:
status = await spawner.poll()
assert status is None
@@ -187,11 +190,13 @@ def test_setcwd():
os.chdir(cwd)
chdir = os.chdir
temp_root = os.path.realpath(os.path.abspath(tempfile.gettempdir()))
def raiser(path):
path = os.path.realpath(os.path.abspath(path))
if not path.startswith(temp_root):
raise OSError(path)
chdir(path)
with mock.patch('os.chdir', raiser):
spawnermod._try_setcwd(cwd)
assert os.getcwd().startswith(temp_root)
@@ -209,6 +214,7 @@ def test_string_formatting(db):
async def test_popen_kwargs(db):
mock_proc = mock.Mock(spec=Popen)
def mock_popen(*args, **kwargs):
mock_proc.args = args
mock_proc.kwargs = kwargs
@@ -226,7 +232,8 @@ async def test_popen_kwargs(db):
async def test_shell_cmd(db, tmpdir, request):
f = tmpdir.join('bashrc')
f.write('export TESTVAR=foo\n')
s = new_spawner(db,
s = new_spawner(
db,
cmd=[sys.executable, '-m', 'jupyterhub.tests.mocksu'],
shell_cmd=['bash', '--rcfile', str(f), '-i', '-c'],
)
@@ -253,6 +260,7 @@ def test_inherit_overwrite():
"""
if sys.version_info >= (3, 6):
with pytest.raises(NotImplementedError):
class S(Spawner):
pass
@@ -372,19 +380,14 @@ async def test_spawner_delete_server(app):
assert spawner.server is None
@pytest.mark.parametrize(
"name",
[
"has@x",
"has~x",
"has%x",
"has%40x",
]
)
@pytest.mark.parametrize("name", ["has@x", "has~x", "has%x", "has%40x"])
async def test_spawner_routing(app, name):
"""Test routing of names with special characters"""
db = app.db
with mock.patch.dict(app.config.LocalProcessSpawner, {'cmd': [sys.executable, '-m', 'jupyterhub.tests.mocksu']}):
with mock.patch.dict(
app.config.LocalProcessSpawner,
{'cmd': [sys.executable, '-m', 'jupyterhub.tests.mocksu']},
):
user = add_user(app.db, app, name=name)
await user.spawn()
await wait_for_spawner(user.spawner)

View File

@@ -1,12 +1,16 @@
import pytest
from traitlets import HasTraits, TraitError
from traitlets import HasTraits
from traitlets import TraitError
from jupyterhub.traitlets import URLPrefix, Command, ByteSpecification
from jupyterhub.traitlets import ByteSpecification
from jupyterhub.traitlets import Command
from jupyterhub.traitlets import URLPrefix
def test_url_prefix():
class C(HasTraits):
url = URLPrefix()
c = C()
c.url = '/a/b/c/'
assert c.url == '/a/b/c/'
@@ -20,6 +24,7 @@ def test_command():
class C(HasTraits):
cmd = Command('default command')
cmd2 = Command(['default_cmd'])
c = C()
assert c.cmd == ['default command']
assert c.cmd2 == ['default_cmd']

View File

@@ -1,9 +1,11 @@
"""Tests for utilities"""
import asyncio
import pytest
from async_generator import aclosing, async_generator, yield_
import pytest
from async_generator import aclosing
from async_generator import async_generator
from async_generator import yield_
from ..utils import iterate_until
@@ -26,12 +28,15 @@ def schedule_future(io_loop, *, delay, result=None):
return f
@pytest.mark.parametrize("deadline, n, delay, expected", [
@pytest.mark.parametrize(
"deadline, n, delay, expected",
[
(0, 3, 1, []),
(0, 3, 0, [0, 1, 2]),
(5, 3, 0.01, [0, 1, 2]),
(0.5, 10, 0.2, [0, 1]),
])
],
)
async def test_iterate_until(io_loop, deadline, n, delay, expected):
f = schedule_future(io_loop, delay=deadline)

Some files were not shown because too many files have changed in this diff Show More