mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-12 12:33:02 +00:00
494 lines
17 KiB
Python
494 lines
17 KiB
Python
# Copyright (c) Jupyter Development Team.
|
|
# Distributed under the terms of the Modified BSD License.
|
|
|
|
from collections import defaultdict
|
|
from datetime import datetime, timedelta
|
|
from urllib.parse import quote, urlparse
|
|
import warnings
|
|
|
|
from oauth2.error import ClientNotFoundError
|
|
from sqlalchemy import inspect
|
|
from tornado import gen
|
|
from tornado.log import app_log
|
|
from traitlets import HasTraits, Any, Dict, default
|
|
|
|
from .utils import url_path_join, default_server_name
|
|
|
|
from . import orm
|
|
from ._version import _check_version, __version__
|
|
from .objects import Server
|
|
from .spawner import LocalProcessSpawner
|
|
from .crypto import encrypt, decrypt, CryptKeeper, EncryptionUnavailable, InvalidToken
|
|
|
|
class UserDict(dict):
|
|
"""Like defaultdict, but for users
|
|
|
|
Getting by a user id OR an orm.User instance returns a User wrapper around the orm user.
|
|
"""
|
|
def __init__(self, db_factory, settings):
|
|
self.db_factory = db_factory
|
|
self.settings = settings
|
|
super().__init__()
|
|
|
|
@property
|
|
def db(self):
|
|
return self.db_factory()
|
|
|
|
def __contains__(self, key):
|
|
if isinstance(key, (User, orm.User)):
|
|
key = key.id
|
|
return dict.__contains__(self, key)
|
|
|
|
def __getitem__(self, key):
|
|
if isinstance(key, User):
|
|
key = key.id
|
|
elif isinstance(key, str):
|
|
orm_user = self.db.query(orm.User).filter(orm.User.name == key).first()
|
|
if orm_user is None:
|
|
raise KeyError("No such user: %s" % key)
|
|
else:
|
|
key = orm_user
|
|
if isinstance(key, orm.User):
|
|
# users[orm_user] returns User(orm_user)
|
|
orm_user = key
|
|
if orm_user.id not in self:
|
|
user = self[orm_user.id] = User(orm_user, self.settings)
|
|
return user
|
|
user = dict.__getitem__(self, orm_user.id)
|
|
user.db = self.db
|
|
return user
|
|
elif isinstance(key, int):
|
|
id = key
|
|
if id not in self:
|
|
orm_user = self.db.query(orm.User).filter(orm.User.id == id).first()
|
|
if orm_user is None:
|
|
raise KeyError("No such user: %s" % id)
|
|
user = self[id] = User(orm_user, self.settings)
|
|
return dict.__getitem__(self, id)
|
|
else:
|
|
raise KeyError(repr(key))
|
|
|
|
def __delitem__(self, key):
|
|
user = self[key]
|
|
user_id = user.id
|
|
db = self.db
|
|
db.delete(user.orm_user)
|
|
db.commit()
|
|
dict.__delitem__(self, user_id)
|
|
|
|
def count_active_users(self):
|
|
"""Count the number of user servers that are active/pending/ready
|
|
|
|
Returns dict with counts of active/pending/ready servers
|
|
"""
|
|
counts = defaultdict(lambda : 0)
|
|
for user in self.values():
|
|
for spawner in user.spawners.values():
|
|
pending = spawner.pending
|
|
if pending:
|
|
counts['pending'] += 1
|
|
counts[pending + '_pending'] += 1
|
|
if spawner.active:
|
|
counts['active'] += 1
|
|
if spawner.ready:
|
|
counts['ready'] += 1
|
|
|
|
return counts
|
|
|
|
|
|
class _SpawnerDict(dict):
|
|
def __init__(self, spawner_factory):
|
|
self.spawner_factory = spawner_factory
|
|
|
|
def __getitem__(self, key):
|
|
if key not in self:
|
|
self[key] = self.spawner_factory(key)
|
|
return super().__getitem__(key)
|
|
|
|
|
|
class User(HasTraits):
|
|
|
|
@default('log')
|
|
def _log_default(self):
|
|
return app_log
|
|
|
|
spawners = None
|
|
settings = Dict()
|
|
|
|
db = Any(allow_none=True)
|
|
|
|
@default('db')
|
|
def _db_default(self):
|
|
if self.orm_user:
|
|
return inspect(self.orm_user).session
|
|
|
|
orm_user = Any(allow_none=True)
|
|
|
|
@gen.coroutine
|
|
def save_auth_state(self, auth_state):
|
|
"""Encrypt and store auth_state"""
|
|
if auth_state is None:
|
|
self.encrypted_auth_state = None
|
|
else:
|
|
self.encrypted_auth_state = yield encrypt(auth_state)
|
|
self.db.commit()
|
|
|
|
@gen.coroutine
|
|
def get_auth_state(self):
|
|
"""Retrieve and decrypt auth_state for the user"""
|
|
encrypted = self.encrypted_auth_state
|
|
if encrypted is None:
|
|
return None
|
|
try:
|
|
auth_state = yield decrypt(encrypted)
|
|
except (ValueError, InvalidToken, EncryptionUnavailable) as e:
|
|
self.log.warning("Failed to retrieve encrypted auth_state for %s because %s",
|
|
self.name, e,
|
|
)
|
|
return
|
|
# loading auth_state
|
|
if auth_state:
|
|
# Crypt has multiple keys, store again with new key for rotation.
|
|
if len(CryptKeeper.instance().keys) > 1:
|
|
yield self.save_auth_state(auth_state)
|
|
return auth_state
|
|
|
|
@property
|
|
def authenticator(self):
|
|
return self.settings.get('authenticator', None)
|
|
|
|
@property
|
|
def spawner_class(self):
|
|
return self.settings.get('spawner_class', LocalProcessSpawner)
|
|
|
|
def __init__(self, orm_user, settings=None, **kwargs):
|
|
if settings:
|
|
kwargs['settings'] = settings
|
|
kwargs['orm_user'] = orm_user
|
|
super().__init__(**kwargs)
|
|
|
|
self.allow_named_servers = self.settings.get('allow_named_servers', False)
|
|
|
|
self.base_url = self.prefix = url_path_join(
|
|
self.settings.get('base_url', '/'), 'user', self.escaped_name) + '/'
|
|
|
|
self.spawners = _SpawnerDict(self._new_spawner)
|
|
# load existing named spawners
|
|
for name in self.orm_spawners:
|
|
self.spawners[name] = self._new_spawner(name)
|
|
|
|
def _new_spawner(self, name, spawner_class=None, **kwargs):
|
|
"""Create a new spawner"""
|
|
if spawner_class is None:
|
|
spawner_class = self.spawner_class
|
|
self.log.debug("Creating %s for %s:%s", spawner_class, self.name, name)
|
|
|
|
orm_spawner = self.orm_spawners.get(name)
|
|
if orm_spawner is None:
|
|
orm_spawner = orm.Spawner(user=self.orm_user, name=name)
|
|
self.db.add(orm_spawner)
|
|
self.db.commit()
|
|
assert name in self.orm_spawners
|
|
if name == '' and self.state:
|
|
# migrate user.state to spawner.state
|
|
orm_spawner.state = self.state
|
|
self.state = None
|
|
|
|
spawn_kwargs = dict(
|
|
user=self,
|
|
orm_spawner=orm_spawner,
|
|
hub=self.settings.get('hub'),
|
|
authenticator=self.authenticator,
|
|
config=self.settings.get('config'),
|
|
proxy_spec=url_path_join(self.proxy_spec, name, '/'),
|
|
)
|
|
# update with kwargs. Mainly for testing.
|
|
spawn_kwargs.update(kwargs)
|
|
spawner = spawner_class(**spawn_kwargs)
|
|
spawner.load_state(orm_spawner.state or {})
|
|
return spawner
|
|
|
|
# singleton property, self.spawner maps onto spawner with empty server_name
|
|
@property
|
|
def spawner(self):
|
|
return self.spawners['']
|
|
|
|
@spawner.setter
|
|
def spawner(self, spawner):
|
|
self.spawners[''] = spawner
|
|
|
|
# pass get/setattr to ORM user
|
|
def __getattr__(self, attr):
|
|
if hasattr(self.orm_user, attr):
|
|
return getattr(self.orm_user, attr)
|
|
else:
|
|
raise AttributeError(attr)
|
|
|
|
def __setattr__(self, attr, value):
|
|
if not attr.startswith('_') and self.orm_user and hasattr(self.orm_user, attr):
|
|
setattr(self.orm_user, attr, value)
|
|
else:
|
|
super().__setattr__(attr, value)
|
|
|
|
def __repr__(self):
|
|
return repr(self.orm_user)
|
|
|
|
@property
|
|
def running(self):
|
|
"""property for whether the user's default server is running"""
|
|
return self.spawner.ready
|
|
|
|
@property
|
|
def active(self):
|
|
"""True if any server is active"""
|
|
return any(s.active for s in self.spawners.values())
|
|
|
|
@property
|
|
def spawn_pending(self):
|
|
warnings.warn("User.spawn_pending is deprecated in JupyterHub 0.8. Use Spawner.pending",
|
|
DeprecationWarning,
|
|
)
|
|
return self.spawner.pending == 'spawn'
|
|
|
|
@property
|
|
def stop_pending(self):
|
|
warnings.warn("User.stop_pending is deprecated in JupyterHub 0.8. Use Spawner.pending",
|
|
DeprecationWarning,
|
|
)
|
|
return self.spawner.pending == 'stop'
|
|
|
|
@property
|
|
def server(self):
|
|
return self.spawner.server
|
|
|
|
@property
|
|
def escaped_name(self):
|
|
"""My name, escaped for use in URLs, cookies, etc."""
|
|
return quote(self.name, safe='@')
|
|
|
|
@property
|
|
def proxy_spec(self):
|
|
"""The proxy routespec for my default server"""
|
|
if self.settings.get('subdomain_host'):
|
|
return url_path_join(self.domain, self.base_url, '/')
|
|
else:
|
|
return url_path_join(self.base_url, '/')
|
|
|
|
@property
|
|
def domain(self):
|
|
"""Get the domain for my server."""
|
|
# FIXME: escaped_name probably isn't escaped enough in general for a domain fragment
|
|
return self.escaped_name + '.' + self.settings['domain']
|
|
|
|
@property
|
|
def host(self):
|
|
"""Get the *host* for my server (proto://domain[:port])"""
|
|
# FIXME: escaped_name probably isn't escaped enough in general for a domain fragment
|
|
parsed = urlparse(self.settings['subdomain_host'])
|
|
h = '%s://%s.%s' % (parsed.scheme, self.escaped_name, parsed.netloc)
|
|
return h
|
|
|
|
@property
|
|
def url(self):
|
|
"""My URL
|
|
|
|
Full name.domain/path if using subdomains, otherwise just my /base/url
|
|
"""
|
|
if self.settings.get('subdomain_host'):
|
|
return '{host}{path}'.format(
|
|
host=self.host,
|
|
path=self.base_url,
|
|
)
|
|
else:
|
|
return self.base_url
|
|
|
|
@gen.coroutine
|
|
def spawn(self, server_name='', options=None):
|
|
"""Start the user's spawner
|
|
|
|
depending from the value of JupyterHub.allow_named_servers
|
|
|
|
if False:
|
|
JupyterHub expects only one single-server per user
|
|
url of the server will be /user/:name
|
|
|
|
if True:
|
|
JupyterHub expects more than one single-server per user
|
|
url of the server will be /user/:name/:server_name
|
|
"""
|
|
db = self.db
|
|
if self.allow_named_servers and not server_name:
|
|
server_name = default_server_name(self)
|
|
|
|
base_url = url_path_join(self.base_url, server_name) + '/'
|
|
|
|
orm_server = orm.Server(
|
|
base_url=base_url,
|
|
)
|
|
db.add(orm_server)
|
|
|
|
api_token = self.new_api_token()
|
|
db.commit()
|
|
|
|
|
|
spawner = self.spawners[server_name]
|
|
spawner.server = server = Server(orm_server=orm_server)
|
|
assert spawner.orm_spawner.server is orm_server
|
|
|
|
# Passing user_options to the spawner
|
|
spawner.user_options = options or {}
|
|
# we are starting a new server, make sure it doesn't restore state
|
|
spawner.clear_state()
|
|
|
|
# create API and OAuth tokens
|
|
spawner.api_token = api_token
|
|
spawner.admin_access = self.settings.get('admin_access', False)
|
|
# use fully quoted name for client_id because it will be used in cookie-name
|
|
# self.escaped_name may contain @ which is legal in URLs but not cookie keys
|
|
client_id = 'user-%s' % quote(self.name)
|
|
if server_name:
|
|
client_id = '%s-%s' % (client_id, quote(server_name))
|
|
spawner.oauth_client_id = client_id
|
|
oauth_provider = self.settings.get('oauth_provider')
|
|
if oauth_provider:
|
|
client_store = oauth_provider.client_authenticator.client_store
|
|
try:
|
|
oauth_client = client_store.fetch_by_client_id(client_id)
|
|
except ClientNotFoundError:
|
|
oauth_client = None
|
|
# create a new OAuth client + secret on every launch,
|
|
# except for resuming containers.
|
|
if oauth_client is None or not spawner.will_resume:
|
|
client_store.add_client(client_id, api_token,
|
|
url_path_join(self.url, 'oauth_callback'),
|
|
)
|
|
db.commit()
|
|
|
|
# trigger pre-spawn hook on authenticator
|
|
authenticator = self.authenticator
|
|
if (authenticator):
|
|
yield gen.maybe_future(authenticator.pre_spawn_start(self, spawner))
|
|
|
|
spawner._spawn_pending = True
|
|
# wait for spawner.start to return
|
|
try:
|
|
# run optional preparation work to bootstrap the notebook
|
|
yield gen.maybe_future(self.spawner.run_pre_spawn_hook())
|
|
f = spawner.start()
|
|
# commit any changes in spawner.start (always commit db changes before yield)
|
|
db.commit()
|
|
ip_port = yield gen.with_timeout(timedelta(seconds=spawner.start_timeout), f)
|
|
if ip_port:
|
|
# get ip, port info from return value of start()
|
|
server.ip, server.port = ip_port
|
|
else:
|
|
# prior to 0.7, spawners had to store this info in user.server themselves.
|
|
# Handle < 0.7 behavior with a warning, assuming info was stored in db by the Spawner.
|
|
self.log.warning("DEPRECATION: Spawner.start should return (ip, port) in JupyterHub >= 0.7")
|
|
if spawner.api_token != api_token:
|
|
# Spawner re-used an API token, discard the unused api_token
|
|
orm_token = orm.APIToken.find(self.db, api_token)
|
|
if orm_token is not None:
|
|
self.db.delete(orm_token)
|
|
self.db.commit()
|
|
except Exception as e:
|
|
if isinstance(e, gen.TimeoutError):
|
|
self.log.warning("{user}'s server failed to start in {s} seconds, giving up".format(
|
|
user=self.name, s=spawner.start_timeout,
|
|
))
|
|
e.reason = 'timeout'
|
|
else:
|
|
self.log.error("Unhandled error starting {user}'s server: {error}".format(
|
|
user=self.name, error=e,
|
|
))
|
|
e.reason = 'error'
|
|
try:
|
|
yield self.stop()
|
|
except Exception:
|
|
self.log.error("Failed to cleanup {user}'s server that failed to start".format(
|
|
user=self.name,
|
|
), exc_info=True)
|
|
# raise original exception
|
|
raise e
|
|
spawner.start_polling()
|
|
|
|
# store state
|
|
if self.state is None:
|
|
self.state = {}
|
|
spawner.orm_spawner.state = spawner.get_state()
|
|
self.last_activity = spawner.orm_spawner.last_activity = datetime.utcnow()
|
|
db.commit()
|
|
spawner._waiting_for_response = True
|
|
try:
|
|
resp = yield server.wait_up(http=True, timeout=spawner.http_timeout)
|
|
except Exception as e:
|
|
if isinstance(e, TimeoutError):
|
|
self.log.warning(
|
|
"{user}'s server never showed up at {url} "
|
|
"after {http_timeout} seconds. Giving up".format(
|
|
user=self.name,
|
|
url=server.url,
|
|
http_timeout=spawner.http_timeout,
|
|
)
|
|
)
|
|
e.reason = 'timeout'
|
|
else:
|
|
e.reason = 'error'
|
|
self.log.error("Unhandled error waiting for {user}'s server to show up at {url}: {error}".format(
|
|
user=self.name, url=server.url, error=e,
|
|
))
|
|
try:
|
|
yield self.stop()
|
|
except Exception:
|
|
self.log.error("Failed to cleanup {user}'s server that failed to start".format(
|
|
user=self.name,
|
|
), exc_info=True)
|
|
# raise original TimeoutError
|
|
raise e
|
|
else:
|
|
server_version = resp.headers.get('X-JupyterHub-Version')
|
|
_check_version(__version__, server_version, self.log)
|
|
finally:
|
|
spawner._waiting_for_response = False
|
|
spawner._spawn_pending = False
|
|
return self
|
|
|
|
@gen.coroutine
|
|
def stop(self, server_name=''):
|
|
"""Stop the user's spawner
|
|
|
|
and cleanup after it.
|
|
"""
|
|
spawner = self.spawners[server_name]
|
|
spawner._spawn_pending = False
|
|
spawner.stop_polling()
|
|
spawner._stop_pending = True
|
|
try:
|
|
api_token = spawner.api_token
|
|
status = yield spawner.poll()
|
|
if status is None:
|
|
yield spawner.stop()
|
|
spawner.clear_state()
|
|
spawner.orm_spawner.state = spawner.get_state()
|
|
self.last_activity = spawner.orm_spawner.last_activity = datetime.utcnow()
|
|
# remove server entry from db
|
|
spawner.server = None
|
|
if not spawner.will_resume:
|
|
# find and remove the API token if the spawner isn't
|
|
# going to re-use it next time
|
|
orm_token = orm.APIToken.find(self.db, api_token)
|
|
if orm_token:
|
|
self.db.delete(orm_token)
|
|
self.db.commit()
|
|
finally:
|
|
# trigger post-spawner hook on authenticator
|
|
auth = spawner.authenticator
|
|
try:
|
|
if auth:
|
|
yield gen.maybe_future(
|
|
auth.post_spawn_stop(self, spawner)
|
|
)
|
|
except Exception:
|
|
self.log.exception("Error in Authenticator.post_spawn_stop for %s", self)
|
|
spawner._stop_pending = False
|