mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-17 23:13:00 +00:00
run autoformat
apologies to anyone finding this commit via git blame or log run the autoformatting by pre-commit run --all-files
This commit is contained in:
@@ -1 +0,0 @@
|
||||
|
||||
|
@@ -1,19 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
"""
|
||||
bower-lite
|
||||
|
||||
Since Bower's on its way out,
|
||||
stage frontend dependencies from node_modules into components
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from os.path import join
|
||||
import shutil
|
||||
from os.path import join
|
||||
|
||||
HERE = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
@@ -1,16 +1,16 @@
|
||||
-r requirements.txt
|
||||
mock
|
||||
# temporary pin of attrs for jsonschema 0.3.0a1
|
||||
# seems to be a pip bug
|
||||
attrs>=17.4.0
|
||||
beautifulsoup4
|
||||
codecov
|
||||
coverage
|
||||
cryptography
|
||||
html5lib # needed for beautifulsoup
|
||||
pytest-cov
|
||||
pytest-asyncio
|
||||
pytest>=3.3
|
||||
mock
|
||||
notebook
|
||||
pytest-asyncio
|
||||
pytest-cov
|
||||
pytest>=3.3
|
||||
requests-mock
|
||||
virtualenv
|
||||
# temporary pin of attrs for jsonschema 0.3.0a1
|
||||
# seems to be a pip bug
|
||||
attrs>=17.4.0
|
||||
|
@@ -7,5 +7,3 @@ ENV LANG=en_US.UTF-8
|
||||
|
||||
USER nobody
|
||||
CMD ["jupyterhub"]
|
||||
|
||||
|
||||
|
@@ -18,4 +18,3 @@ Dockerfile.alpine contains base image for jupyterhub. It does not work independ
|
||||
* Use dummy authenticator for ease of testing. Update following in jupyterhub_config file
|
||||
- c.JupyterHub.authenticator_class = 'dummyauthenticator.DummyAuthenticator'
|
||||
- c.DummyAuthenticator.password = "your strong password"
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# ReadTheDocs uses the `environment.yaml` so make sure to update that as well
|
||||
# if you change this file
|
||||
-r ../requirements.txt
|
||||
sphinx>=1.7
|
||||
alabaster_jupyterhub
|
||||
recommonmark==0.4.0
|
||||
sphinx-copybutton
|
||||
alabaster_jupyterhub
|
||||
sphinx>=1.7
|
||||
|
@@ -13,4 +13,3 @@ Module: :mod:`jupyterhub.app`
|
||||
-------------------
|
||||
|
||||
.. autoconfigurable:: JupyterHub
|
||||
|
||||
|
@@ -30,4 +30,3 @@ Module: :mod:`jupyterhub.auth`
|
||||
---------------------------
|
||||
|
||||
.. autoconfigurable:: DummyAuthenticator
|
||||
|
||||
|
@@ -20,4 +20,3 @@ Module: :mod:`jupyterhub.proxy`
|
||||
|
||||
.. autoconfigurable:: ConfigurableHTTPProxy
|
||||
:members: debug, auth_token, check_running_interval, api_url, command
|
||||
|
||||
|
@@ -14,4 +14,3 @@ Module: :mod:`jupyterhub.services.service`
|
||||
|
||||
.. autoconfigurable:: Service
|
||||
:members: name, admin, url, api_token, managed, kind, command, cwd, environment, user, oauth_client_id, server, prefix, proxy_spec
|
||||
|
||||
|
@@ -38,4 +38,3 @@ Module: :mod:`jupyterhub.services.auth`
|
||||
--------------------------------
|
||||
|
||||
.. autoclass:: HubOAuthCallbackHandler
|
||||
|
||||
|
@@ -19,4 +19,3 @@ Module: :mod:`jupyterhub.spawner`
|
||||
----------------------------
|
||||
|
||||
.. autoconfigurable:: LocalProcessSpawner
|
||||
|
||||
|
@@ -34,4 +34,3 @@ Module: :mod:`jupyterhub.user`
|
||||
.. attribute:: spawner
|
||||
|
||||
The user's :class:`~.Spawner` instance.
|
||||
|
||||
|
@@ -1,11 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
import sys
|
||||
import os
|
||||
import shlex
|
||||
import sys
|
||||
|
||||
import recommonmark.parser
|
||||
|
||||
# For conversion from markdown to html
|
||||
import recommonmark.parser
|
||||
|
||||
# Set paths
|
||||
sys.path.insert(0, os.path.abspath('.'))
|
||||
@@ -21,7 +22,7 @@ extensions = [
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.napoleon',
|
||||
'autodoc_traits',
|
||||
'sphinx_copybutton'
|
||||
'sphinx_copybutton',
|
||||
]
|
||||
|
||||
templates_path = ['_templates']
|
||||
@@ -68,6 +69,7 @@ source_suffix = ['.rst', '.md']
|
||||
|
||||
# The theme to use for HTML and HTML Help pages.
|
||||
import alabaster_jupyterhub
|
||||
|
||||
html_theme = 'alabaster_jupyterhub'
|
||||
html_theme_path = [alabaster_jupyterhub.get_html_theme_path()]
|
||||
|
||||
|
@@ -210,4 +210,3 @@ jupyterhub_config.py amendments:
|
||||
--This is the address on which the proxy will bind. Sets protocol, ip, base_url
|
||||
c.JupyterHub.bind_url = 'http://127.0.0.1:8000/jhub/'
|
||||
```
|
||||
|
||||
|
@@ -1,8 +1,9 @@
|
||||
"""autodoc extension for configurable traits"""
|
||||
|
||||
from traitlets import TraitType, Undefined
|
||||
from sphinx.domains.python import PyClassmember
|
||||
from sphinx.ext.autodoc import ClassDocumenter, AttributeDocumenter
|
||||
from sphinx.ext.autodoc import AttributeDocumenter
|
||||
from sphinx.ext.autodoc import ClassDocumenter
|
||||
from traitlets import TraitType
|
||||
from traitlets import Undefined
|
||||
|
||||
|
||||
class ConfigurableDocumenter(ClassDocumenter):
|
||||
|
@@ -5,8 +5,10 @@ create a directory for the user before the spawner starts
|
||||
# pylint: disable=import-error
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from jupyter_client.localinterfaces import public_ips
|
||||
|
||||
|
||||
def create_dir_hook(spawner):
|
||||
""" Create directory """
|
||||
username = spawner.user.name # get the username
|
||||
@@ -16,6 +18,7 @@ def create_dir_hook(spawner):
|
||||
# now do whatever you think your user needs
|
||||
# ...
|
||||
|
||||
|
||||
def clean_dir_hook(spawner):
|
||||
""" Delete directory """
|
||||
username = spawner.user.name # get the username
|
||||
@@ -23,6 +26,7 @@ def clean_dir_hook(spawner):
|
||||
if os.path.exists(temp_path) and os.path.isdir(temp_path):
|
||||
shutil.rmtree(temp_path)
|
||||
|
||||
|
||||
# attach the hook functions to the spawner
|
||||
# pylint: disable=undefined-variable
|
||||
c.Spawner.pre_spawn_hook = create_dir_hook
|
||||
|
@@ -31,11 +31,11 @@ users and servers, you should add this script to the services list
|
||||
twice, just with different ``name``s, different values, and one with
|
||||
the ``--cull-users`` option.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from functools import partial
|
||||
|
||||
try:
|
||||
from urllib.parse import quote
|
||||
@@ -85,23 +85,21 @@ def format_td(td):
|
||||
|
||||
|
||||
@coroutine
|
||||
def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concurrency=10):
|
||||
def cull_idle(
|
||||
url, api_token, inactive_limit, cull_users=False, max_age=0, concurrency=10
|
||||
):
|
||||
"""Shutdown idle single-user servers
|
||||
|
||||
If cull_users, inactive *users* will be deleted as well.
|
||||
"""
|
||||
auth_header = {
|
||||
'Authorization': 'token %s' % api_token,
|
||||
}
|
||||
req = HTTPRequest(
|
||||
url=url + '/users',
|
||||
headers=auth_header,
|
||||
)
|
||||
auth_header = {'Authorization': 'token %s' % api_token}
|
||||
req = HTTPRequest(url=url + '/users', headers=auth_header)
|
||||
now = datetime.now(timezone.utc)
|
||||
client = AsyncHTTPClient()
|
||||
|
||||
if concurrency:
|
||||
semaphore = Semaphore(concurrency)
|
||||
|
||||
@coroutine
|
||||
def fetch(req):
|
||||
"""client.fetch wrapped in a semaphore to limit concurrency"""
|
||||
@@ -110,6 +108,7 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
|
||||
return (yield client.fetch(req))
|
||||
finally:
|
||||
yield semaphore.release()
|
||||
|
||||
else:
|
||||
fetch = client.fetch
|
||||
|
||||
@@ -129,8 +128,8 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
|
||||
log_name = '%s/%s' % (user['name'], server_name)
|
||||
if server.get('pending'):
|
||||
app_log.warning(
|
||||
"Not culling server %s with pending %s",
|
||||
log_name, server['pending'])
|
||||
"Not culling server %s with pending %s", log_name, server['pending']
|
||||
)
|
||||
return False
|
||||
|
||||
# jupyterhub < 0.9 defined 'server.url' once the server was ready
|
||||
@@ -142,8 +141,8 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
|
||||
|
||||
if not server.get('ready', bool(server['url'])):
|
||||
app_log.warning(
|
||||
"Not culling not-ready not-pending server %s: %s",
|
||||
log_name, server)
|
||||
"Not culling not-ready not-pending server %s: %s", log_name, server
|
||||
)
|
||||
return False
|
||||
|
||||
if server.get('started'):
|
||||
@@ -163,12 +162,13 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
|
||||
# for running servers
|
||||
inactive = age
|
||||
|
||||
should_cull = (inactive is not None and
|
||||
inactive.total_seconds() >= inactive_limit)
|
||||
should_cull = (
|
||||
inactive is not None and inactive.total_seconds() >= inactive_limit
|
||||
)
|
||||
if should_cull:
|
||||
app_log.info(
|
||||
"Culling server %s (inactive for %s)",
|
||||
log_name, format_td(inactive))
|
||||
"Culling server %s (inactive for %s)", log_name, format_td(inactive)
|
||||
)
|
||||
|
||||
if max_age and not should_cull:
|
||||
# only check started if max_age is specified
|
||||
@@ -177,32 +177,34 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
|
||||
if age is not None and age.total_seconds() >= max_age:
|
||||
app_log.info(
|
||||
"Culling server %s (age: %s, inactive for %s)",
|
||||
log_name, format_td(age), format_td(inactive))
|
||||
log_name,
|
||||
format_td(age),
|
||||
format_td(inactive),
|
||||
)
|
||||
should_cull = True
|
||||
|
||||
if not should_cull:
|
||||
app_log.debug(
|
||||
"Not culling server %s (age: %s, inactive for %s)",
|
||||
log_name, format_td(age), format_td(inactive))
|
||||
log_name,
|
||||
format_td(age),
|
||||
format_td(inactive),
|
||||
)
|
||||
return False
|
||||
|
||||
if server_name:
|
||||
# culling a named server
|
||||
delete_url = url + "/users/%s/servers/%s" % (
|
||||
quote(user['name']), quote(server['name'])
|
||||
quote(user['name']),
|
||||
quote(server['name']),
|
||||
)
|
||||
else:
|
||||
delete_url = url + '/users/%s/server' % quote(user['name'])
|
||||
|
||||
req = HTTPRequest(
|
||||
url=delete_url, method='DELETE', headers=auth_header,
|
||||
)
|
||||
req = HTTPRequest(url=delete_url, method='DELETE', headers=auth_header)
|
||||
resp = yield fetch(req)
|
||||
if resp.code == 202:
|
||||
app_log.warning(
|
||||
"Server %s is slow to stop",
|
||||
log_name,
|
||||
)
|
||||
app_log.warning("Server %s is slow to stop", log_name)
|
||||
# return False to prevent culling user with pending shutdowns
|
||||
return False
|
||||
return True
|
||||
@@ -245,7 +247,9 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
|
||||
if still_alive:
|
||||
app_log.debug(
|
||||
"Not culling user %s with %i servers still alive",
|
||||
user['name'], still_alive)
|
||||
user['name'],
|
||||
still_alive,
|
||||
)
|
||||
return False
|
||||
|
||||
should_cull = False
|
||||
@@ -265,12 +269,11 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
|
||||
# which introduces the 'created' field which is never None
|
||||
inactive = age
|
||||
|
||||
should_cull = (inactive is not None and
|
||||
inactive.total_seconds() >= inactive_limit)
|
||||
should_cull = (
|
||||
inactive is not None and inactive.total_seconds() >= inactive_limit
|
||||
)
|
||||
if should_cull:
|
||||
app_log.info(
|
||||
"Culling user %s (inactive for %s)",
|
||||
user['name'], inactive)
|
||||
app_log.info("Culling user %s (inactive for %s)", user['name'], inactive)
|
||||
|
||||
if max_age and not should_cull:
|
||||
# only check created if max_age is specified
|
||||
@@ -279,19 +282,23 @@ def cull_idle(url, api_token, inactive_limit, cull_users=False, max_age=0, concu
|
||||
if age is not None and age.total_seconds() >= max_age:
|
||||
app_log.info(
|
||||
"Culling user %s (age: %s, inactive for %s)",
|
||||
user['name'], format_td(age), format_td(inactive))
|
||||
user['name'],
|
||||
format_td(age),
|
||||
format_td(inactive),
|
||||
)
|
||||
should_cull = True
|
||||
|
||||
if not should_cull:
|
||||
app_log.debug(
|
||||
"Not culling user %s (created: %s, last active: %s)",
|
||||
user['name'], format_td(age), format_td(inactive))
|
||||
user['name'],
|
||||
format_td(age),
|
||||
format_td(inactive),
|
||||
)
|
||||
return False
|
||||
|
||||
req = HTTPRequest(
|
||||
url=url + '/users/%s' % user['name'],
|
||||
method='DELETE',
|
||||
headers=auth_header,
|
||||
url=url + '/users/%s' % user['name'], method='DELETE', headers=auth_header
|
||||
)
|
||||
yield fetch(req)
|
||||
return True
|
||||
@@ -316,20 +323,30 @@ if __name__ == '__main__':
|
||||
help="The JupyterHub API URL",
|
||||
)
|
||||
define('timeout', default=600, help="The idle timeout (in seconds)")
|
||||
define('cull_every', default=0,
|
||||
help="The interval (in seconds) for checking for idle servers to cull")
|
||||
define('max_age', default=0,
|
||||
help="The maximum age (in seconds) of servers that should be culled even if they are active")
|
||||
define('cull_users', default=False,
|
||||
define(
|
||||
'cull_every',
|
||||
default=0,
|
||||
help="The interval (in seconds) for checking for idle servers to cull",
|
||||
)
|
||||
define(
|
||||
'max_age',
|
||||
default=0,
|
||||
help="The maximum age (in seconds) of servers that should be culled even if they are active",
|
||||
)
|
||||
define(
|
||||
'cull_users',
|
||||
default=False,
|
||||
help="""Cull users in addition to servers.
|
||||
This is for use in temporary-user cases such as tmpnb.""",
|
||||
)
|
||||
define('concurrency', default=10,
|
||||
define(
|
||||
'concurrency',
|
||||
default=10,
|
||||
help="""Limit the number of concurrent requests made to the Hub.
|
||||
|
||||
Deleting a lot of users at the same time can slow down the Hub,
|
||||
so limit the number of API requests we have outstanding at any given time.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
parse_command_line()
|
||||
@@ -343,7 +360,8 @@ if __name__ == '__main__':
|
||||
app_log.warning(
|
||||
"Could not load pycurl: %s\n"
|
||||
"pycurl is recommended if you have a large number of users.",
|
||||
e)
|
||||
e,
|
||||
)
|
||||
|
||||
loop = IOLoop.current()
|
||||
cull = partial(
|
||||
|
@@ -4,7 +4,9 @@ import os
|
||||
# this could come from anywhere
|
||||
api_token = os.getenv("JUPYTERHUB_API_TOKEN")
|
||||
if not api_token:
|
||||
raise ValueError("Make sure to `export JUPYTERHUB_API_TOKEN=$(openssl rand -hex 32)`")
|
||||
raise ValueError(
|
||||
"Make sure to `export JUPYTERHUB_API_TOKEN=$(openssl rand -hex 32)`"
|
||||
)
|
||||
|
||||
# tell JupyterHub to register the service as an external oauth client
|
||||
|
||||
@@ -14,5 +16,5 @@ c.JupyterHub.services = [
|
||||
'oauth_client_id': "whoami-oauth-client-test",
|
||||
'api_token': api_token,
|
||||
'oauth_redirect_uri': 'http://127.0.0.1:5555/oauth_callback',
|
||||
},
|
||||
}
|
||||
]
|
||||
|
@@ -3,18 +3,19 @@
|
||||
Implements OAuth handshake manually
|
||||
so all URLs and requests necessary for OAuth with JupyterHub should be in one place
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from urllib.parse import urlencode
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from tornado.auth import OAuth2Mixin
|
||||
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
|
||||
from tornado.httputil import url_concat
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado import log
|
||||
from tornado import web
|
||||
from tornado.auth import OAuth2Mixin
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httpclient import HTTPRequest
|
||||
from tornado.httputil import url_concat
|
||||
from tornado.ioloop import IOLoop
|
||||
|
||||
|
||||
class JupyterHubLoginHandler(web.RequestHandler):
|
||||
@@ -32,11 +33,11 @@ class JupyterHubLoginHandler(web.RequestHandler):
|
||||
code=code,
|
||||
redirect_uri=self.settings['redirect_uri'],
|
||||
)
|
||||
req = HTTPRequest(self.settings['token_url'], method='POST',
|
||||
req = HTTPRequest(
|
||||
self.settings['token_url'],
|
||||
method='POST',
|
||||
body=urlencode(params).encode('utf8'),
|
||||
headers={
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'},
|
||||
)
|
||||
response = await AsyncHTTPClient().fetch(req)
|
||||
data = json.loads(response.body.decode('utf8', 'replace'))
|
||||
@@ -55,14 +56,16 @@ class JupyterHubLoginHandler(web.RequestHandler):
|
||||
# we are the login handler,
|
||||
# begin oauth process which will come back later with an
|
||||
# authorization_code
|
||||
self.redirect(url_concat(
|
||||
self.redirect(
|
||||
url_concat(
|
||||
self.settings['authorize_url'],
|
||||
dict(
|
||||
redirect_uri=self.settings['redirect_uri'],
|
||||
client_id=self.settings['client_id'],
|
||||
response_type='code',
|
||||
),
|
||||
)
|
||||
)
|
||||
))
|
||||
|
||||
|
||||
class WhoAmIHandler(web.RequestHandler):
|
||||
@@ -85,10 +88,7 @@ class WhoAmIHandler(web.RequestHandler):
|
||||
"""Retrieve the user for a given token, via /hub/api/user"""
|
||||
|
||||
req = HTTPRequest(
|
||||
self.settings['user_url'],
|
||||
headers={
|
||||
'Authorization': f'token {token}'
|
||||
},
|
||||
self.settings['user_url'], headers={'Authorization': f'token {token}'}
|
||||
)
|
||||
response = await AsyncHTTPClient().fetch(req)
|
||||
return json.loads(response.body.decode('utf8', 'replace'))
|
||||
@@ -110,23 +110,23 @@ def main():
|
||||
token_url = hub_api + '/oauth2/token'
|
||||
user_url = hub_api + '/user'
|
||||
|
||||
app = web.Application([
|
||||
('/oauth_callback', JupyterHubLoginHandler),
|
||||
('/', WhoAmIHandler),
|
||||
],
|
||||
app = web.Application(
|
||||
[('/oauth_callback', JupyterHubLoginHandler), ('/', WhoAmIHandler)],
|
||||
login_url='/oauth_callback',
|
||||
cookie_secret=os.urandom(32),
|
||||
api_token=os.environ['JUPYTERHUB_API_TOKEN'],
|
||||
client_id=os.environ['JUPYTERHUB_CLIENT_ID'],
|
||||
redirect_uri=os.environ['JUPYTERHUB_SERVICE_URL'].rstrip('/') + '/oauth_callback',
|
||||
redirect_uri=os.environ['JUPYTERHUB_SERVICE_URL'].rstrip('/')
|
||||
+ '/oauth_callback',
|
||||
authorize_url=authorize_url,
|
||||
token_url=token_url,
|
||||
user_url=user_url,
|
||||
)
|
||||
|
||||
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
|
||||
log.app_log.info("Running basic whoami service on %s",
|
||||
os.environ['JUPYTERHUB_SERVICE_URL'])
|
||||
log.app_log.info(
|
||||
"Running basic whoami service on %s", os.environ['JUPYTERHUB_SERVICE_URL']
|
||||
)
|
||||
app.listen(url.port, url.hostname)
|
||||
IOLoop.current().start()
|
||||
|
||||
|
@@ -8,10 +8,10 @@ c.Authenticator.whitelist = {'ganymede', 'io', 'rhea'}
|
||||
|
||||
# These environment variables are automatically supplied by the linked postgres
|
||||
# container.
|
||||
import os;
|
||||
import os
|
||||
|
||||
pg_pass = os.getenv('POSTGRES_ENV_JPY_PSQL_PASSWORD')
|
||||
pg_host = os.getenv('POSTGRES_PORT_5432_TCP_ADDR')
|
||||
c.JupyterHub.db_url = 'postgresql://jupyterhub:{}@{}:5432/jupyterhub'.format(
|
||||
pg_pass,
|
||||
pg_host,
|
||||
pg_pass, pg_host
|
||||
)
|
||||
|
@@ -1,11 +1,14 @@
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
from tornado import escape
|
||||
from tornado import gen
|
||||
from tornado import ioloop
|
||||
from tornado import web
|
||||
|
||||
from jupyterhub.services.auth import HubAuthenticated
|
||||
from tornado import escape, gen, ioloop, web
|
||||
|
||||
|
||||
class AnnouncementRequestHandler(HubAuthenticated, web.RequestHandler):
|
||||
@@ -53,19 +56,19 @@ def main():
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--api-prefix", "-a",
|
||||
parser.add_argument(
|
||||
"--api-prefix",
|
||||
"-a",
|
||||
default=os.environ.get("JUPYTERHUB_SERVICE_PREFIX", "/"),
|
||||
help="application API prefix")
|
||||
parser.add_argument("--port", "-p",
|
||||
default=8888,
|
||||
help="port for API to listen on",
|
||||
type=int)
|
||||
help="application API prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", "-p", default=8888, help="port for API to listen on", type=int
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def create_application(api_prefix="/",
|
||||
handler=AnnouncementRequestHandler,
|
||||
**kwargs):
|
||||
def create_application(api_prefix="/", handler=AnnouncementRequestHandler, **kwargs):
|
||||
storage = dict(announcement="", timestamp="", user="")
|
||||
return web.Application([(api_prefix, handler, dict(storage=storage))])
|
||||
|
||||
|
@@ -22,4 +22,3 @@ In the external example, some extra steps are required to set up supervisor:
|
||||
3. install `shared-notebook-service` somewhere on your system, and update `/path/to/shared-notebook-service` to the absolute path of this destination
|
||||
3. copy `shared-notebook.conf` to `/etc/supervisor/conf.d/`
|
||||
4. `supervisorctl reload`
|
||||
|
||||
|
@@ -1,18 +1,9 @@
|
||||
# our user list
|
||||
c.Authenticator.whitelist = [
|
||||
'minrk',
|
||||
'ellisonbg',
|
||||
'willingc',
|
||||
]
|
||||
c.Authenticator.whitelist = ['minrk', 'ellisonbg', 'willingc']
|
||||
|
||||
# ellisonbg and willingc have access to a shared server:
|
||||
|
||||
c.JupyterHub.load_groups = {
|
||||
'shared': [
|
||||
'ellisonbg',
|
||||
'willingc',
|
||||
]
|
||||
}
|
||||
c.JupyterHub.load_groups = {'shared': ['ellisonbg', 'willingc']}
|
||||
|
||||
# start the notebook server as a service
|
||||
c.JupyterHub.services = [
|
||||
|
@@ -1,18 +1,9 @@
|
||||
# our user list
|
||||
c.Authenticator.whitelist = [
|
||||
'minrk',
|
||||
'ellisonbg',
|
||||
'willingc',
|
||||
]
|
||||
c.Authenticator.whitelist = ['minrk', 'ellisonbg', 'willingc']
|
||||
|
||||
# ellisonbg and willingc have access to a shared server:
|
||||
|
||||
c.JupyterHub.load_groups = {
|
||||
'shared': [
|
||||
'ellisonbg',
|
||||
'willingc',
|
||||
]
|
||||
}
|
||||
c.JupyterHub.load_groups = {'shared': ['ellisonbg', 'willingc']}
|
||||
|
||||
service_name = 'shared-notebook'
|
||||
service_port = 9999
|
||||
@@ -23,10 +14,6 @@ c.JupyterHub.services = [
|
||||
{
|
||||
'name': service_name,
|
||||
'url': 'http://127.0.0.1:{}'.format(service_port),
|
||||
'command': [
|
||||
'jupyterhub-singleuser',
|
||||
'--group=shared',
|
||||
'--debug',
|
||||
],
|
||||
'command': ['jupyterhub-singleuser', '--group=shared', '--debug'],
|
||||
}
|
||||
]
|
@@ -6,16 +6,12 @@ c.JupyterHub.services = [
|
||||
'name': 'whoami',
|
||||
'url': 'http://127.0.0.1:10101',
|
||||
'command': ['flask', 'run', '--port=10101'],
|
||||
'environment': {
|
||||
'FLASK_APP': 'whoami-flask.py',
|
||||
}
|
||||
'environment': {'FLASK_APP': 'whoami-flask.py'},
|
||||
},
|
||||
{
|
||||
'name': 'whoami-oauth',
|
||||
'url': 'http://127.0.0.1:10201',
|
||||
'command': ['flask', 'run', '--port=10201'],
|
||||
'environment': {
|
||||
'FLASK_APP': 'whoami-oauth.py',
|
||||
}
|
||||
'environment': {'FLASK_APP': 'whoami-oauth.py'},
|
||||
},
|
||||
]
|
||||
|
@@ -2,29 +2,29 @@
|
||||
"""
|
||||
whoami service authentication with the Hub
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
import json
|
||||
import os
|
||||
from functools import wraps
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Flask, redirect, request, Response
|
||||
from flask import Flask
|
||||
from flask import redirect
|
||||
from flask import request
|
||||
from flask import Response
|
||||
|
||||
from jupyterhub.services.auth import HubAuth
|
||||
|
||||
|
||||
prefix = os.environ.get('JUPYTERHUB_SERVICE_PREFIX', '/')
|
||||
|
||||
auth = HubAuth(
|
||||
api_token=os.environ['JUPYTERHUB_API_TOKEN'],
|
||||
cache_max_age=60,
|
||||
)
|
||||
auth = HubAuth(api_token=os.environ['JUPYTERHUB_API_TOKEN'], cache_max_age=60)
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
def authenticated(f):
|
||||
"""Decorator for authenticating with the Hub"""
|
||||
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
cookie = request.cookies.get(auth.cookie_name)
|
||||
@@ -40,6 +40,7 @@ def authenticated(f):
|
||||
else:
|
||||
# redirect to login url on failed auth
|
||||
return redirect(auth.login_url + '?next=%s' % quote(request.path))
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@@ -47,7 +48,5 @@ def authenticated(f):
|
||||
@authenticated
|
||||
def whoami(user):
|
||||
return Response(
|
||||
json.dumps(user, indent=1, sort_keys=True),
|
||||
mimetype='application/json',
|
||||
json.dumps(user, indent=1, sort_keys=True), mimetype='application/json'
|
||||
)
|
||||
|
||||
|
@@ -2,28 +2,29 @@
|
||||
"""
|
||||
whoami service authentication with the Hub
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
import json
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
from flask import Flask, redirect, request, Response, make_response
|
||||
from flask import Flask
|
||||
from flask import make_response
|
||||
from flask import redirect
|
||||
from flask import request
|
||||
from flask import Response
|
||||
|
||||
from jupyterhub.services.auth import HubOAuth
|
||||
|
||||
|
||||
prefix = os.environ.get('JUPYTERHUB_SERVICE_PREFIX', '/')
|
||||
|
||||
auth = HubOAuth(
|
||||
api_token=os.environ['JUPYTERHUB_API_TOKEN'],
|
||||
cache_max_age=60,
|
||||
)
|
||||
auth = HubOAuth(api_token=os.environ['JUPYTERHUB_API_TOKEN'], cache_max_age=60)
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
def authenticated(f):
|
||||
"""Decorator for authenticating with the Hub via OAuth"""
|
||||
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
token = request.cookies.get(auth.cookie_name)
|
||||
@@ -39,6 +40,7 @@ def authenticated(f):
|
||||
response = make_response(redirect(auth.login_url + '&state=%s' % state))
|
||||
response.set_cookie(auth.state_cookie_name, state)
|
||||
return response
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@@ -46,10 +48,10 @@ def authenticated(f):
|
||||
@authenticated
|
||||
def whoami(user):
|
||||
return Response(
|
||||
json.dumps(user, indent=1, sort_keys=True),
|
||||
mimetype='application/json',
|
||||
json.dumps(user, indent=1, sort_keys=True), mimetype='application/json'
|
||||
)
|
||||
|
||||
|
||||
@app.route(prefix + 'oauth_callback')
|
||||
def oauth_callback():
|
||||
code = request.args.get('code', None)
|
||||
|
@@ -4,18 +4,22 @@ This example service serves `/services/whoami/`,
|
||||
authenticated with the Hub,
|
||||
showing the user their own info.
|
||||
"""
|
||||
from getpass import getuser
|
||||
import json
|
||||
import os
|
||||
from getpass import getuser
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.web import RequestHandler, Application, authenticated
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.web import Application
|
||||
from tornado.web import authenticated
|
||||
from tornado.web import RequestHandler
|
||||
|
||||
from jupyterhub.services.auth import HubOAuthenticated, HubOAuthCallbackHandler
|
||||
from jupyterhub.services.auth import HubOAuthCallbackHandler
|
||||
from jupyterhub.services.auth import HubOAuthenticated
|
||||
from jupyterhub.utils import url_path_join
|
||||
|
||||
|
||||
class WhoAmIHandler(HubOAuthenticated, RequestHandler):
|
||||
# hub_users can be a set of users who are allowed to access the service
|
||||
# `getuser()` here would mean only the user who started the service
|
||||
@@ -29,12 +33,21 @@ class WhoAmIHandler(HubOAuthenticated, RequestHandler):
|
||||
self.set_header('content-type', 'application/json')
|
||||
self.write(json.dumps(user_model, indent=1, sort_keys=True))
|
||||
|
||||
|
||||
def main():
|
||||
app = Application([
|
||||
app = Application(
|
||||
[
|
||||
(os.environ['JUPYTERHUB_SERVICE_PREFIX'], WhoAmIHandler),
|
||||
(url_path_join(os.environ['JUPYTERHUB_SERVICE_PREFIX'], 'oauth_callback'), HubOAuthCallbackHandler),
|
||||
(
|
||||
url_path_join(
|
||||
os.environ['JUPYTERHUB_SERVICE_PREFIX'], 'oauth_callback'
|
||||
),
|
||||
HubOAuthCallbackHandler,
|
||||
),
|
||||
(r'.*', WhoAmIHandler),
|
||||
], cookie_secret=os.urandom(32))
|
||||
],
|
||||
cookie_secret=os.urandom(32),
|
||||
)
|
||||
|
||||
http_server = HTTPServer(app)
|
||||
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
|
||||
@@ -43,5 +56,6 @@ def main():
|
||||
|
||||
IOLoop.current().start()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@@ -2,14 +2,16 @@
|
||||
|
||||
This serves `/services/whoami/`, authenticated with the Hub, showing the user their own info.
|
||||
"""
|
||||
from getpass import getuser
|
||||
import json
|
||||
import os
|
||||
from getpass import getuser
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.web import RequestHandler, Application, authenticated
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.web import Application
|
||||
from tornado.web import authenticated
|
||||
from tornado.web import RequestHandler
|
||||
|
||||
from jupyterhub.services.auth import HubAuthenticated
|
||||
|
||||
@@ -27,11 +29,14 @@ class WhoAmIHandler(HubAuthenticated, RequestHandler):
|
||||
self.set_header('content-type', 'application/json')
|
||||
self.write(json.dumps(user_model, indent=1, sort_keys=True))
|
||||
|
||||
|
||||
def main():
|
||||
app = Application([
|
||||
app = Application(
|
||||
[
|
||||
(os.environ['JUPYTERHUB_SERVICE_PREFIX'] + '/?', WhoAmIHandler),
|
||||
(r'.*', WhoAmIHandler),
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
http_server = HTTPServer(app)
|
||||
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
|
||||
@@ -40,5 +45,6 @@ def main():
|
||||
|
||||
IOLoop.current().start()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@@ -5,6 +5,7 @@ import shlex
|
||||
|
||||
from jupyterhub.spawner import LocalProcessSpawner
|
||||
|
||||
|
||||
class DemoFormSpawner(LocalProcessSpawner):
|
||||
def _options_form_default(self):
|
||||
default_env = "YOURNAME=%s\n" % self.user.name
|
||||
@@ -13,7 +14,9 @@ class DemoFormSpawner(LocalProcessSpawner):
|
||||
<input name="args" placeholder="e.g. --debug"></input>
|
||||
<label for="env">Environment variables (one per line)</label>
|
||||
<textarea name="env">{env}</textarea>
|
||||
""".format(env=default_env)
|
||||
""".format(
|
||||
env=default_env
|
||||
)
|
||||
|
||||
def options_from_form(self, formdata):
|
||||
options = {}
|
||||
@@ -43,4 +46,5 @@ class DemoFormSpawner(LocalProcessSpawner):
|
||||
env.update(self.user_options['env'])
|
||||
return env
|
||||
|
||||
|
||||
c.JupyterHub.spawner_class = DemoFormSpawner
|
||||
|
@@ -1 +1,2 @@
|
||||
from ._version import version_info, __version__
|
||||
from ._version import __version__
|
||||
from ._version import version_info
|
||||
|
@@ -1,2 +1,3 @@
|
||||
from .app import main
|
||||
|
||||
main()
|
||||
|
@@ -5,6 +5,7 @@ def get_data_files():
|
||||
"""Walk up until we find share/jupyterhub"""
|
||||
import sys
|
||||
from os.path import join, abspath, dirname, exists, split
|
||||
|
||||
path = abspath(dirname(__file__))
|
||||
starting_points = [path]
|
||||
if not path.startswith(sys.prefix):
|
||||
|
@@ -1,5 +1,4 @@
|
||||
"""JupyterHub version info"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
@@ -23,16 +22,23 @@ __version__ = ".".join(map(str, version_info[:3])) + ".".join(version_info[3:])
|
||||
def _check_version(hub_version, singleuser_version, log):
|
||||
"""Compare Hub and single-user server versions"""
|
||||
if not hub_version:
|
||||
log.warning("Hub has no version header, which means it is likely < 0.8. Expected %s", __version__)
|
||||
log.warning(
|
||||
"Hub has no version header, which means it is likely < 0.8. Expected %s",
|
||||
__version__,
|
||||
)
|
||||
return
|
||||
|
||||
if not singleuser_version:
|
||||
log.warning("Single-user server has no version header, which means it is likely < 0.8. Expected %s", __version__)
|
||||
log.warning(
|
||||
"Single-user server has no version header, which means it is likely < 0.8. Expected %s",
|
||||
__version__,
|
||||
)
|
||||
return
|
||||
|
||||
# compare minor X.Y versions
|
||||
if hub_version != singleuser_version:
|
||||
from distutils.version import LooseVersion as V
|
||||
|
||||
hub_major_minor = V(hub_version).version[:2]
|
||||
singleuser_major_minor = V(singleuser_version).version[:2]
|
||||
extra = ""
|
||||
@@ -50,4 +56,6 @@ def _check_version(hub_version, singleuser_version, log):
|
||||
singleuser_version,
|
||||
)
|
||||
else:
|
||||
log.debug("jupyterhub and jupyterhub-singleuser both on version %s" % hub_version)
|
||||
log.debug(
|
||||
"jupyterhub and jupyterhub-singleuser both on version %s" % hub_version
|
||||
)
|
||||
|
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
import logging
|
||||
from logging.config import fileConfig
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
@@ -14,6 +15,7 @@ config = context.config
|
||||
if 'jupyterhub' in sys.modules:
|
||||
from traitlets.config import MultipleInstanceError
|
||||
from jupyterhub.app import JupyterHub
|
||||
|
||||
app = None
|
||||
if JupyterHub.initialized():
|
||||
try:
|
||||
@@ -32,6 +34,7 @@ else:
|
||||
|
||||
# add your model's MetaData object here for 'autogenerate' support
|
||||
from jupyterhub import orm
|
||||
|
||||
target_metadata = orm.Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
@@ -53,8 +56,7 @@ def run_migrations_offline():
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url, target_metadata=target_metadata, literal_binds=True)
|
||||
context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
@@ -70,17 +72,16 @@ def run_migrations_online():
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section),
|
||||
prefix='sqlalchemy.',
|
||||
poolclass=pool.NullPool)
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata
|
||||
)
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
|
@@ -5,7 +5,6 @@ Revises:
|
||||
Create Date: 2016-04-11 16:05:34.873288
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '19c0846f6344'
|
||||
down_revision = None
|
||||
|
@@ -5,7 +5,6 @@ Revises: 3ec6993fe20c
|
||||
Create Date: 2017-12-07 14:43:51.500740
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1cebaf56856c'
|
||||
down_revision = '3ec6993fe20c'
|
||||
@@ -13,6 +12,7 @@ branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('alembic')
|
||||
|
||||
from alembic import op
|
||||
|
@@ -12,7 +12,6 @@ Revises: af4cbdb2d13c
|
||||
Create Date: 2017-07-28 16:44:40.413648
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '3ec6993fe20c'
|
||||
down_revision = 'af4cbdb2d13c'
|
||||
@@ -44,7 +43,9 @@ def upgrade():
|
||||
except sa.exc.OperationalError:
|
||||
# this won't be a problem moving forward, but downgrade will fail
|
||||
if op.get_context().dialect.name == 'sqlite':
|
||||
logger.warning("sqlite cannot drop columns. Leaving unused old columns in place.")
|
||||
logger.warning(
|
||||
"sqlite cannot drop columns. Leaving unused old columns in place."
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -54,15 +55,13 @@ def upgrade():
|
||||
def downgrade():
|
||||
# drop all the new tables
|
||||
engine = op.get_bind().engine
|
||||
for table in ('oauth_clients',
|
||||
'oauth_codes',
|
||||
'oauth_access_tokens',
|
||||
'spawners'):
|
||||
for table in ('oauth_clients', 'oauth_codes', 'oauth_access_tokens', 'spawners'):
|
||||
if engine.has_table(table):
|
||||
op.drop_table(table)
|
||||
|
||||
op.drop_column('users', 'encrypted_auth_state')
|
||||
|
||||
op.add_column('users', sa.Column('auth_state', JSONDict))
|
||||
op.add_column('users', sa.Column('_server_id', sa.Integer, sa.ForeignKey('servers.id')))
|
||||
|
||||
op.add_column(
|
||||
'users', sa.Column('_server_id', sa.Integer, sa.ForeignKey('servers.id'))
|
||||
)
|
||||
|
@@ -5,7 +5,6 @@ Revises: 1cebaf56856c
|
||||
Create Date: 2017-12-19 15:21:09.300513
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '56cc5a70207e'
|
||||
down_revision = '1cebaf56856c'
|
||||
@@ -16,22 +15,48 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('alembic')
|
||||
|
||||
|
||||
def upgrade():
|
||||
tables = op.get_bind().engine.table_names()
|
||||
op.add_column('api_tokens', sa.Column('created', sa.DateTime(), nullable=True))
|
||||
op.add_column('api_tokens', sa.Column('last_activity', sa.DateTime(), nullable=True))
|
||||
op.add_column('api_tokens', sa.Column('note', sa.Unicode(length=1023), nullable=True))
|
||||
op.add_column(
|
||||
'api_tokens', sa.Column('last_activity', sa.DateTime(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
'api_tokens', sa.Column('note', sa.Unicode(length=1023), nullable=True)
|
||||
)
|
||||
if 'oauth_access_tokens' in tables:
|
||||
op.add_column('oauth_access_tokens', sa.Column('created', sa.DateTime(), nullable=True))
|
||||
op.add_column('oauth_access_tokens', sa.Column('last_activity', sa.DateTime(), nullable=True))
|
||||
op.add_column(
|
||||
'oauth_access_tokens', sa.Column('created', sa.DateTime(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
'oauth_access_tokens',
|
||||
sa.Column('last_activity', sa.DateTime(), nullable=True),
|
||||
)
|
||||
if op.get_context().dialect.name == 'sqlite':
|
||||
logger.warning("sqlite cannot use ALTER TABLE to create foreign keys. Upgrade will be incomplete.")
|
||||
logger.warning(
|
||||
"sqlite cannot use ALTER TABLE to create foreign keys. Upgrade will be incomplete."
|
||||
)
|
||||
else:
|
||||
op.create_foreign_key(None, 'oauth_access_tokens', 'oauth_clients', ['client_id'], ['identifier'], ondelete='CASCADE')
|
||||
op.create_foreign_key(None, 'oauth_codes', 'oauth_clients', ['client_id'], ['identifier'], ondelete='CASCADE')
|
||||
op.create_foreign_key(
|
||||
None,
|
||||
'oauth_access_tokens',
|
||||
'oauth_clients',
|
||||
['client_id'],
|
||||
['identifier'],
|
||||
ondelete='CASCADE',
|
||||
)
|
||||
op.create_foreign_key(
|
||||
None,
|
||||
'oauth_codes',
|
||||
'oauth_clients',
|
||||
['client_id'],
|
||||
['identifier'],
|
||||
ondelete='CASCADE',
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
|
@@ -5,7 +5,6 @@ Revises: d68c98b66cd4
|
||||
Create Date: 2018-05-07 11:35:58.050542
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '896818069c98'
|
||||
down_revision = 'd68c98b66cd4'
|
||||
|
@@ -5,7 +5,6 @@ Revises: 56cc5a70207e
|
||||
Create Date: 2018-03-21 14:27:17.466841
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '99a28a4418e1'
|
||||
down_revision = '56cc5a70207e'
|
||||
@@ -18,15 +17,18 @@ import sqlalchemy as sa
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column('users', sa.Column('created', sa.DateTime, nullable=True))
|
||||
c = op.get_bind()
|
||||
# fill created date with current time
|
||||
now = datetime.utcnow()
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET created='%s'
|
||||
""" % (now,)
|
||||
"""
|
||||
% (now,)
|
||||
)
|
||||
|
||||
tables = c.engine.table_names()
|
||||
@@ -34,11 +36,13 @@ def upgrade():
|
||||
if 'spawners' in tables:
|
||||
op.add_column('spawners', sa.Column('started', sa.DateTime, nullable=True))
|
||||
# fill started value with now for running servers
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
UPDATE spawners
|
||||
SET started='%s'
|
||||
WHERE server_id IS NOT NULL
|
||||
""" % (now,)
|
||||
"""
|
||||
% (now,)
|
||||
)
|
||||
|
||||
|
||||
|
@@ -5,7 +5,6 @@ Revises: eeb276e51423
|
||||
Create Date: 2016-07-28 16:16:38.245348
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'af4cbdb2d13c'
|
||||
down_revision = 'eeb276e51423'
|
||||
|
@@ -5,7 +5,6 @@ Revises: 99a28a4418e1
|
||||
Create Date: 2018-04-13 10:50:17.968636
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd68c98b66cd4'
|
||||
down_revision = '99a28a4418e1'
|
||||
@@ -20,8 +19,7 @@ def upgrade():
|
||||
tables = op.get_bind().engine.table_names()
|
||||
if 'oauth_clients' in tables:
|
||||
op.add_column(
|
||||
'oauth_clients',
|
||||
sa.Column('description', sa.Unicode(length=1023))
|
||||
'oauth_clients', sa.Column('description', sa.Unicode(length=1023))
|
||||
)
|
||||
|
||||
|
||||
|
@@ -6,7 +6,6 @@ Revision ID: eeb276e51423
|
||||
Revises: 19c0846f6344
|
||||
Create Date: 2016-04-11 16:06:49.239831
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'eeb276e51423'
|
||||
down_revision = '19c0846f6344'
|
||||
@@ -17,6 +16,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from jupyterhub.orm import JSONDict
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column('users', sa.Column('auth_state', JSONDict))
|
||||
|
||||
|
@@ -1,5 +1,10 @@
|
||||
from . import auth
|
||||
from . import groups
|
||||
from . import hub
|
||||
from . import proxy
|
||||
from . import services
|
||||
from . import users
|
||||
from .base import *
|
||||
from . import auth, hub, proxy, users, groups, services
|
||||
|
||||
default_handlers = []
|
||||
for mod in (auth, hub, proxy, users, groups, services):
|
||||
|
@@ -1,25 +1,23 @@
|
||||
"""Authorization handlers"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
from urllib.parse import (
|
||||
parse_qsl,
|
||||
quote,
|
||||
urlencode,
|
||||
urlparse,
|
||||
urlunparse,
|
||||
)
|
||||
from datetime import datetime
|
||||
from urllib.parse import parse_qsl
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlencode
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlunparse
|
||||
|
||||
from oauthlib import oauth2
|
||||
from tornado import web
|
||||
|
||||
from .. import orm
|
||||
from ..user import User
|
||||
from ..utils import token_authenticated, compare_token
|
||||
from .base import BaseHandler, APIHandler
|
||||
from ..utils import compare_token
|
||||
from ..utils import token_authenticated
|
||||
from .base import APIHandler
|
||||
from .base import BaseHandler
|
||||
|
||||
|
||||
class TokenAPIHandler(APIHandler):
|
||||
@@ -70,7 +68,9 @@ class TokenAPIHandler(APIHandler):
|
||||
if data and data.get('username'):
|
||||
user = self.find_user(data['username'])
|
||||
if user is not requester and not requester.admin:
|
||||
raise web.HTTPError(403, "Only admins can request tokens for other users.")
|
||||
raise web.HTTPError(
|
||||
403, "Only admins can request tokens for other users."
|
||||
)
|
||||
if requester.admin and user is None:
|
||||
raise web.HTTPError(400, "No such user '%s'" % data['username'])
|
||||
|
||||
@@ -82,11 +82,11 @@ class TokenAPIHandler(APIHandler):
|
||||
note += " by %s %s" % (kind, requester.name)
|
||||
|
||||
api_token = user.new_api_token(note=note)
|
||||
self.write(json.dumps({
|
||||
'token': api_token,
|
||||
'warning': warn_msg,
|
||||
'user': self.user_model(user),
|
||||
}))
|
||||
self.write(
|
||||
json.dumps(
|
||||
{'token': api_token, 'warning': warn_msg, 'user': self.user_model(user)}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CookieAPIHandler(APIHandler):
|
||||
@@ -94,7 +94,9 @@ class CookieAPIHandler(APIHandler):
|
||||
def get(self, cookie_name, cookie_value=None):
|
||||
cookie_name = quote(cookie_name, safe='')
|
||||
if cookie_value is None:
|
||||
self.log.warning("Cookie values in request body is deprecated, use `/cookie_name/cookie_value`")
|
||||
self.log.warning(
|
||||
"Cookie values in request body is deprecated, use `/cookie_name/cookie_value`"
|
||||
)
|
||||
cookie_value = self.request.body
|
||||
else:
|
||||
cookie_value = cookie_value.encode('utf8')
|
||||
@@ -134,7 +136,9 @@ class OAuthHandler:
|
||||
return uri
|
||||
# make absolute local redirects full URLs
|
||||
# to satisfy oauthlib's absolute URI requirement
|
||||
redirect_uri = self.request.protocol + "://" + self.request.headers['Host'] + redirect_uri
|
||||
redirect_uri = (
|
||||
self.request.protocol + "://" + self.request.headers['Host'] + redirect_uri
|
||||
)
|
||||
parsed_url = urlparse(uri)
|
||||
query_list = parse_qsl(parsed_url.query, keep_blank_values=True)
|
||||
for idx, item in enumerate(query_list):
|
||||
@@ -142,10 +146,7 @@ class OAuthHandler:
|
||||
query_list[idx] = ('redirect_uri', redirect_uri)
|
||||
break
|
||||
|
||||
return urlunparse(
|
||||
urlparse(uri)
|
||||
._replace(query=urlencode(query_list))
|
||||
)
|
||||
return urlunparse(urlparse(uri)._replace(query=urlencode(query_list)))
|
||||
|
||||
def add_credentials(self, credentials=None):
|
||||
"""Add oauth credentials
|
||||
@@ -164,11 +165,7 @@ class OAuthHandler:
|
||||
user = self.current_user
|
||||
|
||||
# Extra credentials we need in the validator
|
||||
credentials.update({
|
||||
'user': user,
|
||||
'handler': self,
|
||||
'session_id': session_id,
|
||||
})
|
||||
credentials.update({'user': user, 'handler': self, 'session_id': session_id})
|
||||
return credentials
|
||||
|
||||
def send_oauth_response(self, headers, body, status):
|
||||
@@ -193,7 +190,8 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
|
||||
def _complete_login(self, uri, headers, scopes, credentials):
|
||||
try:
|
||||
headers, body, status = self.oauth_provider.create_authorization_response(
|
||||
uri, 'POST', '', headers, scopes, credentials)
|
||||
uri, 'POST', '', headers, scopes, credentials
|
||||
)
|
||||
|
||||
except oauth2.FatalClientError as e:
|
||||
# TODO: human error page
|
||||
@@ -213,13 +211,15 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
|
||||
uri, http_method, body, headers = self.extract_oauth_params()
|
||||
try:
|
||||
scopes, credentials = self.oauth_provider.validate_authorization_request(
|
||||
uri, http_method, body, headers)
|
||||
uri, http_method, body, headers
|
||||
)
|
||||
credentials = self.add_credentials(credentials)
|
||||
client = self.oauth_provider.fetch_by_client_id(credentials['client_id'])
|
||||
if client.redirect_uri.startswith(self.current_user.url):
|
||||
self.log.debug(
|
||||
"Skipping oauth confirmation for %s accessing %s",
|
||||
self.current_user, client.description,
|
||||
self.current_user,
|
||||
client.description,
|
||||
)
|
||||
# access to my own server doesn't require oauth confirmation
|
||||
# this is the pre-1.0 behavior for all oauth
|
||||
@@ -228,11 +228,7 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
|
||||
|
||||
# Render oauth 'Authorize application...' page
|
||||
self.write(
|
||||
self.render_template(
|
||||
"oauth.html",
|
||||
scopes=scopes,
|
||||
oauth_client=client,
|
||||
)
|
||||
self.render_template("oauth.html", scopes=scopes, oauth_client=client)
|
||||
)
|
||||
|
||||
# Errors that should be shown to the user on the provider website
|
||||
@@ -252,7 +248,9 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
|
||||
if referer != full_url:
|
||||
# OAuth post must be made to the URL it came from
|
||||
self.log.error("OAuth POST from %s != %s", referer, full_url)
|
||||
raise web.HTTPError(403, "Authorization form must be sent from authorization page")
|
||||
raise web.HTTPError(
|
||||
403, "Authorization form must be sent from authorization page"
|
||||
)
|
||||
|
||||
# The scopes the user actually authorized, i.e. checkboxes
|
||||
# that were selected.
|
||||
@@ -262,7 +260,7 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
|
||||
|
||||
try:
|
||||
headers, body, status = self.oauth_provider.create_authorization_response(
|
||||
uri, http_method, body, headers, scopes, credentials,
|
||||
uri, http_method, body, headers, scopes, credentials
|
||||
)
|
||||
except oauth2.FatalClientError as e:
|
||||
raise web.HTTPError(e.status_code, e.description)
|
||||
@@ -277,7 +275,8 @@ class OAuthTokenHandler(OAuthHandler, APIHandler):
|
||||
|
||||
try:
|
||||
headers, body, status = self.oauth_provider.create_token_response(
|
||||
uri, http_method, body, headers, credentials)
|
||||
uri, http_method, body, headers, credentials
|
||||
)
|
||||
except oauth2.FatalClientError as e:
|
||||
raise web.HTTPError(e.status_code, e.description)
|
||||
else:
|
||||
|
@@ -1,10 +1,8 @@
|
||||
"""Base API handlers"""
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from datetime import datetime
|
||||
from http.client import responses
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@@ -12,7 +10,8 @@ from tornado import web
|
||||
|
||||
from .. import orm
|
||||
from ..handlers import BaseHandler
|
||||
from ..utils import isoformat, url_path_join
|
||||
from ..utils import isoformat
|
||||
from ..utils import url_path_join
|
||||
|
||||
|
||||
class APIHandler(BaseHandler):
|
||||
@@ -55,8 +54,11 @@ class APIHandler(BaseHandler):
|
||||
host_path = url_path_join(host, self.hub.base_url)
|
||||
referer_path = referer.split('://', 1)[-1]
|
||||
if not (referer_path + '/').startswith(host_path):
|
||||
self.log.warning("Blocking Cross Origin API request. Referer: %s, Host: %s",
|
||||
referer, host_path)
|
||||
self.log.warning(
|
||||
"Blocking Cross Origin API request. Referer: %s, Host: %s",
|
||||
referer,
|
||||
host_path,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -105,9 +107,13 @@ class APIHandler(BaseHandler):
|
||||
if exception and isinstance(exception, SQLAlchemyError):
|
||||
try:
|
||||
exception_str = str(exception)
|
||||
self.log.warning("Rolling back session due to database error %s", exception_str)
|
||||
self.log.warning(
|
||||
"Rolling back session due to database error %s", exception_str
|
||||
)
|
||||
except Exception:
|
||||
self.log.warning("Rolling back session due to database error %s", type(exception))
|
||||
self.log.warning(
|
||||
"Rolling back session due to database error %s", type(exception)
|
||||
)
|
||||
self.db.rollback()
|
||||
|
||||
self.set_header('Content-Type', 'application/json')
|
||||
@@ -121,10 +127,9 @@ class APIHandler(BaseHandler):
|
||||
# Content-Length must be recalculated.
|
||||
self.clear_header('Content-Length')
|
||||
|
||||
self.write(json.dumps({
|
||||
'status': status_code,
|
||||
'message': message or status_message,
|
||||
}))
|
||||
self.write(
|
||||
json.dumps({'status': status_code, 'message': message or status_message})
|
||||
)
|
||||
|
||||
def server_model(self, spawner, include_state=False):
|
||||
"""Get the JSON model for a Spawner"""
|
||||
@@ -144,21 +149,17 @@ class APIHandler(BaseHandler):
|
||||
expires_at = None
|
||||
if isinstance(token, orm.APIToken):
|
||||
kind = 'api_token'
|
||||
extra = {
|
||||
'note': token.note,
|
||||
}
|
||||
extra = {'note': token.note}
|
||||
expires_at = token.expires_at
|
||||
elif isinstance(token, orm.OAuthAccessToken):
|
||||
kind = 'oauth'
|
||||
extra = {
|
||||
'oauth_client': token.client.description or token.client.client_id,
|
||||
}
|
||||
extra = {'oauth_client': token.client.description or token.client.client_id}
|
||||
if token.expires_at:
|
||||
expires_at = datetime.fromtimestamp(token.expires_at)
|
||||
else:
|
||||
raise TypeError(
|
||||
"token must be an APIToken or OAuthAccessToken, not %s"
|
||||
% type(token))
|
||||
"token must be an APIToken or OAuthAccessToken, not %s" % type(token)
|
||||
)
|
||||
|
||||
if token.user:
|
||||
owner_key = 'user'
|
||||
@@ -219,23 +220,11 @@ class APIHandler(BaseHandler):
|
||||
|
||||
def service_model(self, service):
|
||||
"""Get the JSON model for a Service object"""
|
||||
return {
|
||||
'kind': 'service',
|
||||
'name': service.name,
|
||||
'admin': service.admin,
|
||||
}
|
||||
return {'kind': 'service', 'name': service.name, 'admin': service.admin}
|
||||
|
||||
_user_model_types = {
|
||||
'name': str,
|
||||
'admin': bool,
|
||||
'groups': list,
|
||||
'auth_state': dict,
|
||||
}
|
||||
_user_model_types = {'name': str, 'admin': bool, 'groups': list, 'auth_state': dict}
|
||||
|
||||
_group_model_types = {
|
||||
'name': str,
|
||||
'users': list,
|
||||
}
|
||||
_group_model_types = {'name': str, 'users': list}
|
||||
|
||||
def _check_model(self, model, model_types, name):
|
||||
"""Check a model provided by a REST API request
|
||||
@@ -251,24 +240,29 @@ class APIHandler(BaseHandler):
|
||||
raise web.HTTPError(400, "Invalid JSON keys: %r" % model)
|
||||
for key, value in model.items():
|
||||
if not isinstance(value, model_types[key]):
|
||||
raise web.HTTPError(400, "%s.%s must be %s, not: %r" % (
|
||||
name, key, model_types[key], type(value)
|
||||
))
|
||||
raise web.HTTPError(
|
||||
400,
|
||||
"%s.%s must be %s, not: %r"
|
||||
% (name, key, model_types[key], type(value)),
|
||||
)
|
||||
|
||||
def _check_user_model(self, model):
|
||||
"""Check a request-provided user model from a REST API"""
|
||||
self._check_model(model, self._user_model_types, 'user')
|
||||
for username in model.get('users', []):
|
||||
if not isinstance(username, str):
|
||||
raise web.HTTPError(400, ("usernames must be str, not %r", type(username)))
|
||||
raise web.HTTPError(
|
||||
400, ("usernames must be str, not %r", type(username))
|
||||
)
|
||||
|
||||
def _check_group_model(self, model):
|
||||
"""Check a request-provided group model from a REST API"""
|
||||
self._check_model(model, self._group_model_types, 'group')
|
||||
for groupname in model.get('groups', []):
|
||||
if not isinstance(groupname, str):
|
||||
raise web.HTTPError(400, ("group names must be str, not %r", type(groupname)))
|
||||
|
||||
raise web.HTTPError(
|
||||
400, ("group names must be str, not %r", type(groupname))
|
||||
)
|
||||
|
||||
def options(self, *args, **kwargs):
|
||||
self.finish()
|
||||
@@ -279,6 +273,7 @@ class API404(APIHandler):
|
||||
|
||||
Ensures JSON 404 errors for malformed URLs
|
||||
"""
|
||||
|
||||
async def prepare(self):
|
||||
await super().prepare()
|
||||
raise web.HTTPError(404)
|
||||
|
@@ -1,11 +1,10 @@
|
||||
"""Group handlers"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import json
|
||||
|
||||
from tornado import gen, web
|
||||
from tornado import gen
|
||||
from tornado import web
|
||||
|
||||
from .. import orm
|
||||
from ..utils import admin_only
|
||||
@@ -34,6 +33,7 @@ class _GroupAPIHandler(APIHandler):
|
||||
raise web.HTTPError(404, "No such group: %s", name)
|
||||
return group
|
||||
|
||||
|
||||
class GroupListAPIHandler(_GroupAPIHandler):
|
||||
@admin_only
|
||||
def get(self):
|
||||
@@ -61,9 +61,7 @@ class GroupListAPIHandler(_GroupAPIHandler):
|
||||
# check that users exist
|
||||
users = self._usernames_to_users(usernames)
|
||||
# create the group
|
||||
self.log.info("Creating new group %s with %i users",
|
||||
name, len(users),
|
||||
)
|
||||
self.log.info("Creating new group %s with %i users", name, len(users))
|
||||
self.log.debug("Users: %s", usernames)
|
||||
group = orm.Group(name=name, users=users)
|
||||
self.db.add(group)
|
||||
@@ -99,9 +97,7 @@ class GroupAPIHandler(_GroupAPIHandler):
|
||||
users = self._usernames_to_users(usernames)
|
||||
|
||||
# create the group
|
||||
self.log.info("Creating new group %s with %i users",
|
||||
name, len(users),
|
||||
)
|
||||
self.log.info("Creating new group %s with %i users", name, len(users))
|
||||
self.log.debug("Users: %s", usernames)
|
||||
group = orm.Group(name=name, users=users)
|
||||
self.db.add(group)
|
||||
@@ -121,6 +117,7 @@ class GroupAPIHandler(_GroupAPIHandler):
|
||||
|
||||
class GroupUsersAPIHandler(_GroupAPIHandler):
|
||||
"""Modify a group's user list"""
|
||||
|
||||
@admin_only
|
||||
def post(self, name):
|
||||
"""POST adds users to a group"""
|
||||
|
@@ -1,21 +1,18 @@
|
||||
"""API handlers for administering the Hub itself"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from tornado import web
|
||||
from tornado.ioloop import IOLoop
|
||||
|
||||
from .._version import __version__
|
||||
from ..utils import admin_only
|
||||
from .base import APIHandler
|
||||
from .._version import __version__
|
||||
|
||||
|
||||
class ShutdownAPIHandler(APIHandler):
|
||||
|
||||
@admin_only
|
||||
def post(self):
|
||||
"""POST /api/shutdown triggers a clean shutdown
|
||||
@@ -26,6 +23,7 @@ class ShutdownAPIHandler(APIHandler):
|
||||
- proxy: specify whether the proxy should be terminated
|
||||
"""
|
||||
from ..app import JupyterHub
|
||||
|
||||
app = JupyterHub.instance()
|
||||
|
||||
data = self.get_json_body()
|
||||
@@ -33,19 +31,21 @@ class ShutdownAPIHandler(APIHandler):
|
||||
if 'proxy' in data:
|
||||
proxy = data['proxy']
|
||||
if proxy not in {True, False}:
|
||||
raise web.HTTPError(400, "proxy must be true or false, got %r" % proxy)
|
||||
raise web.HTTPError(
|
||||
400, "proxy must be true or false, got %r" % proxy
|
||||
)
|
||||
app.cleanup_proxy = proxy
|
||||
if 'servers' in data:
|
||||
servers = data['servers']
|
||||
if servers not in {True, False}:
|
||||
raise web.HTTPError(400, "servers must be true or false, got %r" % servers)
|
||||
raise web.HTTPError(
|
||||
400, "servers must be true or false, got %r" % servers
|
||||
)
|
||||
app.cleanup_servers = servers
|
||||
|
||||
# finish the request
|
||||
self.set_status(202)
|
||||
self.finish(json.dumps({
|
||||
"message": "Shutting down Hub"
|
||||
}))
|
||||
self.finish(json.dumps({"message": "Shutting down Hub"}))
|
||||
|
||||
# stop the eventloop, which will trigger cleanup
|
||||
loop = IOLoop.current()
|
||||
@@ -53,7 +53,6 @@ class ShutdownAPIHandler(APIHandler):
|
||||
|
||||
|
||||
class RootAPIHandler(APIHandler):
|
||||
|
||||
def get(self):
|
||||
"""GET /api/ returns info about the Hub and its API.
|
||||
|
||||
@@ -61,14 +60,11 @@ class RootAPIHandler(APIHandler):
|
||||
|
||||
For now, it just returns the version of JupyterHub itself.
|
||||
"""
|
||||
data = {
|
||||
'version': __version__,
|
||||
}
|
||||
data = {'version': __version__}
|
||||
self.finish(json.dumps(data))
|
||||
|
||||
|
||||
class InfoAPIHandler(APIHandler):
|
||||
|
||||
@admin_only
|
||||
def get(self):
|
||||
"""GET /api/info returns detailed info about the Hub and its API.
|
||||
@@ -77,10 +73,11 @@ class InfoAPIHandler(APIHandler):
|
||||
|
||||
For now, it just returns the version of JupyterHub itself.
|
||||
"""
|
||||
|
||||
def _class_info(typ):
|
||||
"""info about a class (Spawner or Authenticator)"""
|
||||
info = {
|
||||
'class': '{mod}.{name}'.format(mod=typ.__module__, name=typ.__name__),
|
||||
'class': '{mod}.{name}'.format(mod=typ.__module__, name=typ.__name__)
|
||||
}
|
||||
pkg = typ.__module__.split('.')[0]
|
||||
try:
|
||||
|
@@ -1,12 +1,11 @@
|
||||
"""Proxy handlers"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from tornado import gen, web
|
||||
from tornado import gen
|
||||
from tornado import web
|
||||
|
||||
from .. import orm
|
||||
from ..utils import admin_only
|
||||
@@ -14,7 +13,6 @@ from .base import APIHandler
|
||||
|
||||
|
||||
class ProxyAPIHandler(APIHandler):
|
||||
|
||||
@admin_only
|
||||
async def get(self):
|
||||
"""GET /api/proxy fetches the routing table
|
||||
@@ -58,6 +56,4 @@ class ProxyAPIHandler(APIHandler):
|
||||
await self.proxy.check_routes(self.users, self.services)
|
||||
|
||||
|
||||
default_handlers = [
|
||||
(r"/api/proxy", ProxyAPIHandler),
|
||||
]
|
||||
default_handlers = [(r"/api/proxy", ProxyAPIHandler)]
|
||||
|
@@ -2,10 +2,8 @@
|
||||
|
||||
Currently GET-only, no actions can be taken to modify services.
|
||||
"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import json
|
||||
|
||||
from tornado import web
|
||||
@@ -14,6 +12,7 @@ from .. import orm
|
||||
from ..utils import admin_only
|
||||
from .base import APIHandler
|
||||
|
||||
|
||||
def service_model(service):
|
||||
"""Produce the model for a service"""
|
||||
return {
|
||||
@@ -23,9 +22,10 @@ def service_model(service):
|
||||
'prefix': service.server.base_url if service.server else '',
|
||||
'command': service.command,
|
||||
'pid': service.proc.pid if service.proc else 0,
|
||||
'info': service.info
|
||||
'info': service.info,
|
||||
}
|
||||
|
||||
|
||||
class ServiceListAPIHandler(APIHandler):
|
||||
@admin_only
|
||||
def get(self):
|
||||
@@ -35,6 +35,7 @@ class ServiceListAPIHandler(APIHandler):
|
||||
|
||||
def admin_or_self(method):
|
||||
"""Decorator for restricting access to either the target service or admin"""
|
||||
|
||||
def decorated_method(self, name):
|
||||
current = self.current_user
|
||||
if current is None:
|
||||
@@ -49,10 +50,11 @@ def admin_or_self(method):
|
||||
if name not in self.services:
|
||||
raise web.HTTPError(404)
|
||||
return method(self, name)
|
||||
|
||||
return decorated_method
|
||||
|
||||
class ServiceAPIHandler(APIHandler):
|
||||
|
||||
class ServiceAPIHandler(APIHandler):
|
||||
@admin_or_self
|
||||
def get(self, name):
|
||||
service = self.services[name]
|
||||
|
@@ -1,11 +1,11 @@
|
||||
"""User handlers"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from async_generator import aclosing
|
||||
from dateutil.parser import parse as parse_date
|
||||
@@ -14,7 +14,11 @@ from tornado.iostream import StreamClosedError
|
||||
|
||||
from .. import orm
|
||||
from ..user import User
|
||||
from ..utils import admin_only, isoformat, iterate_until, maybe_future, url_path_join
|
||||
from ..utils import admin_only
|
||||
from ..utils import isoformat
|
||||
from ..utils import iterate_until
|
||||
from ..utils import maybe_future
|
||||
from ..utils import url_path_join
|
||||
from .base import APIHandler
|
||||
|
||||
|
||||
@@ -89,7 +93,9 @@ class UserListAPIHandler(APIHandler):
|
||||
except Exception as e:
|
||||
self.log.error("Failed to create user: %s" % name, exc_info=True)
|
||||
self.users.delete(user)
|
||||
raise web.HTTPError(400, "Failed to create user %s: %s" % (name, str(e)))
|
||||
raise web.HTTPError(
|
||||
400, "Failed to create user %s: %s" % (name, str(e))
|
||||
)
|
||||
else:
|
||||
created.append(user)
|
||||
|
||||
@@ -99,6 +105,7 @@ class UserListAPIHandler(APIHandler):
|
||||
|
||||
def admin_or_self(method):
|
||||
"""Decorator for restricting access to either the target user or admin"""
|
||||
|
||||
def m(self, name, *args, **kwargs):
|
||||
current = self.current_user
|
||||
if current is None:
|
||||
@@ -110,15 +117,17 @@ def admin_or_self(method):
|
||||
if not self.find_user(name):
|
||||
raise web.HTTPError(404)
|
||||
return method(self, name, *args, **kwargs)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class UserAPIHandler(APIHandler):
|
||||
|
||||
@admin_or_self
|
||||
async def get(self, name):
|
||||
user = self.find_user(name)
|
||||
model = self.user_model(user, include_servers=True, include_state=self.current_user.admin)
|
||||
model = self.user_model(
|
||||
user, include_servers=True, include_state=self.current_user.admin
|
||||
)
|
||||
# auth state will only be shown if the requester is an admin
|
||||
# this means users can't see their own auth state unless they
|
||||
# are admins, Hub admins often are also marked as admins so they
|
||||
@@ -161,11 +170,16 @@ class UserAPIHandler(APIHandler):
|
||||
if user.name == self.current_user.name:
|
||||
raise web.HTTPError(400, "Cannot delete yourself!")
|
||||
if user.spawner._stop_pending:
|
||||
raise web.HTTPError(400, "%s's server is in the process of stopping, please wait." % name)
|
||||
raise web.HTTPError(
|
||||
400, "%s's server is in the process of stopping, please wait." % name
|
||||
)
|
||||
if user.running:
|
||||
await self.stop_single_user(user)
|
||||
if user.spawner._stop_pending:
|
||||
raise web.HTTPError(400, "%s's server is in the process of stopping, please wait." % name)
|
||||
raise web.HTTPError(
|
||||
400,
|
||||
"%s's server is in the process of stopping, please wait." % name,
|
||||
)
|
||||
|
||||
await maybe_future(self.authenticator.delete_user(user))
|
||||
# remove from registry
|
||||
@@ -183,7 +197,10 @@ class UserAPIHandler(APIHandler):
|
||||
if 'name' in data and data['name'] != name:
|
||||
# check if the new name is already taken inside db
|
||||
if self.find_user(data['name']):
|
||||
raise web.HTTPError(400, "User %s already exists, username must be unique" % data['name'])
|
||||
raise web.HTTPError(
|
||||
400,
|
||||
"User %s already exists, username must be unique" % data['name'],
|
||||
)
|
||||
for key, value in data.items():
|
||||
if key == 'auth_state':
|
||||
await user.save_auth_state(value)
|
||||
@@ -197,6 +214,7 @@ class UserAPIHandler(APIHandler):
|
||||
|
||||
class UserTokenListAPIHandler(APIHandler):
|
||||
"""API endpoint for listing/creating tokens"""
|
||||
|
||||
@admin_or_self
|
||||
def get(self, name):
|
||||
"""Get tokens for a given user"""
|
||||
@@ -207,6 +225,7 @@ class UserTokenListAPIHandler(APIHandler):
|
||||
now = datetime.utcnow()
|
||||
|
||||
api_tokens = []
|
||||
|
||||
def sort_key(token):
|
||||
return token.last_activity or token.created
|
||||
|
||||
@@ -228,10 +247,7 @@ class UserTokenListAPIHandler(APIHandler):
|
||||
self.db.commit()
|
||||
continue
|
||||
oauth_tokens.append(self.token_model(token))
|
||||
self.write(json.dumps({
|
||||
'api_tokens': api_tokens,
|
||||
'oauth_tokens': oauth_tokens,
|
||||
}))
|
||||
self.write(json.dumps({'api_tokens': api_tokens, 'oauth_tokens': oauth_tokens}))
|
||||
|
||||
async def post(self, name):
|
||||
body = self.get_json_body() or {}
|
||||
@@ -253,8 +269,9 @@ class UserTokenListAPIHandler(APIHandler):
|
||||
except Exception as e:
|
||||
# suppress and log error here in case Authenticator
|
||||
# isn't prepared to handle auth via this data
|
||||
self.log.error("Error authenticating request for %s: %s",
|
||||
self.request.uri, e)
|
||||
self.log.error(
|
||||
"Error authenticating request for %s: %s", self.request.uri, e
|
||||
)
|
||||
raise web.HTTPError(403)
|
||||
requester = self.find_user(name)
|
||||
if requester is None:
|
||||
@@ -274,9 +291,16 @@ class UserTokenListAPIHandler(APIHandler):
|
||||
if requester is not user:
|
||||
note += " by %s %s" % (kind, requester.name)
|
||||
|
||||
api_token = user.new_api_token(note=note, expires_in=body.get('expires_in', None))
|
||||
api_token = user.new_api_token(
|
||||
note=note, expires_in=body.get('expires_in', None)
|
||||
)
|
||||
if requester is not user:
|
||||
self.log.info("%s %s requested API token for %s", kind.title(), requester.name, user.name)
|
||||
self.log.info(
|
||||
"%s %s requested API token for %s",
|
||||
kind.title(),
|
||||
requester.name,
|
||||
user.name,
|
||||
)
|
||||
else:
|
||||
user_kind = 'user' if isinstance(user, User) else 'service'
|
||||
self.log.info("%s %s requested new API token", user_kind.title(), user.name)
|
||||
@@ -333,8 +357,7 @@ class UserTokenAPIHandler(APIHandler):
|
||||
if isinstance(token, orm.OAuthAccessToken):
|
||||
client_id = token.client_id
|
||||
tokens = [
|
||||
token for token in user.oauth_tokens
|
||||
if token.client_id == client_id
|
||||
token for token in user.oauth_tokens if token.client_id == client_id
|
||||
]
|
||||
else:
|
||||
tokens = [token]
|
||||
@@ -354,16 +377,19 @@ class UserServerAPIHandler(APIHandler):
|
||||
if server_name:
|
||||
if not self.allow_named_servers:
|
||||
raise web.HTTPError(400, "Named servers are not enabled.")
|
||||
if self.named_server_limit_per_user > 0 and server_name not in user.orm_spawners:
|
||||
if (
|
||||
self.named_server_limit_per_user > 0
|
||||
and server_name not in user.orm_spawners
|
||||
):
|
||||
named_spawners = list(user.all_spawners(include_default=False))
|
||||
if self.named_server_limit_per_user <= len(named_spawners):
|
||||
raise web.HTTPError(
|
||||
400,
|
||||
"User {} already has the maximum of {} named servers."
|
||||
" One must be deleted before a new server can be created".format(
|
||||
name,
|
||||
self.named_server_limit_per_user
|
||||
))
|
||||
name, self.named_server_limit_per_user
|
||||
),
|
||||
)
|
||||
spawner = user.spawners[server_name]
|
||||
pending = spawner.pending
|
||||
if pending == 'spawn':
|
||||
@@ -396,7 +422,6 @@ class UserServerAPIHandler(APIHandler):
|
||||
options = self.get_json_body()
|
||||
remove = (options or {}).get('remove', False)
|
||||
|
||||
|
||||
def _remove_spawner(f=None):
|
||||
if f and f.exception():
|
||||
return
|
||||
@@ -408,7 +433,9 @@ class UserServerAPIHandler(APIHandler):
|
||||
if not self.allow_named_servers:
|
||||
raise web.HTTPError(400, "Named servers are not enabled.")
|
||||
if server_name not in user.orm_spawners:
|
||||
raise web.HTTPError(404, "%s has no server named '%s'" % (name, server_name))
|
||||
raise web.HTTPError(
|
||||
404, "%s has no server named '%s'" % (name, server_name)
|
||||
)
|
||||
elif remove:
|
||||
raise web.HTTPError(400, "Cannot delete the default server")
|
||||
|
||||
@@ -423,7 +450,8 @@ class UserServerAPIHandler(APIHandler):
|
||||
|
||||
if spawner.pending:
|
||||
raise web.HTTPError(
|
||||
400, "%s is pending %s, please wait" % (spawner._log_name, spawner.pending)
|
||||
400,
|
||||
"%s is pending %s, please wait" % (spawner._log_name, spawner.pending),
|
||||
)
|
||||
|
||||
stop_future = None
|
||||
@@ -449,13 +477,16 @@ class UserAdminAccessAPIHandler(APIHandler):
|
||||
|
||||
This handler sets the necessary cookie for an admin to login to a single-user server.
|
||||
"""
|
||||
|
||||
@admin_only
|
||||
def post(self, name):
|
||||
self.log.warning("Deprecated in JupyterHub 0.8."
|
||||
" Admin access API is not needed now that we use OAuth.")
|
||||
self.log.warning(
|
||||
"Deprecated in JupyterHub 0.8."
|
||||
" Admin access API is not needed now that we use OAuth."
|
||||
)
|
||||
current = self.current_user
|
||||
self.log.warning("Admin user %s has requested access to %s's server",
|
||||
current.name, name,
|
||||
self.log.warning(
|
||||
"Admin user %s has requested access to %s's server", current.name, name
|
||||
)
|
||||
if not self.settings.get('admin_access', False):
|
||||
raise web.HTTPError(403, "admin access to user servers disabled")
|
||||
@@ -501,10 +532,7 @@ class SpawnProgressAPIHandler(APIHandler):
|
||||
except (StreamClosedError, RuntimeError):
|
||||
return
|
||||
|
||||
await asyncio.wait(
|
||||
[self._finish_future],
|
||||
timeout=self.keepalive_interval,
|
||||
)
|
||||
await asyncio.wait([self._finish_future], timeout=self.keepalive_interval)
|
||||
|
||||
@admin_or_self
|
||||
async def get(self, username, server_name=''):
|
||||
@@ -535,11 +563,7 @@ class SpawnProgressAPIHandler(APIHandler):
|
||||
'html_message': 'Server ready at <a href="{0}">{0}</a>'.format(url),
|
||||
'url': url,
|
||||
}
|
||||
failed_event = {
|
||||
'progress': 100,
|
||||
'failed': True,
|
||||
'message': "Spawn failed",
|
||||
}
|
||||
failed_event = {'progress': 100, 'failed': True, 'message': "Spawn failed"}
|
||||
|
||||
if spawner.ready:
|
||||
# spawner already ready. Trigger progress-completion immediately
|
||||
@@ -561,7 +585,9 @@ class SpawnProgressAPIHandler(APIHandler):
|
||||
raise web.HTTPError(400, "%s is not starting...", spawner._log_name)
|
||||
|
||||
# retrieve progress events from the Spawner
|
||||
async with aclosing(iterate_until(spawn_future, spawner._generate_progress())) as events:
|
||||
async with aclosing(
|
||||
iterate_until(spawn_future, spawner._generate_progress())
|
||||
) as events:
|
||||
async for event in events:
|
||||
# don't allow events to sneakily set the 'ready' flag
|
||||
if 'ready' in event:
|
||||
@@ -584,7 +610,9 @@ class SpawnProgressAPIHandler(APIHandler):
|
||||
if f and f.done() and f.exception():
|
||||
failed_event['message'] = "Spawn failed: %s" % f.exception()
|
||||
else:
|
||||
self.log.warning("Server %s didn't start for unknown reason", spawner._log_name)
|
||||
self.log.warning(
|
||||
"Server %s didn't start for unknown reason", spawner._log_name
|
||||
)
|
||||
await self.send_event(failed_event)
|
||||
|
||||
|
||||
@@ -609,13 +637,12 @@ def _parse_timestamp(timestamp):
|
||||
400,
|
||||
"Rejecting activity from more than an hour in the future: {}".format(
|
||||
isoformat(dt)
|
||||
)
|
||||
),
|
||||
)
|
||||
return dt
|
||||
|
||||
|
||||
class ActivityAPIHandler(APIHandler):
|
||||
|
||||
def _validate_servers(self, user, servers):
|
||||
"""Validate servers dict argument
|
||||
|
||||
@@ -632,10 +659,7 @@ class ActivityAPIHandler(APIHandler):
|
||||
if server_name not in spawners:
|
||||
raise web.HTTPError(
|
||||
400,
|
||||
"No such server '{}' for user {}".format(
|
||||
server_name,
|
||||
user.name,
|
||||
)
|
||||
"No such server '{}' for user {}".format(server_name, user.name),
|
||||
)
|
||||
# check that each per-server field is a dict
|
||||
if not isinstance(server_info, dict):
|
||||
@@ -645,7 +669,9 @@ class ActivityAPIHandler(APIHandler):
|
||||
raise web.HTTPError(400, msg)
|
||||
# parse last_activity timestamps
|
||||
# _parse_timestamp above is responsible for raising errors
|
||||
server_info['last_activity'] = _parse_timestamp(server_info['last_activity'])
|
||||
server_info['last_activity'] = _parse_timestamp(
|
||||
server_info['last_activity']
|
||||
)
|
||||
return servers
|
||||
|
||||
@admin_or_self
|
||||
@@ -663,8 +689,7 @@ class ActivityAPIHandler(APIHandler):
|
||||
servers = body.get('servers')
|
||||
if not last_activity_timestamp and not servers:
|
||||
raise web.HTTPError(
|
||||
400,
|
||||
"body must contain at least one of `last_activity` or `servers`"
|
||||
400, "body must contain at least one of `last_activity` or `servers`"
|
||||
)
|
||||
|
||||
if servers:
|
||||
@@ -677,13 +702,9 @@ class ActivityAPIHandler(APIHandler):
|
||||
# update user.last_activity if specified
|
||||
if last_activity_timestamp:
|
||||
last_activity = _parse_timestamp(last_activity_timestamp)
|
||||
if (
|
||||
(not user.last_activity)
|
||||
or last_activity > user.last_activity
|
||||
):
|
||||
self.log.debug("Activity for user %s: %s",
|
||||
user.name,
|
||||
isoformat(last_activity),
|
||||
if (not user.last_activity) or last_activity > user.last_activity:
|
||||
self.log.debug(
|
||||
"Activity for user %s: %s", user.name, isoformat(last_activity)
|
||||
)
|
||||
user.last_activity = last_activity
|
||||
else:
|
||||
@@ -699,11 +720,9 @@ class ActivityAPIHandler(APIHandler):
|
||||
last_activity = server_info['last_activity']
|
||||
spawner = user.orm_spawners[server_name]
|
||||
|
||||
if (
|
||||
(not spawner.last_activity)
|
||||
or last_activity > spawner.last_activity
|
||||
):
|
||||
self.log.debug("Activity on server %s/%s: %s",
|
||||
if (not spawner.last_activity) or last_activity > spawner.last_activity:
|
||||
self.log.debug(
|
||||
"Activity on server %s/%s: %s",
|
||||
user.name,
|
||||
server_name,
|
||||
isoformat(last_activity),
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -1,16 +1,16 @@
|
||||
"""Base Authenticator class and the default PAM Authenticator"""
|
||||
|
||||
# Copyright (c) IPython Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import inspect
|
||||
import pipes
|
||||
import re
|
||||
import sys
|
||||
from shutil import which
|
||||
from subprocess import Popen, PIPE, STDOUT
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from shutil import which
|
||||
from subprocess import PIPE
|
||||
from subprocess import Popen
|
||||
from subprocess import STDOUT
|
||||
|
||||
try:
|
||||
import pamela
|
||||
@@ -33,7 +33,9 @@ class Authenticator(LoggingConfigurable):
|
||||
|
||||
db = Any()
|
||||
|
||||
enable_auth_state = Bool(False, config=True,
|
||||
enable_auth_state = Bool(
|
||||
False,
|
||||
config=True,
|
||||
help="""Enable persisting auth_state (if available).
|
||||
|
||||
auth_state will be encrypted and stored in the Hub's database.
|
||||
@@ -62,7 +64,7 @@ class Authenticator(LoggingConfigurable):
|
||||
|
||||
See :meth:`.refresh_user` for what happens when user auth info is refreshed
|
||||
(nothing by default).
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
refresh_pre_spawn = Bool(
|
||||
@@ -78,7 +80,7 @@ class Authenticator(LoggingConfigurable):
|
||||
|
||||
If refresh_user cannot refresh the user auth data,
|
||||
launch will fail until the user logs in again.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
admin_users = Set(
|
||||
@@ -131,8 +133,11 @@ class Authenticator(LoggingConfigurable):
|
||||
sorted_names = sorted(short_names)
|
||||
single = ''.join(sorted_names)
|
||||
string_set_typo = "set('%s')" % single
|
||||
self.log.warning("whitelist contains single-character names: %s; did you mean set([%r]) instead of %s?",
|
||||
sorted_names[:8], single, string_set_typo,
|
||||
self.log.warning(
|
||||
"whitelist contains single-character names: %s; did you mean set([%r]) instead of %s?",
|
||||
sorted_names[:8],
|
||||
single,
|
||||
string_set_typo,
|
||||
)
|
||||
|
||||
custom_html = Unicode(
|
||||
@@ -199,7 +204,8 @@ class Authenticator(LoggingConfigurable):
|
||||
"""
|
||||
).tag(config=True)
|
||||
|
||||
delete_invalid_users = Bool(False,
|
||||
delete_invalid_users = Bool(
|
||||
False,
|
||||
help="""Delete any users from the database that do not pass validation
|
||||
|
||||
When JupyterHub starts, `.add_user` will be called
|
||||
@@ -213,10 +219,11 @@ class Authenticator(LoggingConfigurable):
|
||||
If False (default), invalid users remain in the Hub's database
|
||||
and a warning will be issued.
|
||||
This is the default to avoid data loss due to config changes.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
post_auth_hook = Any(config=True,
|
||||
post_auth_hook = Any(
|
||||
config=True,
|
||||
help="""
|
||||
An optional hook function that you can implement to do some
|
||||
bootstrapping work during authentication. For example, loading user account
|
||||
@@ -248,12 +255,16 @@ class Authenticator(LoggingConfigurable):
|
||||
|
||||
c.Authenticator.post_auth_hook = my_hook
|
||||
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for method_name in ('check_whitelist', 'check_blacklist', 'check_group_whitelist'):
|
||||
for method_name in (
|
||||
'check_whitelist',
|
||||
'check_blacklist',
|
||||
'check_group_whitelist',
|
||||
):
|
||||
original_method = getattr(self, method_name, None)
|
||||
if original_method is None:
|
||||
# no such method (check_group_whitelist is optional)
|
||||
@@ -273,14 +284,14 @@ class Authenticator(LoggingConfigurable):
|
||||
|
||||
Adapting for compatibility.
|
||||
""".format(
|
||||
self.__class__.__name__,
|
||||
method_name,
|
||||
self.__class__.__name__, method_name
|
||||
),
|
||||
DeprecationWarning
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def wrapped_method(username, authentication=None, **kwargs):
|
||||
return original_method(username, **kwargs)
|
||||
|
||||
setattr(self, method_name, wrapped_method)
|
||||
|
||||
async def run_post_auth_hook(self, handler, authentication):
|
||||
@@ -299,11 +310,7 @@ class Authenticator(LoggingConfigurable):
|
||||
"""
|
||||
if self.post_auth_hook is not None:
|
||||
authentication = await maybe_future(
|
||||
self.post_auth_hook(
|
||||
self,
|
||||
handler,
|
||||
authentication,
|
||||
)
|
||||
self.post_auth_hook(self, handler, authentication)
|
||||
)
|
||||
return authentication
|
||||
|
||||
@@ -380,21 +387,25 @@ class Authenticator(LoggingConfigurable):
|
||||
if 'name' not in authenticated:
|
||||
raise ValueError("user missing a name: %r" % authenticated)
|
||||
else:
|
||||
authenticated = {
|
||||
'name': authenticated,
|
||||
}
|
||||
authenticated = {'name': authenticated}
|
||||
authenticated.setdefault('auth_state', None)
|
||||
# Leave the default as None, but reevaluate later post-whitelist
|
||||
authenticated.setdefault('admin', None)
|
||||
|
||||
# normalize the username
|
||||
authenticated['name'] = username = self.normalize_username(authenticated['name'])
|
||||
authenticated['name'] = username = self.normalize_username(
|
||||
authenticated['name']
|
||||
)
|
||||
if not self.validate_username(username):
|
||||
self.log.warning("Disallowing invalid username %r.", username)
|
||||
return
|
||||
|
||||
blacklist_pass = await maybe_future(self.check_blacklist(username, authenticated))
|
||||
whitelist_pass = await maybe_future(self.check_whitelist(username, authenticated))
|
||||
blacklist_pass = await maybe_future(
|
||||
self.check_blacklist(username, authenticated)
|
||||
)
|
||||
whitelist_pass = await maybe_future(
|
||||
self.check_whitelist(username, authenticated)
|
||||
)
|
||||
|
||||
if blacklist_pass:
|
||||
pass
|
||||
@@ -404,7 +415,9 @@ class Authenticator(LoggingConfigurable):
|
||||
|
||||
if whitelist_pass:
|
||||
if authenticated['admin'] is None:
|
||||
authenticated['admin'] = await maybe_future(self.is_admin(handler, authenticated))
|
||||
authenticated['admin'] = await maybe_future(
|
||||
self.is_admin(handler, authenticated)
|
||||
)
|
||||
|
||||
authenticated = await self.run_post_auth_hook(handler, authenticated)
|
||||
|
||||
@@ -534,7 +547,9 @@ class Authenticator(LoggingConfigurable):
|
||||
"""
|
||||
self.whitelist.discard(user.name)
|
||||
|
||||
auto_login = Bool(False, config=True,
|
||||
auto_login = Bool(
|
||||
False,
|
||||
config=True,
|
||||
help="""Automatically begin the login process
|
||||
|
||||
rather than starting with a "Login with..." link at `/hub/login`
|
||||
@@ -544,7 +559,7 @@ class Authenticator(LoggingConfigurable):
|
||||
registered with `.get_handlers()`.
|
||||
|
||||
.. versionadded:: 0.8
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
def login_url(self, base_url):
|
||||
@@ -592,9 +607,7 @@ class Authenticator(LoggingConfigurable):
|
||||
list of ``('/url', Handler)`` tuples passed to tornado.
|
||||
The Hub prefix is added to any URLs.
|
||||
"""
|
||||
return [
|
||||
('/login', LoginHandler),
|
||||
]
|
||||
return [('/login', LoginHandler)]
|
||||
|
||||
|
||||
class LocalAuthenticator(Authenticator):
|
||||
@@ -603,12 +616,13 @@ class LocalAuthenticator(Authenticator):
|
||||
Checks for local users, and can attempt to create them if they exist.
|
||||
"""
|
||||
|
||||
create_system_users = Bool(False,
|
||||
create_system_users = Bool(
|
||||
False,
|
||||
help="""
|
||||
If set to True, will attempt to create local system users if they do not exist already.
|
||||
|
||||
Supports Linux and BSD variants only.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
add_user_cmd = Command(
|
||||
@@ -699,8 +713,9 @@ class LocalAuthenticator(Authenticator):
|
||||
raise KeyError(
|
||||
"User {} does not exist on the system."
|
||||
" Set LocalAuthenticator.create_system_users=True"
|
||||
" to automatically create system users from jupyterhub users."
|
||||
.format(user.name)
|
||||
" to automatically create system users from jupyterhub users.".format(
|
||||
user.name
|
||||
)
|
||||
)
|
||||
|
||||
await maybe_future(super().add_user(user))
|
||||
@@ -711,6 +726,7 @@ class LocalAuthenticator(Authenticator):
|
||||
on Windows
|
||||
"""
|
||||
import grp
|
||||
|
||||
return grp.getgrnam(name)
|
||||
|
||||
@staticmethod
|
||||
@@ -719,6 +735,7 @@ class LocalAuthenticator(Authenticator):
|
||||
on Windows
|
||||
"""
|
||||
import pwd
|
||||
|
||||
return pwd.getpwnam(name)
|
||||
|
||||
@staticmethod
|
||||
@@ -727,6 +744,7 @@ class LocalAuthenticator(Authenticator):
|
||||
on Windows
|
||||
"""
|
||||
import os
|
||||
|
||||
return os.getgrouplist(name, group)
|
||||
|
||||
def system_user_exists(self, user):
|
||||
@@ -758,23 +776,27 @@ class PAMAuthenticator(LocalAuthenticator):
|
||||
|
||||
# run PAM in a thread, since it can be slow
|
||||
executor = Any()
|
||||
|
||||
@default('executor')
|
||||
def _default_executor(self):
|
||||
return ThreadPoolExecutor(1)
|
||||
|
||||
encoding = Unicode('utf8',
|
||||
encoding = Unicode(
|
||||
'utf8',
|
||||
help="""
|
||||
The text encoding to use when communicating with PAM
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
service = Unicode('login',
|
||||
service = Unicode(
|
||||
'login',
|
||||
help="""
|
||||
The name of the PAM service to use for authentication
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
open_sessions = Bool(True,
|
||||
open_sessions = Bool(
|
||||
True,
|
||||
help="""
|
||||
Whether to open a new PAM session when spawners are started.
|
||||
|
||||
@@ -784,10 +806,11 @@ class PAMAuthenticator(LocalAuthenticator):
|
||||
|
||||
If any errors are encountered when opening/closing PAM sessions,
|
||||
this is automatically set to False.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
check_account = Bool(True,
|
||||
check_account = Bool(
|
||||
True,
|
||||
help="""
|
||||
Whether to check the user's account status via PAM during authentication.
|
||||
|
||||
@@ -797,7 +820,7 @@ class PAMAuthenticator(LocalAuthenticator):
|
||||
|
||||
Disabling this can be dangerous as authenticated but unauthorized users may
|
||||
be granted access and, therefore, arbitrary execution on the system.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
admin_groups = Set(
|
||||
@@ -809,14 +832,15 @@ class PAMAuthenticator(LocalAuthenticator):
|
||||
"""
|
||||
).tag(config=True)
|
||||
|
||||
pam_normalize_username = Bool(False,
|
||||
pam_normalize_username = Bool(
|
||||
False,
|
||||
help="""
|
||||
Round-trip the username via PAM lookups to make sure it is unique
|
||||
|
||||
PAM can accept multiple usernames that map to the same user,
|
||||
for example DOMAIN\\username in some cases. To prevent this,
|
||||
convert username into uid, then back to uid to normalize.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -844,12 +868,19 @@ class PAMAuthenticator(LocalAuthenticator):
|
||||
# (returning None instead of just the username) as this indicates some sort of system failure
|
||||
|
||||
admin_group_gids = {self._getgrnam(x).gr_gid for x in self.admin_groups}
|
||||
user_group_gids = set(self._getgrouplist(username, self._getpwnam(username).pw_gid))
|
||||
user_group_gids = set(
|
||||
self._getgrouplist(username, self._getpwnam(username).pw_gid)
|
||||
)
|
||||
admin_status = len(admin_group_gids & user_group_gids) != 0
|
||||
|
||||
except Exception as e:
|
||||
if handler is not None:
|
||||
self.log.error("PAM Admin Group Check failed (%s@%s): %s", username, handler.request.remote_ip, e)
|
||||
self.log.error(
|
||||
"PAM Admin Group Check failed (%s@%s): %s",
|
||||
username,
|
||||
handler.request.remote_ip,
|
||||
e,
|
||||
)
|
||||
else:
|
||||
self.log.error("PAM Admin Group Check failed: %s", e)
|
||||
# re-raise to return a 500 to the user and indicate a problem. We failed, not them.
|
||||
@@ -865,27 +896,40 @@ class PAMAuthenticator(LocalAuthenticator):
|
||||
"""
|
||||
username = data['username']
|
||||
try:
|
||||
pamela.authenticate(username, data['password'], service=self.service, encoding=self.encoding)
|
||||
pamela.authenticate(
|
||||
username, data['password'], service=self.service, encoding=self.encoding
|
||||
)
|
||||
except pamela.PAMError as e:
|
||||
if handler is not None:
|
||||
self.log.warning("PAM Authentication failed (%s@%s): %s", username, handler.request.remote_ip, e)
|
||||
self.log.warning(
|
||||
"PAM Authentication failed (%s@%s): %s",
|
||||
username,
|
||||
handler.request.remote_ip,
|
||||
e,
|
||||
)
|
||||
else:
|
||||
self.log.warning("PAM Authentication failed: %s", e)
|
||||
return None
|
||||
|
||||
if self.check_account:
|
||||
try:
|
||||
pamela.check_account(username, service=self.service, encoding=self.encoding)
|
||||
pamela.check_account(
|
||||
username, service=self.service, encoding=self.encoding
|
||||
)
|
||||
except pamela.PAMError as e:
|
||||
if handler is not None:
|
||||
self.log.warning("PAM Account Check failed (%s@%s): %s", username, handler.request.remote_ip, e)
|
||||
self.log.warning(
|
||||
"PAM Account Check failed (%s@%s): %s",
|
||||
username,
|
||||
handler.request.remote_ip,
|
||||
e,
|
||||
)
|
||||
else:
|
||||
self.log.warning("PAM Account Check failed: %s", e)
|
||||
return None
|
||||
|
||||
return username
|
||||
|
||||
|
||||
@run_on_executor
|
||||
def pre_spawn_start(self, user, spawner):
|
||||
"""Open PAM session for user if so configured"""
|
||||
@@ -904,7 +948,9 @@ class PAMAuthenticator(LocalAuthenticator):
|
||||
if not self.open_sessions:
|
||||
return
|
||||
try:
|
||||
pamela.close_session(user.name, service=self.service, encoding=self.encoding)
|
||||
pamela.close_session(
|
||||
user.name, service=self.service, encoding=self.encoding
|
||||
)
|
||||
except pamela.PAMError as e:
|
||||
self.log.warning("Failed to close PAM session for %s: %s", user.name, e)
|
||||
self.log.warning("Disabling PAM sessions from now on.")
|
||||
@@ -916,12 +962,14 @@ class PAMAuthenticator(LocalAuthenticator):
|
||||
PAM can accept multiple usernames as the same user, normalize them."""
|
||||
if self.pam_normalize_username:
|
||||
import pwd
|
||||
|
||||
uid = pwd.getpwnam(username).pw_uid
|
||||
username = pwd.getpwuid(uid).pw_name
|
||||
username = self.username_map.get(username, username)
|
||||
else:
|
||||
return super().normalize_username(username)
|
||||
|
||||
|
||||
class DummyAuthenticator(Authenticator):
|
||||
"""Dummy Authenticator for testing
|
||||
|
||||
@@ -938,7 +986,7 @@ class DummyAuthenticator(Authenticator):
|
||||
Set a global password for all users wanting to log in.
|
||||
|
||||
This allows users with any username to log in with the same static password.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
async def authenticate(self, handler, data):
|
||||
|
@@ -1,35 +1,43 @@
|
||||
|
||||
import base64
|
||||
from binascii import a2b_hex
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
import os
|
||||
from binascii import a2b_hex
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from traitlets.config import SingletonConfigurable, Config
|
||||
from traitlets import (
|
||||
Any, Dict, Integer, List,
|
||||
default, observe, validate,
|
||||
)
|
||||
from traitlets import Any
|
||||
from traitlets import default
|
||||
from traitlets import Dict
|
||||
from traitlets import Integer
|
||||
from traitlets import List
|
||||
from traitlets import observe
|
||||
from traitlets import validate
|
||||
from traitlets.config import Config
|
||||
from traitlets.config import SingletonConfigurable
|
||||
|
||||
try:
|
||||
import cryptography
|
||||
from cryptography.fernet import Fernet, MultiFernet, InvalidToken
|
||||
except ImportError:
|
||||
cryptography = None
|
||||
|
||||
class InvalidToken(Exception):
|
||||
pass
|
||||
|
||||
|
||||
from .utils import maybe_future
|
||||
|
||||
KEY_ENV = 'JUPYTERHUB_CRYPT_KEY'
|
||||
|
||||
|
||||
class EncryptionUnavailable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CryptographyUnavailable(EncryptionUnavailable):
|
||||
def __str__(self):
|
||||
return "cryptography library is required for encryption"
|
||||
|
||||
|
||||
class NoEncryptionKeys(EncryptionUnavailable):
|
||||
def __str__(self):
|
||||
return "Encryption keys must be specified in %s env" % KEY_ENV
|
||||
@@ -70,13 +78,16 @@ def _validate_key(key):
|
||||
|
||||
return key
|
||||
|
||||
|
||||
class CryptKeeper(SingletonConfigurable):
|
||||
"""Encapsulate encryption configuration
|
||||
|
||||
Use via the encryption_config singleton below.
|
||||
"""
|
||||
|
||||
n_threads = Integer(max(os.cpu_count(), 1), config=True,
|
||||
n_threads = Integer(
|
||||
max(os.cpu_count(), 1),
|
||||
config=True,
|
||||
help="The number of threads to allocate for encryption",
|
||||
)
|
||||
|
||||
@@ -84,22 +95,27 @@ class CryptKeeper(SingletonConfigurable):
|
||||
def _config_default(self):
|
||||
# load application config by default
|
||||
from .app import JupyterHub
|
||||
|
||||
if JupyterHub.initialized():
|
||||
return JupyterHub.instance().config
|
||||
else:
|
||||
return Config()
|
||||
|
||||
executor = Any()
|
||||
|
||||
def _executor_default(self):
|
||||
return ThreadPoolExecutor(self.n_threads)
|
||||
|
||||
keys = List(config=True)
|
||||
|
||||
def _keys_default(self):
|
||||
if KEY_ENV not in os.environ:
|
||||
return []
|
||||
# key can be a ;-separated sequence for key rotation.
|
||||
# First item in the list is used for encryption.
|
||||
return [ _validate_key(key) for key in os.environ[KEY_ENV].split(';') if key.strip() ]
|
||||
return [
|
||||
_validate_key(key) for key in os.environ[KEY_ENV].split(';') if key.strip()
|
||||
]
|
||||
|
||||
@validate('keys')
|
||||
def _ensure_bytes(self, proposal):
|
||||
@@ -107,6 +123,7 @@ class CryptKeeper(SingletonConfigurable):
|
||||
return [_validate_key(key) for key in proposal.value]
|
||||
|
||||
fernet = Any()
|
||||
|
||||
def _fernet_default(self):
|
||||
if cryptography is None or not self.keys:
|
||||
return None
|
||||
@@ -153,6 +170,7 @@ def encrypt(data):
|
||||
"""
|
||||
return CryptKeeper.instance().encrypt(data)
|
||||
|
||||
|
||||
def decrypt(data):
|
||||
"""decrypt some data with the crypt keeper
|
||||
|
||||
|
@@ -1,15 +1,13 @@
|
||||
"""Database utilities for JupyterHub"""
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
# Based on pgcontents.utils.migrate, used under the Apache license.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
import os
|
||||
import shutil
|
||||
from subprocess import check_call
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from subprocess import check_call
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
@@ -85,9 +83,7 @@ def upgrade(db_url, revision='head'):
|
||||
The alembic revision to upgrade to.
|
||||
"""
|
||||
with _temp_alembic_ini(db_url) as alembic_ini:
|
||||
check_call(
|
||||
['alembic', '-c', alembic_ini, 'upgrade', revision]
|
||||
)
|
||||
check_call(['alembic', '-c', alembic_ini, 'upgrade', revision])
|
||||
|
||||
|
||||
def backup_db_file(db_file, log=None):
|
||||
@@ -133,30 +129,27 @@ def upgrade_if_needed(db_url, backup=True, log=None):
|
||||
def shell(args=None):
|
||||
"""Start an IPython shell hooked up to the jupyerhub database"""
|
||||
from .app import JupyterHub
|
||||
|
||||
hub = JupyterHub()
|
||||
hub.load_config_file(hub.config_file)
|
||||
db_url = hub.db_url
|
||||
db = orm.new_session_factory(db_url, **hub.db_kwargs)()
|
||||
ns = {
|
||||
'db': db,
|
||||
'db_url': db_url,
|
||||
'orm': orm,
|
||||
}
|
||||
ns = {'db': db, 'db_url': db_url, 'orm': orm}
|
||||
|
||||
import IPython
|
||||
|
||||
IPython.start_ipython(args, user_ns=ns)
|
||||
|
||||
|
||||
def _alembic(args):
|
||||
"""Run an alembic command with a temporary alembic.ini"""
|
||||
from .app import JupyterHub
|
||||
|
||||
hub = JupyterHub()
|
||||
hub.load_config_file(hub.config_file)
|
||||
db_url = hub.db_url
|
||||
with _temp_alembic_ini(db_url) as alembic_ini:
|
||||
check_call(
|
||||
['alembic', '-c', alembic_ini] + args
|
||||
)
|
||||
check_call(['alembic', '-c', alembic_ini] + args)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
|
@@ -1,8 +1,10 @@
|
||||
from . import base
|
||||
from . import login
|
||||
from . import metrics
|
||||
from . import pages
|
||||
from .base import *
|
||||
from .login import *
|
||||
|
||||
from . import base, pages, login, metrics
|
||||
|
||||
default_handlers = []
|
||||
for mod in (base, pages, login, metrics):
|
||||
default_handlers.extend(mod.default_handlers)
|
||||
|
@@ -1,41 +1,49 @@
|
||||
"""HTTP Handlers for the hub server"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
from datetime import datetime, timedelta
|
||||
from http.client import responses
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from urllib.parse import urlparse, urlunparse, parse_qs, urlencode
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from http.client import responses
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlencode
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlunparse
|
||||
|
||||
from jinja2 import TemplateNotFound
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from tornado.log import app_log
|
||||
from tornado.httputil import url_concat, HTTPHeaders
|
||||
from tornado import gen
|
||||
from tornado import web
|
||||
from tornado.httputil import HTTPHeaders
|
||||
from tornado.httputil import url_concat
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.web import RequestHandler, MissingArgumentError
|
||||
from tornado import gen, web
|
||||
from tornado.log import app_log
|
||||
from tornado.web import MissingArgumentError
|
||||
from tornado.web import RequestHandler
|
||||
|
||||
from .. import __version__
|
||||
from .. import orm
|
||||
from ..metrics import PROXY_ADD_DURATION_SECONDS
|
||||
from ..metrics import ProxyAddStatus
|
||||
from ..metrics import RUNNING_SERVERS
|
||||
from ..metrics import SERVER_POLL_DURATION_SECONDS
|
||||
from ..metrics import SERVER_SPAWN_DURATION_SECONDS
|
||||
from ..metrics import SERVER_STOP_DURATION_SECONDS
|
||||
from ..metrics import ServerPollStatus
|
||||
from ..metrics import ServerSpawnStatus
|
||||
from ..metrics import ServerStopStatus
|
||||
from ..objects import Server
|
||||
from ..spawner import LocalProcessSpawner
|
||||
from ..user import User
|
||||
from ..utils import maybe_future, url_path_join
|
||||
from ..metrics import (
|
||||
SERVER_SPAWN_DURATION_SECONDS, ServerSpawnStatus,
|
||||
PROXY_ADD_DURATION_SECONDS, ProxyAddStatus,
|
||||
SERVER_POLL_DURATION_SECONDS, ServerPollStatus,
|
||||
RUNNING_SERVERS, SERVER_STOP_DURATION_SECONDS, ServerStopStatus
|
||||
)
|
||||
from ..utils import maybe_future
|
||||
from ..utils import url_path_join
|
||||
|
||||
# pattern for the authentication token header
|
||||
auth_header_pat = re.compile(r'^(?:token|bearer)\s+([^\s]+)$', flags=re.IGNORECASE)
|
||||
@@ -52,6 +60,7 @@ reasons = {
|
||||
# constant, not configurable
|
||||
SESSION_COOKIE_NAME = 'jupyterhub-session-id'
|
||||
|
||||
|
||||
class BaseHandler(RequestHandler):
|
||||
"""Base Handler class with access to common methods and properties."""
|
||||
|
||||
@@ -157,8 +166,8 @@ class BaseHandler(RequestHandler):
|
||||
|
||||
@property
|
||||
def csp_report_uri(self):
|
||||
return self.settings.get('csp_report_uri',
|
||||
url_path_join(self.hub.base_url, 'security/csp-report')
|
||||
return self.settings.get(
|
||||
'csp_report_uri', url_path_join(self.hub.base_url, 'security/csp-report')
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -167,10 +176,9 @@ class BaseHandler(RequestHandler):
|
||||
|
||||
Can be overridden by defining Content-Security-Policy in settings['headers']
|
||||
"""
|
||||
return '; '.join([
|
||||
"frame-ancestors 'self'",
|
||||
"report-uri " + self.csp_report_uri,
|
||||
])
|
||||
return '; '.join(
|
||||
["frame-ancestors 'self'", "report-uri " + self.csp_report_uri]
|
||||
)
|
||||
|
||||
def get_content_type(self):
|
||||
return 'text/html'
|
||||
@@ -190,7 +198,9 @@ class BaseHandler(RequestHandler):
|
||||
self.set_header(header_name, header_content)
|
||||
|
||||
if 'Access-Control-Allow-Headers' not in headers:
|
||||
self.set_header('Access-Control-Allow-Headers', 'accept, content-type, authorization')
|
||||
self.set_header(
|
||||
'Access-Control-Allow-Headers', 'accept, content-type, authorization'
|
||||
)
|
||||
if 'Content-Security-Policy' not in headers:
|
||||
self.set_header('Content-Security-Policy', self.content_security_policy)
|
||||
self.set_header('Content-Type', self.get_content_type())
|
||||
@@ -236,8 +246,7 @@ class BaseHandler(RequestHandler):
|
||||
orm_token = orm.OAuthAccessToken.find(self.db, token)
|
||||
if orm_token is None:
|
||||
return None
|
||||
orm_token.last_activity = \
|
||||
orm_token.user.last_activity = datetime.utcnow()
|
||||
orm_token.last_activity = orm_token.user.last_activity = datetime.utcnow()
|
||||
self.db.commit()
|
||||
return self._user_from_orm(orm_token.user)
|
||||
|
||||
@@ -259,7 +268,11 @@ class BaseHandler(RequestHandler):
|
||||
if not refresh_age:
|
||||
return user
|
||||
now = time.monotonic()
|
||||
if not force and user._auth_refreshed and (now - user._auth_refreshed < refresh_age):
|
||||
if (
|
||||
not force
|
||||
and user._auth_refreshed
|
||||
and (now - user._auth_refreshed < refresh_age)
|
||||
):
|
||||
# auth up-to-date
|
||||
return user
|
||||
|
||||
@@ -276,8 +289,7 @@ class BaseHandler(RequestHandler):
|
||||
|
||||
if not auth_info:
|
||||
self.log.warning(
|
||||
"User %s has stale auth info. Login is required to refresh.",
|
||||
user.name,
|
||||
"User %s has stale auth info. Login is required to refresh.", user.name
|
||||
)
|
||||
return
|
||||
|
||||
@@ -325,10 +337,9 @@ class BaseHandler(RequestHandler):
|
||||
def _user_for_cookie(self, cookie_name, cookie_value=None):
|
||||
"""Get the User for a given cookie, if there is one"""
|
||||
cookie_id = self.get_secure_cookie(
|
||||
cookie_name,
|
||||
cookie_value,
|
||||
max_age_days=self.cookie_max_age_days,
|
||||
cookie_name, cookie_value, max_age_days=self.cookie_max_age_days
|
||||
)
|
||||
|
||||
def clear():
|
||||
self.clear_cookie(cookie_name, path=self.hub.base_url)
|
||||
|
||||
@@ -434,7 +445,11 @@ class BaseHandler(RequestHandler):
|
||||
# clear hub cookie
|
||||
self.clear_cookie(self.hub.cookie_name, path=self.hub.base_url, **kwargs)
|
||||
# clear services cookie
|
||||
self.clear_cookie('jupyterhub-services', path=url_path_join(self.base_url, 'services'), **kwargs)
|
||||
self.clear_cookie(
|
||||
'jupyterhub-services',
|
||||
path=url_path_join(self.base_url, 'services'),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _set_cookie(self, key, value, encrypted=True, **overrides):
|
||||
"""Setting any cookie should go through here
|
||||
@@ -444,9 +459,7 @@ class BaseHandler(RequestHandler):
|
||||
"""
|
||||
# tornado <4.2 have a bug that consider secure==True as soon as
|
||||
# 'secure' kwarg is passed to set_secure_cookie
|
||||
kwargs = {
|
||||
'httponly': True,
|
||||
}
|
||||
kwargs = {'httponly': True}
|
||||
if self.request.protocol == 'https':
|
||||
kwargs['secure'] = True
|
||||
if self.subdomain_host:
|
||||
@@ -463,14 +476,10 @@ class BaseHandler(RequestHandler):
|
||||
self.log.debug("Setting cookie %s: %s", key, kwargs)
|
||||
set_cookie(key, value, **kwargs)
|
||||
|
||||
|
||||
def _set_user_cookie(self, user, server):
|
||||
self.log.debug("Setting cookie for %s: %s", user.name, server.cookie_name)
|
||||
self._set_cookie(
|
||||
server.cookie_name,
|
||||
user.cookie_id,
|
||||
encrypted=True,
|
||||
path=server.base_url,
|
||||
server.cookie_name, user.cookie_id, encrypted=True, path=server.base_url
|
||||
)
|
||||
|
||||
def get_session_cookie(self):
|
||||
@@ -494,10 +503,13 @@ class BaseHandler(RequestHandler):
|
||||
|
||||
def set_service_cookie(self, user):
|
||||
"""set the login cookie for services"""
|
||||
self._set_user_cookie(user, orm.Server(
|
||||
self._set_user_cookie(
|
||||
user,
|
||||
orm.Server(
|
||||
cookie_name='jupyterhub-services',
|
||||
base_url=url_path_join(self.base_url, 'services')
|
||||
))
|
||||
base_url=url_path_join(self.base_url, 'services'),
|
||||
),
|
||||
)
|
||||
|
||||
def set_hub_cookie(self, user):
|
||||
"""set the login cookie for the Hub"""
|
||||
@@ -508,7 +520,9 @@ class BaseHandler(RequestHandler):
|
||||
if self.subdomain_host and not self.request.host.startswith(self.domain):
|
||||
self.log.warning(
|
||||
"Possibly setting cookie on wrong domain: %s != %s",
|
||||
self.request.host, self.domain)
|
||||
self.request.host,
|
||||
self.domain,
|
||||
)
|
||||
|
||||
# set single cookie for services
|
||||
if self.db.query(orm.Service).filter(orm.Service.server != None).first():
|
||||
@@ -555,8 +569,10 @@ class BaseHandler(RequestHandler):
|
||||
# ultimately redirecting to the logged-in user's server.
|
||||
without_prefix = next_url[len(self.base_url) :]
|
||||
next_url = url_path_join(self.hub.base_url, without_prefix)
|
||||
self.log.warning("Redirecting %s to %s. For sharing public links, use /user-redirect/",
|
||||
self.request.uri, next_url,
|
||||
self.log.warning(
|
||||
"Redirecting %s to %s. For sharing public links, use /user-redirect/",
|
||||
self.request.uri,
|
||||
next_url,
|
||||
)
|
||||
|
||||
if not next_url:
|
||||
@@ -627,8 +643,9 @@ class BaseHandler(RequestHandler):
|
||||
else:
|
||||
self.statsd.incr('login.failure')
|
||||
self.statsd.timing('login.authenticate.failure', auth_timer.ms)
|
||||
self.log.warning("Failed login for %s", (data or {}).get('username', 'unknown user'))
|
||||
|
||||
self.log.warning(
|
||||
"Failed login for %s", (data or {}).get('username', 'unknown user')
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# spawning-related
|
||||
@@ -659,7 +676,9 @@ class BaseHandler(RequestHandler):
|
||||
if self.authenticator.refresh_pre_spawn:
|
||||
auth_user = await self.refresh_auth(user, force=True)
|
||||
if auth_user is None:
|
||||
raise web.HTTPError(403, "auth has expired for %s, login again", user.name)
|
||||
raise web.HTTPError(
|
||||
403, "auth has expired for %s, login again", user.name
|
||||
)
|
||||
|
||||
spawn_start_time = time.perf_counter()
|
||||
self.extra_error_html = self.spawn_home_error
|
||||
@@ -681,7 +700,9 @@ class BaseHandler(RequestHandler):
|
||||
# but for 10k users this takes ~5ms
|
||||
# and saves us from bookkeeping errors
|
||||
active_counts = self.users.count_active_users()
|
||||
spawn_pending_count = active_counts['spawn_pending'] + active_counts['proxy_pending']
|
||||
spawn_pending_count = (
|
||||
active_counts['spawn_pending'] + active_counts['proxy_pending']
|
||||
)
|
||||
active_count = active_counts['active']
|
||||
|
||||
concurrent_spawn_limit = self.concurrent_spawn_limit
|
||||
@@ -700,18 +721,21 @@ class BaseHandler(RequestHandler):
|
||||
# round suggestion to nicer human value (nearest 10 seconds or minute)
|
||||
if retry_time <= 90:
|
||||
# round human seconds up to nearest 10
|
||||
human_retry_time = "%i0 seconds" % math.ceil(retry_time / 10.)
|
||||
human_retry_time = "%i0 seconds" % math.ceil(retry_time / 10.0)
|
||||
else:
|
||||
# round number of minutes
|
||||
human_retry_time = "%i minutes" % math.round(retry_time / 60.)
|
||||
human_retry_time = "%i minutes" % math.round(retry_time / 60.0)
|
||||
|
||||
self.log.warning(
|
||||
'%s pending spawns, throttling. Suggested retry in %s seconds.',
|
||||
spawn_pending_count, retry_time,
|
||||
spawn_pending_count,
|
||||
retry_time,
|
||||
)
|
||||
err = web.HTTPError(
|
||||
429,
|
||||
"Too many users trying to log in right now. Try again in {}.".format(human_retry_time)
|
||||
"Too many users trying to log in right now. Try again in {}.".format(
|
||||
human_retry_time
|
||||
),
|
||||
)
|
||||
# can't call set_header directly here because it gets ignored
|
||||
# when errors are raised
|
||||
@@ -720,14 +744,13 @@ class BaseHandler(RequestHandler):
|
||||
raise err
|
||||
|
||||
if active_server_limit and active_count >= active_server_limit:
|
||||
self.log.info(
|
||||
'%s servers active, no space available',
|
||||
active_count,
|
||||
)
|
||||
self.log.info('%s servers active, no space available', active_count)
|
||||
SERVER_SPAWN_DURATION_SECONDS.labels(
|
||||
status=ServerSpawnStatus.too_many_users
|
||||
).observe(time.perf_counter() - spawn_start_time)
|
||||
raise web.HTTPError(429, "Active user limit exceeded. Try again in a few minutes.")
|
||||
raise web.HTTPError(
|
||||
429, "Active user limit exceeded. Try again in a few minutes."
|
||||
)
|
||||
|
||||
tic = IOLoop.current().time()
|
||||
|
||||
@@ -735,12 +758,16 @@ class BaseHandler(RequestHandler):
|
||||
|
||||
spawn_future = user.spawn(server_name, options, handler=self)
|
||||
|
||||
self.log.debug("%i%s concurrent spawns",
|
||||
self.log.debug(
|
||||
"%i%s concurrent spawns",
|
||||
spawn_pending_count,
|
||||
'/%i' % concurrent_spawn_limit if concurrent_spawn_limit else '')
|
||||
self.log.debug("%i%s active servers",
|
||||
'/%i' % concurrent_spawn_limit if concurrent_spawn_limit else '',
|
||||
)
|
||||
self.log.debug(
|
||||
"%i%s active servers",
|
||||
active_count,
|
||||
'/%i' % active_server_limit if active_server_limit else '')
|
||||
'/%i' % active_server_limit if active_server_limit else '',
|
||||
)
|
||||
|
||||
spawner = user.spawners[server_name]
|
||||
# set spawn_pending now, so there's no gap where _spawn_pending is False
|
||||
@@ -756,7 +783,9 @@ class BaseHandler(RequestHandler):
|
||||
# wait for spawn Future
|
||||
await spawn_future
|
||||
toc = IOLoop.current().time()
|
||||
self.log.info("User %s took %.3f seconds to start", user_server_name, toc-tic)
|
||||
self.log.info(
|
||||
"User %s took %.3f seconds to start", user_server_name, toc - tic
|
||||
)
|
||||
self.statsd.timing('spawner.success', (toc - tic) * 1000)
|
||||
RUNNING_SERVERS.inc()
|
||||
SERVER_SPAWN_DURATION_SECONDS.labels(
|
||||
@@ -767,18 +796,16 @@ class BaseHandler(RequestHandler):
|
||||
try:
|
||||
await self.proxy.add_user(user, server_name)
|
||||
|
||||
PROXY_ADD_DURATION_SECONDS.labels(
|
||||
status='success'
|
||||
).observe(
|
||||
PROXY_ADD_DURATION_SECONDS.labels(status='success').observe(
|
||||
time.perf_counter() - proxy_add_start_time
|
||||
)
|
||||
except Exception:
|
||||
self.log.exception("Failed to add %s to proxy!", user_server_name)
|
||||
self.log.error("Stopping %s to avoid inconsistent state", user_server_name)
|
||||
self.log.error(
|
||||
"Stopping %s to avoid inconsistent state", user_server_name
|
||||
)
|
||||
await user.stop()
|
||||
PROXY_ADD_DURATION_SECONDS.labels(
|
||||
status='failure'
|
||||
).observe(
|
||||
PROXY_ADD_DURATION_SECONDS.labels(status='failure').observe(
|
||||
time.perf_counter() - proxy_add_start_time
|
||||
)
|
||||
else:
|
||||
@@ -818,7 +845,8 @@ class BaseHandler(RequestHandler):
|
||||
self.log.warning(
|
||||
"%i consecutive spawns failed. "
|
||||
"Hub will exit if failure count reaches %i before succeeding",
|
||||
failure_count, failure_limit,
|
||||
failure_count,
|
||||
failure_limit,
|
||||
)
|
||||
if failure_limit and failure_count >= failure_limit:
|
||||
self.log.critical(
|
||||
@@ -828,6 +856,7 @@ class BaseHandler(RequestHandler):
|
||||
# mostly propagating errors for the current failures
|
||||
def abort():
|
||||
raise SystemExit(1)
|
||||
|
||||
IOLoop.current().call_later(2, abort)
|
||||
|
||||
finish_spawn_future.add_done_callback(_track_failure_count)
|
||||
@@ -842,8 +871,11 @@ class BaseHandler(RequestHandler):
|
||||
if spawner._spawn_pending and not spawner._waiting_for_response:
|
||||
# still in Spawner.start, which is taking a long time
|
||||
# we shouldn't poll while spawn is incomplete.
|
||||
self.log.warning("User %s is slow to start (timeout=%s)",
|
||||
user_server_name, self.slow_spawn_timeout)
|
||||
self.log.warning(
|
||||
"User %s is slow to start (timeout=%s)",
|
||||
user_server_name,
|
||||
self.slow_spawn_timeout,
|
||||
)
|
||||
return
|
||||
|
||||
# start has finished, but the server hasn't come up
|
||||
@@ -861,22 +893,34 @@ class BaseHandler(RequestHandler):
|
||||
status=ServerSpawnStatus.failure
|
||||
).observe(time.perf_counter() - spawn_start_time)
|
||||
|
||||
raise web.HTTPError(500, "Spawner failed to start [status=%s]. The logs for %s may contain details." % (
|
||||
status, spawner._log_name))
|
||||
raise web.HTTPError(
|
||||
500,
|
||||
"Spawner failed to start [status=%s]. The logs for %s may contain details."
|
||||
% (status, spawner._log_name),
|
||||
)
|
||||
|
||||
if spawner._waiting_for_response:
|
||||
# hit timeout waiting for response, but server's running.
|
||||
# Hope that it'll show up soon enough,
|
||||
# though it's possible that it started at the wrong URL
|
||||
self.log.warning("User %s is slow to become responsive (timeout=%s)",
|
||||
user_server_name, self.slow_spawn_timeout)
|
||||
self.log.debug("Expecting server for %s at: %s",
|
||||
user_server_name, spawner.server.url)
|
||||
self.log.warning(
|
||||
"User %s is slow to become responsive (timeout=%s)",
|
||||
user_server_name,
|
||||
self.slow_spawn_timeout,
|
||||
)
|
||||
self.log.debug(
|
||||
"Expecting server for %s at: %s",
|
||||
user_server_name,
|
||||
spawner.server.url,
|
||||
)
|
||||
if spawner._proxy_pending:
|
||||
# User.spawn finished, but it hasn't been added to the proxy
|
||||
# Could be due to load or a slow proxy
|
||||
self.log.warning("User %s is slow to be added to the proxy (timeout=%s)",
|
||||
user_server_name, self.slow_spawn_timeout)
|
||||
self.log.warning(
|
||||
"User %s is slow to be added to the proxy (timeout=%s)",
|
||||
user_server_name,
|
||||
self.slow_spawn_timeout,
|
||||
)
|
||||
|
||||
async def user_stopped(self, user, server_name):
|
||||
"""Callback that fires when the spawner has stopped"""
|
||||
@@ -888,12 +932,11 @@ class BaseHandler(RequestHandler):
|
||||
status=ServerPollStatus.from_status(status)
|
||||
).observe(time.perf_counter() - poll_start_time)
|
||||
|
||||
|
||||
if status is None:
|
||||
status = 'unknown'
|
||||
|
||||
self.log.warning("User %s server stopped, with exit code: %s",
|
||||
user.name, status,
|
||||
self.log.warning(
|
||||
"User %s server stopped, with exit code: %s", user.name, status
|
||||
)
|
||||
await self.proxy.delete_user(user, server_name)
|
||||
await user.stop(server_name)
|
||||
@@ -920,7 +963,9 @@ class BaseHandler(RequestHandler):
|
||||
await self.proxy.delete_user(user, server_name)
|
||||
await user.stop(server_name)
|
||||
toc = time.perf_counter()
|
||||
self.log.info("User %s server took %.3f seconds to stop", user.name, toc - tic)
|
||||
self.log.info(
|
||||
"User %s server took %.3f seconds to stop", user.name, toc - tic
|
||||
)
|
||||
self.statsd.timing('spawner.stop', (toc - tic) * 1000)
|
||||
RUNNING_SERVERS.dec()
|
||||
SERVER_STOP_DURATION_SECONDS.labels(
|
||||
@@ -934,14 +979,15 @@ class BaseHandler(RequestHandler):
|
||||
spawner._stop_future = None
|
||||
spawner._stop_pending = False
|
||||
|
||||
|
||||
future = spawner._stop_future = asyncio.ensure_future(stop())
|
||||
|
||||
try:
|
||||
await gen.with_timeout(timedelta(seconds=self.slow_stop_timeout), future)
|
||||
except gen.TimeoutError:
|
||||
# hit timeout, but stop is still pending
|
||||
self.log.warning("User %s:%s server is slow to stop", user.name, server_name)
|
||||
self.log.warning(
|
||||
"User %s:%s server is slow to stop", user.name, server_name
|
||||
)
|
||||
|
||||
# return handle on the future for hooking up callbacks
|
||||
return future
|
||||
@@ -1051,6 +1097,7 @@ class BaseHandler(RequestHandler):
|
||||
|
||||
class Template404(BaseHandler):
|
||||
"""Render our 404 template"""
|
||||
|
||||
async def prepare(self):
|
||||
await super().prepare()
|
||||
raise web.HTTPError(404)
|
||||
@@ -1061,6 +1108,7 @@ class PrefixRedirectHandler(BaseHandler):
|
||||
|
||||
Redirects /foo to /prefix/foo, etc.
|
||||
"""
|
||||
|
||||
def get(self):
|
||||
uri = self.request.uri
|
||||
# Since self.base_url will end with trailing slash.
|
||||
@@ -1076,9 +1124,7 @@ class PrefixRedirectHandler(BaseHandler):
|
||||
# default / -> /hub/ redirect
|
||||
# avoiding extra hop through /hub
|
||||
path = '/'
|
||||
self.redirect(url_path_join(
|
||||
self.hub.base_url, path,
|
||||
), permanent=False)
|
||||
self.redirect(url_path_join(self.hub.base_url, path), permanent=False)
|
||||
|
||||
|
||||
class UserSpawnHandler(BaseHandler):
|
||||
@@ -1113,8 +1159,11 @@ class UserSpawnHandler(BaseHandler):
|
||||
if user is None:
|
||||
# no such user
|
||||
raise web.HTTPError(404, "No such user %s" % user_name)
|
||||
self.log.info("Admin %s requesting spawn on behalf of %s",
|
||||
current_user.name, user.name)
|
||||
self.log.info(
|
||||
"Admin %s requesting spawn on behalf of %s",
|
||||
current_user.name,
|
||||
user.name,
|
||||
)
|
||||
admin_spawn = True
|
||||
should_spawn = True
|
||||
else:
|
||||
@@ -1122,7 +1171,7 @@ class UserSpawnHandler(BaseHandler):
|
||||
admin_spawn = False
|
||||
# For non-admins, we should spawn if the user matches
|
||||
# otherwise redirect users to their own server
|
||||
should_spawn = (current_user and current_user.name == user_name)
|
||||
should_spawn = current_user and current_user.name == user_name
|
||||
|
||||
if "api" in user_path.split("/") and user and not user.active:
|
||||
# API request for not-running server (e.g. notebook UI left open)
|
||||
@@ -1142,12 +1191,19 @@ class UserSpawnHandler(BaseHandler):
|
||||
port = host_info.port
|
||||
if not port:
|
||||
port = 443 if host_info.scheme == 'https' else 80
|
||||
if port != Server.from_url(self.proxy.public_url).connect_port and port == self.hub.connect_port:
|
||||
self.log.warning("""
|
||||
if (
|
||||
port != Server.from_url(self.proxy.public_url).connect_port
|
||||
and port == self.hub.connect_port
|
||||
):
|
||||
self.log.warning(
|
||||
"""
|
||||
Detected possible direct connection to Hub's private ip: %s, bypassing proxy.
|
||||
This will result in a redirect loop.
|
||||
Make sure to connect to the proxied public URL %s
|
||||
""", self.request.full_url(), self.proxy.public_url)
|
||||
""",
|
||||
self.request.full_url(),
|
||||
self.proxy.public_url,
|
||||
)
|
||||
|
||||
# logged in as valid user, check for pending spawn
|
||||
if self.allow_named_servers:
|
||||
@@ -1167,19 +1223,31 @@ class UserSpawnHandler(BaseHandler):
|
||||
# Implicit spawn on /user/:name is not allowed if the user's last spawn failed.
|
||||
# We should point the user to Home if the most recent spawn failed.
|
||||
exc = spawner._spawn_future.exception()
|
||||
self.log.error("Preventing implicit spawn for %s because last spawn failed: %s",
|
||||
spawner._log_name, exc)
|
||||
self.log.error(
|
||||
"Preventing implicit spawn for %s because last spawn failed: %s",
|
||||
spawner._log_name,
|
||||
exc,
|
||||
)
|
||||
# raise a copy because each time an Exception object is re-raised, its traceback grows
|
||||
raise copy.copy(exc).with_traceback(exc.__traceback__)
|
||||
|
||||
# check for pending spawn
|
||||
if spawner.pending == 'spawn' and spawner._spawn_future:
|
||||
# wait on the pending spawn
|
||||
self.log.debug("Waiting for %s pending %s", spawner._log_name, spawner.pending)
|
||||
self.log.debug(
|
||||
"Waiting for %s pending %s", spawner._log_name, spawner.pending
|
||||
)
|
||||
try:
|
||||
await gen.with_timeout(timedelta(seconds=self.slow_spawn_timeout), spawner._spawn_future)
|
||||
await gen.with_timeout(
|
||||
timedelta(seconds=self.slow_spawn_timeout),
|
||||
spawner._spawn_future,
|
||||
)
|
||||
except gen.TimeoutError:
|
||||
self.log.info("Pending spawn for %s didn't finish in %.1f seconds", spawner._log_name, self.slow_spawn_timeout)
|
||||
self.log.info(
|
||||
"Pending spawn for %s didn't finish in %.1f seconds",
|
||||
spawner._log_name,
|
||||
self.slow_spawn_timeout,
|
||||
)
|
||||
pass
|
||||
|
||||
# we may have waited above, check pending again:
|
||||
@@ -1194,10 +1262,7 @@ class UserSpawnHandler(BaseHandler):
|
||||
else:
|
||||
page = "spawn_pending.html"
|
||||
html = self.render_template(
|
||||
page,
|
||||
user=user,
|
||||
spawner=spawner,
|
||||
progress_url=spawner._progress_url,
|
||||
page, user=user, spawner=spawner, progress_url=spawner._progress_url
|
||||
)
|
||||
self.finish(html)
|
||||
return
|
||||
@@ -1218,8 +1283,11 @@ class UserSpawnHandler(BaseHandler):
|
||||
if current_user.name != user.name:
|
||||
# spawning on behalf of another user
|
||||
url_parts.append(user.name)
|
||||
self.redirect(url_concat(url_path_join(*url_parts),
|
||||
{'next': self.request.uri}))
|
||||
self.redirect(
|
||||
url_concat(
|
||||
url_path_join(*url_parts), {'next': self.request.uri}
|
||||
)
|
||||
)
|
||||
return
|
||||
else:
|
||||
await self.spawn_single_user(user, server_name)
|
||||
@@ -1230,9 +1298,7 @@ class UserSpawnHandler(BaseHandler):
|
||||
# spawn has started, but not finished
|
||||
self.statsd.incr('redirects.user_spawn_pending', 1)
|
||||
html = self.render_template(
|
||||
"spawn_pending.html",
|
||||
user=user,
|
||||
progress_url=spawner._progress_url,
|
||||
"spawn_pending.html", user=user, progress_url=spawner._progress_url
|
||||
)
|
||||
self.finish(html)
|
||||
return
|
||||
@@ -1243,14 +1309,16 @@ class UserSpawnHandler(BaseHandler):
|
||||
try:
|
||||
redirects = int(self.get_argument('redirects', 0))
|
||||
except ValueError:
|
||||
self.log.warning("Invalid redirects argument %r", self.get_argument('redirects'))
|
||||
self.log.warning(
|
||||
"Invalid redirects argument %r", self.get_argument('redirects')
|
||||
)
|
||||
redirects = 0
|
||||
|
||||
# check redirect limit to prevent browser-enforced limits.
|
||||
# In case of version mismatch, raise on only two redirects.
|
||||
if redirects >= self.settings.get(
|
||||
'user_redirect_limit', 4
|
||||
) or (redirects >= 2 and spawner._jupyterhub_version != __version__):
|
||||
if redirects >= self.settings.get('user_redirect_limit', 4) or (
|
||||
redirects >= 2 and spawner._jupyterhub_version != __version__
|
||||
):
|
||||
# We stop if we've been redirected too many times.
|
||||
msg = "Redirect loop detected."
|
||||
if spawner._jupyterhub_version != __version__:
|
||||
@@ -1259,7 +1327,8 @@ class UserSpawnHandler(BaseHandler):
|
||||
" Try installing jupyterhub=={hub} in the user environment"
|
||||
" if you continue to have problems."
|
||||
).format(
|
||||
singleuser=spawner._jupyterhub_version or 'unknown (likely < 0.8)',
|
||||
singleuser=spawner._jupyterhub_version
|
||||
or 'unknown (likely < 0.8)',
|
||||
hub=__version__,
|
||||
)
|
||||
raise web.HTTPError(500, msg)
|
||||
@@ -1297,10 +1366,9 @@ class UserSpawnHandler(BaseHandler):
|
||||
# not logged in, clear any cookies and reload
|
||||
self.statsd.incr('redirects.user_to_login', 1)
|
||||
self.clear_login_cookie()
|
||||
self.redirect(url_concat(
|
||||
self.settings['login_url'],
|
||||
{'next': self.request.uri},
|
||||
))
|
||||
self.redirect(
|
||||
url_concat(self.settings['login_url'], {'next': self.request.uri})
|
||||
)
|
||||
|
||||
|
||||
class UserRedirectHandler(BaseHandler):
|
||||
@@ -1314,6 +1382,7 @@ class UserRedirectHandler(BaseHandler):
|
||||
|
||||
.. versionadded:: 0.7
|
||||
"""
|
||||
|
||||
@web.authenticated
|
||||
def get(self, path):
|
||||
user = self.current_user
|
||||
@@ -1326,12 +1395,13 @@ class UserRedirectHandler(BaseHandler):
|
||||
|
||||
class CSPReportHandler(BaseHandler):
|
||||
'''Accepts a content security policy violation report'''
|
||||
|
||||
@web.authenticated
|
||||
def post(self):
|
||||
'''Log a content security policy violation report'''
|
||||
self.log.warning(
|
||||
"Content security violation: %s",
|
||||
self.request.body.decode('utf8', 'replace')
|
||||
self.request.body.decode('utf8', 'replace'),
|
||||
)
|
||||
# Report it to statsd as well
|
||||
self.statsd.incr('csp_report')
|
||||
@@ -1339,11 +1409,13 @@ class CSPReportHandler(BaseHandler):
|
||||
|
||||
class AddSlashHandler(BaseHandler):
|
||||
"""Handler for adding trailing slash to URLs that need them"""
|
||||
|
||||
def get(self, *args):
|
||||
src = urlparse(self.request.uri)
|
||||
dest = src._replace(path=src.path + '/')
|
||||
self.redirect(urlunparse(dest))
|
||||
|
||||
|
||||
default_handlers = [
|
||||
(r'', AddSlashHandler), # add trailing / to `/hub`
|
||||
(r'/user/(?P<user_name>[^/]+)(?P<user_path>/.*)?', UserSpawnHandler),
|
||||
|
@@ -1,16 +1,14 @@
|
||||
"""HTTP Handlers for the hub server"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import asyncio
|
||||
|
||||
from tornado import web
|
||||
from tornado.escape import url_escape
|
||||
from tornado.httputil import url_concat
|
||||
from tornado import web
|
||||
|
||||
from .base import BaseHandler
|
||||
from ..utils import maybe_future
|
||||
from .base import BaseHandler
|
||||
|
||||
|
||||
class LogoutHandler(BaseHandler):
|
||||
@@ -52,7 +50,8 @@ class LoginHandler(BaseHandler):
|
||||
"""Render the login page."""
|
||||
|
||||
def _render(self, login_error=None, username=None):
|
||||
return self.render_template('login.html',
|
||||
return self.render_template(
|
||||
'login.html',
|
||||
next=url_escape(self.get_argument('next', default='')),
|
||||
username=username,
|
||||
login_error=login_error,
|
||||
@@ -87,7 +86,9 @@ class LoginHandler(BaseHandler):
|
||||
self.redirect(self.get_next_url(user))
|
||||
else:
|
||||
if self.get_argument('next', default=False):
|
||||
auto_login_url = url_concat(auto_login_url, {'next': self.get_next_url()})
|
||||
auto_login_url = url_concat(
|
||||
auto_login_url, {'next': self.get_next_url()}
|
||||
)
|
||||
self.redirect(auto_login_url)
|
||||
return
|
||||
username = self.get_argument('username', default='')
|
||||
@@ -109,8 +110,7 @@ class LoginHandler(BaseHandler):
|
||||
self.redirect(self.get_next_url(user))
|
||||
else:
|
||||
html = self._render(
|
||||
login_error='Invalid username or password',
|
||||
username=data['username'],
|
||||
login_error='Invalid username or password', username=data['username']
|
||||
)
|
||||
self.finish(html)
|
||||
|
||||
@@ -118,7 +118,4 @@ class LoginHandler(BaseHandler):
|
||||
# /login renders the login page or the "Login with..." link,
|
||||
# so it should always be registered.
|
||||
# /logout clears cookies.
|
||||
default_handlers = [
|
||||
(r"/login", LoginHandler),
|
||||
(r"/logout", LogoutHandler),
|
||||
]
|
||||
default_handlers = [(r"/login", LoginHandler), (r"/logout", LogoutHandler)]
|
||||
|
@@ -1,18 +1,21 @@
|
||||
from prometheus_client import REGISTRY, CONTENT_TYPE_LATEST, generate_latest
|
||||
from prometheus_client import CONTENT_TYPE_LATEST
|
||||
from prometheus_client import generate_latest
|
||||
from prometheus_client import REGISTRY
|
||||
from tornado import gen
|
||||
|
||||
from .base import BaseHandler
|
||||
from ..utils import metrics_authentication
|
||||
from .base import BaseHandler
|
||||
|
||||
|
||||
class MetricsHandler(BaseHandler):
|
||||
"""
|
||||
Handler to serve Prometheus metrics
|
||||
"""
|
||||
|
||||
@metrics_authentication
|
||||
async def get(self):
|
||||
self.set_header('Content-Type', CONTENT_TYPE_LATEST)
|
||||
self.write(generate_latest(REGISTRY))
|
||||
|
||||
default_handlers = [
|
||||
(r'/metrics$', MetricsHandler)
|
||||
]
|
||||
|
||||
default_handlers = [(r'/metrics$', MetricsHandler)]
|
||||
|
@@ -1,18 +1,19 @@
|
||||
"""Basic html-rendering handlers."""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from http.client import responses
|
||||
|
||||
from jinja2 import TemplateNotFound
|
||||
from tornado import web, gen
|
||||
from tornado import gen
|
||||
from tornado import web
|
||||
from tornado.httputil import url_concat
|
||||
|
||||
from .. import orm
|
||||
from ..utils import admin_only, url_path_join, maybe_future
|
||||
from ..utils import admin_only
|
||||
from ..utils import maybe_future
|
||||
from ..utils import url_path_join
|
||||
from .base import BaseHandler
|
||||
|
||||
|
||||
@@ -29,6 +30,7 @@ class RootHandler(BaseHandler):
|
||||
|
||||
Otherwise, renders login page.
|
||||
"""
|
||||
|
||||
def get(self):
|
||||
user = self.current_user
|
||||
if self.default_url:
|
||||
@@ -53,8 +55,13 @@ class HomeHandler(BaseHandler):
|
||||
# send the user to /spawn if they have no active servers,
|
||||
# to establish that this is an explicit spawn request rather
|
||||
# than an implicit one, which can be caused by any link to `/user/:name(/:server_name)`
|
||||
url = url_path_join(self.hub.base_url, 'user', user.name) if user.active else url_path_join(self.hub.base_url, 'spawn')
|
||||
html = self.render_template('home.html',
|
||||
url = (
|
||||
url_path_join(self.hub.base_url, 'user', user.name)
|
||||
if user.active
|
||||
else url_path_join(self.hub.base_url, 'spawn')
|
||||
)
|
||||
html = self.render_template(
|
||||
'home.html',
|
||||
user=user,
|
||||
url=url,
|
||||
allow_named_servers=self.allow_named_servers,
|
||||
@@ -74,13 +81,15 @@ class SpawnHandler(BaseHandler):
|
||||
|
||||
Only enabled when Spawner.options_form is defined.
|
||||
"""
|
||||
|
||||
def _render_form(self, for_user, spawner_options_form, message=''):
|
||||
return self.render_template('spawn.html',
|
||||
return self.render_template(
|
||||
'spawn.html',
|
||||
for_user=for_user,
|
||||
spawner_options_form=spawner_options_form,
|
||||
error_message=message,
|
||||
url=self.request.uri,
|
||||
spawner=for_user.spawner
|
||||
spawner=for_user.spawner,
|
||||
)
|
||||
|
||||
@web.authenticated
|
||||
@@ -92,7 +101,9 @@ class SpawnHandler(BaseHandler):
|
||||
user = current_user = self.current_user
|
||||
if for_user is not None and for_user != user.name:
|
||||
if not user.admin:
|
||||
raise web.HTTPError(403, "Only admins can spawn on behalf of other users")
|
||||
raise web.HTTPError(
|
||||
403, "Only admins can spawn on behalf of other users"
|
||||
)
|
||||
|
||||
user = self.find_user(for_user)
|
||||
if user is None:
|
||||
@@ -108,7 +119,9 @@ class SpawnHandler(BaseHandler):
|
||||
if spawner_options_form:
|
||||
# Add handler to spawner here so you can access query params in form rendering.
|
||||
user.spawner.handler = self
|
||||
form = self._render_form(for_user=user, spawner_options_form=spawner_options_form)
|
||||
form = self._render_form(
|
||||
for_user=user, spawner_options_form=spawner_options_form
|
||||
)
|
||||
self.finish(form)
|
||||
else:
|
||||
# Explicit spawn request: clear _spawn_future
|
||||
@@ -129,7 +142,9 @@ class SpawnHandler(BaseHandler):
|
||||
user = current_user = self.current_user
|
||||
if for_user is not None and for_user != user.name:
|
||||
if not user.admin:
|
||||
raise web.HTTPError(403, "Only admins can spawn on behalf of other users")
|
||||
raise web.HTTPError(
|
||||
403, "Only admins can spawn on behalf of other users"
|
||||
)
|
||||
user = self.find_user(for_user)
|
||||
if user is None:
|
||||
raise web.HTTPError(404, "No such user: %s" % for_user)
|
||||
@@ -151,9 +166,13 @@ class SpawnHandler(BaseHandler):
|
||||
options = await maybe_future(user.spawner.options_from_form(form_options))
|
||||
await self.spawn_single_user(user, options=options)
|
||||
except Exception as e:
|
||||
self.log.error("Failed to spawn single-user server with form", exc_info=True)
|
||||
self.log.error(
|
||||
"Failed to spawn single-user server with form", exc_info=True
|
||||
)
|
||||
spawner_options_form = await user.spawner.get_options_form()
|
||||
form = self._render_form(for_user=user, spawner_options_form=spawner_options_form, message=str(e))
|
||||
form = self._render_form(
|
||||
for_user=user, spawner_options_form=spawner_options_form, message=str(e)
|
||||
)
|
||||
self.finish(form)
|
||||
return
|
||||
if current_user is user:
|
||||
@@ -176,9 +195,7 @@ class AdminHandler(BaseHandler):
|
||||
def get(self):
|
||||
available = {'name', 'admin', 'running', 'last_activity'}
|
||||
default_sort = ['admin', 'name']
|
||||
mapping = {
|
||||
'running': orm.Spawner.server_id,
|
||||
}
|
||||
mapping = {'running': orm.Spawner.server_id}
|
||||
for name in available:
|
||||
if name not in mapping:
|
||||
mapping[name] = getattr(orm.User, name)
|
||||
@@ -219,11 +236,13 @@ class AdminHandler(BaseHandler):
|
||||
users = self.db.query(orm.User).outerjoin(orm.Spawner).order_by(*ordered)
|
||||
users = [self._user_from_orm(u) for u in users]
|
||||
from itertools import chain
|
||||
|
||||
running = []
|
||||
for u in users:
|
||||
running.extend(s for s in u.spawners.values() if s.active)
|
||||
|
||||
html = self.render_template('admin.html',
|
||||
html = self.render_template(
|
||||
'admin.html',
|
||||
current_user=self.current_user,
|
||||
admin_access=self.settings.get('admin_access', False),
|
||||
users=users,
|
||||
@@ -243,11 +262,9 @@ class TokenPageHandler(BaseHandler):
|
||||
never = datetime(1900, 1, 1)
|
||||
|
||||
user = self.current_user
|
||||
|
||||
def sort_key(token):
|
||||
return (
|
||||
token.last_activity or never,
|
||||
token.created or never,
|
||||
)
|
||||
return (token.last_activity or never, token.created or never)
|
||||
|
||||
now = datetime.utcnow()
|
||||
api_tokens = []
|
||||
@@ -285,13 +302,13 @@ class TokenPageHandler(BaseHandler):
|
||||
for token in tokens[1:]:
|
||||
if token.created < created:
|
||||
created = token.created
|
||||
if (
|
||||
last_activity is None or
|
||||
(token.last_activity and token.last_activity > last_activity)
|
||||
if last_activity is None or (
|
||||
token.last_activity and token.last_activity > last_activity
|
||||
):
|
||||
last_activity = token.last_activity
|
||||
token = tokens[0]
|
||||
oauth_clients.append({
|
||||
oauth_clients.append(
|
||||
{
|
||||
'client': token.client,
|
||||
'description': token.client.description or token.client.identifier,
|
||||
'created': created,
|
||||
@@ -301,21 +318,17 @@ class TokenPageHandler(BaseHandler):
|
||||
# revoking one oauth token revokes all oauth tokens for that client
|
||||
'token_id': tokens[0].api_id,
|
||||
'token_count': len(tokens),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
# sort oauth clients by last activity, created
|
||||
def sort_key(client):
|
||||
return (
|
||||
client['last_activity'] or never,
|
||||
client['created'] or never,
|
||||
)
|
||||
return (client['last_activity'] or never, client['created'] or never)
|
||||
|
||||
oauth_clients = sorted(oauth_clients, key=sort_key, reverse=True)
|
||||
|
||||
html = self.render_template(
|
||||
'token.html',
|
||||
api_tokens=api_tokens,
|
||||
oauth_clients=oauth_clients,
|
||||
'token.html', api_tokens=api_tokens, oauth_clients=oauth_clients
|
||||
)
|
||||
self.finish(html)
|
||||
|
||||
@@ -331,10 +344,12 @@ class ProxyErrorHandler(BaseHandler):
|
||||
hub_home = url_path_join(self.hub.base_url, 'home')
|
||||
message_html = ''
|
||||
if status_code == 503:
|
||||
message_html = ' '.join([
|
||||
message_html = ' '.join(
|
||||
[
|
||||
"Your server appears to be down.",
|
||||
"Try restarting it <a href='%s'>from the hub</a>" % hub_home
|
||||
])
|
||||
"Try restarting it <a href='%s'>from the hub</a>" % hub_home,
|
||||
]
|
||||
)
|
||||
ns = dict(
|
||||
status_code=status_code,
|
||||
status_message=status_message,
|
||||
@@ -355,6 +370,7 @@ class ProxyErrorHandler(BaseHandler):
|
||||
|
||||
class HealthCheckHandler(BaseHandler):
|
||||
"""Answer to health check"""
|
||||
|
||||
def get(self, *args):
|
||||
self.finish()
|
||||
|
||||
|
@@ -1,14 +1,16 @@
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import os
|
||||
|
||||
from tornado.web import StaticFileHandler
|
||||
|
||||
|
||||
class CacheControlStaticFilesHandler(StaticFileHandler):
|
||||
"""StaticFileHandler subclass that sets Cache-Control: no-cache without `?v=`
|
||||
|
||||
rather than relying on default browser cache behavior.
|
||||
"""
|
||||
|
||||
def compute_etag(self):
|
||||
return None
|
||||
|
||||
@@ -16,8 +18,10 @@ class CacheControlStaticFilesHandler(StaticFileHandler):
|
||||
if "v" not in self.request.arguments:
|
||||
self.add_header("Cache-Control", "no-cache")
|
||||
|
||||
|
||||
class LogoHandler(StaticFileHandler):
|
||||
"""A singular handler for serving the logo."""
|
||||
|
||||
def get(self):
|
||||
return super().get('')
|
||||
|
||||
@@ -25,4 +29,3 @@ class LogoHandler(StaticFileHandler):
|
||||
def get_absolute_path(cls, root, path):
|
||||
"""We only serve one file, ignore relative path"""
|
||||
return os.path.abspath(root)
|
||||
|
||||
|
@@ -1,13 +1,15 @@
|
||||
"""logging utilities"""
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import json
|
||||
import traceback
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlunparse
|
||||
|
||||
from tornado.log import LogFormatter, access_log
|
||||
from tornado.web import StaticFileHandler, HTTPError
|
||||
from tornado.log import access_log
|
||||
from tornado.log import LogFormatter
|
||||
from tornado.web import HTTPError
|
||||
from tornado.web import StaticFileHandler
|
||||
|
||||
from .metrics import prometheus_log_method
|
||||
|
||||
@@ -23,7 +25,11 @@ def coroutine_frames(all_frames):
|
||||
continue
|
||||
# start out conservative with filename + function matching
|
||||
# maybe just filename matching would be sufficient
|
||||
elif frame[0].endswith('tornado/gen.py') and frame[2] in {'run', 'wrapper', '__init__'}:
|
||||
elif frame[0].endswith('tornado/gen.py') and frame[2] in {
|
||||
'run',
|
||||
'wrapper',
|
||||
'__init__',
|
||||
}:
|
||||
continue
|
||||
elif frame[0].endswith('tornado/concurrent.py') and frame[2] == 'result':
|
||||
continue
|
||||
@@ -51,9 +57,11 @@ def coroutine_traceback(typ, value, tb):
|
||||
|
||||
class CoroutineLogFormatter(LogFormatter):
|
||||
"""Log formatter that scrubs coroutine frames"""
|
||||
|
||||
def formatException(self, exc_info):
|
||||
return ''.join(coroutine_traceback(*exc_info))
|
||||
|
||||
|
||||
# url params to be scrubbed if seen
|
||||
# any url param that *contains* one of these
|
||||
# will be scrubbed from logs
|
||||
@@ -96,6 +104,7 @@ def _scrub_headers(headers):
|
||||
|
||||
# log_request adapted from IPython (BSD)
|
||||
|
||||
|
||||
def log_request(handler):
|
||||
"""log a bit more information about each request than tornado's default
|
||||
|
||||
|
@@ -17,13 +17,13 @@ them manually here.
|
||||
"""
|
||||
from enum import Enum
|
||||
|
||||
from prometheus_client import Histogram
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import Histogram
|
||||
|
||||
REQUEST_DURATION_SECONDS = Histogram(
|
||||
'request_duration_seconds',
|
||||
'request duration for all HTTP requests',
|
||||
['method', 'handler', 'code']
|
||||
['method', 'handler', 'code'],
|
||||
)
|
||||
|
||||
SERVER_SPAWN_DURATION_SECONDS = Histogram(
|
||||
@@ -32,32 +32,29 @@ SERVER_SPAWN_DURATION_SECONDS = Histogram(
|
||||
['status'],
|
||||
# Use custom bucket sizes, since the default bucket ranges
|
||||
# are meant for quick running processes. Spawns can take a while!
|
||||
buckets=[0.5, 1, 2.5, 5, 10, 15, 30, 60, 120, float("inf")]
|
||||
buckets=[0.5, 1, 2.5, 5, 10, 15, 30, 60, 120, float("inf")],
|
||||
)
|
||||
|
||||
RUNNING_SERVERS = Gauge(
|
||||
'running_servers',
|
||||
'the number of user servers currently running'
|
||||
'running_servers', 'the number of user servers currently running'
|
||||
)
|
||||
|
||||
RUNNING_SERVERS.set(0)
|
||||
|
||||
TOTAL_USERS = Gauge(
|
||||
'total_users',
|
||||
'toal number of users'
|
||||
)
|
||||
TOTAL_USERS = Gauge('total_users', 'toal number of users')
|
||||
|
||||
TOTAL_USERS.set(0)
|
||||
|
||||
CHECK_ROUTES_DURATION_SECONDS = Histogram(
|
||||
'check_routes_duration_seconds',
|
||||
'Time taken to validate all routes in proxy'
|
||||
'check_routes_duration_seconds', 'Time taken to validate all routes in proxy'
|
||||
)
|
||||
|
||||
|
||||
class ServerSpawnStatus(Enum):
|
||||
"""
|
||||
Possible values for 'status' label of SERVER_SPAWN_DURATION_SECONDS
|
||||
"""
|
||||
|
||||
success = 'success'
|
||||
failure = 'failure'
|
||||
already_pending = 'already-pending'
|
||||
@@ -67,27 +64,29 @@ class ServerSpawnStatus(Enum):
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
for s in ServerSpawnStatus:
|
||||
# Create empty metrics with the given status
|
||||
SERVER_SPAWN_DURATION_SECONDS.labels(status=s)
|
||||
|
||||
|
||||
PROXY_ADD_DURATION_SECONDS = Histogram(
|
||||
'proxy_add_duration_seconds',
|
||||
'duration for adding user routes to proxy',
|
||||
['status']
|
||||
'proxy_add_duration_seconds', 'duration for adding user routes to proxy', ['status']
|
||||
)
|
||||
|
||||
|
||||
class ProxyAddStatus(Enum):
|
||||
"""
|
||||
Possible values for 'status' label of PROXY_ADD_DURATION_SECONDS
|
||||
"""
|
||||
|
||||
success = 'success'
|
||||
failure = 'failure'
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
for s in ProxyAddStatus:
|
||||
PROXY_ADD_DURATION_SECONDS.labels(status=s)
|
||||
|
||||
@@ -95,13 +94,15 @@ for s in ProxyAddStatus:
|
||||
SERVER_POLL_DURATION_SECONDS = Histogram(
|
||||
'server_poll_duration_seconds',
|
||||
'time taken to poll if server is running',
|
||||
['status']
|
||||
['status'],
|
||||
)
|
||||
|
||||
|
||||
class ServerPollStatus(Enum):
|
||||
"""
|
||||
Possible values for 'status' label of SERVER_POLL_DURATION_SECONDS
|
||||
"""
|
||||
|
||||
running = 'running'
|
||||
stopped = 'stopped'
|
||||
|
||||
@@ -112,27 +113,28 @@ class ServerPollStatus(Enum):
|
||||
return cls.running
|
||||
return cls.stopped
|
||||
|
||||
|
||||
for s in ServerPollStatus:
|
||||
SERVER_POLL_DURATION_SECONDS.labels(status=s)
|
||||
|
||||
|
||||
|
||||
SERVER_STOP_DURATION_SECONDS = Histogram(
|
||||
'server_stop_seconds',
|
||||
'time taken for server stopping operation',
|
||||
['status'],
|
||||
'server_stop_seconds', 'time taken for server stopping operation', ['status']
|
||||
)
|
||||
|
||||
|
||||
class ServerStopStatus(Enum):
|
||||
"""
|
||||
Possible values for 'status' label of SERVER_STOP_DURATION_SECONDS
|
||||
"""
|
||||
|
||||
success = 'success'
|
||||
failure = 'failure'
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
for s in ServerStopStatus:
|
||||
SERVER_STOP_DURATION_SECONDS.labels(status=s)
|
||||
|
||||
@@ -156,5 +158,5 @@ def prometheus_log_method(handler):
|
||||
REQUEST_DURATION_SECONDS.labels(
|
||||
method=handler.request.method,
|
||||
handler='{}.{}'.format(handler.__class__.__module__, type(handler).__name__),
|
||||
code=handler.get_status()
|
||||
code=handler.get_status(),
|
||||
).observe(handler.request.request_time())
|
||||
|
@@ -2,30 +2,29 @@
|
||||
|
||||
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from oauthlib.oauth2 import RequestValidator, WebApplicationServer
|
||||
|
||||
from oauthlib.oauth2 import RequestValidator
|
||||
from oauthlib.oauth2 import WebApplicationServer
|
||||
from oauthlib.oauth2.rfc6749.grant_types import authorization_code
|
||||
from sqlalchemy.orm import scoped_session
|
||||
from tornado import web
|
||||
from tornado.escape import url_escape
|
||||
from tornado.log import app_log
|
||||
from tornado import web
|
||||
|
||||
from .. import orm
|
||||
from ..utils import url_path_join, hash_token, compare_token
|
||||
|
||||
from ..utils import compare_token
|
||||
from ..utils import hash_token
|
||||
from ..utils import url_path_join
|
||||
|
||||
# patch absolute-uri check
|
||||
# because we want to allow relative uri oauth
|
||||
# for internal services
|
||||
from oauthlib.oauth2.rfc6749.grant_types import authorization_code
|
||||
authorization_code.is_absolute_uri = lambda uri: True
|
||||
|
||||
|
||||
class JupyterHubRequestValidator(RequestValidator):
|
||||
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
super().__init__()
|
||||
@@ -51,10 +50,7 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
client_id = request.client_id
|
||||
client_secret = request.client_secret
|
||||
oauth_client = (
|
||||
self.db
|
||||
.query(orm.OAuthClient)
|
||||
.filter_by(identifier=client_id)
|
||||
.first()
|
||||
self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
|
||||
)
|
||||
if oauth_client is None:
|
||||
return False
|
||||
@@ -78,10 +74,7 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
- Authorization Code Grant
|
||||
"""
|
||||
orm_client = (
|
||||
self.db
|
||||
.query(orm.OAuthClient)
|
||||
.filter_by(identifier=client_id)
|
||||
.first()
|
||||
self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
|
||||
)
|
||||
if orm_client is None:
|
||||
app_log.warning("No such oauth client %s", client_id)
|
||||
@@ -89,8 +82,9 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
request.client = orm_client
|
||||
return True
|
||||
|
||||
def confirm_redirect_uri(self, client_id, code, redirect_uri, client,
|
||||
*args, **kwargs):
|
||||
def confirm_redirect_uri(
|
||||
self, client_id, code, redirect_uri, client, *args, **kwargs
|
||||
):
|
||||
"""Ensure that the authorization process represented by this authorization
|
||||
code began with this 'redirect_uri'.
|
||||
If the client specifies a redirect_uri when obtaining code then that
|
||||
@@ -108,8 +102,10 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
"""
|
||||
# TODO: record redirect_uri used during oauth
|
||||
# if we ever support multiple destinations
|
||||
app_log.debug("confirm_redirect_uri: client_id=%s, redirect_uri=%s",
|
||||
client_id, redirect_uri,
|
||||
app_log.debug(
|
||||
"confirm_redirect_uri: client_id=%s, redirect_uri=%s",
|
||||
client_id,
|
||||
redirect_uri,
|
||||
)
|
||||
if redirect_uri == client.redirect_uri:
|
||||
return True
|
||||
@@ -127,10 +123,7 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
- Implicit Grant
|
||||
"""
|
||||
orm_client = (
|
||||
self.db
|
||||
.query(orm.OAuthClient)
|
||||
.filter_by(identifier=client_id)
|
||||
.first()
|
||||
self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
|
||||
)
|
||||
if orm_client is None:
|
||||
raise KeyError(client_id)
|
||||
@@ -159,7 +152,9 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_within_original_scope(self, request_scopes, refresh_token, request, *args, **kwargs):
|
||||
def is_within_original_scope(
|
||||
self, request_scopes, refresh_token, request, *args, **kwargs
|
||||
):
|
||||
"""Check if requested scopes are within a scope of the refresh token.
|
||||
When access tokens are refreshed the scope of the new token
|
||||
needs to be within the scope of the original token. This is
|
||||
@@ -227,12 +222,15 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
- Authorization Code Grant
|
||||
"""
|
||||
log_code = code.get('code', 'undefined')[:3] + '...'
|
||||
app_log.debug("Saving authorization code %s, %s, %s, %s", client_id, log_code, args, kwargs)
|
||||
app_log.debug(
|
||||
"Saving authorization code %s, %s, %s, %s",
|
||||
client_id,
|
||||
log_code,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
orm_client = (
|
||||
self.db
|
||||
.query(orm.OAuthClient)
|
||||
.filter_by(identifier=client_id)
|
||||
.first()
|
||||
self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
|
||||
)
|
||||
if orm_client is None:
|
||||
raise ValueError("No such client: %s" % client_id)
|
||||
@@ -330,7 +328,11 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
app_log.debug("Saving bearer token %s", log_token)
|
||||
if request.user is None:
|
||||
raise ValueError("No user for access token: %s" % request.user)
|
||||
client = self.db.query(orm.OAuthClient).filter_by(identifier=request.client.client_id).first()
|
||||
client = (
|
||||
self.db.query(orm.OAuthClient)
|
||||
.filter_by(identifier=request.client.client_id)
|
||||
.first()
|
||||
)
|
||||
orm_access_token = orm.OAuthAccessToken(
|
||||
client=client,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
@@ -400,10 +402,7 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
"""
|
||||
app_log.debug("Validating client id %s", client_id)
|
||||
orm_client = (
|
||||
self.db
|
||||
.query(orm.OAuthClient)
|
||||
.filter_by(identifier=client_id)
|
||||
.first()
|
||||
self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
|
||||
)
|
||||
if orm_client is None:
|
||||
return False
|
||||
@@ -431,19 +430,13 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
Method is used by:
|
||||
- Authorization Code Grant
|
||||
"""
|
||||
orm_code = (
|
||||
self.db
|
||||
.query(orm.OAuthCode)
|
||||
.filter_by(code=code)
|
||||
.first()
|
||||
)
|
||||
orm_code = self.db.query(orm.OAuthCode).filter_by(code=code).first()
|
||||
if orm_code is None:
|
||||
app_log.debug("No such code: %s", code)
|
||||
return False
|
||||
if orm_code.client_id != client_id:
|
||||
app_log.debug(
|
||||
"OAuth code client id mismatch: %s != %s",
|
||||
client_id, orm_code.client_id,
|
||||
"OAuth code client id mismatch: %s != %s", client_id, orm_code.client_id
|
||||
)
|
||||
return False
|
||||
request.user = orm_code.user
|
||||
@@ -453,7 +446,9 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
request.scopes = ['identify']
|
||||
return True
|
||||
|
||||
def validate_grant_type(self, client_id, grant_type, client, request, *args, **kwargs):
|
||||
def validate_grant_type(
|
||||
self, client_id, grant_type, client, request, *args, **kwargs
|
||||
):
|
||||
"""Ensure client is authorized to use the grant_type requested.
|
||||
:param client_id: Unicode client identifier
|
||||
:param grant_type: Unicode grant type, i.e. authorization_code, password.
|
||||
@@ -480,14 +475,13 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
- Authorization Code Grant
|
||||
- Implicit Grant
|
||||
"""
|
||||
app_log.debug("validate_redirect_uri: client_id=%s, redirect_uri=%s",
|
||||
client_id, redirect_uri,
|
||||
app_log.debug(
|
||||
"validate_redirect_uri: client_id=%s, redirect_uri=%s",
|
||||
client_id,
|
||||
redirect_uri,
|
||||
)
|
||||
orm_client = (
|
||||
self.db
|
||||
.query(orm.OAuthClient)
|
||||
.filter_by(identifier=client_id)
|
||||
.first()
|
||||
self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
|
||||
)
|
||||
if orm_client is None:
|
||||
app_log.warning("No such oauth client %s", client_id)
|
||||
@@ -495,7 +489,9 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
if redirect_uri == orm_client.redirect_uri:
|
||||
return True
|
||||
else:
|
||||
app_log.warning("Redirect uri %s != %s", redirect_uri, orm_client.redirect_uri)
|
||||
app_log.warning(
|
||||
"Redirect uri %s != %s", redirect_uri, orm_client.redirect_uri
|
||||
)
|
||||
return False
|
||||
|
||||
def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs):
|
||||
@@ -514,7 +510,9 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
return False
|
||||
raise NotImplementedError('Subclasses must implement this method.')
|
||||
|
||||
def validate_response_type(self, client_id, response_type, client, request, *args, **kwargs):
|
||||
def validate_response_type(
|
||||
self, client_id, response_type, client, request, *args, **kwargs
|
||||
):
|
||||
"""Ensure client is authorized to use the response_type requested.
|
||||
:param client_id: Unicode client identifier
|
||||
:param response_type: Unicode response type, i.e. code, token.
|
||||
@@ -555,10 +553,8 @@ class JupyterHubOAuthServer(WebApplicationServer):
|
||||
hash its client_secret before putting it in the database.
|
||||
"""
|
||||
# clear existing clients with same ID
|
||||
for orm_client in (
|
||||
self.db
|
||||
.query(orm.OAuthClient)\
|
||||
.filter_by(identifier=client_id)
|
||||
for orm_client in self.db.query(orm.OAuthClient).filter_by(
|
||||
identifier=client_id
|
||||
):
|
||||
self.db.delete(orm_client)
|
||||
self.db.commit()
|
||||
@@ -574,12 +570,7 @@ class JupyterHubOAuthServer(WebApplicationServer):
|
||||
|
||||
def fetch_by_client_id(self, client_id):
|
||||
"""Find a client by its id"""
|
||||
return (
|
||||
self.db
|
||||
.query(orm.OAuthClient)
|
||||
.filter_by(identifier=client_id)
|
||||
.first()
|
||||
)
|
||||
return self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
|
||||
|
||||
|
||||
def make_provider(session_factory, url_prefix, login_url):
|
||||
@@ -588,4 +579,3 @@ def make_provider(session_factory, url_prefix, login_url):
|
||||
validator = JupyterHubRequestValidator(db)
|
||||
server = JupyterHubOAuthServer(db, validator)
|
||||
return server
|
||||
|
||||
|
@@ -1,22 +1,28 @@
|
||||
"""Some general objects for use in JupyterHub"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import socket
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
import warnings
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlunparse
|
||||
|
||||
from traitlets import default
|
||||
from traitlets import HasTraits
|
||||
from traitlets import Instance
|
||||
from traitlets import Integer
|
||||
from traitlets import observe
|
||||
from traitlets import Unicode
|
||||
from traitlets import validate
|
||||
|
||||
from traitlets import (
|
||||
HasTraits, Instance, Integer, Unicode,
|
||||
default, observe, validate,
|
||||
)
|
||||
from .traitlets import URLPrefix
|
||||
from . import orm
|
||||
from .utils import (
|
||||
url_path_join, can_connect, wait_for_server,
|
||||
wait_for_http_server, random_port, make_ssl_context,
|
||||
)
|
||||
from .traitlets import URLPrefix
|
||||
from .utils import can_connect
|
||||
from .utils import make_ssl_context
|
||||
from .utils import random_port
|
||||
from .utils import url_path_join
|
||||
from .utils import wait_for_http_server
|
||||
from .utils import wait_for_server
|
||||
|
||||
|
||||
class Server(HasTraits):
|
||||
"""An object representing an HTTP endpoint.
|
||||
@@ -24,6 +30,7 @@ class Server(HasTraits):
|
||||
*Some* of these reside in the database (user servers),
|
||||
but others (Hub, proxy) are in-memory only.
|
||||
"""
|
||||
|
||||
orm_server = Instance(orm.Server, allow_none=True)
|
||||
|
||||
ip = Unicode()
|
||||
@@ -141,36 +148,31 @@ class Server(HasTraits):
|
||||
def host(self):
|
||||
if self.connect_url:
|
||||
parsed = urlparse(self.connect_url)
|
||||
return "{proto}://{host}".format(
|
||||
proto=parsed.scheme,
|
||||
host=parsed.netloc,
|
||||
)
|
||||
return "{proto}://{host}".format(proto=parsed.scheme, host=parsed.netloc)
|
||||
return "{proto}://{ip}:{port}".format(
|
||||
proto=self.proto,
|
||||
ip=self._connect_ip,
|
||||
port=self._connect_port,
|
||||
proto=self.proto, ip=self._connect_ip, port=self._connect_port
|
||||
)
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
if self.connect_url:
|
||||
return self.connect_url
|
||||
return "{host}{uri}".format(
|
||||
host=self.host,
|
||||
uri=self.base_url,
|
||||
)
|
||||
|
||||
return "{host}{uri}".format(host=self.host, uri=self.base_url)
|
||||
|
||||
def wait_up(self, timeout=10, http=False, ssl_context=None):
|
||||
"""Wait for this server to come up"""
|
||||
if http:
|
||||
ssl_context = ssl_context or make_ssl_context(
|
||||
self.keyfile, self.certfile, cafile=self.cafile)
|
||||
self.keyfile, self.certfile, cafile=self.cafile
|
||||
)
|
||||
|
||||
return wait_for_http_server(
|
||||
self.url, timeout=timeout, ssl_context=ssl_context)
|
||||
self.url, timeout=timeout, ssl_context=ssl_context
|
||||
)
|
||||
else:
|
||||
return wait_for_server(self._connect_ip, self._connect_port, timeout=timeout)
|
||||
return wait_for_server(
|
||||
self._connect_ip, self._connect_port, timeout=timeout
|
||||
)
|
||||
|
||||
def is_up(self):
|
||||
"""Is the server accepting connections?"""
|
||||
@@ -190,11 +192,13 @@ class Hub(Server):
|
||||
|
||||
@property
|
||||
def server(self):
|
||||
warnings.warn("Hub.server is deprecated in JupyterHub 0.8. Access attributes on the Hub directly.",
|
||||
warnings.warn(
|
||||
"Hub.server is deprecated in JupyterHub 0.8. Access attributes on the Hub directly.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self
|
||||
|
||||
public_host = Unicode()
|
||||
routespec = Unicode()
|
||||
|
||||
@@ -205,5 +209,7 @@ class Hub(Server):
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s %s:%s>" % (
|
||||
self.__class__.__name__, self.server.ip, self.server.port,
|
||||
self.__class__.__name__,
|
||||
self.server.ip,
|
||||
self.server.port,
|
||||
)
|
||||
|
@@ -1,36 +1,45 @@
|
||||
"""sqlalchemy ORM tools for the state of the constellation of processes"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
import enum
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
import alembic.config
|
||||
import alembic.command
|
||||
import alembic.config
|
||||
from alembic.script import ScriptDirectory
|
||||
from tornado.log import app_log
|
||||
|
||||
from sqlalchemy.types import TypeDecorator, Text, LargeBinary
|
||||
from sqlalchemy import (
|
||||
create_engine, event, exc, inspect, or_, select,
|
||||
Column, Integer, ForeignKey, Unicode, Boolean,
|
||||
DateTime, Enum, Table,
|
||||
)
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import Unicode
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import (
|
||||
Session,
|
||||
interfaces, object_session, relationship, sessionmaker,
|
||||
)
|
||||
|
||||
from sqlalchemy.orm import interfaces
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.sql.expression import bindparam
|
||||
from sqlalchemy.types import LargeBinary
|
||||
from sqlalchemy.types import Text
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from tornado.log import app_log
|
||||
|
||||
from .utils import (
|
||||
random_port,
|
||||
new_token, hash_token, compare_token,
|
||||
)
|
||||
from .utils import compare_token
|
||||
from .utils import hash_token
|
||||
from .utils import new_token
|
||||
from .utils import random_port
|
||||
|
||||
# top-level variable for easier mocking in tests
|
||||
utcnow = datetime.utcnow
|
||||
@@ -68,6 +77,7 @@ class Server(Base):
|
||||
|
||||
connection and cookie info
|
||||
"""
|
||||
|
||||
__tablename__ = 'servers'
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
@@ -82,7 +92,9 @@ class Server(Base):
|
||||
|
||||
|
||||
# user:group many:many mapping table
|
||||
user_group_map = Table('user_group_map', Base.metadata,
|
||||
user_group_map = Table(
|
||||
'user_group_map',
|
||||
Base.metadata,
|
||||
Column('user_id', ForeignKey('users.id', ondelete='CASCADE'), primary_key=True),
|
||||
Column('group_id', ForeignKey('groups.id', ondelete='CASCADE'), primary_key=True),
|
||||
)
|
||||
@@ -90,6 +102,7 @@ user_group_map = Table('user_group_map', Base.metadata,
|
||||
|
||||
class Group(Base):
|
||||
"""User Groups"""
|
||||
|
||||
__tablename__ = 'groups'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(Unicode(255), unique=True)
|
||||
@@ -97,7 +110,9 @@ class Group(Base):
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s %s (%i users)>" % (
|
||||
self.__class__.__name__, self.name, len(self.users)
|
||||
self.__class__.__name__,
|
||||
self.name,
|
||||
len(self.users),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -130,15 +145,15 @@ class User(Base):
|
||||
`servers` is a list that contains a reference for each of the user's single user notebook servers.
|
||||
The method `server` returns the first entry in the user's `servers` list.
|
||||
"""
|
||||
|
||||
__tablename__ = 'users'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(Unicode(255), unique=True)
|
||||
|
||||
_orm_spawners = relationship(
|
||||
"Spawner",
|
||||
backref="user",
|
||||
cascade="all, delete-orphan",
|
||||
"Spawner", backref="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@property
|
||||
def orm_spawners(self):
|
||||
return {s.name: s for s in self._orm_spawners}
|
||||
@@ -147,20 +162,12 @@ class User(Base):
|
||||
created = Column(DateTime, default=datetime.utcnow)
|
||||
last_activity = Column(DateTime, nullable=True)
|
||||
|
||||
api_tokens = relationship(
|
||||
"APIToken",
|
||||
backref="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
api_tokens = relationship("APIToken", backref="user", cascade="all, delete-orphan")
|
||||
oauth_tokens = relationship(
|
||||
"OAuthAccessToken",
|
||||
backref="user",
|
||||
cascade="all, delete-orphan",
|
||||
"OAuthAccessToken", backref="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oauth_codes = relationship(
|
||||
"OAuthCode",
|
||||
backref="user",
|
||||
cascade="all, delete-orphan",
|
||||
"OAuthCode", backref="user", cascade="all, delete-orphan"
|
||||
)
|
||||
cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True)
|
||||
# User.state is actually Spawner state
|
||||
@@ -192,8 +199,10 @@ class User(Base):
|
||||
"""
|
||||
return db.query(cls).filter(cls.name == name).first()
|
||||
|
||||
|
||||
class Spawner(Base):
|
||||
""""State about a Spawner"""
|
||||
|
||||
__tablename__ = 'spawners'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
@@ -214,10 +223,12 @@ class Spawner(Base):
|
||||
# for which these should all be False
|
||||
active = running = ready = False
|
||||
pending = None
|
||||
|
||||
@property
|
||||
def orm_spawner(self):
|
||||
return self
|
||||
|
||||
|
||||
class Service(Base):
|
||||
"""A service run with JupyterHub
|
||||
|
||||
@@ -235,6 +246,7 @@ class Service(Base):
|
||||
- pid: the process id (if managed)
|
||||
|
||||
"""
|
||||
|
||||
__tablename__ = 'services'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
@@ -243,9 +255,7 @@ class Service(Base):
|
||||
admin = Column(Boolean, default=False)
|
||||
|
||||
api_tokens = relationship(
|
||||
"APIToken",
|
||||
backref="service",
|
||||
cascade="all, delete-orphan",
|
||||
"APIToken", backref="service", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# service-specific interface
|
||||
@@ -270,6 +280,7 @@ class Service(Base):
|
||||
|
||||
class Hashed(object):
|
||||
"""Mixin for tables with hashed tokens"""
|
||||
|
||||
prefix_length = 4
|
||||
algorithm = "sha512"
|
||||
rounds = 16384
|
||||
@@ -299,7 +310,9 @@ class Hashed(object):
|
||||
else:
|
||||
rounds = self.rounds
|
||||
salt_bytes = self.salt_bytes
|
||||
self.hashed = hash_token(token, rounds=rounds, salt=salt_bytes, algorithm=self.algorithm)
|
||||
self.hashed = hash_token(
|
||||
token, rounds=rounds, salt=salt_bytes, algorithm=self.algorithm
|
||||
)
|
||||
|
||||
def match(self, token):
|
||||
"""Is this my token?"""
|
||||
@@ -309,8 +322,9 @@ class Hashed(object):
|
||||
def check_token(cls, db, token):
|
||||
"""Check if a token is acceptable"""
|
||||
if len(token) < cls.min_length:
|
||||
raise ValueError("Tokens must be at least %i characters, got %r" % (
|
||||
cls.min_length, token)
|
||||
raise ValueError(
|
||||
"Tokens must be at least %i characters, got %r"
|
||||
% (cls.min_length, token)
|
||||
)
|
||||
found = cls.find(db, token)
|
||||
if found:
|
||||
@@ -344,10 +358,13 @@ class Hashed(object):
|
||||
|
||||
class APIToken(Hashed, Base):
|
||||
"""An API token"""
|
||||
|
||||
__tablename__ = 'api_tokens'
|
||||
|
||||
user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
|
||||
service_id = Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True)
|
||||
service_id = Column(
|
||||
Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
hashed = Column(Unicode(255), unique=True)
|
||||
@@ -375,10 +392,7 @@ class APIToken(Hashed, Base):
|
||||
kind = 'owner'
|
||||
name = 'unknown'
|
||||
return "<{cls}('{pre}...', {kind}='{name}')>".format(
|
||||
cls=self.__class__.__name__,
|
||||
pre=self.prefix,
|
||||
kind=kind,
|
||||
name=name,
|
||||
cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -387,9 +401,7 @@ class APIToken(Hashed, Base):
|
||||
now = utcnow()
|
||||
deleted = False
|
||||
for token in (
|
||||
db.query(cls)
|
||||
.filter(cls.expires_at != None)
|
||||
.filter(cls.expires_at < now)
|
||||
db.query(cls).filter(cls.expires_at != None).filter(cls.expires_at < now)
|
||||
):
|
||||
app_log.debug("Purging expired %s", token)
|
||||
deleted = True
|
||||
@@ -421,8 +433,15 @@ class APIToken(Hashed, Base):
|
||||
return orm_token
|
||||
|
||||
@classmethod
|
||||
def new(cls, token=None, user=None, service=None, note='', generated=True,
|
||||
expires_in=None):
|
||||
def new(
|
||||
cls,
|
||||
token=None,
|
||||
user=None,
|
||||
service=None,
|
||||
note='',
|
||||
generated=True,
|
||||
expires_in=None,
|
||||
):
|
||||
"""Generate a new API token for a user or service"""
|
||||
assert user or service
|
||||
assert not (user and service)
|
||||
@@ -473,7 +492,9 @@ class OAuthAccessToken(Hashed, Base):
|
||||
def api_id(self):
|
||||
return 'o%i' % self.id
|
||||
|
||||
client_id = Column(Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE'))
|
||||
client_id = Column(
|
||||
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
|
||||
)
|
||||
grant_type = Column(Enum(GrantType), nullable=False)
|
||||
expires_at = Column(Integer)
|
||||
refresh_token = Column(Unicode(255))
|
||||
@@ -517,7 +538,9 @@ class OAuthAccessToken(Hashed, Base):
|
||||
class OAuthCode(Base):
|
||||
__tablename__ = 'oauth_codes'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
client_id = Column(Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE'))
|
||||
client_id = Column(
|
||||
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
|
||||
)
|
||||
code = Column(Unicode(36))
|
||||
expires_at = Column(Integer)
|
||||
redirect_uri = Column(Unicode(1023))
|
||||
@@ -539,18 +562,14 @@ class OAuthClient(Base):
|
||||
return self.identifier
|
||||
|
||||
access_tokens = relationship(
|
||||
OAuthAccessToken,
|
||||
backref='client',
|
||||
cascade='all, delete-orphan',
|
||||
)
|
||||
codes = relationship(
|
||||
OAuthCode,
|
||||
backref='client',
|
||||
cascade='all, delete-orphan',
|
||||
OAuthAccessToken, backref='client', cascade='all, delete-orphan'
|
||||
)
|
||||
codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan')
|
||||
|
||||
|
||||
# General database utilities
|
||||
|
||||
|
||||
class DatabaseSchemaMismatch(Exception):
|
||||
"""Exception raised when the database schema version does not match
|
||||
|
||||
@@ -560,6 +579,7 @@ class DatabaseSchemaMismatch(Exception):
|
||||
|
||||
def register_foreign_keys(engine):
|
||||
"""register PRAGMA foreign_keys=on on connection"""
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def connect(dbapi_con, con_record):
|
||||
cursor = dbapi_con.cursor()
|
||||
@@ -609,6 +629,7 @@ def register_ping_connection(engine):
|
||||
|
||||
https://docs.sqlalchemy.org/en/rel_1_1/core/pooling.html#disconnect-handling-pessimistic
|
||||
"""
|
||||
|
||||
@event.listens_for(engine, "engine_connect")
|
||||
def ping_connection(connection, branch):
|
||||
if branch:
|
||||
@@ -633,7 +654,9 @@ def register_ping_connection(engine):
|
||||
# condition, which is based on inspection of the original exception
|
||||
# by the dialect in use.
|
||||
if err.connection_invalidated:
|
||||
app_log.error("Database connection error, attempting to reconnect: %s", err)
|
||||
app_log.error(
|
||||
"Database connection error, attempting to reconnect: %s", err
|
||||
)
|
||||
# run the same SELECT again - the connection will re-validate
|
||||
# itself and establish a new connection. The disconnect detection
|
||||
# here also causes the whole connection pool to be invalidated
|
||||
@@ -697,29 +720,37 @@ def check_db_revision(engine):
|
||||
|
||||
# check database schema version
|
||||
# it should always be defined at this point
|
||||
alembic_revision = engine.execute('SELECT version_num FROM alembic_version').first()[0]
|
||||
alembic_revision = engine.execute(
|
||||
'SELECT version_num FROM alembic_version'
|
||||
).first()[0]
|
||||
if alembic_revision == head:
|
||||
app_log.debug("database schema version found: %s", alembic_revision)
|
||||
pass
|
||||
else:
|
||||
raise DatabaseSchemaMismatch("Found database schema version {found} != {head}. "
|
||||
raise DatabaseSchemaMismatch(
|
||||
"Found database schema version {found} != {head}. "
|
||||
"Backup your database and run `jupyterhub upgrade-db`"
|
||||
" to upgrade to the latest schema.".format(
|
||||
found=alembic_revision,
|
||||
head=head,
|
||||
))
|
||||
found=alembic_revision, head=head
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def mysql_large_prefix_check(engine):
|
||||
"""Check mysql has innodb_large_prefix set"""
|
||||
if not str(engine.url).startswith('mysql'):
|
||||
return False
|
||||
variables = dict(engine.execute(
|
||||
variables = dict(
|
||||
engine.execute(
|
||||
'show variables where variable_name like '
|
||||
'"innodb_large_prefix" or '
|
||||
'variable_name like "innodb_file_format";').fetchall())
|
||||
if (variables['innodb_file_format'] == 'Barracuda' and
|
||||
variables['innodb_large_prefix'] == 'ON'):
|
||||
'variable_name like "innodb_file_format";'
|
||||
).fetchall()
|
||||
)
|
||||
if (
|
||||
variables['innodb_file_format'] == 'Barracuda'
|
||||
and variables['innodb_large_prefix'] == 'ON'
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@@ -730,10 +761,9 @@ def add_row_format(base):
|
||||
t.dialect_kwargs['mysql_ROW_FORMAT'] = 'DYNAMIC'
|
||||
|
||||
|
||||
def new_session_factory(url="sqlite:///:memory:",
|
||||
reset=False,
|
||||
expire_on_commit=False,
|
||||
**kwargs):
|
||||
def new_session_factory(
|
||||
url="sqlite:///:memory:", reset=False, expire_on_commit=False, **kwargs
|
||||
):
|
||||
"""Create a new session at url"""
|
||||
if url.startswith('sqlite'):
|
||||
kwargs.setdefault('connect_args', {'check_same_thread': False})
|
||||
@@ -767,7 +797,5 @@ def new_session_factory(url="sqlite:///:memory:",
|
||||
# SQLAlchemy to expire objects after committing - we don't expect
|
||||
# concurrent runs of the hub talking to the same db. Turning
|
||||
# this off gives us a major performance boost
|
||||
session_factory = sessionmaker(bind=engine,
|
||||
expire_on_commit=expire_on_commit,
|
||||
)
|
||||
session_factory = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
|
||||
return session_factory
|
||||
|
@@ -14,36 +14,38 @@ Route Specification:
|
||||
'host.tld/path/' for host-based routing or '/path/' for default routing.
|
||||
- Route paths should be normalized to always start and end with '/'
|
||||
"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
from subprocess import Popen
|
||||
import time
|
||||
from urllib.parse import quote, urlparse
|
||||
from functools import wraps
|
||||
from subprocess import Popen
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from tornado import gen
|
||||
from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPError
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httpclient import HTTPError
|
||||
from tornado.httpclient import HTTPRequest
|
||||
from tornado.ioloop import PeriodicCallback
|
||||
|
||||
|
||||
from traitlets import (
|
||||
Any, Bool, Instance, Integer, Unicode,
|
||||
default, observe,
|
||||
)
|
||||
from jupyterhub.traitlets import Command
|
||||
|
||||
from traitlets import Any
|
||||
from traitlets import Bool
|
||||
from traitlets import default
|
||||
from traitlets import Instance
|
||||
from traitlets import Integer
|
||||
from traitlets import observe
|
||||
from traitlets import Unicode
|
||||
from traitlets.config import LoggingConfigurable
|
||||
|
||||
from . import utils
|
||||
from .metrics import CHECK_ROUTES_DURATION_SECONDS
|
||||
from .objects import Server
|
||||
from . import utils
|
||||
from .utils import url_path_join, make_ssl_context
|
||||
from .utils import make_ssl_context
|
||||
from .utils import url_path_join
|
||||
from jupyterhub.traitlets import Command
|
||||
|
||||
|
||||
def _one_at_a_time(method):
|
||||
@@ -53,6 +55,7 @@ def _one_at_a_time(method):
|
||||
queue them instead of allowing them to be concurrently outstanding.
|
||||
"""
|
||||
method._lock = asyncio.Lock()
|
||||
|
||||
@wraps(method)
|
||||
async def locked_method(*args, **kwargs):
|
||||
async with method._lock:
|
||||
@@ -86,6 +89,7 @@ class Proxy(LoggingConfigurable):
|
||||
"""
|
||||
|
||||
db_factory = Any()
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
return self.db_factory()
|
||||
@@ -97,13 +101,16 @@ class Proxy(LoggingConfigurable):
|
||||
ssl_cert = Unicode()
|
||||
host_routing = Bool()
|
||||
|
||||
should_start = Bool(True, config=True,
|
||||
should_start = Bool(
|
||||
True,
|
||||
config=True,
|
||||
help="""Should the Hub start the proxy
|
||||
|
||||
If True, the Hub will start the proxy and stop it.
|
||||
Set to False if the proxy is managed externally,
|
||||
such as by systemd, docker, or another service manager.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
def start(self):
|
||||
"""Start the proxy.
|
||||
@@ -136,9 +143,13 @@ class Proxy(LoggingConfigurable):
|
||||
# check host routing
|
||||
host_route = not routespec.startswith('/')
|
||||
if host_route and not self.host_routing:
|
||||
raise ValueError("Cannot add host-based route %r, not using host-routing" % routespec)
|
||||
raise ValueError(
|
||||
"Cannot add host-based route %r, not using host-routing" % routespec
|
||||
)
|
||||
if self.host_routing and not host_route:
|
||||
raise ValueError("Cannot add route without host %r, using host-routing" % routespec)
|
||||
raise ValueError(
|
||||
"Cannot add route without host %r, using host-routing" % routespec
|
||||
)
|
||||
# add trailing slash
|
||||
if not routespec.endswith('/'):
|
||||
return routespec + '/'
|
||||
@@ -220,16 +231,19 @@ class Proxy(LoggingConfigurable):
|
||||
"""Add a service's server to the proxy table."""
|
||||
if not service.server:
|
||||
raise RuntimeError(
|
||||
"Service %s does not have an http endpoint to add to the proxy.", service.name)
|
||||
"Service %s does not have an http endpoint to add to the proxy.",
|
||||
service.name,
|
||||
)
|
||||
|
||||
self.log.info("Adding service %s to proxy %s => %s",
|
||||
service.name, service.proxy_spec, service.server.host,
|
||||
self.log.info(
|
||||
"Adding service %s to proxy %s => %s",
|
||||
service.name,
|
||||
service.proxy_spec,
|
||||
service.server.host,
|
||||
)
|
||||
|
||||
await self.add_route(
|
||||
service.proxy_spec,
|
||||
service.server.host,
|
||||
{'service': service.name}
|
||||
service.proxy_spec, service.server.host, {'service': service.name}
|
||||
)
|
||||
|
||||
async def delete_service(self, service, client=None):
|
||||
@@ -240,22 +254,23 @@ class Proxy(LoggingConfigurable):
|
||||
async def add_user(self, user, server_name='', client=None):
|
||||
"""Add a user's server to the proxy table."""
|
||||
spawner = user.spawners[server_name]
|
||||
self.log.info("Adding user %s to proxy %s => %s",
|
||||
user.name, spawner.proxy_spec, spawner.server.host,
|
||||
self.log.info(
|
||||
"Adding user %s to proxy %s => %s",
|
||||
user.name,
|
||||
spawner.proxy_spec,
|
||||
spawner.server.host,
|
||||
)
|
||||
|
||||
if spawner.pending and spawner.pending != 'spawn':
|
||||
raise RuntimeError(
|
||||
"%s is pending %s, shouldn't be added to the proxy yet!" % (spawner._log_name, spawner.pending)
|
||||
"%s is pending %s, shouldn't be added to the proxy yet!"
|
||||
% (spawner._log_name, spawner.pending)
|
||||
)
|
||||
|
||||
await self.add_route(
|
||||
spawner.proxy_spec,
|
||||
spawner.server.host,
|
||||
{
|
||||
'user': user.name,
|
||||
'server_name': server_name,
|
||||
}
|
||||
{'user': user.name, 'server_name': server_name},
|
||||
)
|
||||
|
||||
async def delete_user(self, user, server_name=''):
|
||||
@@ -314,7 +329,9 @@ class Proxy(LoggingConfigurable):
|
||||
else:
|
||||
route = routes[self.app.hub.routespec]
|
||||
if route['target'] != hub.host:
|
||||
self.log.warning("Updating default route %s → %s", route['target'], hub.host)
|
||||
self.log.warning(
|
||||
"Updating default route %s → %s", route['target'], hub.host
|
||||
)
|
||||
futures.append(self.add_hub_route(hub))
|
||||
|
||||
for user in user_dict.values():
|
||||
@@ -324,14 +341,17 @@ class Proxy(LoggingConfigurable):
|
||||
good_routes.add(spec)
|
||||
if spec not in user_routes:
|
||||
self.log.warning(
|
||||
"Adding missing route for %s (%s)", spec, spawner.server)
|
||||
"Adding missing route for %s (%s)", spec, spawner.server
|
||||
)
|
||||
futures.append(self.add_user(user, name))
|
||||
else:
|
||||
route = routes[spec]
|
||||
if route['target'] != spawner.server.host:
|
||||
self.log.warning(
|
||||
"Updating route for %s (%s → %s)",
|
||||
spec, route['target'], spawner.server,
|
||||
spec,
|
||||
route['target'],
|
||||
spawner.server,
|
||||
)
|
||||
futures.append(self.add_user(user, name))
|
||||
elif spawner.pending:
|
||||
@@ -341,22 +361,26 @@ class Proxy(LoggingConfigurable):
|
||||
good_routes.add(spawner.proxy_spec)
|
||||
|
||||
# check service routes
|
||||
service_routes = {r['data']['service']: r
|
||||
for r in routes.values() if 'service' in r['data']}
|
||||
service_routes = {
|
||||
r['data']['service']: r for r in routes.values() if 'service' in r['data']
|
||||
}
|
||||
for service in service_dict.values():
|
||||
if service.server is None:
|
||||
continue
|
||||
good_routes.add(service.proxy_spec)
|
||||
if service.name not in service_routes:
|
||||
self.log.warning("Adding missing route for %s (%s)",
|
||||
service.name, service.server)
|
||||
self.log.warning(
|
||||
"Adding missing route for %s (%s)", service.name, service.server
|
||||
)
|
||||
futures.append(self.add_service(service))
|
||||
else:
|
||||
route = service_routes[service.name]
|
||||
if route['target'] != service.server.host:
|
||||
self.log.warning(
|
||||
"Updating route for %s (%s → %s)",
|
||||
route['routespec'], route['target'], service.server.host,
|
||||
route['routespec'],
|
||||
route['target'],
|
||||
service.server.host,
|
||||
)
|
||||
futures.append(self.add_service(service))
|
||||
|
||||
@@ -424,7 +448,7 @@ class ConfigurableHTTPProxy(Proxy):
|
||||
help="""The Proxy auth token
|
||||
|
||||
Loaded from the CONFIGPROXY_AUTH_TOKEN env variable by default.
|
||||
""",
|
||||
"""
|
||||
).tag(config=True)
|
||||
check_running_interval = Integer(5, config=True)
|
||||
|
||||
@@ -437,8 +461,8 @@ class ConfigurableHTTPProxy(Proxy):
|
||||
token = utils.new_token()
|
||||
return token
|
||||
|
||||
api_url = Unicode(config=True,
|
||||
help="""The ip (or hostname) of the proxy's API endpoint"""
|
||||
api_url = Unicode(
|
||||
config=True, help="""The ip (or hostname) of the proxy's API endpoint"""
|
||||
)
|
||||
|
||||
@default('api_url')
|
||||
@@ -448,13 +472,12 @@ class ConfigurableHTTPProxy(Proxy):
|
||||
if self.app.internal_ssl:
|
||||
proto = 'https'
|
||||
|
||||
return "{proto}://{url}".format(
|
||||
proto=proto,
|
||||
url=url,
|
||||
)
|
||||
return "{proto}://{url}".format(proto=proto, url=url)
|
||||
|
||||
command = Command('configurable-http-proxy', config=True,
|
||||
help="""The command to start the proxy"""
|
||||
command = Command(
|
||||
'configurable-http-proxy',
|
||||
config=True,
|
||||
help="""The command to start the proxy""",
|
||||
)
|
||||
|
||||
pid_file = Unicode(
|
||||
@@ -463,11 +486,14 @@ class ConfigurableHTTPProxy(Proxy):
|
||||
help="File in which to write the PID of the proxy process.",
|
||||
)
|
||||
|
||||
_check_running_callback = Any(help="PeriodicCallback to check if the proxy is running")
|
||||
_check_running_callback = Any(
|
||||
help="PeriodicCallback to check if the proxy is running"
|
||||
)
|
||||
|
||||
def _check_pid(self, pid):
|
||||
if os.name == 'nt':
|
||||
import psutil
|
||||
|
||||
if not psutil.pid_exists(pid):
|
||||
raise ProcessLookupError
|
||||
else:
|
||||
@@ -558,11 +584,16 @@ class ConfigurableHTTPProxy(Proxy):
|
||||
env = os.environ.copy()
|
||||
env['CONFIGPROXY_AUTH_TOKEN'] = self.auth_token
|
||||
cmd = self.command + [
|
||||
'--ip', public_server.ip,
|
||||
'--port', str(public_server.port),
|
||||
'--api-ip', api_server.ip,
|
||||
'--api-port', str(api_server.port),
|
||||
'--error-target', url_path_join(self.hub.url, 'error'),
|
||||
'--ip',
|
||||
public_server.ip,
|
||||
'--port',
|
||||
str(public_server.port),
|
||||
'--api-ip',
|
||||
api_server.ip,
|
||||
'--api-port',
|
||||
str(api_server.port),
|
||||
'--error-target',
|
||||
url_path_join(self.hub.url, 'error'),
|
||||
]
|
||||
if self.app.subdomain_host:
|
||||
cmd.append('--host-routing')
|
||||
@@ -595,28 +626,36 @@ class ConfigurableHTTPProxy(Proxy):
|
||||
cmd.extend(['--client-ssl-request-cert'])
|
||||
cmd.extend(['--client-ssl-reject-unauthorized'])
|
||||
if self.app.statsd_host:
|
||||
cmd.extend([
|
||||
'--statsd-host', self.app.statsd_host,
|
||||
'--statsd-port', str(self.app.statsd_port),
|
||||
'--statsd-prefix', self.app.statsd_prefix + '.chp'
|
||||
])
|
||||
cmd.extend(
|
||||
[
|
||||
'--statsd-host',
|
||||
self.app.statsd_host,
|
||||
'--statsd-port',
|
||||
str(self.app.statsd_port),
|
||||
'--statsd-prefix',
|
||||
self.app.statsd_prefix + '.chp',
|
||||
]
|
||||
)
|
||||
# Warn if SSL is not used
|
||||
if ' --ssl' not in ' '.join(cmd):
|
||||
self.log.warning("Running JupyterHub without SSL."
|
||||
" I hope there is SSL termination happening somewhere else...")
|
||||
self.log.warning(
|
||||
"Running JupyterHub without SSL."
|
||||
" I hope there is SSL termination happening somewhere else..."
|
||||
)
|
||||
self.log.info("Starting proxy @ %s", public_server.bind_url)
|
||||
self.log.debug("Proxy cmd: %s", cmd)
|
||||
shell = os.name == 'nt'
|
||||
try:
|
||||
self.proxy_process = Popen(cmd, env=env, start_new_session=True, shell=shell)
|
||||
self.proxy_process = Popen(
|
||||
cmd, env=env, start_new_session=True, shell=shell
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
self.log.error(
|
||||
"Failed to find proxy %r\n"
|
||||
"The proxy can be installed with `npm install -g configurable-http-proxy`."
|
||||
"To install `npm`, install nodejs which includes `npm`."
|
||||
"If you see an `EACCES` error or permissions error, refer to the `npm` "
|
||||
"documentation on How To Prevent Permissions Errors."
|
||||
% self.command
|
||||
"documentation on How To Prevent Permissions Errors." % self.command
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -625,8 +664,7 @@ class ConfigurableHTTPProxy(Proxy):
|
||||
def _check_process():
|
||||
status = self.proxy_process.poll()
|
||||
if status is not None:
|
||||
e = RuntimeError(
|
||||
"Proxy failed to start with exit code %i" % status)
|
||||
e = RuntimeError("Proxy failed to start with exit code %i" % status)
|
||||
raise e from None
|
||||
|
||||
for server in (public_server, api_server):
|
||||
@@ -678,8 +716,9 @@ class ConfigurableHTTPProxy(Proxy):
|
||||
"""Check if the proxy is still running"""
|
||||
if self.proxy_process.poll() is None:
|
||||
return
|
||||
self.log.error("Proxy stopped with exit code %r",
|
||||
'unknown' if self.proxy_process is None else self.proxy_process.poll()
|
||||
self.log.error(
|
||||
"Proxy stopped with exit code %r",
|
||||
'unknown' if self.proxy_process is None else self.proxy_process.poll(),
|
||||
)
|
||||
self._remove_pid_file()
|
||||
await self.start()
|
||||
@@ -724,10 +763,10 @@ class ConfigurableHTTPProxy(Proxy):
|
||||
if isinstance(body, dict):
|
||||
body = json.dumps(body)
|
||||
self.log.debug("Proxy: Fetching %s %s", method, url)
|
||||
req = HTTPRequest(url,
|
||||
req = HTTPRequest(
|
||||
url,
|
||||
method=method,
|
||||
headers={'Authorization': 'token {}'.format(
|
||||
self.auth_token)},
|
||||
headers={'Authorization': 'token {}'.format(self.auth_token)},
|
||||
body=body,
|
||||
)
|
||||
async with self.semaphore:
|
||||
@@ -739,11 +778,7 @@ class ConfigurableHTTPProxy(Proxy):
|
||||
body['target'] = target
|
||||
body['jupyterhub'] = True
|
||||
path = self._routespec_to_chp_path(routespec)
|
||||
await self.api_request(
|
||||
path,
|
||||
method='POST',
|
||||
body=body,
|
||||
)
|
||||
await self.api_request(path, method='POST', body=body)
|
||||
|
||||
async def delete_route(self, routespec):
|
||||
path = self._routespec_to_chp_path(routespec)
|
||||
@@ -762,11 +797,7 @@ class ConfigurableHTTPProxy(Proxy):
|
||||
"""Reformat CHP data format to JupyterHub's proxy API."""
|
||||
target = chp_data.pop('target')
|
||||
chp_data.pop('jupyterhub')
|
||||
return {
|
||||
'routespec': routespec,
|
||||
'target': target,
|
||||
'data': chp_data,
|
||||
}
|
||||
return {'routespec': routespec, 'target': target, 'data': chp_data}
|
||||
|
||||
async def get_all_routes(self, client=None):
|
||||
"""Fetch the proxy's routes."""
|
||||
@@ -779,6 +810,5 @@ class ConfigurableHTTPProxy(Proxy):
|
||||
# exclude routes not associated with JupyterHub
|
||||
self.log.debug("Omitting non-jupyterhub route %r", routespec)
|
||||
continue
|
||||
all_routes[routespec] = self._reformat_routespec(
|
||||
routespec, chp_data)
|
||||
all_routes[routespec] = self._reformat_routespec(routespec, chp_data)
|
||||
return all_routes
|
||||
|
@@ -9,7 +9,6 @@ model describing the authenticated user.
|
||||
authenticate with the Hub.
|
||||
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
@@ -18,22 +17,25 @@ import re
|
||||
import socket
|
||||
import string
|
||||
import time
|
||||
from urllib.parse import quote, urlencode
|
||||
import uuid
|
||||
import warnings
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
|
||||
from tornado.gen import coroutine
|
||||
from tornado.log import app_log
|
||||
from tornado.httputil import url_concat
|
||||
from tornado.web import HTTPError, RequestHandler
|
||||
|
||||
from tornado.log import app_log
|
||||
from tornado.web import HTTPError
|
||||
from tornado.web import RequestHandler
|
||||
from traitlets import default
|
||||
from traitlets import Dict
|
||||
from traitlets import Instance
|
||||
from traitlets import Integer
|
||||
from traitlets import observe
|
||||
from traitlets import Unicode
|
||||
from traitlets import validate
|
||||
from traitlets.config import SingletonConfigurable
|
||||
from traitlets import (
|
||||
Unicode, Integer, Instance, Dict,
|
||||
default, observe, validate,
|
||||
)
|
||||
|
||||
from ..utils import url_path_join
|
||||
|
||||
@@ -63,13 +65,14 @@ class _ExpiringDict(dict):
|
||||
def __repr__(self):
|
||||
"""include values and timestamps in repr"""
|
||||
now = time.monotonic()
|
||||
return repr({
|
||||
return repr(
|
||||
{
|
||||
key: '{value} (age={age:.0f}s)'.format(
|
||||
value=repr(value)[:16] + '...',
|
||||
age=now-self.timestamps[key],
|
||||
value=repr(value)[:16] + '...', age=now - self.timestamps[key]
|
||||
)
|
||||
for key, value in self.values.items()
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
def _check_age(self, key):
|
||||
"""Check timestamp for a key"""
|
||||
@@ -131,24 +134,28 @@ class HubAuth(SingletonConfigurable):
|
||||
|
||||
"""
|
||||
|
||||
hub_host = Unicode('',
|
||||
hub_host = Unicode(
|
||||
'',
|
||||
help="""The public host of JupyterHub
|
||||
|
||||
Only used if JupyterHub is spreading servers across subdomains.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
@default('hub_host')
|
||||
def _default_hub_host(self):
|
||||
return os.getenv('JUPYTERHUB_HOST', '')
|
||||
|
||||
base_url = Unicode(os.getenv('JUPYTERHUB_SERVICE_PREFIX') or '/',
|
||||
base_url = Unicode(
|
||||
os.getenv('JUPYTERHUB_SERVICE_PREFIX') or '/',
|
||||
help="""The base URL prefix of this application
|
||||
|
||||
e.g. /services/service-name/ or /user/name/
|
||||
|
||||
Default: get from JUPYTERHUB_SERVICE_PREFIX
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
@validate('base_url')
|
||||
def _add_slash(self, proposal):
|
||||
"""Ensure base_url starts and ends with /"""
|
||||
@@ -160,12 +167,14 @@ class HubAuth(SingletonConfigurable):
|
||||
return value
|
||||
|
||||
# where is the hub
|
||||
api_url = Unicode(os.getenv('JUPYTERHUB_API_URL') or 'http://127.0.0.1:8081/hub/api',
|
||||
api_url = Unicode(
|
||||
os.getenv('JUPYTERHUB_API_URL') or 'http://127.0.0.1:8081/hub/api',
|
||||
help="""The base API URL of the Hub.
|
||||
|
||||
Typically `http://hub-ip:hub-port/hub/api`
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
@default('api_url')
|
||||
def _api_url(self):
|
||||
env_url = os.getenv('JUPYTERHUB_API_URL')
|
||||
@@ -174,56 +183,64 @@ class HubAuth(SingletonConfigurable):
|
||||
else:
|
||||
return 'http://127.0.0.1:8081' + url_path_join(self.hub_prefix, 'api')
|
||||
|
||||
api_token = Unicode(os.getenv('JUPYTERHUB_API_TOKEN', ''),
|
||||
api_token = Unicode(
|
||||
os.getenv('JUPYTERHUB_API_TOKEN', ''),
|
||||
help="""API key for accessing Hub API.
|
||||
|
||||
Generate with `jupyterhub token [username]` or add to JupyterHub.services config.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
hub_prefix = Unicode('/hub/',
|
||||
hub_prefix = Unicode(
|
||||
'/hub/',
|
||||
help="""The URL prefix for the Hub itself.
|
||||
|
||||
Typically /hub/
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
@default('hub_prefix')
|
||||
def _default_hub_prefix(self):
|
||||
return url_path_join(os.getenv('JUPYTERHUB_BASE_URL') or '/', 'hub') + '/'
|
||||
|
||||
login_url = Unicode('/hub/login',
|
||||
login_url = Unicode(
|
||||
'/hub/login',
|
||||
help="""The login URL to use
|
||||
|
||||
Typically /hub/login
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
@default('login_url')
|
||||
def _default_login_url(self):
|
||||
return self.hub_host + url_path_join(self.hub_prefix, 'login')
|
||||
|
||||
keyfile = Unicode('',
|
||||
keyfile = Unicode(
|
||||
'',
|
||||
help="""The ssl key to use for requests
|
||||
|
||||
Use with certfile
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
certfile = Unicode('',
|
||||
certfile = Unicode(
|
||||
'',
|
||||
help="""The ssl cert to use for requests
|
||||
|
||||
Use with keyfile
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
client_ca = Unicode('',
|
||||
client_ca = Unicode(
|
||||
'',
|
||||
help="""The ssl certificate authority to use to verify requests
|
||||
|
||||
Use with keyfile and certfile
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
cookie_name = Unicode('jupyterhub-services',
|
||||
help="""The name of the cookie I should be looking for"""
|
||||
cookie_name = Unicode(
|
||||
'jupyterhub-services', help="""The name of the cookie I should be looking for"""
|
||||
).tag(config=True)
|
||||
|
||||
cookie_options = Dict(
|
||||
@@ -245,21 +262,26 @@ class HubAuth(SingletonConfigurable):
|
||||
return {}
|
||||
|
||||
cookie_cache_max_age = Integer(help="DEPRECATED. Use cache_max_age")
|
||||
|
||||
@observe('cookie_cache_max_age')
|
||||
def _deprecated_cookie_cache(self, change):
|
||||
warnings.warn("cookie_cache_max_age is deprecated in JupyterHub 0.8. Use cache_max_age instead.")
|
||||
warnings.warn(
|
||||
"cookie_cache_max_age is deprecated in JupyterHub 0.8. Use cache_max_age instead."
|
||||
)
|
||||
self.cache_max_age = change.new
|
||||
|
||||
cache_max_age = Integer(300,
|
||||
cache_max_age = Integer(
|
||||
300,
|
||||
help="""The maximum time (in seconds) to cache the Hub's responses for authentication.
|
||||
|
||||
A larger value reduces load on the Hub and occasional response lag.
|
||||
A smaller value reduces propagation time of changes on the Hub (rare).
|
||||
|
||||
Default: 300 (five minutes)
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
cache = Instance(_ExpiringDict, allow_none=False)
|
||||
|
||||
@default('cache')
|
||||
def _default_cache(self):
|
||||
return _ExpiringDict(self.cache_max_age)
|
||||
@@ -311,25 +333,42 @@ class HubAuth(SingletonConfigurable):
|
||||
except requests.ConnectionError as e:
|
||||
app_log.error("Error connecting to %s: %s", self.api_url, e)
|
||||
msg = "Failed to connect to Hub API at %r." % self.api_url
|
||||
msg += " Is the Hub accessible at this URL (from host: %s)?" % socket.gethostname()
|
||||
msg += (
|
||||
" Is the Hub accessible at this URL (from host: %s)?"
|
||||
% socket.gethostname()
|
||||
)
|
||||
if '127.0.0.1' in self.api_url:
|
||||
msg += " Make sure to set c.JupyterHub.hub_ip to an IP accessible to" + \
|
||||
" single-user servers if the servers are not on the same host as the Hub."
|
||||
msg += (
|
||||
" Make sure to set c.JupyterHub.hub_ip to an IP accessible to"
|
||||
+ " single-user servers if the servers are not on the same host as the Hub."
|
||||
)
|
||||
raise HTTPError(500, msg)
|
||||
|
||||
data = None
|
||||
if r.status_code == 404 and allow_404:
|
||||
pass
|
||||
elif r.status_code == 403:
|
||||
app_log.error("I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s", r.status_code, r.reason)
|
||||
app_log.error(
|
||||
"I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s",
|
||||
r.status_code,
|
||||
r.reason,
|
||||
)
|
||||
app_log.error(r.text)
|
||||
raise HTTPError(500, "Permission failure checking authorization, I may need a new token")
|
||||
raise HTTPError(
|
||||
500, "Permission failure checking authorization, I may need a new token"
|
||||
)
|
||||
elif r.status_code >= 500:
|
||||
app_log.error("Upstream failure verifying auth token: [%i] %s", r.status_code, r.reason)
|
||||
app_log.error(
|
||||
"Upstream failure verifying auth token: [%i] %s",
|
||||
r.status_code,
|
||||
r.reason,
|
||||
)
|
||||
app_log.error(r.text)
|
||||
raise HTTPError(502, "Failed to check authorization (upstream problem)")
|
||||
elif r.status_code >= 400:
|
||||
app_log.warning("Failed to check authorization: [%i] %s", r.status_code, r.reason)
|
||||
app_log.warning(
|
||||
"Failed to check authorization: [%i] %s", r.status_code, r.reason
|
||||
)
|
||||
app_log.warning(r.text)
|
||||
msg = "Failed to check authorization"
|
||||
# pass on error_description from oauth failure
|
||||
@@ -358,10 +397,12 @@ class HubAuth(SingletonConfigurable):
|
||||
The 'name' field contains the user's name.
|
||||
"""
|
||||
return self._check_hub_authorization(
|
||||
url=url_path_join(self.api_url,
|
||||
url=url_path_join(
|
||||
self.api_url,
|
||||
"authorizations/cookie",
|
||||
self.cookie_name,
|
||||
quote(encrypted_cookie, safe='')),
|
||||
quote(encrypted_cookie, safe=''),
|
||||
),
|
||||
cache_key='cookie:{}:{}'.format(session_id, encrypted_cookie),
|
||||
use_cache=use_cache,
|
||||
)
|
||||
@@ -379,9 +420,9 @@ class HubAuth(SingletonConfigurable):
|
||||
The 'name' field contains the user's name.
|
||||
"""
|
||||
return self._check_hub_authorization(
|
||||
url=url_path_join(self.api_url,
|
||||
"authorizations/token",
|
||||
quote(token, safe='')),
|
||||
url=url_path_join(
|
||||
self.api_url, "authorizations/token", quote(token, safe='')
|
||||
),
|
||||
cache_key='token:{}:{}'.format(session_id, token),
|
||||
use_cache=use_cache,
|
||||
)
|
||||
@@ -399,7 +440,9 @@ class HubAuth(SingletonConfigurable):
|
||||
user_token = handler.get_argument('token', '')
|
||||
if not user_token:
|
||||
# get it from Authorization header
|
||||
m = self.auth_header_pat.match(handler.request.headers.get(self.auth_header_name, ''))
|
||||
m = self.auth_header_pat.match(
|
||||
handler.request.headers.get(self.auth_header_name, '')
|
||||
)
|
||||
if m:
|
||||
user_token = m.group(1)
|
||||
return user_token
|
||||
@@ -469,11 +512,14 @@ class HubOAuth(HubAuth):
|
||||
|
||||
@default('login_url')
|
||||
def _login_url(self):
|
||||
return url_concat(self.oauth_authorization_url, {
|
||||
return url_concat(
|
||||
self.oauth_authorization_url,
|
||||
{
|
||||
'client_id': self.oauth_client_id,
|
||||
'redirect_uri': self.oauth_redirect_uri,
|
||||
'response_type': 'code',
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def cookie_name(self):
|
||||
@@ -511,6 +557,7 @@ class HubOAuth(HubAuth):
|
||||
Use JUPYTERHUB_CLIENT_ID by default.
|
||||
"""
|
||||
).tag(config=True)
|
||||
|
||||
@default('oauth_client_id')
|
||||
def _client_id(self):
|
||||
return os.getenv('JUPYTERHUB_CLIENT_ID', '')
|
||||
@@ -527,13 +574,18 @@ class HubOAuth(HubAuth):
|
||||
Should generally be /base_url/oauth_callback
|
||||
"""
|
||||
).tag(config=True)
|
||||
|
||||
@default('oauth_redirect_uri')
|
||||
def _default_redirect(self):
|
||||
return os.getenv('JUPYTERHUB_OAUTH_CALLBACK_URL') or url_path_join(self.base_url, 'oauth_callback')
|
||||
return os.getenv('JUPYTERHUB_OAUTH_CALLBACK_URL') or url_path_join(
|
||||
self.base_url, 'oauth_callback'
|
||||
)
|
||||
|
||||
oauth_authorization_url = Unicode('/hub/api/oauth2/authorize',
|
||||
oauth_authorization_url = Unicode(
|
||||
'/hub/api/oauth2/authorize',
|
||||
help="The URL to redirect to when starting the OAuth process",
|
||||
).tag(config=True)
|
||||
|
||||
@default('oauth_authorization_url')
|
||||
def _auth_url(self):
|
||||
return self.hub_host + url_path_join(self.hub_prefix, 'api/oauth2/authorize')
|
||||
@@ -541,6 +593,7 @@ class HubOAuth(HubAuth):
|
||||
oauth_token_url = Unicode(
|
||||
help="""The URL for requesting an OAuth token from JupyterHub"""
|
||||
).tag(config=True)
|
||||
|
||||
@default('oauth_token_url')
|
||||
def _token_url(self):
|
||||
return url_path_join(self.api_url, 'oauth2/token')
|
||||
@@ -565,11 +618,12 @@ class HubOAuth(HubAuth):
|
||||
redirect_uri=self.oauth_redirect_uri,
|
||||
)
|
||||
|
||||
token_reply = self._api_request('POST', self.oauth_token_url,
|
||||
token_reply = self._api_request(
|
||||
'POST',
|
||||
self.oauth_token_url,
|
||||
data=urlencode(params).encode('utf8'),
|
||||
headers={
|
||||
'Content-Type': 'application/x-www-form-urlencoded'
|
||||
})
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'},
|
||||
)
|
||||
|
||||
return token_reply['access_token']
|
||||
|
||||
@@ -577,9 +631,11 @@ class HubOAuth(HubAuth):
|
||||
"""Encode a state dict as url-safe base64"""
|
||||
# trim trailing `=` because = is itself not url-safe!
|
||||
json_state = json.dumps(state)
|
||||
return base64.urlsafe_b64encode(
|
||||
json_state.encode('utf8')
|
||||
).decode('ascii').rstrip('=')
|
||||
return (
|
||||
base64.urlsafe_b64encode(json_state.encode('utf8'))
|
||||
.decode('ascii')
|
||||
.rstrip('=')
|
||||
)
|
||||
|
||||
def _decode_state(self, b64_state):
|
||||
"""Decode a base64 state
|
||||
@@ -621,7 +677,9 @@ class HubOAuth(HubAuth):
|
||||
# use a randomized cookie suffix to avoid collisions
|
||||
# in case of concurrent logins
|
||||
app_log.warning("Detected unused OAuth state cookies")
|
||||
cookie_suffix = ''.join(random.choice(string.ascii_letters) for i in range(8))
|
||||
cookie_suffix = ''.join(
|
||||
random.choice(string.ascii_letters) for i in range(8)
|
||||
)
|
||||
cookie_name = '{}-{}'.format(self.state_cookie_name, cookie_suffix)
|
||||
extra_state['cookie_name'] = cookie_name
|
||||
else:
|
||||
@@ -640,11 +698,7 @@ class HubOAuth(HubAuth):
|
||||
kwargs['secure'] = True
|
||||
# load user cookie overrides
|
||||
kwargs.update(self.cookie_options)
|
||||
handler.set_secure_cookie(
|
||||
cookie_name,
|
||||
b64_state,
|
||||
**kwargs
|
||||
)
|
||||
handler.set_secure_cookie(cookie_name, b64_state, **kwargs)
|
||||
return b64_state
|
||||
|
||||
def generate_state(self, next_url=None, **extra_state):
|
||||
@@ -658,10 +712,7 @@ class HubOAuth(HubAuth):
|
||||
-------
|
||||
state (str): The base64-encoded state string.
|
||||
"""
|
||||
state = {
|
||||
'uuid': uuid.uuid4().hex,
|
||||
'next_url': next_url,
|
||||
}
|
||||
state = {'uuid': uuid.uuid4().hex, 'next_url': next_url}
|
||||
state.update(extra_state)
|
||||
return self._encode_state(state)
|
||||
|
||||
@@ -681,21 +732,19 @@ class HubOAuth(HubAuth):
|
||||
|
||||
def set_cookie(self, handler, access_token):
|
||||
"""Set a cookie recording OAuth result"""
|
||||
kwargs = {
|
||||
'path': self.base_url,
|
||||
'httponly': True,
|
||||
}
|
||||
kwargs = {'path': self.base_url, 'httponly': True}
|
||||
if handler.request.protocol == 'https':
|
||||
kwargs['secure'] = True
|
||||
# load user cookie overrides
|
||||
kwargs.update(self.cookie_options)
|
||||
app_log.debug("Setting oauth cookie for %s: %s, %s",
|
||||
handler.request.remote_ip, self.cookie_name, kwargs)
|
||||
handler.set_secure_cookie(
|
||||
app_log.debug(
|
||||
"Setting oauth cookie for %s: %s, %s",
|
||||
handler.request.remote_ip,
|
||||
self.cookie_name,
|
||||
access_token,
|
||||
**kwargs
|
||||
kwargs,
|
||||
)
|
||||
handler.set_secure_cookie(self.cookie_name, access_token, **kwargs)
|
||||
|
||||
def clear_cookie(self, handler):
|
||||
"""Clear the OAuth cookie"""
|
||||
handler.clear_cookie(self.cookie_name, path=self.base_url)
|
||||
@@ -703,6 +752,7 @@ class HubOAuth(HubAuth):
|
||||
|
||||
class UserNotAllowed(Exception):
|
||||
"""Exception raised when a user is identified and not allowed"""
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
@@ -738,6 +788,7 @@ class HubAuthenticated(object):
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
hub_services = None # set of allowed services
|
||||
hub_users = None # set of allowed users
|
||||
hub_groups = None # set of allowed groups
|
||||
@@ -748,9 +799,11 @@ class HubAuthenticated(object):
|
||||
"""Property indicating that all successfully identified user
|
||||
or service should be allowed.
|
||||
"""
|
||||
return (self.hub_services is None
|
||||
return (
|
||||
self.hub_services is None
|
||||
and self.hub_users is None
|
||||
and self.hub_groups is None)
|
||||
and self.hub_groups is None
|
||||
)
|
||||
|
||||
# self.hub_auth must be a HubAuth instance.
|
||||
# If nothing specified, use default config,
|
||||
@@ -758,6 +811,7 @@ class HubAuthenticated(object):
|
||||
# based on JupyterHub environment variables for services.
|
||||
_hub_auth = None
|
||||
hub_auth_class = HubAuth
|
||||
|
||||
@property
|
||||
def hub_auth(self):
|
||||
if self._hub_auth is None:
|
||||
@@ -794,7 +848,9 @@ class HubAuthenticated(object):
|
||||
name = model['name']
|
||||
kind = model.setdefault('kind', 'user')
|
||||
if self.allow_all:
|
||||
app_log.debug("Allowing Hub %s %s (all Hub users and services allowed)", kind, name)
|
||||
app_log.debug(
|
||||
"Allowing Hub %s %s (all Hub users and services allowed)", kind, name
|
||||
)
|
||||
return model
|
||||
|
||||
if self.allow_admin and model.get('admin', False):
|
||||
@@ -816,7 +872,11 @@ class HubAuthenticated(object):
|
||||
return model
|
||||
elif self.hub_groups and set(model['groups']).intersection(self.hub_groups):
|
||||
allowed_groups = set(model['groups']).intersection(self.hub_groups)
|
||||
app_log.debug("Allowing Hub user %s in group(s) %s", name, ','.join(sorted(allowed_groups)))
|
||||
app_log.debug(
|
||||
"Allowing Hub user %s in group(s) %s",
|
||||
name,
|
||||
','.join(sorted(allowed_groups)),
|
||||
)
|
||||
# group in whitelist
|
||||
return model
|
||||
else:
|
||||
@@ -845,7 +905,10 @@ class HubAuthenticated(object):
|
||||
# This is not the best, but avoids problems that can be caused
|
||||
# when get_current_user is allowed to raise.
|
||||
def raise_on_redirect(*args, **kwargs):
|
||||
raise HTTPError(403, "{kind} {name} is not allowed.".format(**user_model))
|
||||
raise HTTPError(
|
||||
403, "{kind} {name} is not allowed.".format(**user_model)
|
||||
)
|
||||
|
||||
self.redirect = raise_on_redirect
|
||||
return
|
||||
except Exception:
|
||||
@@ -869,6 +932,7 @@ class HubAuthenticated(object):
|
||||
|
||||
class HubOAuthenticated(HubAuthenticated):
|
||||
"""Simple subclass of HubAuthenticated using OAuth instead of old shared cookies"""
|
||||
|
||||
hub_auth_class = HubOAuth
|
||||
|
||||
|
||||
@@ -917,5 +981,3 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
|
||||
app_log.info("Logged-in user %s", user_model)
|
||||
self.hub_auth.set_cookie(self, token)
|
||||
self.redirect(next_url or self.hub_auth.base_url)
|
||||
|
||||
|
||||
|
@@ -38,24 +38,26 @@ A hub-managed service with no URL::
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import pipes
|
||||
import shutil
|
||||
import os
|
||||
from subprocess import Popen
|
||||
|
||||
from traitlets import (
|
||||
HasTraits,
|
||||
Any, Bool, Dict, Unicode, Instance,
|
||||
default,
|
||||
)
|
||||
from traitlets import Any
|
||||
from traitlets import Bool
|
||||
from traitlets import default
|
||||
from traitlets import Dict
|
||||
from traitlets import HasTraits
|
||||
from traitlets import Instance
|
||||
from traitlets import Unicode
|
||||
from traitlets.config import LoggingConfigurable
|
||||
|
||||
from .. import orm
|
||||
from ..objects import Server
|
||||
from ..spawner import LocalProcessSpawner
|
||||
from ..spawner import set_user_setuid
|
||||
from ..traitlets import Command
|
||||
from ..spawner import LocalProcessSpawner, set_user_setuid
|
||||
from ..utils import url_path_join
|
||||
|
||||
|
||||
@@ -81,14 +83,17 @@ class _MockUser(HasTraits):
|
||||
return ''
|
||||
return self.server.base_url
|
||||
|
||||
|
||||
# We probably shouldn't use a Spawner here,
|
||||
# but there are too many concepts to share.
|
||||
|
||||
|
||||
class _ServiceSpawner(LocalProcessSpawner):
|
||||
"""Subclass of LocalProcessSpawner
|
||||
|
||||
Removes notebook-specific-ness from LocalProcessSpawner.
|
||||
"""
|
||||
|
||||
cwd = Unicode()
|
||||
cmd = Command(minlen=0)
|
||||
|
||||
@@ -115,7 +120,9 @@ class _ServiceSpawner(LocalProcessSpawner):
|
||||
|
||||
self.log.info("Spawning %s", ' '.join(pipes.quote(s) for s in cmd))
|
||||
try:
|
||||
self.proc = Popen(self.cmd, env=env,
|
||||
self.proc = Popen(
|
||||
self.cmd,
|
||||
env=env,
|
||||
preexec_fn=self.make_preexec_fn(self.user.name),
|
||||
start_new_session=True, # don't forward signals
|
||||
cwd=self.cwd or None,
|
||||
@@ -123,8 +130,10 @@ class _ServiceSpawner(LocalProcessSpawner):
|
||||
except PermissionError:
|
||||
# use which to get abspath
|
||||
script = shutil.which(cmd[0]) or cmd[0]
|
||||
self.log.error("Permission denied trying to run %r. Does %s have access to this file?",
|
||||
script, self.user.name,
|
||||
self.log.error(
|
||||
"Permission denied trying to run %r. Does %s have access to this file?",
|
||||
script,
|
||||
self.user.name,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -165,9 +174,9 @@ class Service(LoggingConfigurable):
|
||||
If the service has an http endpoint, it
|
||||
"""
|
||||
).tag(input=True)
|
||||
admin = Bool(False,
|
||||
help="Does the service need admin-access to the Hub API?"
|
||||
).tag(input=True)
|
||||
admin = Bool(False, help="Does the service need admin-access to the Hub API?").tag(
|
||||
input=True
|
||||
)
|
||||
url = Unicode(
|
||||
help="""URL of the service.
|
||||
|
||||
@@ -205,22 +214,23 @@ class Service(LoggingConfigurable):
|
||||
"""
|
||||
return 'managed' if self.managed else 'external'
|
||||
|
||||
command = Command(minlen=0,
|
||||
help="Command to spawn this service, if managed."
|
||||
).tag(input=True)
|
||||
cwd = Unicode(
|
||||
help="""The working directory in which to run the service."""
|
||||
).tag(input=True)
|
||||
command = Command(minlen=0, help="Command to spawn this service, if managed.").tag(
|
||||
input=True
|
||||
)
|
||||
cwd = Unicode(help="""The working directory in which to run the service.""").tag(
|
||||
input=True
|
||||
)
|
||||
environment = Dict(
|
||||
help="""Environment variables to pass to the service.
|
||||
Only used if the Hub is spawning the service.
|
||||
"""
|
||||
).tag(input=True)
|
||||
user = Unicode("",
|
||||
user = Unicode(
|
||||
"",
|
||||
help="""The user to become when launching the service.
|
||||
|
||||
If unspecified, run the service as the same user as the Hub.
|
||||
"""
|
||||
""",
|
||||
).tag(input=True)
|
||||
|
||||
domain = Unicode()
|
||||
@@ -245,6 +255,7 @@ class Service(LoggingConfigurable):
|
||||
Default: `service-<name>`
|
||||
"""
|
||||
).tag(input=True)
|
||||
|
||||
@default('oauth_client_id')
|
||||
def _default_client_id(self):
|
||||
return 'service-%s' % self.name
|
||||
@@ -256,6 +267,7 @@ class Service(LoggingConfigurable):
|
||||
Default: `/services/:name/oauth_callback`
|
||||
"""
|
||||
).tag(input=True)
|
||||
|
||||
@default('oauth_redirect_uri')
|
||||
def _default_redirect_uri(self):
|
||||
if self.server is None:
|
||||
@@ -328,10 +340,7 @@ class Service(LoggingConfigurable):
|
||||
cwd=self.cwd,
|
||||
hub=self.hub,
|
||||
user=_MockUser(
|
||||
name=self.user,
|
||||
service=self,
|
||||
server=self.orm.server,
|
||||
host=self.host,
|
||||
name=self.user, service=self, server=self.orm.server, host=self.host
|
||||
),
|
||||
internal_ssl=self.app.internal_ssl,
|
||||
internal_certs_location=self.app.internal_certs_location,
|
||||
@@ -344,7 +353,9 @@ class Service(LoggingConfigurable):
|
||||
|
||||
def _proc_stopped(self):
|
||||
"""Called when the service process unexpectedly exits"""
|
||||
self.log.error("Service %s exited with status %i", self.name, self.proc.returncode)
|
||||
self.log.error(
|
||||
"Service %s exited with status %i", self.name, self.proc.returncode
|
||||
)
|
||||
self.start()
|
||||
|
||||
async def stop(self):
|
||||
@@ -357,4 +368,4 @@ class Service(LoggingConfigurable):
|
||||
self.db.delete(self.orm.server)
|
||||
self.db.commit()
|
||||
self.spawner.stop_polling()
|
||||
return (await self.spawner.stop())
|
||||
return await self.spawner.stop()
|
||||
|
@@ -1,23 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
"""Extend regular notebook server to be aware of multiuser things."""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from textwrap import dedent
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from jinja2 import ChoiceLoader, FunctionLoader
|
||||
|
||||
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
|
||||
from jinja2 import ChoiceLoader
|
||||
from jinja2 import FunctionLoader
|
||||
from tornado import gen
|
||||
from tornado import ioloop
|
||||
from tornado.web import HTTPError, RequestHandler
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httpclient import HTTPRequest
|
||||
from tornado.web import HTTPError
|
||||
from tornado.web import RequestHandler
|
||||
|
||||
try:
|
||||
import notebook
|
||||
@@ -60,7 +61,9 @@ class HubAuthenticatedHandler(HubOAuthenticated):
|
||||
|
||||
@property
|
||||
def allow_admin(self):
|
||||
return self.settings.get('admin_access', os.getenv('JUPYTERHUB_ADMIN_ACCESS') or False)
|
||||
return self.settings.get(
|
||||
'admin_access', os.getenv('JUPYTERHUB_ADMIN_ACCESS') or False
|
||||
)
|
||||
|
||||
@property
|
||||
def hub_auth(self):
|
||||
@@ -79,6 +82,7 @@ class HubAuthenticatedHandler(HubOAuthenticated):
|
||||
|
||||
class JupyterHubLoginHandler(LoginHandler):
|
||||
"""LoginHandler that hooks up Hub authentication"""
|
||||
|
||||
@staticmethod
|
||||
def login_available(settings):
|
||||
return True
|
||||
@@ -113,12 +117,14 @@ class JupyterHubLogoutHandler(LogoutHandler):
|
||||
def get(self):
|
||||
self.settings['hub_auth'].clear_cookie(self)
|
||||
self.redirect(
|
||||
self.settings['hub_host'] +
|
||||
url_path_join(self.settings['hub_prefix'], 'logout'))
|
||||
self.settings['hub_host']
|
||||
+ url_path_join(self.settings['hub_prefix'], 'logout')
|
||||
)
|
||||
|
||||
|
||||
class OAuthCallbackHandler(HubOAuthCallbackHandler, IPythonHandler):
|
||||
"""Mixin IPythonHandler to get the right error pages, etc."""
|
||||
|
||||
@property
|
||||
def hub_auth(self):
|
||||
return self.settings['hub_auth']
|
||||
@@ -126,7 +132,8 @@ class OAuthCallbackHandler(HubOAuthCallbackHandler, IPythonHandler):
|
||||
|
||||
# register new hub related command-line aliases
|
||||
aliases = dict(notebook_aliases)
|
||||
aliases.update({
|
||||
aliases.update(
|
||||
{
|
||||
'user': 'SingleUserNotebookApp.user',
|
||||
'group': 'SingleUserNotebookApp.group',
|
||||
'cookie-name': 'HubAuth.cookie_name',
|
||||
@@ -134,15 +141,17 @@ aliases.update({
|
||||
'hub-host': 'SingleUserNotebookApp.hub_host',
|
||||
'hub-api-url': 'SingleUserNotebookApp.hub_api_url',
|
||||
'base-url': 'SingleUserNotebookApp.base_url',
|
||||
})
|
||||
flags = dict(notebook_flags)
|
||||
flags.update({
|
||||
'disable-user-config': ({
|
||||
'SingleUserNotebookApp': {
|
||||
'disable_user_config': True
|
||||
}
|
||||
}, "Disable user-controlled configuration of the notebook server.")
|
||||
})
|
||||
)
|
||||
flags = dict(notebook_flags)
|
||||
flags.update(
|
||||
{
|
||||
'disable-user-config': (
|
||||
{'SingleUserNotebookApp': {'disable_user_config': True}},
|
||||
"Disable user-controlled configuration of the notebook server.",
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
page_template = """
|
||||
{% extends "templates/page.html" %}
|
||||
@@ -209,11 +218,14 @@ def _exclude_home(path_list):
|
||||
|
||||
class SingleUserNotebookApp(NotebookApp):
|
||||
"""A Subclass of the regular NotebookApp that is aware of the parent multiuser context."""
|
||||
description = dedent("""
|
||||
|
||||
description = dedent(
|
||||
"""
|
||||
Single-user server for JupyterHub. Extends the Jupyter Notebook server.
|
||||
|
||||
Meant to be invoked by JupyterHub Spawners, and not directly.
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
examples = ""
|
||||
subcommands = {}
|
||||
@@ -229,6 +241,7 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
# ensures that each spawn clears any cookies from previous session,
|
||||
# triggering OAuth again
|
||||
cookie_secret = Bytes()
|
||||
|
||||
def _cookie_secret_default(self):
|
||||
return os.urandom(32)
|
||||
|
||||
@@ -320,14 +333,17 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
trust_xheaders = True
|
||||
login_handler_class = JupyterHubLoginHandler
|
||||
logout_handler_class = JupyterHubLogoutHandler
|
||||
port_retries = 0 # disable port-retries, since the Spawner will tell us what port to use
|
||||
port_retries = (
|
||||
0
|
||||
) # disable port-retries, since the Spawner will tell us what port to use
|
||||
|
||||
disable_user_config = Bool(False,
|
||||
disable_user_config = Bool(
|
||||
False,
|
||||
help="""Disable user configuration of single-user server.
|
||||
|
||||
Prevents user-writable files that normally configure the single-user server
|
||||
from being loaded, ensuring admins have full control of configuration.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
@validate('notebook_dir')
|
||||
@@ -394,22 +410,15 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
# create dynamic default http client,
|
||||
# configured with any relevant ssl config
|
||||
hub_http_client = Any()
|
||||
|
||||
@default('hub_http_client')
|
||||
def _default_client(self):
|
||||
ssl_context = make_ssl_context(
|
||||
self.keyfile,
|
||||
self.certfile,
|
||||
cafile=self.client_ca,
|
||||
)
|
||||
AsyncHTTPClient.configure(
|
||||
None,
|
||||
defaults={
|
||||
"ssl_options": ssl_context,
|
||||
},
|
||||
self.keyfile, self.certfile, cafile=self.client_ca
|
||||
)
|
||||
AsyncHTTPClient.configure(None, defaults={"ssl_options": ssl_context})
|
||||
return AsyncHTTPClient()
|
||||
|
||||
|
||||
async def check_hub_version(self):
|
||||
"""Test a connection to my Hub
|
||||
|
||||
@@ -422,8 +431,12 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
try:
|
||||
resp = await client.fetch(self.hub_api_url)
|
||||
except Exception:
|
||||
self.log.exception("Failed to connect to my Hub at %s (attempt %i/%i). Is it running?",
|
||||
self.hub_api_url, i, RETRIES)
|
||||
self.log.exception(
|
||||
"Failed to connect to my Hub at %s (attempt %i/%i). Is it running?",
|
||||
self.hub_api_url,
|
||||
i,
|
||||
RETRIES,
|
||||
)
|
||||
await gen.sleep(min(2 ** i, 16))
|
||||
else:
|
||||
break
|
||||
@@ -434,14 +447,15 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
_check_version(hub_version, __version__, self.log)
|
||||
|
||||
server_name = Unicode()
|
||||
|
||||
@default('server_name')
|
||||
def _server_name_default(self):
|
||||
return os.environ.get('JUPYTERHUB_SERVER_NAME', '')
|
||||
|
||||
hub_activity_url = Unicode(
|
||||
config=True,
|
||||
help="URL for sending JupyterHub activity updates",
|
||||
config=True, help="URL for sending JupyterHub activity updates"
|
||||
)
|
||||
|
||||
@default('hub_activity_url')
|
||||
def _default_activity_url(self):
|
||||
return os.environ.get('JUPYTERHUB_ACTIVITY_URL', '')
|
||||
@@ -452,8 +466,9 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
help="""
|
||||
Interval (in seconds) on which to update the Hub
|
||||
with our latest activity.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
@default('hub_activity_interval')
|
||||
def _default_activity_interval(self):
|
||||
env_value = os.environ.get('JUPYTERHUB_ACTIVITY_INTERVAL')
|
||||
@@ -478,10 +493,7 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
self.log.warning("last activity is using naïve timestamps")
|
||||
last_activity = last_activity.replace(tzinfo=timezone.utc)
|
||||
|
||||
if (
|
||||
self._last_activity_sent
|
||||
and last_activity < self._last_activity_sent
|
||||
):
|
||||
if self._last_activity_sent and last_activity < self._last_activity_sent:
|
||||
self.log.debug("No activity since %s", self._last_activity_sent)
|
||||
return
|
||||
|
||||
@@ -496,14 +508,14 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
"Authorization": "token {}".format(self.hub_auth.api_token),
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body=json.dumps({
|
||||
body=json.dumps(
|
||||
{
|
||||
'servers': {
|
||||
self.server_name: {
|
||||
'last_activity': last_activity_timestamp,
|
||||
},
|
||||
self.server_name: {'last_activity': last_activity_timestamp}
|
||||
},
|
||||
'last_activity': last_activity_timestamp,
|
||||
})
|
||||
}
|
||||
),
|
||||
)
|
||||
try:
|
||||
await client.fetch(req)
|
||||
@@ -526,8 +538,8 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
if not self.hub_activity_url or not self.hub_activity_interval:
|
||||
self.log.warning("Activity events disabled")
|
||||
return
|
||||
self.log.info("Updating Hub with activity every %s seconds",
|
||||
self.hub_activity_interval
|
||||
self.log.info(
|
||||
"Updating Hub with activity every %s seconds", self.hub_activity_interval
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
@@ -561,7 +573,9 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
api_token = os.environ['JUPYTERHUB_API_TOKEN']
|
||||
|
||||
if not api_token:
|
||||
self.exit("JUPYTERHUB_API_TOKEN env is required to run jupyterhub-singleuser. Did you launch it manually?")
|
||||
self.exit(
|
||||
"JUPYTERHUB_API_TOKEN env is required to run jupyterhub-singleuser. Did you launch it manually?"
|
||||
)
|
||||
self.hub_auth = HubOAuth(
|
||||
parent=self,
|
||||
api_token=api_token,
|
||||
@@ -586,21 +600,23 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
s['hub_prefix'] = self.hub_prefix
|
||||
s['hub_host'] = self.hub_host
|
||||
s['hub_auth'] = self.hub_auth
|
||||
csp_report_uri = s['csp_report_uri'] = self.hub_host + url_path_join(self.hub_prefix, 'security/csp-report')
|
||||
csp_report_uri = s['csp_report_uri'] = self.hub_host + url_path_join(
|
||||
self.hub_prefix, 'security/csp-report'
|
||||
)
|
||||
headers = s.setdefault('headers', {})
|
||||
headers['X-JupyterHub-Version'] = __version__
|
||||
# set CSP header directly to workaround bugs in jupyter/notebook 5.0
|
||||
headers.setdefault('Content-Security-Policy', ';'.join([
|
||||
"frame-ancestors 'self'",
|
||||
"report-uri " + csp_report_uri,
|
||||
]))
|
||||
headers.setdefault(
|
||||
'Content-Security-Policy',
|
||||
';'.join(["frame-ancestors 'self'", "report-uri " + csp_report_uri]),
|
||||
)
|
||||
super(SingleUserNotebookApp, self).init_webapp()
|
||||
|
||||
# add OAuth callback
|
||||
self.web_app.add_handlers(r".*$", [(
|
||||
urlparse(self.hub_auth.oauth_redirect_uri).path,
|
||||
OAuthCallbackHandler
|
||||
)])
|
||||
self.web_app.add_handlers(
|
||||
r".*$",
|
||||
[(urlparse(self.hub_auth.oauth_redirect_uri).path, OAuthCallbackHandler)],
|
||||
)
|
||||
|
||||
# apply X-JupyterHub-Version to *all* request handlers (even redirects)
|
||||
self.patch_default_headers()
|
||||
@@ -610,6 +626,7 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
if hasattr(RequestHandler, '_orig_set_default_headers'):
|
||||
return
|
||||
RequestHandler._orig_set_default_headers = RequestHandler.set_default_headers
|
||||
|
||||
def set_jupyterhub_header(self):
|
||||
self._orig_set_default_headers()
|
||||
self.set_header('X-JupyterHub-Version', __version__)
|
||||
@@ -619,13 +636,16 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
def patch_templates(self):
|
||||
"""Patch page templates to add Hub-related buttons"""
|
||||
|
||||
self.jinja_template_vars['logo_url'] = self.hub_host + url_path_join(self.hub_prefix, 'logo')
|
||||
self.jinja_template_vars['logo_url'] = self.hub_host + url_path_join(
|
||||
self.hub_prefix, 'logo'
|
||||
)
|
||||
self.jinja_template_vars['hub_host'] = self.hub_host
|
||||
self.jinja_template_vars['hub_prefix'] = self.hub_prefix
|
||||
env = self.web_app.settings['jinja2_env']
|
||||
|
||||
env.globals['hub_control_panel_url'] = \
|
||||
self.hub_host + url_path_join(self.hub_prefix, 'home')
|
||||
env.globals['hub_control_panel_url'] = self.hub_host + url_path_join(
|
||||
self.hub_prefix, 'home'
|
||||
)
|
||||
|
||||
# patch jinja env loading to modify page template
|
||||
def get_page(name):
|
||||
@@ -633,10 +653,7 @@ class SingleUserNotebookApp(NotebookApp):
|
||||
return page_template
|
||||
|
||||
orig_loader = env.loader
|
||||
env.loader = ChoiceLoader([
|
||||
FunctionLoader(get_page),
|
||||
orig_loader,
|
||||
])
|
||||
env.loader = ChoiceLoader([FunctionLoader(get_page), orig_loader])
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
|
@@ -1,10 +1,8 @@
|
||||
"""
|
||||
Contains base Spawner class & default implementation
|
||||
"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import errno
|
||||
@@ -18,22 +16,35 @@ import warnings
|
||||
from subprocess import Popen
|
||||
from tempfile import mkdtemp
|
||||
|
||||
# FIXME: remove when we drop Python 3.5 support
|
||||
from async_generator import async_generator, yield_
|
||||
|
||||
from async_generator import async_generator
|
||||
from async_generator import yield_
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from tornado.ioloop import PeriodicCallback
|
||||
|
||||
from traitlets import Any
|
||||
from traitlets import Bool
|
||||
from traitlets import default
|
||||
from traitlets import Dict
|
||||
from traitlets import Float
|
||||
from traitlets import Instance
|
||||
from traitlets import Integer
|
||||
from traitlets import List
|
||||
from traitlets import observe
|
||||
from traitlets import Unicode
|
||||
from traitlets import Union
|
||||
from traitlets import validate
|
||||
from traitlets.config import LoggingConfigurable
|
||||
from traitlets import (
|
||||
Any, Bool, Dict, Instance, Integer, Float, List, Unicode, Union,
|
||||
default, observe, validate,
|
||||
)
|
||||
|
||||
from .objects import Server
|
||||
from .traitlets import Command, ByteSpecification, Callable
|
||||
from .utils import iterate_until, maybe_future, random_port, url_path_join, exponential_backoff
|
||||
from .traitlets import ByteSpecification
|
||||
from .traitlets import Callable
|
||||
from .traitlets import Command
|
||||
from .utils import exponential_backoff
|
||||
from .utils import iterate_until
|
||||
from .utils import maybe_future
|
||||
from .utils import random_port
|
||||
from .utils import url_path_join
|
||||
|
||||
# FIXME: remove when we drop Python 3.5 support
|
||||
|
||||
|
||||
def _quote_safe(s):
|
||||
@@ -53,6 +64,7 @@ def _quote_safe(s):
|
||||
# to avoid getting interpreted by traitlets
|
||||
return repr(s)
|
||||
|
||||
|
||||
class Spawner(LoggingConfigurable):
|
||||
"""Base class for spawning single-user notebook servers.
|
||||
|
||||
@@ -146,8 +158,12 @@ class Spawner(LoggingConfigurable):
|
||||
missing.append(attr)
|
||||
|
||||
if missing:
|
||||
raise NotImplementedError("class `{}` needs to redefine the `start`,"
|
||||
"`stop` and `poll` methods. `{}` not redefined.".format(cls.__name__, '`, `'.join(missing)))
|
||||
raise NotImplementedError(
|
||||
"class `{}` needs to redefine the `start`,"
|
||||
"`stop` and `poll` methods. `{}` not redefined.".format(
|
||||
cls.__name__, '`, `'.join(missing)
|
||||
)
|
||||
)
|
||||
|
||||
proxy_spec = Unicode()
|
||||
|
||||
@@ -180,6 +196,7 @@ class Spawner(LoggingConfigurable):
|
||||
if self.orm_spawner:
|
||||
return self.orm_spawner.name
|
||||
return ''
|
||||
|
||||
hub = Any()
|
||||
authenticator = Any()
|
||||
internal_ssl = Bool(False)
|
||||
@@ -191,7 +208,8 @@ class Spawner(LoggingConfigurable):
|
||||
oauth_client_id = Unicode()
|
||||
handler = Any()
|
||||
|
||||
will_resume = Bool(False,
|
||||
will_resume = Bool(
|
||||
False,
|
||||
help="""Whether the Spawner will resume on next start
|
||||
|
||||
|
||||
@@ -199,18 +217,20 @@ class Spawner(LoggingConfigurable):
|
||||
If True, an existing Spawner will resume instead of starting anew
|
||||
(e.g. resuming a Docker container),
|
||||
and API tokens in use when the Spawner stops will not be deleted.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
ip = Unicode('',
|
||||
ip = Unicode(
|
||||
'',
|
||||
help="""
|
||||
The IP address (or hostname) the single-user server should listen on.
|
||||
|
||||
The JupyterHub proxy implementation should be able to send packets to this interface.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
port = Integer(0,
|
||||
port = Integer(
|
||||
0,
|
||||
help="""
|
||||
The port for single-user servers to listen on.
|
||||
|
||||
@@ -221,7 +241,7 @@ class Spawner(LoggingConfigurable):
|
||||
e.g. in containers.
|
||||
|
||||
New in version 0.7.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
consecutive_failure_limit = Integer(
|
||||
@@ -237,47 +257,48 @@ class Spawner(LoggingConfigurable):
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
start_timeout = Integer(60,
|
||||
start_timeout = Integer(
|
||||
60,
|
||||
help="""
|
||||
Timeout (in seconds) before giving up on starting of single-user server.
|
||||
|
||||
This is the timeout for start to return, not the timeout for the server to respond.
|
||||
Callers of spawner.start will assume that startup has failed if it takes longer than this.
|
||||
start should return when the server process is started and its location is known.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
http_timeout = Integer(30,
|
||||
http_timeout = Integer(
|
||||
30,
|
||||
help="""
|
||||
Timeout (in seconds) before giving up on a spawned HTTP server
|
||||
|
||||
Once a server has successfully been spawned, this is the amount of time
|
||||
we wait before assuming that the server is unable to accept
|
||||
connections.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
poll_interval = Integer(30,
|
||||
poll_interval = Integer(
|
||||
30,
|
||||
help="""
|
||||
Interval (in seconds) on which to poll the spawner for single-user server's status.
|
||||
|
||||
At every poll interval, each spawner's `.poll` method is called, which checks
|
||||
if the single-user server is still running. If it isn't running, then JupyterHub modifies
|
||||
its own state accordingly and removes appropriate routes from the configurable proxy.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
_callbacks = List()
|
||||
_poll_callback = Any()
|
||||
|
||||
debug = Bool(False,
|
||||
help="Enable debug-logging of the single-user server"
|
||||
).tag(config=True)
|
||||
debug = Bool(False, help="Enable debug-logging of the single-user server").tag(
|
||||
config=True
|
||||
)
|
||||
|
||||
options_form = Union([
|
||||
Unicode(),
|
||||
Callable()
|
||||
],
|
||||
options_form = Union(
|
||||
[Unicode(), Callable()],
|
||||
help="""
|
||||
An HTML form for options a user can specify on launching their server.
|
||||
|
||||
@@ -303,7 +324,8 @@ class Spawner(LoggingConfigurable):
|
||||
be called asynchronously if it returns a future, rather than a str. Note that
|
||||
the interface of the spawner class is not deemed stable across versions,
|
||||
so using this functionality might cause your JupyterHub upgrades to break.
|
||||
""").tag(config=True)
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
async def get_options_form(self):
|
||||
"""Get the options form
|
||||
@@ -341,9 +363,11 @@ class Spawner(LoggingConfigurable):
|
||||
|
||||
These user options are usually provided by the `options_form` displayed to the user when they start
|
||||
their server.
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
env_keep = List([
|
||||
env_keep = List(
|
||||
[
|
||||
'PATH',
|
||||
'PYTHONPATH',
|
||||
'CONDA_ROOT',
|
||||
@@ -357,14 +381,16 @@ class Spawner(LoggingConfigurable):
|
||||
|
||||
This whitelist is used to ensure that sensitive information in the JupyterHub process's environment
|
||||
(such as `CONFIGPROXY_AUTH_TOKEN`) is not passed to the single-user server's process.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
env = Dict(help="""Deprecated: use Spawner.get_env or Spawner.environment
|
||||
env = Dict(
|
||||
help="""Deprecated: use Spawner.get_env or Spawner.environment
|
||||
|
||||
- extend Spawner.get_env for adding required env in Spawner subclasses
|
||||
- Spawner.environment for config-specified env
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
environment = Dict(
|
||||
help="""
|
||||
@@ -386,7 +412,8 @@ class Spawner(LoggingConfigurable):
|
||||
"""
|
||||
).tag(config=True)
|
||||
|
||||
cmd = Command(['jupyterhub-singleuser'],
|
||||
cmd = Command(
|
||||
['jupyterhub-singleuser'],
|
||||
allow_none=True,
|
||||
help="""
|
||||
The command used for starting the single-user server.
|
||||
@@ -399,16 +426,17 @@ class Spawner(LoggingConfigurable):
|
||||
|
||||
Some spawners allow shell-style expansion here, allowing you to use environment variables.
|
||||
Most, including the default, do not. Consult the documentation for your spawner to verify!
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
args = List(Unicode(),
|
||||
args = List(
|
||||
Unicode(),
|
||||
help="""
|
||||
Extra arguments to be passed to the single-user server.
|
||||
|
||||
Some spawners allow shell-style expansion here, allowing you to use environment variables here.
|
||||
Most, including the default, do not. Consult the documentation for your spawner to verify!
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
notebook_dir = Unicode(
|
||||
@@ -446,14 +474,16 @@ class Spawner(LoggingConfigurable):
|
||||
def _deprecate_percent_u(self, proposal):
|
||||
v = proposal['value']
|
||||
if '%U' in v:
|
||||
self.log.warning("%%U for username in %s is deprecated in JupyterHub 0.7, use {username}",
|
||||
self.log.warning(
|
||||
"%%U for username in %s is deprecated in JupyterHub 0.7, use {username}",
|
||||
proposal['trait'].name,
|
||||
)
|
||||
v = v.replace('%U', '{username}')
|
||||
self.log.warning("Converting %r to %r", proposal['value'], v)
|
||||
return v
|
||||
|
||||
disable_user_config = Bool(False,
|
||||
disable_user_config = Bool(
|
||||
False,
|
||||
help="""
|
||||
Disable per-user configuration of single-user servers.
|
||||
|
||||
@@ -462,10 +492,11 @@ class Spawner(LoggingConfigurable):
|
||||
|
||||
Note: a user could circumvent this if the user modifies their Python environment, such as when
|
||||
they have their own conda environments / virtualenvs / containers.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
mem_limit = ByteSpecification(None,
|
||||
mem_limit = ByteSpecification(
|
||||
None,
|
||||
help="""
|
||||
Maximum number of bytes a single-user notebook server is allowed to use.
|
||||
|
||||
@@ -484,10 +515,11 @@ class Spawner(LoggingConfigurable):
|
||||
for the limit to work.** The default spawner, `LocalProcessSpawner`,
|
||||
does **not** implement this support. A custom spawner **must** add
|
||||
support for this setting for it to be enforced.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
cpu_limit = Float(None,
|
||||
cpu_limit = Float(
|
||||
None,
|
||||
allow_none=True,
|
||||
help="""
|
||||
Maximum number of cpu-cores a single-user notebook server is allowed to use.
|
||||
@@ -503,10 +535,11 @@ class Spawner(LoggingConfigurable):
|
||||
for the limit to work.** The default spawner, `LocalProcessSpawner`,
|
||||
does **not** implement this support. A custom spawner **must** add
|
||||
support for this setting for it to be enforced.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
mem_guarantee = ByteSpecification(None,
|
||||
mem_guarantee = ByteSpecification(
|
||||
None,
|
||||
help="""
|
||||
Minimum number of bytes a single-user notebook server is guaranteed to have available.
|
||||
|
||||
@@ -520,10 +553,11 @@ class Spawner(LoggingConfigurable):
|
||||
for the limit to work.** The default spawner, `LocalProcessSpawner`,
|
||||
does **not** implement this support. A custom spawner **must** add
|
||||
support for this setting for it to be enforced.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
cpu_guarantee = Float(None,
|
||||
cpu_guarantee = Float(
|
||||
None,
|
||||
allow_none=True,
|
||||
help="""
|
||||
Minimum number of cpu-cores a single-user notebook server is guaranteed to have available.
|
||||
@@ -535,7 +569,7 @@ class Spawner(LoggingConfigurable):
|
||||
for the limit to work.** The default spawner, `LocalProcessSpawner`,
|
||||
does **not** implement this support. A custom spawner **must** add
|
||||
support for this setting for it to be enforced.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
pre_spawn_hook = Any(
|
||||
@@ -621,7 +655,9 @@ class Spawner(LoggingConfigurable):
|
||||
"""
|
||||
env = {}
|
||||
if self.env:
|
||||
warnings.warn("Spawner.env is deprecated, found %s" % self.env, DeprecationWarning)
|
||||
warnings.warn(
|
||||
"Spawner.env is deprecated, found %s" % self.env, DeprecationWarning
|
||||
)
|
||||
env.update(self.env)
|
||||
|
||||
for key in self.env_keep:
|
||||
@@ -648,8 +684,9 @@ class Spawner(LoggingConfigurable):
|
||||
if self.cookie_options:
|
||||
env['JUPYTERHUB_COOKIE_OPTIONS'] = json.dumps(self.cookie_options)
|
||||
env['JUPYTERHUB_HOST'] = self.hub.public_host
|
||||
env['JUPYTERHUB_OAUTH_CALLBACK_URL'] = \
|
||||
url_path_join(self.user.url, self.name, 'oauth_callback')
|
||||
env['JUPYTERHUB_OAUTH_CALLBACK_URL'] = url_path_join(
|
||||
self.user.url, self.name, 'oauth_callback'
|
||||
)
|
||||
|
||||
# Info previously passed on args
|
||||
env['JUPYTERHUB_USER'] = self.user.name
|
||||
@@ -749,7 +786,7 @@ class Spawner(LoggingConfigurable):
|
||||
|
||||
May be set in config if all spawners should have the same value(s),
|
||||
or set at runtime by Spawner that know their names.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
@default('ssl_alt_names')
|
||||
@@ -793,6 +830,7 @@ class Spawner(LoggingConfigurable):
|
||||
to the host by either IP or DNS name (note the `default_names` below).
|
||||
"""
|
||||
from certipy import Certipy
|
||||
|
||||
default_names = ["DNS:localhost", "IP:127.0.0.1"]
|
||||
alt_names = []
|
||||
alt_names.extend(self.ssl_alt_names)
|
||||
@@ -800,10 +838,7 @@ class Spawner(LoggingConfigurable):
|
||||
if self.ssl_alt_names_include_local:
|
||||
alt_names = default_names + alt_names
|
||||
|
||||
self.log.info("Creating certs for %s: %s",
|
||||
self._log_name,
|
||||
';'.join(alt_names),
|
||||
)
|
||||
self.log.info("Creating certs for %s: %s", self._log_name, ';'.join(alt_names))
|
||||
|
||||
common_name = self.user.name or 'service'
|
||||
certipy = Certipy(store_dir=self.internal_certs_location)
|
||||
@@ -812,7 +847,7 @@ class Spawner(LoggingConfigurable):
|
||||
'user-' + common_name,
|
||||
notebook_component,
|
||||
alt_names=alt_names,
|
||||
overwrite=True
|
||||
overwrite=True,
|
||||
)
|
||||
paths = {
|
||||
"keyfile": notebook_key_pair['files']['key'],
|
||||
@@ -862,7 +897,9 @@ class Spawner(LoggingConfigurable):
|
||||
if self.port:
|
||||
args.append('--port=%i' % self.port)
|
||||
elif self.server and self.server.port:
|
||||
self.log.warning("Setting port from user.server is deprecated as of JupyterHub 0.7.")
|
||||
self.log.warning(
|
||||
"Setting port from user.server is deprecated as of JupyterHub 0.7."
|
||||
)
|
||||
args.append('--port=%i' % self.server.port)
|
||||
|
||||
if self.notebook_dir:
|
||||
@@ -903,13 +940,12 @@ class Spawner(LoggingConfigurable):
|
||||
This method is always an async generator and will always yield at least one event.
|
||||
"""
|
||||
if not self._spawn_pending:
|
||||
self.log.warning("Spawn not pending, can't generate progress for %s", self._log_name)
|
||||
self.log.warning(
|
||||
"Spawn not pending, can't generate progress for %s", self._log_name
|
||||
)
|
||||
return
|
||||
|
||||
await yield_({
|
||||
"progress": 0,
|
||||
"message": "Server requested",
|
||||
})
|
||||
await yield_({"progress": 0, "message": "Server requested"})
|
||||
from async_generator import aclosing
|
||||
|
||||
async with aclosing(self.progress()) as progress:
|
||||
@@ -940,10 +976,7 @@ class Spawner(LoggingConfigurable):
|
||||
|
||||
.. versionadded:: 0.9
|
||||
"""
|
||||
await yield_({
|
||||
"progress": 50,
|
||||
"message": "Spawning server...",
|
||||
})
|
||||
await yield_({"progress": 50, "message": "Spawning server..."})
|
||||
|
||||
async def start(self):
|
||||
"""Start the single-user server
|
||||
@@ -954,7 +987,9 @@ class Spawner(LoggingConfigurable):
|
||||
.. versionchanged:: 0.7
|
||||
Return ip, port instead of setting on self.user.server directly.
|
||||
"""
|
||||
raise NotImplementedError("Override in subclass. Must be a Tornado gen.coroutine.")
|
||||
raise NotImplementedError(
|
||||
"Override in subclass. Must be a Tornado gen.coroutine."
|
||||
)
|
||||
|
||||
async def stop(self, now=False):
|
||||
"""Stop the single-user server
|
||||
@@ -967,7 +1002,9 @@ class Spawner(LoggingConfigurable):
|
||||
|
||||
Must be a coroutine.
|
||||
"""
|
||||
raise NotImplementedError("Override in subclass. Must be a Tornado gen.coroutine.")
|
||||
raise NotImplementedError(
|
||||
"Override in subclass. Must be a Tornado gen.coroutine."
|
||||
)
|
||||
|
||||
async def poll(self):
|
||||
"""Check if the single-user process is running
|
||||
@@ -993,7 +1030,9 @@ class Spawner(LoggingConfigurable):
|
||||
process has not yet completed.
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Override in subclass. Must be a Tornado gen.coroutine.")
|
||||
raise NotImplementedError(
|
||||
"Override in subclass. Must be a Tornado gen.coroutine."
|
||||
)
|
||||
|
||||
def add_poll_callback(self, callback, *args, **kwargs):
|
||||
"""Add a callback to fire when the single-user server stops"""
|
||||
@@ -1023,8 +1062,7 @@ class Spawner(LoggingConfigurable):
|
||||
self.stop_polling()
|
||||
|
||||
self._poll_callback = PeriodicCallback(
|
||||
self.poll_and_notify,
|
||||
1e3 * self.poll_interval
|
||||
self.poll_and_notify, 1e3 * self.poll_interval
|
||||
)
|
||||
self._poll_callback.start()
|
||||
|
||||
@@ -1048,8 +1086,10 @@ class Spawner(LoggingConfigurable):
|
||||
return status
|
||||
|
||||
death_interval = Float(0.1)
|
||||
|
||||
async def wait_for_death(self, timeout=10):
|
||||
"""Wait for the single-user server to die, up to timeout seconds"""
|
||||
|
||||
async def _wait_for_death():
|
||||
status = await self.poll()
|
||||
return status is not None
|
||||
@@ -1093,6 +1133,7 @@ def set_user_setuid(username, chdir=True):
|
||||
"""
|
||||
import grp
|
||||
import pwd
|
||||
|
||||
user = pwd.getpwnam(username)
|
||||
uid = user.pw_uid
|
||||
gid = user.pw_gid
|
||||
@@ -1132,29 +1173,32 @@ class LocalProcessSpawner(Spawner):
|
||||
Note: This spawner does not implement CPU / memory guarantees and limits.
|
||||
"""
|
||||
|
||||
interrupt_timeout = Integer(10,
|
||||
interrupt_timeout = Integer(
|
||||
10,
|
||||
help="""
|
||||
Seconds to wait for single-user server process to halt after SIGINT.
|
||||
|
||||
If the process has not exited cleanly after this many seconds, a SIGTERM is sent.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
term_timeout = Integer(5,
|
||||
term_timeout = Integer(
|
||||
5,
|
||||
help="""
|
||||
Seconds to wait for single-user server process to halt after SIGTERM.
|
||||
|
||||
If the process does not exit cleanly after this many seconds of SIGTERM, a SIGKILL is sent.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
kill_timeout = Integer(5,
|
||||
kill_timeout = Integer(
|
||||
5,
|
||||
help="""
|
||||
Seconds to wait for process to halt after SIGKILL before giving up.
|
||||
|
||||
If the process does not exit cleanly after this many seconds of SIGKILL, it becomes a zombie
|
||||
process. The hub process will log a warning and then give up.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
popen_kwargs = Dict(
|
||||
@@ -1168,7 +1212,8 @@ class LocalProcessSpawner(Spawner):
|
||||
|
||||
"""
|
||||
).tag(config=True)
|
||||
shell_cmd = Command(minlen=0,
|
||||
shell_cmd = Command(
|
||||
minlen=0,
|
||||
help="""Specify a shell command to launch.
|
||||
|
||||
The single-user command will be appended to this list,
|
||||
@@ -1185,20 +1230,23 @@ class LocalProcessSpawner(Spawner):
|
||||
Using shell_cmd gives users control over PATH, etc.,
|
||||
which could change what the jupyterhub-singleuser launch command does.
|
||||
Only use this for trusted users.
|
||||
"""
|
||||
""",
|
||||
).tag(config=True)
|
||||
|
||||
proc = Instance(Popen,
|
||||
proc = Instance(
|
||||
Popen,
|
||||
allow_none=True,
|
||||
help="""
|
||||
The process representing the single-user server process spawned for current user.
|
||||
|
||||
Is None if no process has been spawned yet.
|
||||
""")
|
||||
pid = Integer(0,
|
||||
""",
|
||||
)
|
||||
pid = Integer(
|
||||
0,
|
||||
help="""
|
||||
The process id (pid) of the single-user server process spawned for current user.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
def make_preexec_fn(self, name):
|
||||
@@ -1236,6 +1284,7 @@ class LocalProcessSpawner(Spawner):
|
||||
def user_env(self, env):
|
||||
"""Augment environment of spawned process with user specific env variables."""
|
||||
import pwd
|
||||
|
||||
env['USER'] = self.user.name
|
||||
home = pwd.getpwnam(self.user.name).pw_dir
|
||||
shell = pwd.getpwnam(self.user.name).pw_shell
|
||||
@@ -1267,6 +1316,7 @@ class LocalProcessSpawner(Spawner):
|
||||
and make them readable by the user.
|
||||
"""
|
||||
import pwd
|
||||
|
||||
key = paths['keyfile']
|
||||
cert = paths['certfile']
|
||||
ca = paths['cafile']
|
||||
@@ -1324,8 +1374,10 @@ class LocalProcessSpawner(Spawner):
|
||||
except PermissionError:
|
||||
# use which to get abspath
|
||||
script = shutil.which(cmd[0]) or cmd[0]
|
||||
self.log.error("Permission denied trying to run %r. Does %s have access to this file?",
|
||||
script, self.user.name,
|
||||
self.log.error(
|
||||
"Permission denied trying to run %r. Does %s have access to this file?",
|
||||
script,
|
||||
self.user.name,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -1445,24 +1497,25 @@ class SimpleLocalProcessSpawner(LocalProcessSpawner):
|
||||
help="""
|
||||
Template to expand to set the user home.
|
||||
{username} is expanded to the jupyterhub username.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
home_dir = Unicode(help="The home directory for the user")
|
||||
|
||||
@default('home_dir')
|
||||
def _default_home_dir(self):
|
||||
return self.home_dir_template.format(
|
||||
username=self.user.name,
|
||||
)
|
||||
return self.home_dir_template.format(username=self.user.name)
|
||||
|
||||
def make_preexec_fn(self, name):
|
||||
home = self.home_dir
|
||||
|
||||
def preexec():
|
||||
try:
|
||||
os.makedirs(home, 0o755, exist_ok=True)
|
||||
os.chdir(home)
|
||||
except Exception as e:
|
||||
self.log.exception("Error in preexec for %s", name)
|
||||
|
||||
return preexec
|
||||
|
||||
def user_env(self, env):
|
||||
@@ -1474,4 +1527,3 @@ class SimpleLocalProcessSpawner(LocalProcessSpawner):
|
||||
def move_certs(self, paths):
|
||||
"""No-op for installing certs"""
|
||||
return paths
|
||||
|
||||
|
@@ -23,34 +23,33 @@ Fixtures to add functionality or spawning behavior
|
||||
- `slow_bad_spawn`
|
||||
|
||||
"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import asyncio
|
||||
from getpass import getuser
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from getpass import getuser
|
||||
from subprocess import TimeoutExpired
|
||||
from unittest import mock
|
||||
|
||||
from pytest import fixture, raises
|
||||
from tornado import ioloop, gen
|
||||
from pytest import fixture
|
||||
from pytest import raises
|
||||
from tornado import gen
|
||||
from tornado import ioloop
|
||||
from tornado.httpclient import HTTPError
|
||||
from tornado.platform.asyncio import AsyncIOMainLoop
|
||||
|
||||
from .. import orm
|
||||
from .. import crypto
|
||||
from ..utils import random_port
|
||||
|
||||
from . import mocking
|
||||
from .mocking import MockHub
|
||||
from .utils import ssl_setup, add_user
|
||||
from .test_services import mockservice_cmd
|
||||
|
||||
import jupyterhub.services.service
|
||||
from . import mocking
|
||||
from .. import crypto
|
||||
from .. import orm
|
||||
from ..utils import random_port
|
||||
from .mocking import MockHub
|
||||
from .test_services import mockservice_cmd
|
||||
from .utils import add_user
|
||||
from .utils import ssl_setup
|
||||
|
||||
# global db session object
|
||||
_db = None
|
||||
@@ -78,12 +77,7 @@ def app(request, io_loop, ssl_tmpdir):
|
||||
ssl_enabled = getattr(request.module, "ssl_enabled", False)
|
||||
kwargs = dict()
|
||||
if ssl_enabled:
|
||||
kwargs.update(
|
||||
dict(
|
||||
internal_ssl=True,
|
||||
internal_certs_location=str(ssl_tmpdir),
|
||||
)
|
||||
)
|
||||
kwargs.update(dict(internal_ssl=True, internal_certs_location=str(ssl_tmpdir)))
|
||||
|
||||
mocked_app = MockHub.instance(**kwargs)
|
||||
|
||||
@@ -107,9 +101,7 @@ def app(request, io_loop, ssl_tmpdir):
|
||||
|
||||
@fixture
|
||||
def auth_state_enabled(app):
|
||||
app.authenticator.auth_state = {
|
||||
'who': 'cares',
|
||||
}
|
||||
app.authenticator.auth_state = {'who': 'cares'}
|
||||
app.authenticator.enable_auth_state = True
|
||||
ck = crypto.CryptKeeper.instance()
|
||||
before_keys = ck.keys
|
||||
@@ -128,9 +120,7 @@ def db():
|
||||
global _db
|
||||
if _db is None:
|
||||
_db = orm.new_session_factory('sqlite:///:memory:')()
|
||||
user = orm.User(
|
||||
name=getuser(),
|
||||
)
|
||||
user = orm.User(name=getuser())
|
||||
_db.add(user)
|
||||
_db.commit()
|
||||
return _db
|
||||
@@ -221,12 +211,12 @@ def admin_user(app, username):
|
||||
yield user
|
||||
|
||||
|
||||
|
||||
class MockServiceSpawner(jupyterhub.services.service._ServiceSpawner):
|
||||
"""mock services for testing.
|
||||
|
||||
Shorter intervals, etc.
|
||||
"""
|
||||
|
||||
poll_interval = 1
|
||||
|
||||
|
||||
@@ -237,11 +227,7 @@ def _mockservice(request, app, url=False):
|
||||
global _mock_service_counter
|
||||
_mock_service_counter += 1
|
||||
name = 'mock-service-%i' % _mock_service_counter
|
||||
spec = {
|
||||
'name': name,
|
||||
'command': mockservice_cmd,
|
||||
'admin': True,
|
||||
}
|
||||
spec = {'name': name, 'command': mockservice_cmd, 'admin': True}
|
||||
if url:
|
||||
if app.internal_ssl:
|
||||
spec['url'] = 'https://127.0.0.1:%i' % random_port()
|
||||
@@ -250,22 +236,29 @@ def _mockservice(request, app, url=False):
|
||||
|
||||
io_loop = app.io_loop
|
||||
|
||||
with mock.patch.object(jupyterhub.services.service, '_ServiceSpawner', MockServiceSpawner):
|
||||
with mock.patch.object(
|
||||
jupyterhub.services.service, '_ServiceSpawner', MockServiceSpawner
|
||||
):
|
||||
app.services = [spec]
|
||||
app.init_services()
|
||||
assert name in app._service_map
|
||||
service = app._service_map[name]
|
||||
|
||||
@gen.coroutine
|
||||
def start():
|
||||
# wait for proxy to be updated before starting the service
|
||||
yield app.proxy.add_all_services(app._service_map)
|
||||
service.start()
|
||||
|
||||
io_loop.run_sync(start)
|
||||
|
||||
def cleanup():
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(service.stop())
|
||||
app.services[:] = []
|
||||
app._service_map.clear()
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
# ensure process finishes starting
|
||||
with raises(TimeoutExpired):
|
||||
@@ -290,47 +283,44 @@ def mockservice_url(request, app):
|
||||
@fixture
|
||||
def admin_access(app):
|
||||
"""Grant admin-access with this fixture"""
|
||||
with mock.patch.dict(app.tornado_settings,
|
||||
{'admin_access': True}):
|
||||
with mock.patch.dict(app.tornado_settings, {'admin_access': True}):
|
||||
yield
|
||||
|
||||
|
||||
@fixture
|
||||
def no_patience(app):
|
||||
"""Set slow-spawning timeouts to zero"""
|
||||
with mock.patch.dict(app.tornado_settings,
|
||||
{'slow_spawn_timeout': 0.1,
|
||||
'slow_stop_timeout': 0.1}):
|
||||
with mock.patch.dict(
|
||||
app.tornado_settings, {'slow_spawn_timeout': 0.1, 'slow_stop_timeout': 0.1}
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@fixture
|
||||
def slow_spawn(app):
|
||||
"""Fixture enabling SlowSpawner"""
|
||||
with mock.patch.dict(app.tornado_settings,
|
||||
{'spawner_class': mocking.SlowSpawner}):
|
||||
with mock.patch.dict(app.tornado_settings, {'spawner_class': mocking.SlowSpawner}):
|
||||
yield
|
||||
|
||||
|
||||
@fixture
|
||||
def never_spawn(app):
|
||||
"""Fixture enabling NeverSpawner"""
|
||||
with mock.patch.dict(app.tornado_settings,
|
||||
{'spawner_class': mocking.NeverSpawner}):
|
||||
with mock.patch.dict(app.tornado_settings, {'spawner_class': mocking.NeverSpawner}):
|
||||
yield
|
||||
|
||||
|
||||
@fixture
|
||||
def bad_spawn(app):
|
||||
"""Fixture enabling BadSpawner"""
|
||||
with mock.patch.dict(app.tornado_settings,
|
||||
{'spawner_class': mocking.BadSpawner}):
|
||||
with mock.patch.dict(app.tornado_settings, {'spawner_class': mocking.BadSpawner}):
|
||||
yield
|
||||
|
||||
|
||||
@fixture
|
||||
def slow_bad_spawn(app):
|
||||
"""Fixture enabling SlowBadSpawner"""
|
||||
with mock.patch.dict(app.tornado_settings,
|
||||
{'spawner_class': mocking.SlowBadSpawner}):
|
||||
with mock.patch.dict(
|
||||
app.tornado_settings, {'spawner_class': mocking.SlowBadSpawner}
|
||||
):
|
||||
yield
|
||||
|
@@ -26,32 +26,36 @@ Other components
|
||||
- public_url
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import os
|
||||
import sys
|
||||
from tempfile import NamedTemporaryFile
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tempfile import NamedTemporaryFile
|
||||
from unittest import mock
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pamela import PAMError
|
||||
from tornado import gen
|
||||
from tornado.concurrent import Future
|
||||
from tornado.ioloop import IOLoop
|
||||
from traitlets import Bool
|
||||
from traitlets import default
|
||||
from traitlets import Dict
|
||||
|
||||
from traitlets import Bool, Dict, default
|
||||
|
||||
from .. import orm
|
||||
from ..app import JupyterHub
|
||||
from ..auth import PAMAuthenticator
|
||||
from .. import orm
|
||||
from ..objects import Server
|
||||
from ..spawner import LocalProcessSpawner, SimpleLocalProcessSpawner
|
||||
from ..singleuser import SingleUserNotebookApp
|
||||
from ..utils import random_port, url_path_join
|
||||
from .utils import async_requests, ssl_setup, public_host, public_url
|
||||
|
||||
from pamela import PAMError
|
||||
from ..spawner import LocalProcessSpawner
|
||||
from ..spawner import SimpleLocalProcessSpawner
|
||||
from ..utils import random_port
|
||||
from ..utils import url_path_join
|
||||
from .utils import async_requests
|
||||
from .utils import public_host
|
||||
from .utils import public_url
|
||||
from .utils import ssl_setup
|
||||
|
||||
|
||||
def mock_authenticate(username, password, service, encoding):
|
||||
@@ -79,6 +83,7 @@ class MockSpawner(SimpleLocalProcessSpawner):
|
||||
- disables user-switching that we need root permissions to do
|
||||
- spawns `jupyterhub.tests.mocksu` instead of a full single-user server
|
||||
"""
|
||||
|
||||
def user_env(self, env):
|
||||
env = super().user_env(env)
|
||||
if self.handler:
|
||||
@@ -90,6 +95,7 @@ class MockSpawner(SimpleLocalProcessSpawner):
|
||||
return [sys.executable, '-m', 'jupyterhub.tests.mocksu']
|
||||
|
||||
use_this_api_token = None
|
||||
|
||||
def start(self):
|
||||
if self.use_this_api_token:
|
||||
self.api_token = self.use_this_api_token
|
||||
@@ -103,6 +109,7 @@ class SlowSpawner(MockSpawner):
|
||||
|
||||
delay = 2
|
||||
_start_future = None
|
||||
|
||||
@gen.coroutine
|
||||
def start(self):
|
||||
(ip, port) = yield super().start()
|
||||
@@ -140,6 +147,7 @@ class NeverSpawner(MockSpawner):
|
||||
|
||||
class BadSpawner(MockSpawner):
|
||||
"""Spawner that fails immediately"""
|
||||
|
||||
def start(self):
|
||||
raise RuntimeError("I don't work!")
|
||||
|
||||
@@ -154,6 +162,7 @@ class SlowBadSpawner(MockSpawner):
|
||||
|
||||
class FormSpawner(MockSpawner):
|
||||
"""A spawner that has an options form defined"""
|
||||
|
||||
options_form = "IMAFORM"
|
||||
|
||||
def options_from_form(self, form_data):
|
||||
@@ -167,6 +176,7 @@ class FormSpawner(MockSpawner):
|
||||
options['hello'] = form_data['hello_file'][0]
|
||||
return options
|
||||
|
||||
|
||||
class FalsyCallableFormSpawner(FormSpawner):
|
||||
"""A spawner that has a callable options form defined returning a falsy value"""
|
||||
|
||||
@@ -181,6 +191,7 @@ class MockStructGroup:
|
||||
self.gr_mem = members
|
||||
self.gr_gid = gid
|
||||
|
||||
|
||||
class MockStructPasswd:
|
||||
"""Mock pwd.struct_passwd"""
|
||||
|
||||
@@ -193,6 +204,7 @@ class MockPAMAuthenticator(PAMAuthenticator):
|
||||
auth_state = None
|
||||
# If true, return admin users marked as admin.
|
||||
return_admin = False
|
||||
|
||||
@default('admin_users')
|
||||
def _admin_users_default(self):
|
||||
return {'admin'}
|
||||
@@ -203,20 +215,20 @@ class MockPAMAuthenticator(PAMAuthenticator):
|
||||
|
||||
@gen.coroutine
|
||||
def authenticate(self, *args, **kwargs):
|
||||
with mock.patch.multiple('pamela',
|
||||
with mock.patch.multiple(
|
||||
'pamela',
|
||||
authenticate=mock_authenticate,
|
||||
open_session=mock_open_session,
|
||||
close_session=mock_open_session,
|
||||
check_account=mock_check_account,
|
||||
):
|
||||
username = yield super(MockPAMAuthenticator, self).authenticate(*args, **kwargs)
|
||||
username = yield super(MockPAMAuthenticator, self).authenticate(
|
||||
*args, **kwargs
|
||||
)
|
||||
if username is None:
|
||||
return
|
||||
elif self.auth_state:
|
||||
return {
|
||||
'name': username,
|
||||
'auth_state': self.auth_state,
|
||||
}
|
||||
return {'name': username, 'auth_state': self.auth_state}
|
||||
else:
|
||||
return username
|
||||
|
||||
@@ -349,11 +361,9 @@ class MockHub(JupyterHub):
|
||||
external_ca = None
|
||||
if self.internal_ssl:
|
||||
external_ca = self.external_certs['files']['ca']
|
||||
r = yield async_requests.post(base_url + 'hub/login',
|
||||
data={
|
||||
'username': name,
|
||||
'password': name,
|
||||
},
|
||||
r = yield async_requests.post(
|
||||
base_url + 'hub/login',
|
||||
data={'username': name, 'password': name},
|
||||
allow_redirects=False,
|
||||
verify=external_ca,
|
||||
)
|
||||
@@ -364,6 +374,7 @@ class MockHub(JupyterHub):
|
||||
|
||||
# single-user-server mocking:
|
||||
|
||||
|
||||
class MockSingleUserServer(SingleUserNotebookApp):
|
||||
"""Mock-out problematic parts of single-user server when run in a thread
|
||||
|
||||
@@ -378,7 +389,9 @@ class MockSingleUserServer(SingleUserNotebookApp):
|
||||
|
||||
class StubSingleUserSpawner(MockSpawner):
|
||||
"""Spawner that starts a MockSingleUserServer in a thread."""
|
||||
|
||||
_thread = None
|
||||
|
||||
@gen.coroutine
|
||||
def start(self):
|
||||
ip = self.ip = '127.0.0.1'
|
||||
@@ -387,6 +400,7 @@ class StubSingleUserSpawner(MockSpawner):
|
||||
args = self.get_args()
|
||||
evt = threading.Event()
|
||||
print(args, env)
|
||||
|
||||
def _run():
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
io_loop = IOLoop()
|
||||
@@ -420,4 +434,3 @@ class StubSingleUserSpawner(MockSpawner):
|
||||
return None
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
@@ -12,28 +12,33 @@ Handlers and their purpose include:
|
||||
- WhoAmIHandler: returns name of user making a request (deprecated cookie login)
|
||||
- OWhoAmIHandler: returns name of user making a request (OAuth login)
|
||||
"""
|
||||
|
||||
import json
|
||||
import pprint
|
||||
import os
|
||||
import pprint
|
||||
import sys
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from tornado import web, httpserver, ioloop
|
||||
from tornado import httpserver
|
||||
from tornado import ioloop
|
||||
from tornado import web
|
||||
|
||||
from jupyterhub.services.auth import HubAuthenticated, HubOAuthenticated, HubOAuthCallbackHandler
|
||||
from jupyterhub.services.auth import HubAuthenticated
|
||||
from jupyterhub.services.auth import HubOAuthCallbackHandler
|
||||
from jupyterhub.services.auth import HubOAuthenticated
|
||||
from jupyterhub.utils import make_ssl_context
|
||||
|
||||
|
||||
class EchoHandler(web.RequestHandler):
|
||||
"""Reply to an HTTP request with the path of the request."""
|
||||
|
||||
def get(self):
|
||||
self.write(self.request.path)
|
||||
|
||||
|
||||
class EnvHandler(web.RequestHandler):
|
||||
"""Reply to an HTTP request with the service's environment as JSON."""
|
||||
|
||||
def get(self):
|
||||
self.set_header('Content-Type', 'application/json')
|
||||
self.write(json.dumps(dict(os.environ)))
|
||||
@@ -41,11 +46,12 @@ class EnvHandler(web.RequestHandler):
|
||||
|
||||
class APIHandler(web.RequestHandler):
|
||||
"""Relay API requests to the Hub's API using the service's API token."""
|
||||
|
||||
def get(self, path):
|
||||
api_token = os.environ['JUPYTERHUB_API_TOKEN']
|
||||
api_url = os.environ['JUPYTERHUB_API_URL']
|
||||
r = requests.get(api_url + path,
|
||||
headers={'Authorization': 'token %s' % api_token},
|
||||
r = requests.get(
|
||||
api_url + path, headers={'Authorization': 'token %s' % api_token}
|
||||
)
|
||||
r.raise_for_status()
|
||||
self.set_header('Content-Type', 'application/json')
|
||||
@@ -57,6 +63,7 @@ class WhoAmIHandler(HubAuthenticated, web.RequestHandler):
|
||||
|
||||
Uses "deprecated" cookie login
|
||||
"""
|
||||
|
||||
@web.authenticated
|
||||
def get(self):
|
||||
self.write(self.get_current_user())
|
||||
@@ -67,6 +74,7 @@ class OWhoAmIHandler(HubOAuthenticated, web.RequestHandler):
|
||||
|
||||
Uses OAuth login flow
|
||||
"""
|
||||
|
||||
@web.authenticated
|
||||
def get(self):
|
||||
self.write(self.get_current_user())
|
||||
@@ -77,14 +85,17 @@ def main():
|
||||
|
||||
if os.getenv('JUPYTERHUB_SERVICE_URL'):
|
||||
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
|
||||
app = web.Application([
|
||||
app = web.Application(
|
||||
[
|
||||
(r'.*/env', EnvHandler),
|
||||
(r'.*/api/(.*)', APIHandler),
|
||||
(r'.*/whoami/?', WhoAmIHandler),
|
||||
(r'.*/owhoami/?', OWhoAmIHandler),
|
||||
(r'.*/oauth_callback', HubOAuthCallbackHandler),
|
||||
(r'.*', EchoHandler),
|
||||
], cookie_secret=os.urandom(32))
|
||||
],
|
||||
cookie_secret=os.urandom(32),
|
||||
)
|
||||
|
||||
ssl_context = None
|
||||
key = os.environ.get('JUPYTERHUB_SSL_KEYFILE') or ''
|
||||
@@ -92,11 +103,7 @@ def main():
|
||||
ca = os.environ.get('JUPYTERHUB_SSL_CLIENT_CA') or ''
|
||||
|
||||
if key and cert and ca:
|
||||
ssl_context = make_ssl_context(
|
||||
key,
|
||||
cert,
|
||||
cafile = ca,
|
||||
check_hostname = False)
|
||||
ssl_context = make_ssl_context(key, cert, cafile=ca, check_hostname=False)
|
||||
|
||||
server = httpserver.HTTPServer(app, ssl_options=ssl_context)
|
||||
server.listen(url.port, url.hostname)
|
||||
@@ -108,5 +115,6 @@ def main():
|
||||
|
||||
if __name__ == '__main__':
|
||||
from tornado.options import parse_command_line
|
||||
|
||||
parse_command_line()
|
||||
main()
|
||||
|
@@ -13,28 +13,32 @@ Handlers and their purpose include:
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tornado import httpserver
|
||||
from tornado import ioloop
|
||||
from tornado import web
|
||||
|
||||
from tornado import web, httpserver, ioloop
|
||||
from .mockservice import EnvHandler
|
||||
from ..utils import make_ssl_context
|
||||
from .mockservice import EnvHandler
|
||||
|
||||
|
||||
class EchoHandler(web.RequestHandler):
|
||||
def get(self):
|
||||
self.write(self.request.path)
|
||||
|
||||
|
||||
class ArgsHandler(web.RequestHandler):
|
||||
def get(self):
|
||||
self.write(json.dumps(sys.argv))
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
app = web.Application([
|
||||
(r'.*/args', ArgsHandler),
|
||||
(r'.*/env', EnvHandler),
|
||||
(r'.*', EchoHandler),
|
||||
])
|
||||
app = web.Application(
|
||||
[(r'.*/args', ArgsHandler), (r'.*/env', EnvHandler), (r'.*', EchoHandler)]
|
||||
)
|
||||
|
||||
ssl_context = None
|
||||
key = os.environ.get('JUPYTERHUB_SSL_KEYFILE') or ''
|
||||
@@ -42,12 +46,7 @@ def main(args):
|
||||
ca = os.environ.get('JUPYTERHUB_SSL_CLIENT_CA') or ''
|
||||
|
||||
if key and cert and ca:
|
||||
ssl_context = make_ssl_context(
|
||||
key,
|
||||
cert,
|
||||
cafile = ca,
|
||||
check_hostname = False
|
||||
)
|
||||
ssl_context = make_ssl_context(key, cert, cafile=ca, check_hostname=False)
|
||||
|
||||
server = httpserver.HTTPServer(app, ssl_options=ssl_context)
|
||||
server.listen(args.port)
|
||||
@@ -56,6 +55,7 @@ def main(args):
|
||||
except KeyboardInterrupt:
|
||||
print('\nInterrupted')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--port', type=int)
|
||||
|
@@ -4,9 +4,8 @@ Run with old versions of jupyterhub to test upgrade/downgrade
|
||||
|
||||
used in test_db.py
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import jupyterhub
|
||||
from jupyterhub import orm
|
||||
@@ -90,6 +89,7 @@ def populate_db(url):
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
url = sys.argv[1]
|
||||
else:
|
||||
|
@@ -1,30 +1,32 @@
|
||||
"""Tests for the REST API."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from concurrent.futures import Future
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from unittest import mock
|
||||
from urllib.parse import urlparse, quote
|
||||
import uuid
|
||||
from async_generator import async_generator, yield_
|
||||
from concurrent.futures import Future
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from unittest import mock
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from async_generator import async_generator
|
||||
from async_generator import yield_
|
||||
from pytest import mark
|
||||
from tornado import gen
|
||||
|
||||
import jupyterhub
|
||||
from .. import orm
|
||||
from ..utils import url_path_join as ujoin, utcnow
|
||||
from .mocking import public_host, public_url
|
||||
from .utils import (
|
||||
add_user,
|
||||
api_request,
|
||||
async_requests,
|
||||
auth_header,
|
||||
find_user,
|
||||
)
|
||||
from ..utils import url_path_join as ujoin
|
||||
from ..utils import utcnow
|
||||
from .mocking import public_host
|
||||
from .mocking import public_url
|
||||
from .utils import add_user
|
||||
from .utils import api_request
|
||||
from .utils import async_requests
|
||||
from .utils import auth_header
|
||||
from .utils import find_user
|
||||
|
||||
|
||||
# --------------------
|
||||
@@ -48,12 +50,15 @@ async def test_auth_api(app):
|
||||
assert reply['name'] == user.name
|
||||
|
||||
# check fail
|
||||
r = await api_request(app, 'authorizations/token', api_token,
|
||||
headers={'Authorization': 'no sir'},
|
||||
r = await api_request(
|
||||
app, 'authorizations/token', api_token, headers={'Authorization': 'no sir'}
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
r = await api_request(app, 'authorizations/token', api_token,
|
||||
r = await api_request(
|
||||
app,
|
||||
'authorizations/token',
|
||||
api_token,
|
||||
headers={'Authorization': 'token: %s' % user.cookie_id},
|
||||
)
|
||||
assert r.status_code == 403
|
||||
@@ -67,37 +72,39 @@ async def test_referer_check(app):
|
||||
user = add_user(app.db, name='admin', admin=True)
|
||||
cookies = await app.login_user('admin')
|
||||
|
||||
r = await api_request(app, 'users',
|
||||
headers={
|
||||
'Authorization': '',
|
||||
'Referer': 'null',
|
||||
}, cookies=cookies,
|
||||
r = await api_request(
|
||||
app, 'users', headers={'Authorization': '', 'Referer': 'null'}, cookies=cookies
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
r = await api_request(app, 'users',
|
||||
r = await api_request(
|
||||
app,
|
||||
'users',
|
||||
headers={
|
||||
'Authorization': '',
|
||||
'Referer': 'http://attack.com/csrf/vulnerability',
|
||||
}, cookies=cookies,
|
||||
},
|
||||
cookies=cookies,
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
r = await api_request(app, 'users',
|
||||
headers={
|
||||
'Authorization': '',
|
||||
'Referer': url,
|
||||
'Host': host,
|
||||
}, cookies=cookies,
|
||||
r = await api_request(
|
||||
app,
|
||||
'users',
|
||||
headers={'Authorization': '', 'Referer': url, 'Host': host},
|
||||
cookies=cookies,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
r = await api_request(app, 'users',
|
||||
r = await api_request(
|
||||
app,
|
||||
'users',
|
||||
headers={
|
||||
'Authorization': '',
|
||||
'Referer': ujoin(url, 'foo/bar/baz/bat'),
|
||||
'Host': host,
|
||||
}, cookies=cookies,
|
||||
},
|
||||
cookies=cookies,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
@@ -106,6 +113,7 @@ async def test_referer_check(app):
|
||||
# User API tests
|
||||
# --------------
|
||||
|
||||
|
||||
def normalize_timestamp(ts):
|
||||
"""Normalize a timestamp
|
||||
|
||||
@@ -128,12 +136,16 @@ def normalize_user(user):
|
||||
for server in user['servers'].values():
|
||||
for key in ('started', 'last_activity'):
|
||||
server[key] = normalize_timestamp(server[key])
|
||||
server['progress_url'] = re.sub(r'.*/hub/api', 'PREFIX/hub/api', server['progress_url'])
|
||||
if (isinstance(server['state'], dict)
|
||||
and isinstance(server['state'].get('pid', None), int)):
|
||||
server['progress_url'] = re.sub(
|
||||
r'.*/hub/api', 'PREFIX/hub/api', server['progress_url']
|
||||
)
|
||||
if isinstance(server['state'], dict) and isinstance(
|
||||
server['state'].get('pid', None), int
|
||||
):
|
||||
server['state']['pid'] = 0
|
||||
return user
|
||||
|
||||
|
||||
def fill_user(model):
|
||||
"""Fill a default user model
|
||||
|
||||
@@ -153,6 +165,7 @@ def fill_user(model):
|
||||
|
||||
TIMESTAMP = normalize_timestamp(datetime.now().isoformat() + 'Z')
|
||||
|
||||
|
||||
@mark.user
|
||||
async def test_get_users(app):
|
||||
db = app.db
|
||||
@@ -162,20 +175,11 @@ async def test_get_users(app):
|
||||
users = sorted(r.json(), key=lambda d: d['name'])
|
||||
users = [normalize_user(u) for u in users]
|
||||
assert users == [
|
||||
fill_user({
|
||||
'name': 'admin',
|
||||
'admin': True,
|
||||
}),
|
||||
fill_user({
|
||||
'name': 'user',
|
||||
'admin': False,
|
||||
'last_activity': None,
|
||||
}),
|
||||
fill_user({'name': 'admin', 'admin': True}),
|
||||
fill_user({'name': 'user', 'admin': False, 'last_activity': None}),
|
||||
]
|
||||
|
||||
r = await api_request(app, 'users',
|
||||
headers=auth_header(db, 'user'),
|
||||
)
|
||||
r = await api_request(app, 'users', headers=auth_header(db, 'user'))
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
@@ -202,17 +206,13 @@ async def test_get_self(app):
|
||||
)
|
||||
db.add(oauth_token)
|
||||
db.commit()
|
||||
r = await api_request(app, 'user', headers={
|
||||
'Authorization': 'token ' + token,
|
||||
})
|
||||
r = await api_request(app, 'user', headers={'Authorization': 'token ' + token})
|
||||
r.raise_for_status()
|
||||
model = r.json()
|
||||
assert model['name'] == u.name
|
||||
|
||||
# invalid auth gets 403
|
||||
r = await api_request(app, 'user', headers={
|
||||
'Authorization': 'token notvalid',
|
||||
})
|
||||
r = await api_request(app, 'user', headers={'Authorization': 'token notvalid'})
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
@@ -251,8 +251,11 @@ async def test_add_multi_user_bad(app):
|
||||
@mark.user
|
||||
async def test_add_multi_user_invalid(app):
|
||||
app.authenticator.username_pattern = r'w.*'
|
||||
r = await api_request(app, 'users', method='post',
|
||||
data=json.dumps({'usernames': ['Willow', 'Andrew', 'Tara']})
|
||||
r = await api_request(
|
||||
app,
|
||||
'users',
|
||||
method='post',
|
||||
data=json.dumps({'usernames': ['Willow', 'Andrew', 'Tara']}),
|
||||
)
|
||||
app.authenticator.username_pattern = ''
|
||||
assert r.status_code == 400
|
||||
@@ -263,8 +266,8 @@ async def test_add_multi_user_invalid(app):
|
||||
async def test_add_multi_user(app):
|
||||
db = app.db
|
||||
names = ['a', 'b']
|
||||
r = await api_request(app, 'users', method='post',
|
||||
data=json.dumps({'usernames': names}),
|
||||
r = await api_request(
|
||||
app, 'users', method='post', data=json.dumps({'usernames': names})
|
||||
)
|
||||
assert r.status_code == 201
|
||||
reply = r.json()
|
||||
@@ -278,16 +281,16 @@ async def test_add_multi_user(app):
|
||||
assert not user.admin
|
||||
|
||||
# try to create the same users again
|
||||
r = await api_request(app, 'users', method='post',
|
||||
data=json.dumps({'usernames': names}),
|
||||
r = await api_request(
|
||||
app, 'users', method='post', data=json.dumps({'usernames': names})
|
||||
)
|
||||
assert r.status_code == 409
|
||||
|
||||
names = ['a', 'b', 'ab']
|
||||
|
||||
# try to create the same users again
|
||||
r = await api_request(app, 'users', method='post',
|
||||
data=json.dumps({'usernames': names}),
|
||||
r = await api_request(
|
||||
app, 'users', method='post', data=json.dumps({'usernames': names})
|
||||
)
|
||||
assert r.status_code == 201
|
||||
reply = r.json()
|
||||
@@ -299,7 +302,10 @@ async def test_add_multi_user(app):
|
||||
async def test_add_multi_user_admin(app):
|
||||
db = app.db
|
||||
names = ['c', 'd']
|
||||
r = await api_request(app, 'users', method='post',
|
||||
r = await api_request(
|
||||
app,
|
||||
'users',
|
||||
method='post',
|
||||
data=json.dumps({'usernames': names, 'admin': True}),
|
||||
)
|
||||
assert r.status_code == 201
|
||||
@@ -340,8 +346,8 @@ async def test_add_user_duplicate(app):
|
||||
async def test_add_admin(app):
|
||||
db = app.db
|
||||
name = 'newadmin'
|
||||
r = await api_request(app, 'users', name, method='post',
|
||||
data=json.dumps({'admin': True}),
|
||||
r = await api_request(
|
||||
app, 'users', name, method='post', data=json.dumps({'admin': True})
|
||||
)
|
||||
assert r.status_code == 201
|
||||
user = find_user(db, name)
|
||||
@@ -369,8 +375,8 @@ async def test_make_admin(app):
|
||||
assert user.name == name
|
||||
assert not user.admin
|
||||
|
||||
r = await api_request(app, 'users', name, method='patch',
|
||||
data=json.dumps({'admin': True})
|
||||
r = await api_request(
|
||||
app, 'users', name, method='patch', data=json.dumps({'admin': True})
|
||||
)
|
||||
assert r.status_code == 200
|
||||
user = find_user(db, name)
|
||||
@@ -388,8 +394,8 @@ async def test_set_auth_state(app, auth_state_enabled):
|
||||
assert user is not None
|
||||
assert user.name == name
|
||||
|
||||
r = await api_request(app, 'users', name, method='patch',
|
||||
data=json.dumps({'auth_state': auth_state})
|
||||
r = await api_request(
|
||||
app, 'users', name, method='patch', data=json.dumps({'auth_state': auth_state})
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
@@ -409,7 +415,10 @@ async def test_user_set_auth_state(app, auth_state_enabled):
|
||||
assert user_auth_state is None
|
||||
|
||||
r = await api_request(
|
||||
app, 'users', name, method='patch',
|
||||
app,
|
||||
'users',
|
||||
name,
|
||||
method='patch',
|
||||
data=json.dumps({'auth_state': auth_state}),
|
||||
headers=auth_header(app.db, name),
|
||||
)
|
||||
@@ -446,8 +455,7 @@ async def test_user_get_auth_state(app, auth_state_enabled):
|
||||
assert user.name == name
|
||||
await user.save_auth_state(auth_state)
|
||||
|
||||
r = await api_request(app, 'users', name,
|
||||
headers=auth_header(app.db, name))
|
||||
r = await api_request(app, 'users', name, headers=auth_header(app.db, name))
|
||||
|
||||
assert r.status_code == 200
|
||||
assert 'auth_state' not in r.json()
|
||||
@@ -457,13 +465,10 @@ async def test_spawn(app):
|
||||
db = app.db
|
||||
name = 'wash'
|
||||
user = add_user(db, app=app, name=name)
|
||||
options = {
|
||||
's': ['value'],
|
||||
'i': 5,
|
||||
}
|
||||
options = {'s': ['value'], 'i': 5}
|
||||
before_servers = sorted(db.query(orm.Server), key=lambda s: s.url)
|
||||
r = await api_request(app, 'users', name, 'server', method='post',
|
||||
data=json.dumps(options),
|
||||
r = await api_request(
|
||||
app, 'users', name, 'server', method='post', data=json.dumps(options)
|
||||
)
|
||||
assert r.status_code == 201
|
||||
assert 'pid' in user.orm_spawners[''].state
|
||||
@@ -520,7 +525,9 @@ async def test_spawn_handler(app):
|
||||
app_user = app.users[name]
|
||||
|
||||
# spawn via API with ?foo=bar
|
||||
r = await api_request(app, 'users', name, 'server', method='post', params={'foo': 'bar'})
|
||||
r = await api_request(
|
||||
app, 'users', name, 'server', method='post', params={'foo': 'bar'}
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
# verify that request params got passed down
|
||||
@@ -640,6 +647,7 @@ def next_event(it):
|
||||
if line.startswith('data:'):
|
||||
return json.loads(line.split(':', 1)[1])
|
||||
|
||||
|
||||
@mark.slow
|
||||
async def test_progress(request, app, no_patience, slow_spawn):
|
||||
db = app.db
|
||||
@@ -655,15 +663,9 @@ async def test_progress(request, app, no_patience, slow_spawn):
|
||||
ex = async_requests.executor
|
||||
line_iter = iter(r.iter_lines(decode_unicode=True))
|
||||
evt = await ex.submit(next_event, line_iter)
|
||||
assert evt == {
|
||||
'progress': 0,
|
||||
'message': 'Server requested',
|
||||
}
|
||||
assert evt == {'progress': 0, 'message': 'Server requested'}
|
||||
evt = await ex.submit(next_event, line_iter)
|
||||
assert evt == {
|
||||
'progress': 50,
|
||||
'message': 'Spawning server...',
|
||||
}
|
||||
assert evt == {'progress': 50, 'message': 'Spawning server...'}
|
||||
evt = await ex.submit(next_event, line_iter)
|
||||
url = app_user.url
|
||||
assert evt == {
|
||||
@@ -769,10 +771,7 @@ async def test_progress_bad_slow(request, app, no_patience, slow_bad_spawn):
|
||||
async def progress_forever():
|
||||
"""progress function that yields messages forever"""
|
||||
for i in range(1, 10):
|
||||
await yield_({
|
||||
'progress': i,
|
||||
'message': 'Stage %s' % i,
|
||||
})
|
||||
await yield_({'progress': i, 'message': 'Stage %s' % i})
|
||||
# wait a long time before the next event
|
||||
await gen.sleep(10)
|
||||
|
||||
@@ -781,7 +780,8 @@ if sys.version_info >= (3, 6):
|
||||
# additional progress_forever defined as native
|
||||
# async generator
|
||||
# to test for issues with async_generator wrappers
|
||||
exec("""
|
||||
exec(
|
||||
"""
|
||||
async def progress_forever_native():
|
||||
for i in range(1, 10):
|
||||
yield {
|
||||
@@ -790,7 +790,9 @@ async def progress_forever_native():
|
||||
}
|
||||
# wait a long time before the next event
|
||||
await gen.sleep(10)
|
||||
""", globals())
|
||||
""",
|
||||
globals(),
|
||||
)
|
||||
|
||||
|
||||
async def test_spawn_progress_cutoff(request, app, no_patience, slow_spawn):
|
||||
@@ -818,18 +820,14 @@ async def test_spawn_progress_cutoff(request, app, no_patience, slow_spawn):
|
||||
evt = await ex.submit(next_event, line_iter)
|
||||
assert evt['progress'] == 0
|
||||
evt = await ex.submit(next_event, line_iter)
|
||||
assert evt == {
|
||||
'progress': 1,
|
||||
'message': 'Stage 1',
|
||||
}
|
||||
assert evt == {'progress': 1, 'message': 'Stage 1'}
|
||||
evt = await ex.submit(next_event, line_iter)
|
||||
assert evt['progress'] == 100
|
||||
|
||||
|
||||
async def test_spawn_limit(app, no_patience, slow_spawn, request):
|
||||
db = app.db
|
||||
p = mock.patch.dict(app.tornado_settings,
|
||||
{'concurrent_spawn_limit': 2})
|
||||
p = mock.patch.dict(app.tornado_settings, {'concurrent_spawn_limit': 2})
|
||||
p.start()
|
||||
request.addfinalizer(p.stop)
|
||||
|
||||
@@ -875,11 +873,11 @@ async def test_spawn_limit(app, no_patience, slow_spawn, request):
|
||||
while any(u.spawner.active for u in users):
|
||||
await gen.sleep(0.1)
|
||||
|
||||
|
||||
@mark.slow
|
||||
async def test_active_server_limit(app, request):
|
||||
db = app.db
|
||||
p = mock.patch.dict(app.tornado_settings,
|
||||
{'active_server_limit': 2})
|
||||
p = mock.patch.dict(app.tornado_settings, {'active_server_limit': 2})
|
||||
p.start()
|
||||
request.addfinalizer(p.stop)
|
||||
|
||||
@@ -932,6 +930,7 @@ async def test_active_server_limit(app, request):
|
||||
assert counts['ready'] == 0
|
||||
assert counts['pending'] == 0
|
||||
|
||||
|
||||
@mark.slow
|
||||
async def test_start_stop_race(app, no_patience, slow_spawn):
|
||||
user = add_user(app.db, app, name='panda')
|
||||
@@ -996,22 +995,18 @@ async def test_cookie(app):
|
||||
cookie_name = app.hub.cookie_name
|
||||
# cookie jar gives '"cookie-value"', we want 'cookie-value'
|
||||
cookie = cookies[cookie_name][1:-1]
|
||||
r = await api_request(app, 'authorizations/cookie',
|
||||
cookie_name, "nothintoseehere",
|
||||
)
|
||||
r = await api_request(app, 'authorizations/cookie', cookie_name, "nothintoseehere")
|
||||
assert r.status_code == 404
|
||||
|
||||
r = await api_request(app, 'authorizations/cookie',
|
||||
cookie_name, quote(cookie, safe=''),
|
||||
r = await api_request(
|
||||
app, 'authorizations/cookie', cookie_name, quote(cookie, safe='')
|
||||
)
|
||||
r.raise_for_status()
|
||||
reply = r.json()
|
||||
assert reply['name'] == name
|
||||
|
||||
# deprecated cookie in body:
|
||||
r = await api_request(app, 'authorizations/cookie',
|
||||
cookie_name, data=cookie,
|
||||
)
|
||||
r = await api_request(app, 'authorizations/cookie', cookie_name, data=cookie)
|
||||
r.raise_for_status()
|
||||
reply = r.json()
|
||||
assert reply['name'] == name
|
||||
@@ -1035,15 +1030,11 @@ async def test_check_token(app):
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
@mark.parametrize("headers, status", [
|
||||
({}, 200),
|
||||
({'Authorization': 'token bad'}, 403),
|
||||
])
|
||||
@mark.parametrize("headers, status", [({}, 200), ({'Authorization': 'token bad'}, 403)])
|
||||
async def test_get_new_token_deprecated(app, headers, status):
|
||||
# request a new token
|
||||
r = await api_request(app, 'authorizations', 'token',
|
||||
method='post',
|
||||
headers=headers,
|
||||
r = await api_request(
|
||||
app, 'authorizations', 'token', method='post', headers=headers
|
||||
)
|
||||
assert r.status_code == status
|
||||
if status != 200:
|
||||
@@ -1058,11 +1049,11 @@ async def test_get_new_token_deprecated(app, headers, status):
|
||||
|
||||
async def test_token_formdata_deprecated(app):
|
||||
"""Create a token for a user with formdata and no auth header"""
|
||||
data = {
|
||||
'username': 'fake',
|
||||
'password': 'fake',
|
||||
}
|
||||
r = await api_request(app, 'authorizations', 'token',
|
||||
data = {'username': 'fake', 'password': 'fake'}
|
||||
r = await api_request(
|
||||
app,
|
||||
'authorizations',
|
||||
'token',
|
||||
method='post',
|
||||
data=json.dumps(data) if data else None,
|
||||
noauth=True,
|
||||
@@ -1076,22 +1067,26 @@ async def test_token_formdata_deprecated(app):
|
||||
assert reply['name'] == data['username']
|
||||
|
||||
|
||||
@mark.parametrize("as_user, for_user, status", [
|
||||
@mark.parametrize(
|
||||
"as_user, for_user, status",
|
||||
[
|
||||
('admin', 'other', 200),
|
||||
('admin', 'missing', 400),
|
||||
('user', 'other', 403),
|
||||
('user', 'user', 200),
|
||||
])
|
||||
],
|
||||
)
|
||||
async def test_token_as_user_deprecated(app, as_user, for_user, status):
|
||||
# ensure both users exist
|
||||
u = add_user(app.db, app, name=as_user)
|
||||
if for_user != 'missing':
|
||||
add_user(app.db, app, name=for_user)
|
||||
data = {'username': for_user}
|
||||
headers = {
|
||||
'Authorization': 'token %s' % u.new_api_token(),
|
||||
}
|
||||
r = await api_request(app, 'authorizations', 'token',
|
||||
headers = {'Authorization': 'token %s' % u.new_api_token()}
|
||||
r = await api_request(
|
||||
app,
|
||||
'authorizations',
|
||||
'token',
|
||||
method='post',
|
||||
data=json.dumps(data),
|
||||
headers=headers,
|
||||
@@ -1107,11 +1102,14 @@ async def test_token_as_user_deprecated(app, as_user, for_user, status):
|
||||
assert reply['name'] == data['username']
|
||||
|
||||
|
||||
@mark.parametrize("headers, status, note, expires_in", [
|
||||
@mark.parametrize(
|
||||
"headers, status, note, expires_in",
|
||||
[
|
||||
({}, 200, 'test note', None),
|
||||
({}, 200, '', 100),
|
||||
({'Authorization': 'token bad'}, 403, '', None),
|
||||
])
|
||||
],
|
||||
)
|
||||
async def test_get_new_token(app, headers, status, note, expires_in):
|
||||
options = {}
|
||||
if note:
|
||||
@@ -1123,10 +1121,8 @@ async def test_get_new_token(app, headers, status, note, expires_in):
|
||||
else:
|
||||
body = ''
|
||||
# request a new token
|
||||
r = await api_request(app, 'users/admin/tokens',
|
||||
method='post',
|
||||
headers=headers,
|
||||
data=body,
|
||||
r = await api_request(
|
||||
app, 'users/admin/tokens', method='post', headers=headers, data=body
|
||||
)
|
||||
assert r.status_code == status
|
||||
if status != 200:
|
||||
@@ -1157,30 +1153,34 @@ async def test_get_new_token(app, headers, status, note, expires_in):
|
||||
assert normalize_token(reply) == initial
|
||||
|
||||
# delete the token
|
||||
r = await api_request(app, 'users/admin/tokens', token_id,
|
||||
method='delete')
|
||||
r = await api_request(app, 'users/admin/tokens', token_id, method='delete')
|
||||
assert r.status_code == 204
|
||||
# verify deletion
|
||||
r = await api_request(app, 'users/admin/tokens', token_id)
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
@mark.parametrize("as_user, for_user, status", [
|
||||
@mark.parametrize(
|
||||
"as_user, for_user, status",
|
||||
[
|
||||
('admin', 'other', 200),
|
||||
('admin', 'missing', 404),
|
||||
('user', 'other', 403),
|
||||
('user', 'user', 200),
|
||||
])
|
||||
],
|
||||
)
|
||||
async def test_token_for_user(app, as_user, for_user, status):
|
||||
# ensure both users exist
|
||||
u = add_user(app.db, app, name=as_user)
|
||||
if for_user != 'missing':
|
||||
add_user(app.db, app, name=for_user)
|
||||
data = {'username': for_user}
|
||||
headers = {
|
||||
'Authorization': 'token %s' % u.new_api_token(),
|
||||
}
|
||||
r = await api_request(app, 'users', for_user, 'tokens',
|
||||
headers = {'Authorization': 'token %s' % u.new_api_token()}
|
||||
r = await api_request(
|
||||
app,
|
||||
'users',
|
||||
for_user,
|
||||
'tokens',
|
||||
method='post',
|
||||
data=json.dumps(data),
|
||||
headers=headers,
|
||||
@@ -1191,9 +1191,7 @@ async def test_token_for_user(app, as_user, for_user, status):
|
||||
return
|
||||
assert 'token' in reply
|
||||
token_id = reply['id']
|
||||
r = await api_request(app, 'users', for_user, 'tokens', token_id,
|
||||
headers=headers,
|
||||
)
|
||||
r = await api_request(app, 'users', for_user, 'tokens', token_id, headers=headers)
|
||||
r.raise_for_status()
|
||||
reply = r.json()
|
||||
assert reply['user'] == for_user
|
||||
@@ -1203,30 +1201,25 @@ async def test_token_for_user(app, as_user, for_user, status):
|
||||
note = 'Requested via api by user %s' % as_user
|
||||
assert reply['note'] == note
|
||||
|
||||
|
||||
# delete the token
|
||||
r = await api_request(app, 'users', for_user, 'tokens', token_id,
|
||||
method='delete',
|
||||
headers=headers,
|
||||
r = await api_request(
|
||||
app, 'users', for_user, 'tokens', token_id, method='delete', headers=headers
|
||||
)
|
||||
|
||||
assert r.status_code == 204
|
||||
r = await api_request(app, 'users', for_user, 'tokens', token_id,
|
||||
headers=headers,
|
||||
)
|
||||
r = await api_request(app, 'users', for_user, 'tokens', token_id, headers=headers)
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
async def test_token_authenticator_noauth(app):
|
||||
"""Create a token for a user relying on Authenticator.authenticate and no auth header"""
|
||||
name = 'user'
|
||||
data = {
|
||||
'auth': {
|
||||
'username': name,
|
||||
'password': name,
|
||||
},
|
||||
}
|
||||
r = await api_request(app, 'users', name, 'tokens',
|
||||
data = {'auth': {'username': name, 'password': name}}
|
||||
r = await api_request(
|
||||
app,
|
||||
'users',
|
||||
name,
|
||||
'tokens',
|
||||
method='post',
|
||||
data=json.dumps(data) if data else None,
|
||||
noauth=True,
|
||||
@@ -1242,17 +1235,14 @@ async def test_token_authenticator_noauth(app):
|
||||
|
||||
async def test_token_authenticator_dict_noauth(app):
|
||||
"""Create a token for a user relying on Authenticator.authenticate and no auth header"""
|
||||
app.authenticator.auth_state = {
|
||||
'who': 'cares',
|
||||
}
|
||||
app.authenticator.auth_state = {'who': 'cares'}
|
||||
name = 'user'
|
||||
data = {
|
||||
'auth': {
|
||||
'username': name,
|
||||
'password': name,
|
||||
},
|
||||
}
|
||||
r = await api_request(app, 'users', name, 'tokens',
|
||||
data = {'auth': {'username': name, 'password': name}}
|
||||
r = await api_request(
|
||||
app,
|
||||
'users',
|
||||
name,
|
||||
'tokens',
|
||||
method='post',
|
||||
data=json.dumps(data) if data else None,
|
||||
noauth=True,
|
||||
@@ -1266,22 +1256,21 @@ async def test_token_authenticator_dict_noauth(app):
|
||||
assert reply['name'] == name
|
||||
|
||||
|
||||
@mark.parametrize("as_user, for_user, status", [
|
||||
@mark.parametrize(
|
||||
"as_user, for_user, status",
|
||||
[
|
||||
('admin', 'other', 200),
|
||||
('admin', 'missing', 404),
|
||||
('user', 'other', 403),
|
||||
('user', 'user', 200),
|
||||
])
|
||||
],
|
||||
)
|
||||
async def test_token_list(app, as_user, for_user, status):
|
||||
u = add_user(app.db, app, name=as_user)
|
||||
if for_user != 'missing':
|
||||
for_user_obj = add_user(app.db, app, name=for_user)
|
||||
headers = {
|
||||
'Authorization': 'token %s' % u.new_api_token(),
|
||||
}
|
||||
r = await api_request(app, 'users', for_user, 'tokens',
|
||||
headers=headers,
|
||||
)
|
||||
headers = {'Authorization': 'token %s' % u.new_api_token()}
|
||||
r = await api_request(app, 'users', for_user, 'tokens', headers=headers)
|
||||
assert r.status_code == status
|
||||
if status != 200:
|
||||
return
|
||||
@@ -1292,8 +1281,8 @@ async def test_token_list(app, as_user, for_user, status):
|
||||
assert all(token['user'] == for_user for token in reply['oauth_tokens'])
|
||||
# validate individual token ids
|
||||
for token in reply['api_tokens'] + reply['oauth_tokens']:
|
||||
r = await api_request(app, 'users', for_user, 'tokens', token['id'],
|
||||
headers=headers,
|
||||
r = await api_request(
|
||||
app, 'users', for_user, 'tokens', token['id'], headers=headers
|
||||
)
|
||||
r.raise_for_status()
|
||||
reply = r.json()
|
||||
@@ -1320,19 +1309,15 @@ async def test_groups_list(app):
|
||||
r = await api_request(app, 'groups')
|
||||
r.raise_for_status()
|
||||
reply = r.json()
|
||||
assert reply == [{
|
||||
'kind': 'group',
|
||||
'name': 'alphaflight',
|
||||
'users': []
|
||||
}]
|
||||
assert reply == [{'kind': 'group', 'name': 'alphaflight', 'users': []}]
|
||||
|
||||
|
||||
@mark.group
|
||||
async def test_add_multi_group(app):
|
||||
db = app.db
|
||||
names = ['group1', 'group2']
|
||||
r = await api_request(app, 'groups', method='post',
|
||||
data=json.dumps({'groups': names}),
|
||||
r = await api_request(
|
||||
app, 'groups', method='post', data=json.dumps({'groups': names})
|
||||
)
|
||||
assert r.status_code == 201
|
||||
reply = r.json()
|
||||
@@ -1340,8 +1325,8 @@ async def test_add_multi_group(app):
|
||||
assert names == r_names
|
||||
|
||||
# try to create the same groups again
|
||||
r = await api_request(app, 'groups', method='post',
|
||||
data=json.dumps({'groups': names}),
|
||||
r = await api_request(
|
||||
app, 'groups', method='post', data=json.dumps({'groups': names})
|
||||
)
|
||||
assert r.status_code == 409
|
||||
|
||||
@@ -1359,11 +1344,7 @@ async def test_group_get(app):
|
||||
r = await api_request(app, 'groups/alphaflight')
|
||||
r.raise_for_status()
|
||||
reply = r.json()
|
||||
assert reply == {
|
||||
'kind': 'group',
|
||||
'name': 'alphaflight',
|
||||
'users': ['sasquatch']
|
||||
}
|
||||
assert reply == {'kind': 'group', 'name': 'alphaflight', 'users': ['sasquatch']}
|
||||
|
||||
|
||||
@mark.group
|
||||
@@ -1372,13 +1353,16 @@ async def test_group_create_delete(app):
|
||||
r = await api_request(app, 'groups/runaways', method='delete')
|
||||
assert r.status_code == 404
|
||||
|
||||
r = await api_request(app, 'groups/new', method='post',
|
||||
data=json.dumps({'users': ['doesntexist']}),
|
||||
r = await api_request(
|
||||
app, 'groups/new', method='post', data=json.dumps({'users': ['doesntexist']})
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert orm.Group.find(db, name='new') is None
|
||||
|
||||
r = await api_request(app, 'groups/omegaflight', method='post',
|
||||
r = await api_request(
|
||||
app,
|
||||
'groups/omegaflight',
|
||||
method='post',
|
||||
data=json.dumps({'users': ['sasquatch']}),
|
||||
)
|
||||
r.raise_for_status()
|
||||
@@ -1410,10 +1394,15 @@ async def test_group_add_users(app):
|
||||
assert r.status_code == 400
|
||||
|
||||
names = ['aurora', 'guardian', 'northstar', 'sasquatch', 'shaman', 'snowbird']
|
||||
users = [ find_user(db, name=name) or add_user(db, app=app, name=name) for name in names ]
|
||||
r = await api_request(app, 'groups/alphaflight/users', method='post', data=json.dumps({
|
||||
'users': names,
|
||||
}))
|
||||
users = [
|
||||
find_user(db, name=name) or add_user(db, app=app, name=name) for name in names
|
||||
]
|
||||
r = await api_request(
|
||||
app,
|
||||
'groups/alphaflight/users',
|
||||
method='post',
|
||||
data=json.dumps({'users': names}),
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
for user in users:
|
||||
@@ -1433,9 +1422,12 @@ async def test_group_delete_users(app):
|
||||
|
||||
names = ['aurora', 'guardian', 'northstar', 'sasquatch', 'shaman', 'snowbird']
|
||||
users = [find_user(db, name=name) for name in names]
|
||||
r = await api_request(app, 'groups/alphaflight/users', method='delete', data=json.dumps({
|
||||
'users': names[:2],
|
||||
}))
|
||||
r = await api_request(
|
||||
app,
|
||||
'groups/alphaflight/users',
|
||||
method='delete',
|
||||
data=json.dumps({'users': names[:2]}),
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
for user in users[:2]:
|
||||
@@ -1473,9 +1465,7 @@ async def test_get_services(app, mockservice_url):
|
||||
}
|
||||
}
|
||||
|
||||
r = await api_request(app, 'services',
|
||||
headers=auth_header(db, 'user'),
|
||||
)
|
||||
r = await api_request(app, 'services', headers=auth_header(db, 'user'))
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
@@ -1498,14 +1488,14 @@ async def test_get_service(app, mockservice_url):
|
||||
'info': {},
|
||||
}
|
||||
|
||||
r = await api_request(app, 'services/%s' % mockservice.name,
|
||||
headers={
|
||||
'Authorization': 'token %s' % mockservice.api_token
|
||||
}
|
||||
r = await api_request(
|
||||
app,
|
||||
'services/%s' % mockservice.name,
|
||||
headers={'Authorization': 'token %s' % mockservice.api_token},
|
||||
)
|
||||
r.raise_for_status()
|
||||
r = await api_request(app, 'services/%s' % mockservice.name,
|
||||
headers=auth_header(db, 'user'),
|
||||
r = await api_request(
|
||||
app, 'services/%s' % mockservice.name, headers=auth_header(db, 'user')
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
@@ -1519,9 +1509,7 @@ async def test_root_api(app):
|
||||
kwargs["verify"] = app.internal_ssl_ca
|
||||
r = await async_requests.get(url, **kwargs)
|
||||
r.raise_for_status()
|
||||
expected = {
|
||||
'version': jupyterhub.__version__
|
||||
}
|
||||
expected = {'version': jupyterhub.__version__}
|
||||
assert r.json() == expected
|
||||
|
||||
|
||||
@@ -1662,15 +1650,20 @@ def test_shutdown(app):
|
||||
# which makes gen_test unhappy. So we run the loop ourselves.
|
||||
|
||||
async def shutdown():
|
||||
r = await api_request(app, 'shutdown', method='post',
|
||||
data=json.dumps({'servers': True, 'proxy': True,}),
|
||||
r = await api_request(
|
||||
app,
|
||||
'shutdown',
|
||||
method='post',
|
||||
data=json.dumps({'servers': True, 'proxy': True}),
|
||||
)
|
||||
return r
|
||||
|
||||
real_stop = loop.stop
|
||||
|
||||
def stop():
|
||||
stop.called = True
|
||||
loop.call_later(1, real_stop)
|
||||
|
||||
with mock.patch.object(loop, 'stop', stop):
|
||||
r = loop.run_sync(shutdown, timeout=5)
|
||||
r.raise_for_status()
|
||||
|
@@ -1,25 +1,30 @@
|
||||
"""Test the JupyterHub entry point"""
|
||||
|
||||
import binascii
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from subprocess import check_output, Popen, PIPE
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
from subprocess import check_output
|
||||
from subprocess import PIPE
|
||||
from subprocess import Popen
|
||||
from tempfile import NamedTemporaryFile
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from tornado import gen
|
||||
from traitlets.config import Config
|
||||
|
||||
from .. import orm
|
||||
from ..app import COOKIE_SECRET_BYTES
|
||||
from ..app import JupyterHub
|
||||
from .mocking import MockHub
|
||||
from .test_api import add_user
|
||||
from .. import orm
|
||||
from ..app import COOKIE_SECRET_BYTES, JupyterHub
|
||||
|
||||
|
||||
def test_help_all():
|
||||
out = check_output([sys.executable, '-m', 'jupyterhub', '--help-all']).decode('utf8', 'replace')
|
||||
out = check_output([sys.executable, '-m', 'jupyterhub', '--help-all']).decode(
|
||||
'utf8', 'replace'
|
||||
)
|
||||
assert '--ip' in out
|
||||
assert '--JupyterHub.ip' in out
|
||||
|
||||
@@ -39,9 +44,11 @@ def test_generate_config():
|
||||
cfg_file = tf.name
|
||||
with open(cfg_file, 'w') as f:
|
||||
f.write("c.A = 5")
|
||||
p = Popen([sys.executable, '-m', 'jupyterhub',
|
||||
'--generate-config', '-f', cfg_file],
|
||||
stdout=PIPE, stdin=PIPE)
|
||||
p = Popen(
|
||||
[sys.executable, '-m', 'jupyterhub', '--generate-config', '-f', cfg_file],
|
||||
stdout=PIPE,
|
||||
stdin=PIPE,
|
||||
)
|
||||
out, _ = p.communicate(b'n')
|
||||
out = out.decode('utf8', 'replace')
|
||||
assert os.path.exists(cfg_file)
|
||||
@@ -49,9 +56,11 @@ def test_generate_config():
|
||||
cfg_text = f.read()
|
||||
assert cfg_text == 'c.A = 5'
|
||||
|
||||
p = Popen([sys.executable, '-m', 'jupyterhub',
|
||||
'--generate-config', '-f', cfg_file],
|
||||
stdout=PIPE, stdin=PIPE)
|
||||
p = Popen(
|
||||
[sys.executable, '-m', 'jupyterhub', '--generate-config', '-f', cfg_file],
|
||||
stdout=PIPE,
|
||||
stdin=PIPE,
|
||||
)
|
||||
out, _ = p.communicate(b'x\ny')
|
||||
out = out.decode('utf8', 'replace')
|
||||
assert os.path.exists(cfg_file)
|
||||
@@ -192,9 +201,13 @@ async def test_load_groups(tmpdir, request):
|
||||
|
||||
async def test_resume_spawners(tmpdir, request):
|
||||
if not os.getenv('JUPYTERHUB_TEST_DB_URL'):
|
||||
p = patch.dict(os.environ, {
|
||||
'JUPYTERHUB_TEST_DB_URL': 'sqlite:///%s' % tmpdir.join('jupyterhub.sqlite'),
|
||||
})
|
||||
p = patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
'JUPYTERHUB_TEST_DB_URL': 'sqlite:///%s'
|
||||
% tmpdir.join('jupyterhub.sqlite')
|
||||
},
|
||||
)
|
||||
p.start()
|
||||
request.addfinalizer(p.stop)
|
||||
|
||||
@@ -253,32 +266,18 @@ async def test_resume_spawners(tmpdir, request):
|
||||
@pytest.mark.parametrize(
|
||||
'hub_config, expected',
|
||||
[
|
||||
(
|
||||
{'ip': '0.0.0.0'},
|
||||
{'bind_url': 'http://0.0.0.0:8000/'},
|
||||
),
|
||||
({'ip': '0.0.0.0'}, {'bind_url': 'http://0.0.0.0:8000/'}),
|
||||
(
|
||||
{'port': 123, 'base_url': '/prefix'},
|
||||
{
|
||||
'bind_url': 'http://:123/prefix/',
|
||||
'base_url': '/prefix/',
|
||||
},
|
||||
),
|
||||
(
|
||||
{'bind_url': 'http://0.0.0.0:12345/sub'},
|
||||
{'base_url': '/sub/'},
|
||||
{'bind_url': 'http://:123/prefix/', 'base_url': '/prefix/'},
|
||||
),
|
||||
({'bind_url': 'http://0.0.0.0:12345/sub'}, {'base_url': '/sub/'}),
|
||||
(
|
||||
# no config, test defaults
|
||||
{},
|
||||
{
|
||||
'base_url': '/',
|
||||
'bind_url': 'http://:8000',
|
||||
'ip': '',
|
||||
'port': 8000,
|
||||
},
|
||||
{'base_url': '/', 'bind_url': 'http://:8000', 'ip': '', 'port': 8000},
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_url_config(hub_config, expected):
|
||||
# construct the config object
|
||||
|
@@ -1,59 +1,59 @@
|
||||
"""Tests for PAM authentication"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from requests import HTTPError
|
||||
|
||||
from jupyterhub import auth, crypto, orm
|
||||
|
||||
from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd
|
||||
from .mocking import MockPAMAuthenticator
|
||||
from .mocking import MockStructGroup
|
||||
from .mocking import MockStructPasswd
|
||||
from .utils import add_user
|
||||
from jupyterhub import auth
|
||||
from jupyterhub import crypto
|
||||
from jupyterhub import orm
|
||||
|
||||
|
||||
async def test_pam_auth():
|
||||
authenticator = MockPAMAuthenticator()
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'match',
|
||||
'password': 'match',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'match', 'password': 'match'}
|
||||
)
|
||||
assert authorized['name'] == 'match'
|
||||
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'match',
|
||||
'password': 'nomatch',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'match', 'password': 'nomatch'}
|
||||
)
|
||||
assert authorized is None
|
||||
|
||||
# Account check is on by default for increased security
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'notallowedmatch',
|
||||
'password': 'notallowedmatch',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'notallowedmatch', 'password': 'notallowedmatch'}
|
||||
)
|
||||
assert authorized is None
|
||||
|
||||
|
||||
async def test_pam_auth_account_check_disabled():
|
||||
authenticator = MockPAMAuthenticator(check_account=False)
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'allowedmatch',
|
||||
'password': 'allowedmatch',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'allowedmatch', 'password': 'allowedmatch'}
|
||||
)
|
||||
assert authorized['name'] == 'allowedmatch'
|
||||
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'notallowedmatch',
|
||||
'password': 'notallowedmatch',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'notallowedmatch', 'password': 'notallowedmatch'}
|
||||
)
|
||||
assert authorized['name'] == 'notallowedmatch'
|
||||
|
||||
|
||||
async def test_pam_auth_admin_groups():
|
||||
jh_users = MockStructGroup('jh_users', ['group_admin', 'also_group_admin', 'override_admin', 'non_admin'], 1234)
|
||||
jh_users = MockStructGroup(
|
||||
'jh_users',
|
||||
['group_admin', 'also_group_admin', 'override_admin', 'non_admin'],
|
||||
1234,
|
||||
)
|
||||
jh_admins = MockStructGroup('jh_admins', ['group_admin'], 5678)
|
||||
wheel = MockStructGroup('wheel', ['also_group_admin'], 9999)
|
||||
system_groups = [jh_users, jh_admins, wheel]
|
||||
@@ -68,7 +68,7 @@ async def test_pam_auth_admin_groups():
|
||||
'group_admin': [jh_users.gr_gid, jh_admins.gr_gid],
|
||||
'also_group_admin': [jh_users.gr_gid, wheel.gr_gid],
|
||||
'override_admin': [jh_users.gr_gid],
|
||||
'non_admin': [jh_users.gr_gid]
|
||||
'non_admin': [jh_users.gr_gid],
|
||||
}
|
||||
|
||||
def getgrnam(name):
|
||||
@@ -80,76 +80,78 @@ async def test_pam_auth_admin_groups():
|
||||
def getgrouplist(name, group):
|
||||
return user_group_map[name]
|
||||
|
||||
authenticator = MockPAMAuthenticator(admin_groups={'jh_admins', 'wheel'},
|
||||
admin_users={'override_admin'})
|
||||
authenticator = MockPAMAuthenticator(
|
||||
admin_groups={'jh_admins', 'wheel'}, admin_users={'override_admin'}
|
||||
)
|
||||
|
||||
# Check admin_group applies as expected
|
||||
with mock.patch.multiple(authenticator,
|
||||
with mock.patch.multiple(
|
||||
authenticator,
|
||||
_getgrnam=getgrnam,
|
||||
_getpwnam=getpwnam,
|
||||
_getgrouplist=getgrouplist):
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'group_admin',
|
||||
'password': 'group_admin'
|
||||
})
|
||||
_getgrouplist=getgrouplist,
|
||||
):
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'group_admin', 'password': 'group_admin'}
|
||||
)
|
||||
assert authorized['name'] == 'group_admin'
|
||||
assert authorized['admin'] is True
|
||||
|
||||
# Check multiple groups work, just in case.
|
||||
with mock.patch.multiple(authenticator,
|
||||
with mock.patch.multiple(
|
||||
authenticator,
|
||||
_getgrnam=getgrnam,
|
||||
_getpwnam=getpwnam,
|
||||
_getgrouplist=getgrouplist):
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'also_group_admin',
|
||||
'password': 'also_group_admin'
|
||||
})
|
||||
_getgrouplist=getgrouplist,
|
||||
):
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'also_group_admin', 'password': 'also_group_admin'}
|
||||
)
|
||||
assert authorized['name'] == 'also_group_admin'
|
||||
assert authorized['admin'] is True
|
||||
|
||||
# Check admin_users still applies correctly
|
||||
with mock.patch.multiple(authenticator,
|
||||
with mock.patch.multiple(
|
||||
authenticator,
|
||||
_getgrnam=getgrnam,
|
||||
_getpwnam=getpwnam,
|
||||
_getgrouplist=getgrouplist):
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'override_admin',
|
||||
'password': 'override_admin'
|
||||
})
|
||||
_getgrouplist=getgrouplist,
|
||||
):
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'override_admin', 'password': 'override_admin'}
|
||||
)
|
||||
assert authorized['name'] == 'override_admin'
|
||||
assert authorized['admin'] is True
|
||||
|
||||
# Check it doesn't admin everyone
|
||||
with mock.patch.multiple(authenticator,
|
||||
with mock.patch.multiple(
|
||||
authenticator,
|
||||
_getgrnam=getgrnam,
|
||||
_getpwnam=getpwnam,
|
||||
_getgrouplist=getgrouplist):
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'non_admin',
|
||||
'password': 'non_admin'
|
||||
})
|
||||
_getgrouplist=getgrouplist,
|
||||
):
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'non_admin', 'password': 'non_admin'}
|
||||
)
|
||||
assert authorized['name'] == 'non_admin'
|
||||
assert authorized['admin'] is False
|
||||
|
||||
|
||||
async def test_pam_auth_whitelist():
|
||||
authenticator = MockPAMAuthenticator(whitelist={'wash', 'kaylee'})
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'kaylee',
|
||||
'password': 'kaylee',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'kaylee', 'password': 'kaylee'}
|
||||
)
|
||||
assert authorized['name'] == 'kaylee'
|
||||
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'wash',
|
||||
'password': 'nomatch',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'wash', 'password': 'nomatch'}
|
||||
)
|
||||
assert authorized is None
|
||||
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'mal',
|
||||
'password': 'mal',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'mal', 'password': 'mal'}
|
||||
)
|
||||
assert authorized is None
|
||||
|
||||
|
||||
@@ -160,80 +162,78 @@ async def test_pam_auth_group_whitelist():
|
||||
authenticator = MockPAMAuthenticator(group_whitelist={'group'})
|
||||
|
||||
with mock.patch.object(authenticator, '_getgrnam', getgrnam):
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'kaylee',
|
||||
'password': 'kaylee',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'kaylee', 'password': 'kaylee'}
|
||||
)
|
||||
assert authorized['name'] == 'kaylee'
|
||||
|
||||
with mock.patch.object(authenticator, '_getgrnam', getgrnam):
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'mal',
|
||||
'password': 'mal',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'mal', 'password': 'mal'}
|
||||
)
|
||||
assert authorized is None
|
||||
|
||||
|
||||
async def test_pam_auth_blacklist():
|
||||
# Null case compared to next case
|
||||
authenticator = MockPAMAuthenticator()
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'wash',
|
||||
'password': 'wash',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'wash', 'password': 'wash'}
|
||||
)
|
||||
assert authorized['name'] == 'wash'
|
||||
|
||||
# Blacklist basics
|
||||
authenticator = MockPAMAuthenticator(blacklist={'wash'})
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'wash',
|
||||
'password': 'wash',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'wash', 'password': 'wash'}
|
||||
)
|
||||
assert authorized is None
|
||||
|
||||
# User in both white and blacklists: default deny. Make error someday?
|
||||
authenticator = MockPAMAuthenticator(blacklist={'wash'}, whitelist={'wash', 'kaylee'})
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'wash',
|
||||
'password': 'wash',
|
||||
})
|
||||
authenticator = MockPAMAuthenticator(
|
||||
blacklist={'wash'}, whitelist={'wash', 'kaylee'}
|
||||
)
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'wash', 'password': 'wash'}
|
||||
)
|
||||
assert authorized is None
|
||||
|
||||
# User not in blacklist can log in
|
||||
authenticator = MockPAMAuthenticator(blacklist={'wash'}, whitelist={'wash', 'kaylee'})
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'kaylee',
|
||||
'password': 'kaylee',
|
||||
})
|
||||
authenticator = MockPAMAuthenticator(
|
||||
blacklist={'wash'}, whitelist={'wash', 'kaylee'}
|
||||
)
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'kaylee', 'password': 'kaylee'}
|
||||
)
|
||||
assert authorized['name'] == 'kaylee'
|
||||
|
||||
# User in whitelist, blacklist irrelevent
|
||||
authenticator = MockPAMAuthenticator(blacklist={'mal'}, whitelist={'wash', 'kaylee'})
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'wash',
|
||||
'password': 'wash',
|
||||
})
|
||||
authenticator = MockPAMAuthenticator(
|
||||
blacklist={'mal'}, whitelist={'wash', 'kaylee'}
|
||||
)
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'wash', 'password': 'wash'}
|
||||
)
|
||||
assert authorized['name'] == 'wash'
|
||||
|
||||
# User in neither list
|
||||
authenticator = MockPAMAuthenticator(blacklist={'mal'}, whitelist={'wash', 'kaylee'})
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'simon',
|
||||
'password': 'simon',
|
||||
})
|
||||
authenticator = MockPAMAuthenticator(
|
||||
blacklist={'mal'}, whitelist={'wash', 'kaylee'}
|
||||
)
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'simon', 'password': 'simon'}
|
||||
)
|
||||
assert authorized is None
|
||||
|
||||
# blacklist == {}
|
||||
authenticator = MockPAMAuthenticator(blacklist=set(), whitelist={'wash', 'kaylee'})
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'kaylee',
|
||||
'password': 'kaylee',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'kaylee', 'password': 'kaylee'}
|
||||
)
|
||||
assert authorized['name'] == 'kaylee'
|
||||
|
||||
|
||||
async def test_deprecated_signatures():
|
||||
|
||||
def deprecated_xlist(self, username):
|
||||
return True
|
||||
|
||||
@@ -244,20 +244,18 @@ async def test_deprecated_signatures():
|
||||
check_blacklist=deprecated_xlist,
|
||||
):
|
||||
deprecated_authenticator = MockPAMAuthenticator()
|
||||
authorized = await deprecated_authenticator.get_authenticated_user(None, {
|
||||
'username': 'test',
|
||||
'password': 'test'
|
||||
})
|
||||
authorized = await deprecated_authenticator.get_authenticated_user(
|
||||
None, {'username': 'test', 'password': 'test'}
|
||||
)
|
||||
|
||||
assert authorized is not None
|
||||
|
||||
|
||||
async def test_pam_auth_no_such_group():
|
||||
authenticator = MockPAMAuthenticator(group_whitelist={'nosuchcrazygroup'})
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'kaylee',
|
||||
'password': 'kaylee',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'kaylee', 'password': 'kaylee'}
|
||||
)
|
||||
assert authorized is None
|
||||
|
||||
|
||||
@@ -302,6 +300,7 @@ async def test_add_system_user():
|
||||
authenticator.add_user_cmd = ['echo', '/home/USERNAME']
|
||||
|
||||
record = {}
|
||||
|
||||
class DummyPopen:
|
||||
def __init__(self, cmd, *args, **kwargs):
|
||||
record['cmd'] = cmd
|
||||
@@ -402,44 +401,38 @@ async def test_auth_state_disabled(app, auth_state_unavailable):
|
||||
|
||||
async def test_normalize_names():
|
||||
a = MockPAMAuthenticator()
|
||||
authorized = await a.get_authenticated_user(None, {
|
||||
'username': 'ZOE',
|
||||
'password': 'ZOE',
|
||||
})
|
||||
authorized = await a.get_authenticated_user(
|
||||
None, {'username': 'ZOE', 'password': 'ZOE'}
|
||||
)
|
||||
assert authorized['name'] == 'zoe'
|
||||
|
||||
authorized = await a.get_authenticated_user(None, {
|
||||
'username': 'Glenn',
|
||||
'password': 'Glenn',
|
||||
})
|
||||
authorized = await a.get_authenticated_user(
|
||||
None, {'username': 'Glenn', 'password': 'Glenn'}
|
||||
)
|
||||
assert authorized['name'] == 'glenn'
|
||||
|
||||
authorized = await a.get_authenticated_user(None, {
|
||||
'username': 'hExi',
|
||||
'password': 'hExi',
|
||||
})
|
||||
authorized = await a.get_authenticated_user(
|
||||
None, {'username': 'hExi', 'password': 'hExi'}
|
||||
)
|
||||
assert authorized['name'] == 'hexi'
|
||||
|
||||
authorized = await a.get_authenticated_user(None, {
|
||||
'username': 'Test',
|
||||
'password': 'Test',
|
||||
})
|
||||
authorized = await a.get_authenticated_user(
|
||||
None, {'username': 'Test', 'password': 'Test'}
|
||||
)
|
||||
assert authorized['name'] == 'test'
|
||||
|
||||
|
||||
async def test_username_map():
|
||||
a = MockPAMAuthenticator(username_map={'wash': 'alpha'})
|
||||
authorized = await a.get_authenticated_user(None, {
|
||||
'username': 'WASH',
|
||||
'password': 'WASH',
|
||||
})
|
||||
authorized = await a.get_authenticated_user(
|
||||
None, {'username': 'WASH', 'password': 'WASH'}
|
||||
)
|
||||
|
||||
assert authorized['name'] == 'alpha'
|
||||
|
||||
authorized = await a.get_authenticated_user(None, {
|
||||
'username': 'Inara',
|
||||
'password': 'Inara',
|
||||
})
|
||||
authorized = await a.get_authenticated_user(
|
||||
None, {'username': 'Inara', 'password': 'Inara'}
|
||||
)
|
||||
assert authorized['name'] == 'inara'
|
||||
|
||||
|
||||
@@ -463,9 +456,8 @@ def test_post_auth_hook():
|
||||
|
||||
a = MockPAMAuthenticator(post_auth_hook=test_auth_hook)
|
||||
|
||||
authorized = yield a.get_authenticated_user(None, {
|
||||
'username': 'test_user',
|
||||
'password': 'test_user'
|
||||
})
|
||||
authorized = yield a.get_authenticated_user(
|
||||
None, {'username': 'test_user', 'password': 'test_user'}
|
||||
)
|
||||
|
||||
assert authorized['testkey'] == 'testvalue'
|
||||
|
@@ -7,15 +7,16 @@ authentication can expire in a number of ways:
|
||||
- doesn't need refresh
|
||||
- needs refresh and cannot be refreshed without new login
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
from unittest import mock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
|
||||
from .utils import api_request, get_page
|
||||
from .utils import api_request
|
||||
from .utils import get_page
|
||||
|
||||
|
||||
async def refresh_expired(authenticator, user):
|
||||
|
@@ -1,24 +1,29 @@
|
||||
from binascii import b2a_hex, b2a_base64
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from binascii import b2a_base64
|
||||
from binascii import b2a_hex
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import crypto
|
||||
from ..crypto import encrypt, decrypt
|
||||
from ..crypto import decrypt
|
||||
from ..crypto import encrypt
|
||||
|
||||
keys = [('%i' % i).encode('ascii') * 32 for i in range(3)]
|
||||
hex_keys = [b2a_hex(key).decode('ascii') for key in keys]
|
||||
b64_keys = [b2a_base64(key).decode('ascii').strip() for key in keys]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("key_env, keys", [
|
||||
@pytest.mark.parametrize(
|
||||
"key_env, keys",
|
||||
[
|
||||
(hex_keys[0], [keys[0]]),
|
||||
(';'.join([b64_keys[0], hex_keys[1]]), keys[:2]),
|
||||
(';'.join([hex_keys[0], b64_keys[1], '']), keys[:2]),
|
||||
('', []),
|
||||
(';', []),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_env_constructor(key_env, keys):
|
||||
with patch.dict(os.environ, {crypto.KEY_ENV: key_env}):
|
||||
ck = crypto.CryptKeeper()
|
||||
@@ -29,12 +34,15 @@ def test_env_constructor(key_env, keys):
|
||||
assert ck.fernet is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("key", [
|
||||
@pytest.mark.parametrize(
|
||||
"key",
|
||||
[
|
||||
'a' * 44, # base64, not 32 bytes
|
||||
('%44s' % 'notbase64'), # not base64
|
||||
b'x' * 64, # not hex
|
||||
b'short', # not 32 bytes
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_bad_keys(key):
|
||||
ck = crypto.CryptKeeper()
|
||||
with pytest.raises(ValueError):
|
||||
@@ -76,4 +84,3 @@ async def test_missing_keys(crypt_keeper):
|
||||
|
||||
with pytest.raises(crypto.NoEncryptionKeys):
|
||||
await decrypt(b'whatever')
|
||||
|
||||
|
@@ -1,14 +1,16 @@
|
||||
from glob import glob
|
||||
import os
|
||||
from subprocess import check_call
|
||||
import sys
|
||||
import tempfile
|
||||
from glob import glob
|
||||
from subprocess import check_call
|
||||
|
||||
import pytest
|
||||
from pytest import raises
|
||||
from traitlets.config import Config
|
||||
|
||||
from ..app import NewToken, UpgradeDB, JupyterHub
|
||||
from ..app import JupyterHub
|
||||
from ..app import NewToken
|
||||
from ..app import UpgradeDB
|
||||
|
||||
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
@@ -33,13 +35,7 @@ def generate_old_db(env_dir, hub_version, db_url):
|
||||
check_call([env_py, populate_db, db_url])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'hub_version',
|
||||
[
|
||||
'0.7.2',
|
||||
'0.8.1',
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize('hub_version', ['0.7.2', '0.8.1'])
|
||||
async def test_upgrade(tmpdir, hub_version):
|
||||
db_url = os.getenv('JUPYTERHUB_TEST_DB_URL')
|
||||
if db_url:
|
||||
|
@@ -1,52 +1,47 @@
|
||||
"""Tests for dummy authentication"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import pytest
|
||||
|
||||
from jupyterhub.auth import DummyAuthenticator
|
||||
|
||||
|
||||
async def test_dummy_auth_without_global_password():
|
||||
authenticator = DummyAuthenticator()
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'test_user',
|
||||
'password': 'test_pass',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'test_user', 'password': 'test_pass'}
|
||||
)
|
||||
assert authorized['name'] == 'test_user'
|
||||
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'test_user',
|
||||
'password': '',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'test_user', 'password': ''}
|
||||
)
|
||||
assert authorized['name'] == 'test_user'
|
||||
|
||||
|
||||
async def test_dummy_auth_without_username():
|
||||
authenticator = DummyAuthenticator()
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': '',
|
||||
'password': 'test_pass',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': '', 'password': 'test_pass'}
|
||||
)
|
||||
assert authorized is None
|
||||
|
||||
|
||||
async def test_dummy_auth_with_global_password():
|
||||
authenticator = DummyAuthenticator()
|
||||
authenticator.password = "test_password"
|
||||
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'test_user',
|
||||
'password': 'test_password',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'test_user', 'password': 'test_password'}
|
||||
)
|
||||
assert authorized['name'] == 'test_user'
|
||||
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'test_user',
|
||||
'password': 'qwerty',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'test_user', 'password': 'qwerty'}
|
||||
)
|
||||
assert authorized is None
|
||||
|
||||
authorized = await authenticator.get_authenticated_user(None, {
|
||||
'username': 'some_other_user',
|
||||
'password': 'test_password',
|
||||
})
|
||||
authorized = await authenticator.get_authenticated_user(
|
||||
None, {'username': 'some_other_user', 'password': 'test_password'}
|
||||
)
|
||||
assert authorized['name'] == 'some_other_user'
|
||||
|
@@ -1,7 +1,6 @@
|
||||
"""Tests for the SSL enabled REST API."""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from jupyterhub.tests.test_api import *
|
||||
|
||||
ssl_enabled = True
|
||||
|
@@ -1,10 +1,9 @@
|
||||
"""Test the JupyterHub entry point with internal ssl"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import sys
|
||||
import jupyterhub.tests.mocking
|
||||
|
||||
import jupyterhub.tests.mocking
|
||||
from jupyterhub.tests.test_app import *
|
||||
|
||||
ssl_enabled = True
|
||||
|
@@ -1,19 +1,17 @@
|
||||
"""Tests for jupyterhub internal_ssl connections"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from subprocess import check_output
|
||||
import sys
|
||||
from unittest import mock
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import SSLError
|
||||
from tornado import gen
|
||||
|
||||
import jupyterhub
|
||||
from tornado import gen
|
||||
from unittest import mock
|
||||
from requests.exceptions import SSLError
|
||||
|
||||
from .utils import async_requests
|
||||
from .test_api import add_user
|
||||
from .utils import async_requests
|
||||
|
||||
ssl_enabled = True
|
||||
|
||||
@@ -25,8 +23,10 @@ def wait_for_spawner(spawner, timeout=10):
|
||||
polling at shorter intervals for early termination
|
||||
"""
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
def wait():
|
||||
return spawner.server.wait_up(timeout=1, http=True)
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
status = yield spawner.poll()
|
||||
assert status is None
|
||||
@@ -59,7 +59,7 @@ async def test_connection_notebook_wrong_certs(app):
|
||||
"""Connecting to a notebook fails without correct certs"""
|
||||
with mock.patch.dict(
|
||||
app.config.LocalProcessSpawner,
|
||||
{'cmd': [sys.executable, '-m', 'jupyterhub.tests.mocksu']}
|
||||
{'cmd': [sys.executable, '-m', 'jupyterhub.tests.mocksu']},
|
||||
):
|
||||
user = add_user(app.db, app, name='foo')
|
||||
await user.spawn()
|
||||
|
@@ -1,7 +1,6 @@
|
||||
"""Tests for process spawning with internal_ssl"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from jupyterhub.tests.test_spawner import *
|
||||
|
||||
ssl_enabled = True
|
||||
|
@@ -5,15 +5,21 @@ from unittest import mock
|
||||
import pytest
|
||||
|
||||
from ..utils import url_path_join
|
||||
|
||||
from .test_api import api_request, add_user, fill_user, normalize_user, TIMESTAMP
|
||||
from .mocking import public_url
|
||||
from .test_api import add_user
|
||||
from .test_api import api_request
|
||||
from .test_api import fill_user
|
||||
from .test_api import normalize_user
|
||||
from .test_api import TIMESTAMP
|
||||
from .utils import async_requests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def named_servers(app):
|
||||
with mock.patch.dict(app.tornado_settings,
|
||||
{'allow_named_servers': True, 'named_server_limit_per_user': 2}):
|
||||
with mock.patch.dict(
|
||||
app.tornado_settings,
|
||||
{'allow_named_servers': True, 'named_server_limit_per_user': 2},
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -30,7 +36,8 @@ async def test_default_server(app, named_servers):
|
||||
|
||||
user_model = normalize_user(r.json())
|
||||
print(user_model)
|
||||
assert user_model == fill_user({
|
||||
assert user_model == fill_user(
|
||||
{
|
||||
'name': username,
|
||||
'auth_state': None,
|
||||
'server': user.url,
|
||||
@@ -42,12 +49,14 @@ async def test_default_server(app, named_servers):
|
||||
'url': user.url,
|
||||
'pending': None,
|
||||
'ready': True,
|
||||
'progress_url': 'PREFIX/hub/api/users/{}/server/progress'.format(username),
|
||||
'progress_url': 'PREFIX/hub/api/users/{}/server/progress'.format(
|
||||
username
|
||||
),
|
||||
'state': {'pid': 0},
|
||||
}
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
# now stop the server
|
||||
r = await api_request(app, 'users', username, 'server', method='delete')
|
||||
@@ -58,11 +67,9 @@ async def test_default_server(app, named_servers):
|
||||
r.raise_for_status()
|
||||
|
||||
user_model = normalize_user(r.json())
|
||||
assert user_model == fill_user({
|
||||
'name': username,
|
||||
'servers': {},
|
||||
'auth_state': None,
|
||||
})
|
||||
assert user_model == fill_user(
|
||||
{'name': username, 'servers': {}, 'auth_state': None}
|
||||
)
|
||||
|
||||
|
||||
async def test_create_named_server(app, named_servers):
|
||||
@@ -89,7 +96,8 @@ async def test_create_named_server(app, named_servers):
|
||||
r.raise_for_status()
|
||||
|
||||
user_model = normalize_user(r.json())
|
||||
assert user_model == fill_user({
|
||||
assert user_model == fill_user(
|
||||
{
|
||||
'name': username,
|
||||
'auth_state': None,
|
||||
'servers': {
|
||||
@@ -101,12 +109,14 @@ async def test_create_named_server(app, named_servers):
|
||||
'pending': None,
|
||||
'ready': True,
|
||||
'progress_url': 'PREFIX/hub/api/users/{}/servers/{}/progress'.format(
|
||||
username, servername),
|
||||
username, servername
|
||||
),
|
||||
'state': {'pid': 0},
|
||||
}
|
||||
for name in [servername]
|
||||
},
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def test_delete_named_server(app, named_servers):
|
||||
@@ -119,7 +129,9 @@ async def test_delete_named_server(app, named_servers):
|
||||
r.raise_for_status()
|
||||
assert r.status_code == 201
|
||||
|
||||
r = await api_request(app, 'users', username, 'servers', servername, method='delete')
|
||||
r = await api_request(
|
||||
app, 'users', username, 'servers', servername, method='delete'
|
||||
)
|
||||
r.raise_for_status()
|
||||
assert r.status_code == 204
|
||||
|
||||
@@ -127,18 +139,20 @@ async def test_delete_named_server(app, named_servers):
|
||||
r.raise_for_status()
|
||||
|
||||
user_model = normalize_user(r.json())
|
||||
assert user_model == fill_user({
|
||||
'name': username,
|
||||
'auth_state': None,
|
||||
'servers': {},
|
||||
})
|
||||
assert user_model == fill_user(
|
||||
{'name': username, 'auth_state': None, 'servers': {}}
|
||||
)
|
||||
# wrapper Spawner is gone
|
||||
assert servername not in user.spawners
|
||||
# low-level record still exists
|
||||
assert servername in user.orm_spawners
|
||||
|
||||
r = await api_request(
|
||||
app, 'users', username, 'servers', servername,
|
||||
app,
|
||||
'users',
|
||||
username,
|
||||
'servers',
|
||||
servername,
|
||||
method='delete',
|
||||
data=json.dumps({'remove': True}),
|
||||
)
|
||||
@@ -153,7 +167,9 @@ async def test_named_server_disabled(app):
|
||||
servername = 'okay'
|
||||
r = await api_request(app, 'users', username, 'servers', servername, method='post')
|
||||
assert r.status_code == 400
|
||||
r = await api_request(app, 'users', username, 'servers', servername, method='delete')
|
||||
r = await api_request(
|
||||
app, 'users', username, 'servers', servername, method='delete'
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
@@ -180,7 +196,10 @@ async def test_named_server_limit(app, named_servers):
|
||||
servername3 = 'bar-3'
|
||||
r = await api_request(app, 'users', username, 'servers', servername3, method='post')
|
||||
assert r.status_code == 400
|
||||
assert r.json() == {"status": 400, "message": "User foo already has the maximum of 2 named servers. One must be deleted before a new server can be created"}
|
||||
assert r.json() == {
|
||||
"status": 400,
|
||||
"message": "User foo already has the maximum of 2 named servers. One must be deleted before a new server can be created",
|
||||
}
|
||||
|
||||
# Create default server
|
||||
r = await api_request(app, 'users', username, 'server', method='post')
|
||||
@@ -189,7 +208,11 @@ async def test_named_server_limit(app, named_servers):
|
||||
|
||||
# Delete 1st named server
|
||||
r = await api_request(
|
||||
app, 'users', username, 'servers', servername1,
|
||||
app,
|
||||
'users',
|
||||
username,
|
||||
'servers',
|
||||
servername1,
|
||||
method='delete',
|
||||
data=json.dumps({'remove': True}),
|
||||
)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
"""Tests for basic object-wrappers"""
|
||||
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
|
||||
from jupyterhub.objects import Server
|
||||
@@ -16,7 +16,7 @@ from jupyterhub.objects import Server
|
||||
'port': 123,
|
||||
'host': 'http://abc:123',
|
||||
'url': 'http://abc:123/x/',
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
'https://abc',
|
||||
@@ -26,9 +26,9 @@ from jupyterhub.objects import Server
|
||||
'proto': 'https',
|
||||
'host': 'https://abc:443',
|
||||
'url': 'https://abc:443/x/',
|
||||
}
|
||||
},
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_bind_url(bind_url, attrs):
|
||||
s = Server(bind_url=bind_url, base_url='/x/')
|
||||
@@ -43,26 +43,28 @@ _hostname = socket.gethostname()
|
||||
'ip, port, attrs',
|
||||
[
|
||||
(
|
||||
'', 123,
|
||||
'',
|
||||
123,
|
||||
{
|
||||
'ip': '',
|
||||
'port': 123,
|
||||
'host': 'http://{}:123'.format(_hostname),
|
||||
'url': 'http://{}:123/x/'.format(_hostname),
|
||||
'bind_url': 'http://*:123/x/',
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
'127.0.0.1', 999,
|
||||
'127.0.0.1',
|
||||
999,
|
||||
{
|
||||
'ip': '127.0.0.1',
|
||||
'port': 999,
|
||||
'host': 'http://127.0.0.1:999',
|
||||
'url': 'http://127.0.0.1:999/x/',
|
||||
'bind_url': 'http://127.0.0.1:999/x/',
|
||||
}
|
||||
},
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_ip_port(ip, port, attrs):
|
||||
s = Server(ip=ip, port=port, base_url='/x/')
|
||||
|
@@ -1,22 +1,21 @@
|
||||
"""Tests for the ORM bits"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import socket
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from tornado import gen
|
||||
|
||||
from .. import orm
|
||||
from .. import objects
|
||||
from .. import crypto
|
||||
from .. import objects
|
||||
from .. import orm
|
||||
from ..emptyclass import EmptyClass
|
||||
from ..user import User
|
||||
from .mocking import MockSpawner
|
||||
from ..emptyclass import EmptyClass
|
||||
|
||||
|
||||
def assert_not_found(db, ORMType, id):
|
||||
@@ -124,7 +123,9 @@ def test_token_expiry(db):
|
||||
# approximate range
|
||||
assert orm_token.expires_at > now + timedelta(seconds=50)
|
||||
assert orm_token.expires_at < now + timedelta(seconds=70)
|
||||
the_future = mock.patch('jupyterhub.orm.utcnow', lambda : now + timedelta(seconds=70))
|
||||
the_future = mock.patch(
|
||||
'jupyterhub.orm.utcnow', lambda: now + timedelta(seconds=70)
|
||||
)
|
||||
with the_future:
|
||||
found = orm.APIToken.find(db, token=token)
|
||||
assert found is None
|
||||
@@ -215,11 +216,9 @@ async def test_spawn_fails(db):
|
||||
def start(self):
|
||||
raise RuntimeError("Split the party")
|
||||
|
||||
user = User(orm_user, {
|
||||
'spawner_class': BadSpawner,
|
||||
'config': None,
|
||||
'statsd': EmptyClass(),
|
||||
})
|
||||
user = User(
|
||||
orm_user, {'spawner_class': BadSpawner, 'config': None, 'statsd': EmptyClass()}
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
await user.spawn()
|
||||
@@ -346,9 +345,7 @@ def test_user_delete_cascade(db):
|
||||
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
|
||||
db.add(oauth_code)
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
client=oauth_client,
|
||||
user=user,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
|
||||
)
|
||||
db.add(oauth_token)
|
||||
db.commit()
|
||||
@@ -384,9 +381,7 @@ def test_oauth_client_delete_cascade(db):
|
||||
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
|
||||
db.add(oauth_code)
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
client=oauth_client,
|
||||
user=user,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
|
||||
)
|
||||
db.add(oauth_token)
|
||||
db.commit()
|
||||
@@ -477,6 +472,3 @@ def test_group_delete_cascade(db):
|
||||
db.delete(user1)
|
||||
db.commit()
|
||||
assert user1 not in group1.users
|
||||
|
||||
|
||||
|
||||
|
@@ -1,30 +1,27 @@
|
||||
"""Tests for HTML pages"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from unittest import mock
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from urllib.parse import urlencode
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
from bs4 import BeautifulSoup
|
||||
from tornado import gen
|
||||
from tornado.httputil import url_concat
|
||||
|
||||
from ..handlers import BaseHandler
|
||||
from ..utils import url_path_join as ujoin
|
||||
from .. import orm
|
||||
from ..auth import Authenticator
|
||||
|
||||
import pytest
|
||||
|
||||
from .mocking import FormSpawner, FalsyCallableFormSpawner
|
||||
from .utils import (
|
||||
async_requests,
|
||||
api_request,
|
||||
add_user,
|
||||
get_page,
|
||||
public_url,
|
||||
public_host,
|
||||
)
|
||||
from ..handlers import BaseHandler
|
||||
from ..utils import url_path_join as ujoin
|
||||
from .mocking import FalsyCallableFormSpawner
|
||||
from .mocking import FormSpawner
|
||||
from .utils import add_user
|
||||
from .utils import api_request
|
||||
from .utils import async_requests
|
||||
from .utils import get_page
|
||||
from .utils import public_host
|
||||
from .utils import public_url
|
||||
|
||||
|
||||
async def test_root_no_auth(app):
|
||||
@@ -53,8 +50,7 @@ async def test_root_redirect(app):
|
||||
|
||||
|
||||
async def test_root_default_url_noauth(app):
|
||||
with mock.patch.dict(app.tornado_settings,
|
||||
{'default_url': '/foo/bar'}):
|
||||
with mock.patch.dict(app.tornado_settings, {'default_url': '/foo/bar'}):
|
||||
r = await get_page('/', app, allow_redirects=False)
|
||||
r.raise_for_status()
|
||||
url = r.headers.get('Location', '')
|
||||
@@ -65,8 +61,7 @@ async def test_root_default_url_noauth(app):
|
||||
async def test_root_default_url_auth(app):
|
||||
name = 'wash'
|
||||
cookies = await app.login_user(name)
|
||||
with mock.patch.dict(app.tornado_settings,
|
||||
{'default_url': '/foo/bar'}):
|
||||
with mock.patch.dict(app.tornado_settings, {'default_url': '/foo/bar'}):
|
||||
r = await get_page('/', app, cookies=cookies, allow_redirects=False)
|
||||
r.raise_for_status()
|
||||
url = r.headers.get('Location', '')
|
||||
@@ -106,12 +101,7 @@ async def test_admin(app):
|
||||
assert r.url.endswith('/admin')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('sort', [
|
||||
'running',
|
||||
'last_activity',
|
||||
'admin',
|
||||
'name',
|
||||
])
|
||||
@pytest.mark.parametrize('sort', ['running', 'last_activity', 'admin', 'name'])
|
||||
async def test_admin_sort(app, sort):
|
||||
cookies = await app.login_user('admin')
|
||||
r = await get_page('admin?sort=%s' % sort, app, cookies=cookies)
|
||||
@@ -146,7 +136,9 @@ async def test_spawn_redirect(app):
|
||||
assert path == ujoin(app.base_url, '/user/%s/' % name)
|
||||
|
||||
# stop server to ensure /user/name is handled by the Hub
|
||||
r = await api_request(app, 'users', name, 'server', method='delete', cookies=cookies)
|
||||
r = await api_request(
|
||||
app, 'users', name, 'server', method='delete', cookies=cookies
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
# test handing of trailing slash on `/user/name`
|
||||
@@ -208,7 +200,9 @@ async def test_spawn_page(app):
|
||||
|
||||
|
||||
async def test_spawn_page_falsy_callable(app):
|
||||
with mock.patch.dict(app.users.settings, {'spawner_class': FalsyCallableFormSpawner}):
|
||||
with mock.patch.dict(
|
||||
app.users.settings, {'spawner_class': FalsyCallableFormSpawner}
|
||||
):
|
||||
cookies = await app.login_user('erik')
|
||||
r = await get_page('spawn', app, cookies=cookies)
|
||||
assert 'user/erik' in r.url
|
||||
@@ -276,22 +270,22 @@ async def test_spawn_form_with_file(app):
|
||||
u = app.users[orm_u]
|
||||
await u.stop()
|
||||
|
||||
r = await async_requests.post(ujoin(base_url, 'spawn'),
|
||||
r = await async_requests.post(
|
||||
ujoin(base_url, 'spawn'),
|
||||
cookies=cookies,
|
||||
data={
|
||||
'bounds': ['-1', '1'],
|
||||
'energy': '511keV',
|
||||
},
|
||||
files={'hello': ('hello.txt', b'hello world\n')}
|
||||
data={'bounds': ['-1', '1'], 'energy': '511keV'},
|
||||
files={'hello': ('hello.txt', b'hello world\n')},
|
||||
)
|
||||
r.raise_for_status()
|
||||
assert u.spawner.user_options == {
|
||||
'energy': '511keV',
|
||||
'bounds': [-1, 1],
|
||||
'notspecified': 5,
|
||||
'hello': {'filename': 'hello.txt',
|
||||
'hello': {
|
||||
'filename': 'hello.txt',
|
||||
'body': b'hello world\n',
|
||||
'content_type': 'application/unknown'},
|
||||
'content_type': 'application/unknown',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -305,9 +299,9 @@ async def test_user_redirect(app):
|
||||
path = urlparse(r.url).path
|
||||
assert path == ujoin(app.base_url, '/hub/login')
|
||||
query = urlparse(r.url).query
|
||||
assert query == urlencode({
|
||||
'next': ujoin(app.hub.base_url, '/user-redirect/tree/top/')
|
||||
})
|
||||
assert query == urlencode(
|
||||
{'next': ujoin(app.hub.base_url, '/user-redirect/tree/top/')}
|
||||
)
|
||||
|
||||
r = await get_page('/user-redirect/notebooks/test.ipynb', app, cookies=cookies)
|
||||
r.raise_for_status()
|
||||
@@ -339,19 +333,17 @@ async def test_user_redirect_deprecated(app):
|
||||
path = urlparse(r.url).path
|
||||
assert path == ujoin(app.base_url, '/hub/login')
|
||||
query = urlparse(r.url).query
|
||||
assert query == urlencode({
|
||||
'next': ujoin(app.base_url, '/hub/user/baduser/test.ipynb')
|
||||
})
|
||||
assert query == urlencode(
|
||||
{'next': ujoin(app.base_url, '/hub/user/baduser/test.ipynb')}
|
||||
)
|
||||
|
||||
|
||||
async def test_login_fail(app):
|
||||
name = 'wash'
|
||||
base_url = public_url(app)
|
||||
r = await async_requests.post(base_url + 'hub/login',
|
||||
data={
|
||||
'username': name,
|
||||
'password': 'wrong',
|
||||
},
|
||||
r = await async_requests.post(
|
||||
base_url + 'hub/login',
|
||||
data={'username': name, 'password': 'wrong'},
|
||||
allow_redirects=False,
|
||||
)
|
||||
assert not r.cookies
|
||||
@@ -359,20 +351,17 @@ async def test_login_fail(app):
|
||||
|
||||
async def test_login_strip(app):
|
||||
"""Test that login form doesn't strip whitespace from passwords"""
|
||||
form_data = {
|
||||
'username': 'spiff',
|
||||
'password': ' space man ',
|
||||
}
|
||||
form_data = {'username': 'spiff', 'password': ' space man '}
|
||||
base_url = public_url(app)
|
||||
called_with = []
|
||||
|
||||
@gen.coroutine
|
||||
def mock_authenticate(handler, data):
|
||||
called_with.append(data)
|
||||
|
||||
with mock.patch.object(app.authenticator, 'authenticate', mock_authenticate):
|
||||
await async_requests.post(base_url + 'hub/login',
|
||||
data=form_data,
|
||||
allow_redirects=False,
|
||||
await async_requests.post(
|
||||
base_url + 'hub/login', data=form_data, allow_redirects=False
|
||||
)
|
||||
|
||||
assert called_with == [form_data]
|
||||
@@ -389,12 +378,11 @@ async def test_login_strip(app):
|
||||
(False, '/user/other', '/hub/user/other'),
|
||||
(False, '/absolute', '/absolute'),
|
||||
(False, '/has?query#andhash', '/has?query#andhash'),
|
||||
|
||||
# next_url outside is not allowed
|
||||
(False, 'https://other.domain', ''),
|
||||
(False, 'ftp://other.domain', ''),
|
||||
(False, '//other.domain', ''),
|
||||
]
|
||||
],
|
||||
)
|
||||
async def test_login_redirect(app, running, next_url, location):
|
||||
cookies = await app.login_user('river')
|
||||
@@ -427,10 +415,11 @@ async def test_auto_login(app, request):
|
||||
class DummyLoginHandler(BaseHandler):
|
||||
def get(self):
|
||||
self.write('ok!')
|
||||
|
||||
base_url = public_url(app) + '/'
|
||||
app.tornado_application.add_handlers(".*$", [
|
||||
(ujoin(app.hub.base_url, 'dummy'), DummyLoginHandler),
|
||||
])
|
||||
app.tornado_application.add_handlers(
|
||||
".*$", [(ujoin(app.hub.base_url, 'dummy'), DummyLoginHandler)]
|
||||
)
|
||||
# no auto_login: end up at /hub/login
|
||||
r = await async_requests.get(base_url)
|
||||
assert r.url == public_url(app, path='hub/login')
|
||||
@@ -438,9 +427,7 @@ async def test_auto_login(app, request):
|
||||
authenticator = Authenticator(auto_login=True)
|
||||
authenticator.login_url = lambda base_url: ujoin(base_url, 'dummy')
|
||||
|
||||
with mock.patch.dict(app.tornado_settings, {
|
||||
'authenticator': authenticator,
|
||||
}):
|
||||
with mock.patch.dict(app.tornado_settings, {'authenticator': authenticator}):
|
||||
r = await async_requests.get(base_url)
|
||||
assert r.url == public_url(app, path='hub/dummy')
|
||||
|
||||
@@ -449,10 +436,12 @@ async def test_auto_login_logout(app):
|
||||
name = 'burnham'
|
||||
cookies = await app.login_user(name)
|
||||
|
||||
with mock.patch.dict(app.tornado_settings, {
|
||||
'authenticator': Authenticator(auto_login=True),
|
||||
}):
|
||||
r = await async_requests.get(public_host(app) + app.tornado_settings['logout_url'], cookies=cookies)
|
||||
with mock.patch.dict(
|
||||
app.tornado_settings, {'authenticator': Authenticator(auto_login=True)}
|
||||
):
|
||||
r = await async_requests.get(
|
||||
public_host(app) + app.tornado_settings['logout_url'], cookies=cookies
|
||||
)
|
||||
r.raise_for_status()
|
||||
logout_url = public_host(app) + app.tornado_settings['logout_url']
|
||||
assert r.url == logout_url
|
||||
@@ -462,7 +451,9 @@ async def test_auto_login_logout(app):
|
||||
async def test_logout(app):
|
||||
name = 'wash'
|
||||
cookies = await app.login_user(name)
|
||||
r = await async_requests.get(public_host(app) + app.tornado_settings['logout_url'], cookies=cookies)
|
||||
r = await async_requests.get(
|
||||
public_host(app) + app.tornado_settings['logout_url'], cookies=cookies
|
||||
)
|
||||
r.raise_for_status()
|
||||
login_url = public_host(app) + app.tornado_settings['login_url']
|
||||
assert r.url == login_url
|
||||
@@ -489,12 +480,11 @@ async def test_shutdown_on_logout(app, shutdown_on_logout):
|
||||
assert spawner.active
|
||||
|
||||
# logout
|
||||
with mock.patch.dict(app.tornado_settings, {
|
||||
'shutdown_on_logout': shutdown_on_logout,
|
||||
}):
|
||||
with mock.patch.dict(
|
||||
app.tornado_settings, {'shutdown_on_logout': shutdown_on_logout}
|
||||
):
|
||||
r = await async_requests.get(
|
||||
public_host(app) + app.tornado_settings['logout_url'],
|
||||
cookies=cookies,
|
||||
public_host(app) + app.tornado_settings['logout_url'], cookies=cookies
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -549,7 +539,9 @@ async def test_oauth_token_page(app):
|
||||
user = app.users[orm.User.find(app.db, name)]
|
||||
client = orm.OAuthClient(identifier='token')
|
||||
app.db.add(client)
|
||||
oauth_token = orm.OAuthAccessToken(client=client, user=user, grant_type=orm.GrantType.authorization_code)
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
client=client, user=user, grant_type=orm.GrantType.authorization_code
|
||||
)
|
||||
app.db.add(oauth_token)
|
||||
app.db.commit()
|
||||
r = await get_page('token', app, cookies=cookies)
|
||||
@@ -557,23 +549,14 @@ async def test_oauth_token_page(app):
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.parametrize("error_status", [
|
||||
503,
|
||||
404,
|
||||
])
|
||||
@pytest.mark.parametrize("error_status", [503, 404])
|
||||
async def test_proxy_error(app, error_status):
|
||||
r = await get_page('/error/%i' % error_status, app)
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"announcements",
|
||||
[
|
||||
"",
|
||||
"spawn",
|
||||
"spawn,home,login",
|
||||
"login,logout",
|
||||
]
|
||||
"announcements", ["", "spawn", "spawn,home,login", "login,logout"]
|
||||
)
|
||||
async def test_announcements(app, announcements):
|
||||
"""Test announcements on various pages"""
|
||||
@@ -618,16 +601,13 @@ async def test_announcements(app, announcements):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
[
|
||||
"",
|
||||
"redirect_uri=/noexist",
|
||||
"redirect_uri=ok&client_id=nosuchthing",
|
||||
]
|
||||
"params", ["", "redirect_uri=/noexist", "redirect_uri=ok&client_id=nosuchthing"]
|
||||
)
|
||||
async def test_bad_oauth_get(app, params):
|
||||
cookies = await app.login_user("authorizer")
|
||||
r = await get_page("hub/api/oauth2/authorize?" + params, app, hub=False, cookies=cookies)
|
||||
r = await get_page(
|
||||
"hub/api/oauth2/authorize?" + params, app, hub=False, cookies=cookies
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
@@ -637,11 +617,14 @@ async def test_token_page(app):
|
||||
r = await get_page("token", app, cookies=cookies)
|
||||
r.raise_for_status()
|
||||
assert urlparse(r.url).path.endswith('/hub/token')
|
||||
|
||||
def extract_body(r):
|
||||
soup = BeautifulSoup(r.text, "html5lib")
|
||||
import re
|
||||
|
||||
# trim empty lines
|
||||
return re.sub(r"(\n\s*)+", "\n", soup.body.find(class_="container").text)
|
||||
|
||||
body = extract_body(r)
|
||||
assert "Request new API token" in body, body
|
||||
# no tokens yet, no lists
|
||||
|
@@ -1,20 +1,21 @@
|
||||
"""Test a proxy being started before the Hub"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
import json
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from queue import Queue
|
||||
from subprocess import Popen
|
||||
from urllib.parse import urlparse, quote
|
||||
|
||||
from traitlets.config import Config
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
from traitlets.config import Config
|
||||
|
||||
from .. import orm
|
||||
from ..utils import url_path_join as ujoin
|
||||
from ..utils import wait_for_http_server
|
||||
from .mocking import MockHub
|
||||
from .test_api import api_request, add_user
|
||||
from ..utils import wait_for_http_server, url_path_join as ujoin
|
||||
from .test_api import add_user
|
||||
from .test_api import api_request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -52,25 +53,30 @@ async def test_external_proxy(request):
|
||||
env['CONFIGPROXY_AUTH_TOKEN'] = auth_token
|
||||
cmd = [
|
||||
'configurable-http-proxy',
|
||||
'--ip', app.ip,
|
||||
'--port', str(app.port),
|
||||
'--api-ip', proxy_ip,
|
||||
'--api-port', str(proxy_port),
|
||||
'--ip',
|
||||
app.ip,
|
||||
'--port',
|
||||
str(app.port),
|
||||
'--api-ip',
|
||||
proxy_ip,
|
||||
'--api-port',
|
||||
str(proxy_port),
|
||||
'--log-level=debug',
|
||||
]
|
||||
if app.subdomain_host:
|
||||
cmd.append('--host-routing')
|
||||
proxy = Popen(cmd, env=env)
|
||||
|
||||
|
||||
def _cleanup_proxy():
|
||||
if proxy.poll() is None:
|
||||
proxy.terminate()
|
||||
proxy.wait(timeout=10)
|
||||
|
||||
request.addfinalizer(_cleanup_proxy)
|
||||
|
||||
def wait_for_proxy():
|
||||
return wait_for_http_server('http://%s:%i' % (proxy_ip, proxy_port))
|
||||
|
||||
await wait_for_proxy()
|
||||
|
||||
await app.initialize([])
|
||||
@@ -84,8 +90,9 @@ async def test_external_proxy(request):
|
||||
# add user to the db and start a single user server
|
||||
name = 'river'
|
||||
add_user(app.db, app, name=name)
|
||||
r = await api_request(app, 'users', name, 'server', method='post',
|
||||
bypass_proxy=True)
|
||||
r = await api_request(
|
||||
app, 'users', name, 'server', method='post', bypass_proxy=True
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
routes = await app.proxy.get_all_routes()
|
||||
@@ -122,12 +129,18 @@ async def test_external_proxy(request):
|
||||
new_auth_token = 'different!'
|
||||
env['CONFIGPROXY_AUTH_TOKEN'] = new_auth_token
|
||||
proxy_port = 55432
|
||||
cmd = ['configurable-http-proxy',
|
||||
'--ip', app.ip,
|
||||
'--port', str(app.port),
|
||||
'--api-ip', proxy_ip,
|
||||
'--api-port', str(proxy_port),
|
||||
'--default-target', 'http://%s:%i' % (app.hub_ip, app.hub_port),
|
||||
cmd = [
|
||||
'configurable-http-proxy',
|
||||
'--ip',
|
||||
app.ip,
|
||||
'--port',
|
||||
str(app.port),
|
||||
'--api-ip',
|
||||
proxy_ip,
|
||||
'--api-port',
|
||||
str(proxy_port),
|
||||
'--default-target',
|
||||
'http://%s:%i' % (app.hub_ip, app.hub_port),
|
||||
]
|
||||
if app.subdomain_host:
|
||||
cmd.append('--host-routing')
|
||||
@@ -140,10 +153,7 @@ async def test_external_proxy(request):
|
||||
app,
|
||||
'proxy',
|
||||
method='patch',
|
||||
data=json.dumps({
|
||||
'api_url': new_api_url,
|
||||
'auth_token': new_auth_token,
|
||||
}),
|
||||
data=json.dumps({'api_url': new_api_url, 'auth_token': new_auth_token}),
|
||||
bypass_proxy=True,
|
||||
)
|
||||
r.raise_for_status()
|
||||
@@ -156,13 +166,7 @@ async def test_external_proxy(request):
|
||||
assert sorted(routes.keys()) == [app.hub.routespec, user_spec]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("username", [
|
||||
'zoe',
|
||||
'50fia',
|
||||
'秀樹',
|
||||
'~TestJH',
|
||||
'has@',
|
||||
])
|
||||
@pytest.mark.parametrize("username", ['zoe', '50fia', '秀樹', '~TestJH', 'has@'])
|
||||
async def test_check_routes(app, username, disable_check_routes):
|
||||
proxy = app.proxy
|
||||
test_user = add_user(app.db, app, name=username)
|
||||
@@ -191,14 +195,17 @@ async def test_check_routes(app, username, disable_check_routes):
|
||||
assert before == after
|
||||
|
||||
|
||||
@pytest.mark.parametrize("routespec", [
|
||||
@pytest.mark.parametrize(
|
||||
"routespec",
|
||||
[
|
||||
'/has%20space/foo/',
|
||||
'/missing-trailing/slash',
|
||||
'/has/@/',
|
||||
'/has/' + quote('üñîçø∂é'),
|
||||
'host.name/path/',
|
||||
'other.host/path/no/slash',
|
||||
])
|
||||
],
|
||||
)
|
||||
async def test_add_get_delete(app, routespec, disable_check_routes):
|
||||
arg = routespec
|
||||
if not routespec.endswith('/'):
|
||||
@@ -207,6 +214,7 @@ async def test_add_get_delete(app, routespec, disable_check_routes):
|
||||
# host-routes when not host-routing raises an error
|
||||
# and vice versa
|
||||
expect_value_error = bool(app.subdomain_host) ^ (not routespec.startswith('/'))
|
||||
|
||||
@contextmanager
|
||||
def context():
|
||||
if expect_value_error:
|
||||
|
@@ -1,22 +1,26 @@
|
||||
"""Tests for services"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from binascii import hexlify
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
from subprocess import Popen
|
||||
import sys
|
||||
from threading import Event
|
||||
import time
|
||||
|
||||
from async_generator import asynccontextmanager, async_generator, yield_
|
||||
import pytest
|
||||
import requests
|
||||
from async_generator import async_generator
|
||||
from async_generator import asynccontextmanager
|
||||
from async_generator import yield_
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
|
||||
from ..utils import maybe_future
|
||||
from ..utils import random_port
|
||||
from ..utils import url_path_join
|
||||
from ..utils import wait_for_http_server
|
||||
from .mocking import public_url
|
||||
from ..utils import url_path_join, wait_for_http_server, random_port, maybe_future
|
||||
from .utils import async_requests
|
||||
|
||||
mockservice_path = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -80,12 +84,14 @@ async def test_proxy_service(app, mockservice_url):
|
||||
async def test_external_service(app):
|
||||
name = 'external'
|
||||
async with external_service(app, name=name) as env:
|
||||
app.services = [{
|
||||
app.services = [
|
||||
{
|
||||
'name': name,
|
||||
'admin': True,
|
||||
'url': env['JUPYTERHUB_SERVICE_URL'],
|
||||
'api_token': env['JUPYTERHUB_API_TOKEN'],
|
||||
}]
|
||||
}
|
||||
]
|
||||
await maybe_future(app.init_services())
|
||||
await app.init_api_tokens()
|
||||
await app.proxy.add_all_services(app._service_map)
|
||||
|
@@ -1,36 +1,43 @@
|
||||
"""Tests for service authentication"""
|
||||
import asyncio
|
||||
from binascii import hexlify
|
||||
import copy
|
||||
from functools import partial
|
||||
import json
|
||||
import os
|
||||
from queue import Queue
|
||||
import sys
|
||||
from binascii import hexlify
|
||||
from functools import partial
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from unittest import mock
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
from pytest import raises
|
||||
import requests
|
||||
import requests_mock
|
||||
|
||||
from tornado.ioloop import IOLoop
|
||||
from pytest import raises
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.web import RequestHandler, Application, authenticated, HTTPError
|
||||
from tornado.httputil import url_concat
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.web import Application
|
||||
from tornado.web import authenticated
|
||||
from tornado.web import HTTPError
|
||||
from tornado.web import RequestHandler
|
||||
|
||||
from .. import orm
|
||||
from ..services.auth import _ExpiringDict, HubAuth, HubAuthenticated
|
||||
from ..services.auth import _ExpiringDict
|
||||
from ..services.auth import HubAuth
|
||||
from ..services.auth import HubAuthenticated
|
||||
from ..utils import url_path_join
|
||||
from .mocking import public_url, public_host
|
||||
from .mocking import public_host
|
||||
from .mocking import public_url
|
||||
from .test_api import add_user
|
||||
from .utils import async_requests, AsyncSession
|
||||
from .utils import async_requests
|
||||
from .utils import AsyncSession
|
||||
|
||||
# mock for sending monotonic counter way into the future
|
||||
monotonic_future = mock.patch('time.monotonic', lambda: sys.maxsize)
|
||||
|
||||
|
||||
def test_expiring_dict():
|
||||
cache = _ExpiringDict(max_age=30)
|
||||
cache['key'] = 'cached value'
|
||||
@@ -69,9 +76,7 @@ def test_expiring_dict():
|
||||
|
||||
def test_hub_auth():
|
||||
auth = HubAuth(cookie_name='foo')
|
||||
mock_model = {
|
||||
'name': 'onyxia'
|
||||
}
|
||||
mock_model = {'name': 'onyxia'}
|
||||
url = url_path_join(auth.api_url, "authorizations/cookie/foo/bar")
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(url, text=json.dumps(mock_model))
|
||||
@@ -87,9 +92,7 @@ def test_hub_auth():
|
||||
assert user_model is None
|
||||
|
||||
# invalidate cache with timer
|
||||
mock_model = {
|
||||
'name': 'willow'
|
||||
}
|
||||
mock_model = {'name': 'willow'}
|
||||
with monotonic_future, requests_mock.Mocker() as m:
|
||||
m.get(url, text=json.dumps(mock_model))
|
||||
user_model = auth.user_for_cookie('bar')
|
||||
@@ -110,16 +113,14 @@ def test_hub_auth():
|
||||
|
||||
def test_hub_authenticated(request):
|
||||
auth = HubAuth(cookie_name='jubal')
|
||||
mock_model = {
|
||||
'name': 'jubalearly',
|
||||
'groups': ['lions'],
|
||||
}
|
||||
mock_model = {'name': 'jubalearly', 'groups': ['lions']}
|
||||
cookie_url = url_path_join(auth.api_url, "authorizations/cookie", auth.cookie_name)
|
||||
good_url = url_path_join(cookie_url, "early")
|
||||
bad_url = url_path_join(cookie_url, "late")
|
||||
|
||||
class TestHandler(HubAuthenticated, RequestHandler):
|
||||
hub_auth = auth
|
||||
|
||||
@authenticated
|
||||
def get(self):
|
||||
self.finish(self.get_current_user())
|
||||
@@ -127,11 +128,10 @@ def test_hub_authenticated(request):
|
||||
# start hub-authenticated service in a thread:
|
||||
port = 50505
|
||||
q = Queue()
|
||||
|
||||
def run():
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
app = Application([
|
||||
('/*', TestHandler),
|
||||
], login_url=auth.login_url)
|
||||
app = Application([('/*', TestHandler)], login_url=auth.login_url)
|
||||
|
||||
http_server = HTTPServer(app)
|
||||
http_server.listen(port)
|
||||
@@ -146,6 +146,7 @@ def test_hub_authenticated(request):
|
||||
loop.add_callback(loop.stop)
|
||||
t.join(timeout=30)
|
||||
assert not t.is_alive()
|
||||
|
||||
request.addfinalizer(finish_thread)
|
||||
|
||||
# wait for thread to start
|
||||
@@ -153,16 +154,15 @@ def test_hub_authenticated(request):
|
||||
|
||||
with requests_mock.Mocker(real_http=True) as m:
|
||||
# no cookie
|
||||
r = requests.get('http://127.0.0.1:%i' % port,
|
||||
allow_redirects=False,
|
||||
)
|
||||
r = requests.get('http://127.0.0.1:%i' % port, allow_redirects=False)
|
||||
r.raise_for_status()
|
||||
assert r.status_code == 302
|
||||
assert auth.login_url in r.headers['Location']
|
||||
|
||||
# wrong cookie
|
||||
m.get(bad_url, status_code=404)
|
||||
r = requests.get('http://127.0.0.1:%i' % port,
|
||||
r = requests.get(
|
||||
'http://127.0.0.1:%i' % port,
|
||||
cookies={'jubal': 'late'},
|
||||
allow_redirects=False,
|
||||
)
|
||||
@@ -176,7 +176,8 @@ def test_hub_authenticated(request):
|
||||
|
||||
# upstream 403
|
||||
m.get(bad_url, status_code=403)
|
||||
r = requests.get('http://127.0.0.1:%i' % port,
|
||||
r = requests.get(
|
||||
'http://127.0.0.1:%i' % port,
|
||||
cookies={'jubal': 'late'},
|
||||
allow_redirects=False,
|
||||
)
|
||||
@@ -185,7 +186,8 @@ def test_hub_authenticated(request):
|
||||
m.get(good_url, text=json.dumps(mock_model))
|
||||
|
||||
# no whitelist
|
||||
r = requests.get('http://127.0.0.1:%i' % port,
|
||||
r = requests.get(
|
||||
'http://127.0.0.1:%i' % port,
|
||||
cookies={'jubal': 'early'},
|
||||
allow_redirects=False,
|
||||
)
|
||||
@@ -194,7 +196,8 @@ def test_hub_authenticated(request):
|
||||
|
||||
# pass whitelist
|
||||
TestHandler.hub_users = {'jubalearly'}
|
||||
r = requests.get('http://127.0.0.1:%i' % port,
|
||||
r = requests.get(
|
||||
'http://127.0.0.1:%i' % port,
|
||||
cookies={'jubal': 'early'},
|
||||
allow_redirects=False,
|
||||
)
|
||||
@@ -203,7 +206,8 @@ def test_hub_authenticated(request):
|
||||
|
||||
# no pass whitelist
|
||||
TestHandler.hub_users = {'kaylee'}
|
||||
r = requests.get('http://127.0.0.1:%i' % port,
|
||||
r = requests.get(
|
||||
'http://127.0.0.1:%i' % port,
|
||||
cookies={'jubal': 'early'},
|
||||
allow_redirects=False,
|
||||
)
|
||||
@@ -211,7 +215,8 @@ def test_hub_authenticated(request):
|
||||
|
||||
# pass group whitelist
|
||||
TestHandler.hub_groups = {'lions'}
|
||||
r = requests.get('http://127.0.0.1:%i' % port,
|
||||
r = requests.get(
|
||||
'http://127.0.0.1:%i' % port,
|
||||
cookies={'jubal': 'early'},
|
||||
allow_redirects=False,
|
||||
)
|
||||
@@ -220,7 +225,8 @@ def test_hub_authenticated(request):
|
||||
|
||||
# no pass group whitelist
|
||||
TestHandler.hub_groups = {'tigers'}
|
||||
r = requests.get('http://127.0.0.1:%i' % port,
|
||||
r = requests.get(
|
||||
'http://127.0.0.1:%i' % port,
|
||||
cookies={'jubal': 'early'},
|
||||
allow_redirects=False,
|
||||
)
|
||||
@@ -230,15 +236,14 @@ def test_hub_authenticated(request):
|
||||
async def test_hubauth_cookie(app, mockservice_url):
|
||||
"""Test HubAuthenticated service with user cookies"""
|
||||
cookies = await app.login_user('badger')
|
||||
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/', cookies=cookies)
|
||||
r = await async_requests.get(
|
||||
public_url(app, mockservice_url) + '/whoami/', cookies=cookies
|
||||
)
|
||||
r.raise_for_status()
|
||||
print(r.text)
|
||||
reply = r.json()
|
||||
sub_reply = {key: reply.get(key, 'missing') for key in ['name', 'admin']}
|
||||
assert sub_reply == {
|
||||
'name': 'badger',
|
||||
'admin': False,
|
||||
}
|
||||
assert sub_reply == {'name': 'badger', 'admin': False}
|
||||
|
||||
|
||||
async def test_hubauth_token(app, mockservice_url):
|
||||
@@ -248,28 +253,25 @@ async def test_hubauth_token(app, mockservice_url):
|
||||
app.db.commit()
|
||||
|
||||
# token in Authorization header
|
||||
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/',
|
||||
headers={
|
||||
'Authorization': 'token %s' % token,
|
||||
})
|
||||
r = await async_requests.get(
|
||||
public_url(app, mockservice_url) + '/whoami/',
|
||||
headers={'Authorization': 'token %s' % token},
|
||||
)
|
||||
reply = r.json()
|
||||
sub_reply = {key: reply.get(key, 'missing') for key in ['name', 'admin']}
|
||||
assert sub_reply == {
|
||||
'name': 'river',
|
||||
'admin': False,
|
||||
}
|
||||
assert sub_reply == {'name': 'river', 'admin': False}
|
||||
|
||||
# token in ?token parameter
|
||||
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/?token=%s' % token)
|
||||
r = await async_requests.get(
|
||||
public_url(app, mockservice_url) + '/whoami/?token=%s' % token
|
||||
)
|
||||
r.raise_for_status()
|
||||
reply = r.json()
|
||||
sub_reply = {key: reply.get(key, 'missing') for key in ['name', 'admin']}
|
||||
assert sub_reply == {
|
||||
'name': 'river',
|
||||
'admin': False,
|
||||
}
|
||||
assert sub_reply == {'name': 'river', 'admin': False}
|
||||
|
||||
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/?token=no-such-token',
|
||||
r = await async_requests.get(
|
||||
public_url(app, mockservice_url) + '/whoami/?token=no-such-token',
|
||||
allow_redirects=False,
|
||||
)
|
||||
assert r.status_code == 302
|
||||
@@ -288,30 +290,25 @@ async def test_hubauth_service_token(app, mockservice_url):
|
||||
await app.init_api_tokens()
|
||||
|
||||
# token in Authorization header
|
||||
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/',
|
||||
headers={
|
||||
'Authorization': 'token %s' % token,
|
||||
})
|
||||
r = await async_requests.get(
|
||||
public_url(app, mockservice_url) + '/whoami/',
|
||||
headers={'Authorization': 'token %s' % token},
|
||||
)
|
||||
r.raise_for_status()
|
||||
reply = r.json()
|
||||
assert reply == {
|
||||
'kind': 'service',
|
||||
'name': name,
|
||||
'admin': False,
|
||||
}
|
||||
assert reply == {'kind': 'service', 'name': name, 'admin': False}
|
||||
assert not r.cookies
|
||||
|
||||
# token in ?token parameter
|
||||
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/?token=%s' % token)
|
||||
r = await async_requests.get(
|
||||
public_url(app, mockservice_url) + '/whoami/?token=%s' % token
|
||||
)
|
||||
r.raise_for_status()
|
||||
reply = r.json()
|
||||
assert reply == {
|
||||
'kind': 'service',
|
||||
'name': name,
|
||||
'admin': False,
|
||||
}
|
||||
assert reply == {'kind': 'service', 'name': name, 'admin': False}
|
||||
|
||||
r = await async_requests.get(public_url(app, mockservice_url) + '/whoami/?token=no-such-token',
|
||||
r = await async_requests.get(
|
||||
public_url(app, mockservice_url) + '/whoami/?token=no-such-token',
|
||||
allow_redirects=False,
|
||||
)
|
||||
assert r.status_code == 302
|
||||
@@ -351,10 +348,7 @@ async def test_oauth_service(app, mockservice_url):
|
||||
assert r.status_code == 200
|
||||
reply = r.json()
|
||||
sub_reply = {key: reply.get(key, 'missing') for key in ('kind', 'name')}
|
||||
assert sub_reply == {
|
||||
'name': 'link',
|
||||
'kind': 'user',
|
||||
}
|
||||
assert sub_reply == {'name': 'link', 'kind': 'user'}
|
||||
|
||||
# token-authenticated request to HubOAuth
|
||||
token = app.users[name].new_api_token()
|
||||
@@ -367,11 +361,7 @@ async def test_oauth_service(app, mockservice_url):
|
||||
# verify that ?token= requests set a cookie
|
||||
assert len(r.cookies) != 0
|
||||
# ensure cookie works in future requests
|
||||
r = await async_requests.get(
|
||||
url,
|
||||
cookies=r.cookies,
|
||||
allow_redirects=False,
|
||||
)
|
||||
r = await async_requests.get(url, cookies=r.cookies, allow_redirects=False)
|
||||
r.raise_for_status()
|
||||
assert r.url == url
|
||||
reply = r.json()
|
||||
@@ -408,9 +398,7 @@ async def test_oauth_cookie_collision(app, mockservice_url):
|
||||
# finish oauth 2
|
||||
# submit the oauth form to complete authorization
|
||||
r = await s.post(
|
||||
oauth_2.url,
|
||||
data={'scopes': ['identify']},
|
||||
headers={'Referer': oauth_2.url},
|
||||
oauth_2.url, data={'scopes': ['identify']}, headers={'Referer': oauth_2.url}
|
||||
)
|
||||
r.raise_for_status()
|
||||
assert r.url == url
|
||||
@@ -422,9 +410,7 @@ async def test_oauth_cookie_collision(app, mockservice_url):
|
||||
|
||||
# finish oauth 1
|
||||
r = await s.post(
|
||||
oauth_1.url,
|
||||
data={'scopes': ['identify']},
|
||||
headers={'Referer': oauth_1.url},
|
||||
oauth_1.url, data={'scopes': ['identify']}, headers={'Referer': oauth_1.url}
|
||||
)
|
||||
r.raise_for_status()
|
||||
assert r.url == url
|
||||
@@ -455,11 +441,13 @@ async def test_oauth_logout(app, mockservice_url):
|
||||
s = AsyncSession()
|
||||
name = 'propha'
|
||||
app_user = add_user(app.db, app=app, name=name)
|
||||
|
||||
def auth_tokens():
|
||||
"""Return list of OAuth access tokens for the user"""
|
||||
return list(
|
||||
app.db.query(orm.OAuthAccessToken).filter(
|
||||
orm.OAuthAccessToken.user_id == app_user.id)
|
||||
orm.OAuthAccessToken.user_id == app_user.id
|
||||
)
|
||||
)
|
||||
|
||||
# ensure we start empty
|
||||
@@ -480,14 +468,8 @@ async def test_oauth_logout(app, mockservice_url):
|
||||
r.raise_for_status()
|
||||
assert r.status_code == 200
|
||||
reply = r.json()
|
||||
sub_reply = {
|
||||
key: reply.get(key, 'missing')
|
||||
for key in ('kind', 'name')
|
||||
}
|
||||
assert sub_reply == {
|
||||
'name': name,
|
||||
'kind': 'user',
|
||||
}
|
||||
sub_reply = {key: reply.get(key, 'missing') for key in ('kind', 'name')}
|
||||
assert sub_reply == {'name': name, 'kind': 'user'}
|
||||
# save cookies to verify cache
|
||||
saved_cookies = copy.deepcopy(s.cookies)
|
||||
session_id = s.cookies['jupyterhub-session-id']
|
||||
@@ -522,11 +504,5 @@ async def test_oauth_logout(app, mockservice_url):
|
||||
r.raise_for_status()
|
||||
assert r.status_code == 200
|
||||
reply = r.json()
|
||||
sub_reply = {
|
||||
key: reply.get(key, 'missing')
|
||||
for key in ('kind', 'name')
|
||||
}
|
||||
assert sub_reply == {
|
||||
'name': name,
|
||||
'kind': 'user',
|
||||
}
|
||||
sub_reply = {key: reply.get(key, 'missing') for key in ('kind', 'name')}
|
||||
assert sub_reply == {'name': name, 'kind': 'user'}
|
||||
|
@@ -1,16 +1,16 @@
|
||||
"""Tests for jupyterhub.singleuser"""
|
||||
|
||||
from subprocess import check_output
|
||||
import sys
|
||||
from subprocess import check_output
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
|
||||
import jupyterhub
|
||||
from .mocking import StubSingleUserSpawner, public_url
|
||||
from ..utils import url_path_join
|
||||
|
||||
from .utils import async_requests, AsyncSession
|
||||
from .mocking import public_url
|
||||
from .mocking import StubSingleUserSpawner
|
||||
from .utils import async_requests
|
||||
from .utils import AsyncSession
|
||||
|
||||
|
||||
async def test_singleuser_auth(app):
|
||||
@@ -47,11 +47,7 @@ async def test_singleuser_auth(app):
|
||||
r = await s.get(url)
|
||||
assert urlparse(r.url).path.endswith('/oauth2/authorize')
|
||||
# submit the oauth form to complete authorization
|
||||
r = await s.post(
|
||||
r.url,
|
||||
data={'scopes': ['identify']},
|
||||
headers={'Referer': r.url},
|
||||
)
|
||||
r = await s.post(r.url, data={'scopes': ['identify']}, headers={'Referer': r.url})
|
||||
assert urlparse(r.url).path.rstrip('/').endswith('/user/nandy/tree')
|
||||
# user isn't authorized, should raise 403
|
||||
assert r.status_code == 403
|
||||
@@ -85,11 +81,14 @@ async def test_disable_user_config(app):
|
||||
|
||||
|
||||
def test_help_output():
|
||||
out = check_output([sys.executable, '-m', 'jupyterhub.singleuser', '--help-all']).decode('utf8', 'replace')
|
||||
out = check_output(
|
||||
[sys.executable, '-m', 'jupyterhub.singleuser', '--help-all']
|
||||
).decode('utf8', 'replace')
|
||||
assert 'JupyterHub' in out
|
||||
|
||||
|
||||
def test_version():
|
||||
out = check_output([sys.executable, '-m', 'jupyterhub.singleuser', '--version']).decode('utf8', 'replace')
|
||||
out = check_output(
|
||||
[sys.executable, '-m', 'jupyterhub.singleuser', '--version']
|
||||
).decode('utf8', 'replace')
|
||||
assert jupyterhub.__version__ in out
|
||||
|
||||
|
@@ -1,27 +1,28 @@
|
||||
"""Tests for process spawning"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
from subprocess import Popen
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from subprocess import Popen
|
||||
from unittest import mock
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
from tornado import gen
|
||||
|
||||
from ..objects import Hub, Server
|
||||
from .. import orm
|
||||
from .. import spawner as spawnermod
|
||||
from ..spawner import LocalProcessSpawner, Spawner
|
||||
from ..objects import Hub
|
||||
from ..objects import Server
|
||||
from ..spawner import LocalProcessSpawner
|
||||
from ..spawner import Spawner
|
||||
from ..user import User
|
||||
from ..utils import new_token, url_path_join
|
||||
from ..utils import new_token
|
||||
from ..utils import url_path_join
|
||||
from .mocking import public_url
|
||||
from .test_api import add_user
|
||||
from .utils import async_requests
|
||||
@@ -84,8 +85,10 @@ async def wait_for_spawner(spawner, timeout=10):
|
||||
polling at shorter intervals for early termination
|
||||
"""
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
def wait():
|
||||
return spawner.server.wait_up(timeout=1, http=True)
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
status = await spawner.poll()
|
||||
assert status is None
|
||||
@@ -187,11 +190,13 @@ def test_setcwd():
|
||||
os.chdir(cwd)
|
||||
chdir = os.chdir
|
||||
temp_root = os.path.realpath(os.path.abspath(tempfile.gettempdir()))
|
||||
|
||||
def raiser(path):
|
||||
path = os.path.realpath(os.path.abspath(path))
|
||||
if not path.startswith(temp_root):
|
||||
raise OSError(path)
|
||||
chdir(path)
|
||||
|
||||
with mock.patch('os.chdir', raiser):
|
||||
spawnermod._try_setcwd(cwd)
|
||||
assert os.getcwd().startswith(temp_root)
|
||||
@@ -209,6 +214,7 @@ def test_string_formatting(db):
|
||||
|
||||
async def test_popen_kwargs(db):
|
||||
mock_proc = mock.Mock(spec=Popen)
|
||||
|
||||
def mock_popen(*args, **kwargs):
|
||||
mock_proc.args = args
|
||||
mock_proc.kwargs = kwargs
|
||||
@@ -226,7 +232,8 @@ async def test_popen_kwargs(db):
|
||||
async def test_shell_cmd(db, tmpdir, request):
|
||||
f = tmpdir.join('bashrc')
|
||||
f.write('export TESTVAR=foo\n')
|
||||
s = new_spawner(db,
|
||||
s = new_spawner(
|
||||
db,
|
||||
cmd=[sys.executable, '-m', 'jupyterhub.tests.mocksu'],
|
||||
shell_cmd=['bash', '--rcfile', str(f), '-i', '-c'],
|
||||
)
|
||||
@@ -253,6 +260,7 @@ def test_inherit_overwrite():
|
||||
"""
|
||||
if sys.version_info >= (3, 6):
|
||||
with pytest.raises(NotImplementedError):
|
||||
|
||||
class S(Spawner):
|
||||
pass
|
||||
|
||||
@@ -372,19 +380,14 @@ async def test_spawner_delete_server(app):
|
||||
assert spawner.server is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
[
|
||||
"has@x",
|
||||
"has~x",
|
||||
"has%x",
|
||||
"has%40x",
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize("name", ["has@x", "has~x", "has%x", "has%40x"])
|
||||
async def test_spawner_routing(app, name):
|
||||
"""Test routing of names with special characters"""
|
||||
db = app.db
|
||||
with mock.patch.dict(app.config.LocalProcessSpawner, {'cmd': [sys.executable, '-m', 'jupyterhub.tests.mocksu']}):
|
||||
with mock.patch.dict(
|
||||
app.config.LocalProcessSpawner,
|
||||
{'cmd': [sys.executable, '-m', 'jupyterhub.tests.mocksu']},
|
||||
):
|
||||
user = add_user(app.db, app, name=name)
|
||||
await user.spawn()
|
||||
await wait_for_spawner(user.spawner)
|
||||
|
@@ -1,12 +1,16 @@
|
||||
import pytest
|
||||
from traitlets import HasTraits, TraitError
|
||||
from traitlets import HasTraits
|
||||
from traitlets import TraitError
|
||||
|
||||
from jupyterhub.traitlets import URLPrefix, Command, ByteSpecification
|
||||
from jupyterhub.traitlets import ByteSpecification
|
||||
from jupyterhub.traitlets import Command
|
||||
from jupyterhub.traitlets import URLPrefix
|
||||
|
||||
|
||||
def test_url_prefix():
|
||||
class C(HasTraits):
|
||||
url = URLPrefix()
|
||||
|
||||
c = C()
|
||||
c.url = '/a/b/c/'
|
||||
assert c.url == '/a/b/c/'
|
||||
@@ -20,6 +24,7 @@ def test_command():
|
||||
class C(HasTraits):
|
||||
cmd = Command('default command')
|
||||
cmd2 = Command(['default_cmd'])
|
||||
|
||||
c = C()
|
||||
assert c.cmd == ['default command']
|
||||
assert c.cmd2 == ['default_cmd']
|
||||
|
@@ -1,9 +1,11 @@
|
||||
"""Tests for utilities"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from async_generator import aclosing, async_generator, yield_
|
||||
import pytest
|
||||
from async_generator import aclosing
|
||||
from async_generator import async_generator
|
||||
from async_generator import yield_
|
||||
|
||||
from ..utils import iterate_until
|
||||
|
||||
|
||||
@@ -26,12 +28,15 @@ def schedule_future(io_loop, *, delay, result=None):
|
||||
return f
|
||||
|
||||
|
||||
@pytest.mark.parametrize("deadline, n, delay, expected", [
|
||||
@pytest.mark.parametrize(
|
||||
"deadline, n, delay, expected",
|
||||
[
|
||||
(0, 3, 1, []),
|
||||
(0, 3, 0, [0, 1, 2]),
|
||||
(5, 3, 0.01, [0, 1, 2]),
|
||||
(0.5, 10, 0.2, [0, 1]),
|
||||
])
|
||||
],
|
||||
)
|
||||
async def test_iterate_until(io_loop, deadline, n, delay, expected):
|
||||
f = schedule_future(io_loop, delay=deadline)
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user