# Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. from collections import defaultdict from datetime import datetime, timedelta from urllib.parse import quote, urlparse import warnings from oauth2.error import ClientNotFoundError from sqlalchemy import inspect from tornado import gen from tornado.log import app_log from traitlets import HasTraits, Any, Dict, default from .utils import url_path_join from . import orm from ._version import _check_version, __version__ from .objects import Server from .spawner import LocalProcessSpawner from .crypto import encrypt, decrypt, CryptKeeper, EncryptionUnavailable, InvalidToken class UserDict(dict): """Like defaultdict, but for users Getting by a user id OR an orm.User instance returns a User wrapper around the orm user. """ def __init__(self, db_factory, settings): self.db_factory = db_factory self.settings = settings super().__init__() @property def db(self): return self.db_factory() def from_orm(self, orm_user): return User(orm_user, self.settings) def add(self, orm_user): """Add a user to the UserDict""" if orm_user.id not in self: self[orm_user.id] = self.from_orm(orm_user) return self[orm_user.id] def __contains__(self, key): if isinstance(key, (User, orm.User)): key = key.id return dict.__contains__(self, key) def __getitem__(self, key): if isinstance(key, User): key = key.id elif isinstance(key, str): orm_user = self.db.query(orm.User).filter(orm.User.name == key).first() if orm_user is None: raise KeyError("No such user: %s" % key) else: key = orm_user if isinstance(key, orm.User): # users[orm_user] returns User(orm_user) orm_user = key if orm_user.id not in self: user = self[orm_user.id] = User(orm_user, self.settings) return user user = dict.__getitem__(self, orm_user.id) user.db = self.db return user elif isinstance(key, int): id = key if id not in self: orm_user = self.db.query(orm.User).filter(orm.User.id == id).first() if orm_user is None: raise KeyError("No such user: %s" % id) user = self.add(orm_user) else: user = dict.__getitem__(self, id) return user else: raise KeyError(repr(key)) def __delitem__(self, key): user = self[key] for orm_spawner in user.orm_user._orm_spawners: if orm_spawner in self.db: self.db.expunge(orm_spawner) if user.orm_user in self.db: self.db.expunge(user.orm_user) dict.__delitem__(self, user.id) def delete(self, key): """Delete a user from the cache and the database""" user = self[key] user_id = user.id self.db.delete(user) self.db.commit() # delete from dict after commit del self[user_id] def count_active_users(self): """Count the number of user servers that are active/pending/ready Returns dict with counts of active/pending/ready servers """ counts = defaultdict(lambda : 0) for user in self.values(): for spawner in user.spawners.values(): pending = spawner.pending if pending: counts['pending'] += 1 counts[pending + '_pending'] += 1 if spawner.active: counts['active'] += 1 if spawner.ready: counts['ready'] += 1 return counts class _SpawnerDict(dict): def __init__(self, spawner_factory): self.spawner_factory = spawner_factory def __getitem__(self, key): if key not in self: self[key] = self.spawner_factory(key) return super().__getitem__(key) class User: """High-level wrapper around an orm.User object""" # declare instance attributes db = None orm_user = None log = app_log settings = None def __init__(self, orm_user, settings=None, **kwargs): self.orm_user = orm_user self.db = inspect(orm_user).session self.settings = settings or {} for key, attr in kwargs: print('setting', key, attr) setattr(self, key, attr) self.allow_named_servers = self.settings.get('allow_named_servers', False) self.base_url = self.prefix = url_path_join( self.settings.get('base_url', '/'), 'user', self.escaped_name) + '/' self.spawners = _SpawnerDict(self._new_spawner) @property def authenticator(self): return self.settings.get('authenticator', None) @property def spawner_class(self): return self.settings.get('spawner_class', LocalProcessSpawner) @gen.coroutine def save_auth_state(self, auth_state): """Encrypt and store auth_state""" if auth_state is None: self.encrypted_auth_state = None else: self.encrypted_auth_state = yield encrypt(auth_state) self.db.commit() @gen.coroutine def get_auth_state(self): """Retrieve and decrypt auth_state for the user""" encrypted = self.encrypted_auth_state if encrypted is None: return None try: auth_state = yield decrypt(encrypted) except (ValueError, InvalidToken, EncryptionUnavailable) as e: self.log.warning("Failed to retrieve encrypted auth_state for %s because %s", self.name, e, ) return # loading auth_state if auth_state: # Crypt has multiple keys, store again with new key for rotation. if len(CryptKeeper.instance().keys) > 1: yield self.save_auth_state(auth_state) return auth_state def _new_spawner(self, name, spawner_class=None, **kwargs): """Create a new spawner""" if spawner_class is None: spawner_class = self.spawner_class self.log.debug("Creating %s for %s:%s", spawner_class, self.name, name) orm_spawner = self.orm_spawners.get(name) if orm_spawner is None: orm_spawner = orm.Spawner(user=self.orm_user, name=name) self.db.add(orm_spawner) self.db.commit() assert name in self.orm_spawners if name == '' and self.state: # migrate user.state to spawner.state orm_spawner.state = self.state self.state = None spawn_kwargs = dict( user=self, orm_spawner=orm_spawner, hub=self.settings.get('hub'), authenticator=self.authenticator, config=self.settings.get('config'), proxy_spec=url_path_join(self.proxy_spec, name, '/'), db=self.db, ) # update with kwargs. Mainly for testing. spawn_kwargs.update(kwargs) spawner = spawner_class(**spawn_kwargs) spawner.load_state(orm_spawner.state or {}) return spawner # singleton property, self.spawner maps onto spawner with empty server_name @property def spawner(self): return self.spawners[''] @spawner.setter def spawner(self, spawner): self.spawners[''] = spawner # pass get/setattr to ORM user def __getattr__(self, attr): if hasattr(self.orm_user, attr): return getattr(self.orm_user, attr) else: raise AttributeError(attr) def __setattr__(self, attr, value): if not attr.startswith('_') and self.orm_user and hasattr(self.orm_user, attr): setattr(self.orm_user, attr, value) else: super().__setattr__(attr, value) def __repr__(self): return repr(self.orm_user) @property def running(self): """property for whether the user's default server is running""" if not self.spawners: return False return self.spawner.ready @property def active(self): """True if any server is active""" if not self.spawners: return False return any(s.active for s in self.spawners.values()) @property def spawn_pending(self): warnings.warn("User.spawn_pending is deprecated in JupyterHub 0.8. Use Spawner.pending", DeprecationWarning, ) return self.spawner.pending == 'spawn' @property def stop_pending(self): warnings.warn("User.stop_pending is deprecated in JupyterHub 0.8. Use Spawner.pending", DeprecationWarning, ) return self.spawner.pending == 'stop' @property def server(self): return self.spawner.server @property def escaped_name(self): """My name, escaped for use in URLs, cookies, etc.""" return quote(self.name, safe='@') @property def proxy_spec(self): """The proxy routespec for my default server""" if self.settings.get('subdomain_host'): return url_path_join(self.domain, self.base_url, '/') else: return url_path_join(self.base_url, '/') @property def domain(self): """Get the domain for my server.""" # FIXME: escaped_name probably isn't escaped enough in general for a domain fragment return self.escaped_name + '.' + self.settings['domain'] @property def host(self): """Get the *host* for my server (proto://domain[:port])""" # FIXME: escaped_name probably isn't escaped enough in general for a domain fragment parsed = urlparse(self.settings['subdomain_host']) h = '%s://%s.%s' % (parsed.scheme, self.escaped_name, parsed.netloc) return h @property def url(self): """My URL Full name.domain/path if using subdomains, otherwise just my /base/url """ if self.settings.get('subdomain_host'): return '{host}{path}'.format( host=self.host, path=self.base_url, ) else: return self.base_url @gen.coroutine def spawn(self, server_name='', options=None): """Start the user's spawner depending from the value of JupyterHub.allow_named_servers if False: JupyterHub expects only one single-server per user url of the server will be /user/:name if True: JupyterHub expects more than one single-server per user url of the server will be /user/:name/:server_name """ db = self.db base_url = url_path_join(self.base_url, server_name) + '/' orm_server = orm.Server( base_url=base_url, ) db.add(orm_server) api_token = self.new_api_token() db.commit() spawner = self.spawners[server_name] spawner.server = server = Server(orm_server=orm_server) assert spawner.orm_spawner.server is orm_server # Passing user_options to the spawner spawner.user_options = options or {} # we are starting a new server, make sure it doesn't restore state spawner.clear_state() # create API and OAuth tokens spawner.api_token = api_token spawner.admin_access = self.settings.get('admin_access', False) # use fully quoted name for client_id because it will be used in cookie-name # self.escaped_name may contain @ which is legal in URLs but not cookie keys client_id = 'user-%s' % quote(self.name) if server_name: client_id = '%s-%s' % (client_id, quote(server_name)) spawner.oauth_client_id = client_id oauth_provider = self.settings.get('oauth_provider') if oauth_provider: client_store = oauth_provider.client_authenticator.client_store try: oauth_client = client_store.fetch_by_client_id(client_id) except ClientNotFoundError: oauth_client = None # create a new OAuth client + secret on every launch # containers that resume will be updated below client_store.add_client(client_id, api_token, url_path_join(self.url, server_name, 'oauth_callback'), ) db.commit() # trigger pre-spawn hook on authenticator authenticator = self.authenticator if (authenticator): yield gen.maybe_future(authenticator.pre_spawn_start(self, spawner)) spawner._start_pending = True # wait for spawner.start to return try: # run optional preparation work to bootstrap the notebook yield gen.maybe_future(spawner.run_pre_spawn_hook()) f = spawner.start() # commit any changes in spawner.start (always commit db changes before yield) db.commit() ip_port = yield gen.with_timeout(timedelta(seconds=spawner.start_timeout), f) if ip_port: # get ip, port info from return value of start() server.ip, server.port = ip_port else: # prior to 0.7, spawners had to store this info in user.server themselves. # Handle < 0.7 behavior with a warning, assuming info was stored in db by the Spawner. self.log.warning("DEPRECATION: Spawner.start should return (ip, port) in JupyterHub >= 0.7") if spawner.api_token and spawner.api_token != api_token: # Spawner re-used an API token, discard the unused api_token orm_token = orm.APIToken.find(self.db, api_token) if orm_token is not None: self.db.delete(orm_token) self.db.commit() # check if the re-used API token is valid found = orm.APIToken.find(self.db, spawner.api_token) if found: if found.user is not self.orm_user: self.log.error("%s's server is using %s's token! Revoking this token.", self.name, (found.user or found.service).name) self.db.delete(found) self.db.commit() raise ValueError("Invalid token for %s!" % self.name) else: # Spawner.api_token has changed, but isn't in the db. # What happened? Maybe something unclean in a resumed container. self.log.warning("%s's server specified its own API token that's not in the database", self.name ) # use generated=False because we don't trust this token # to have been generated properly self.new_api_token(spawner.api_token, generated=False) # update OAuth client secret with updated API token if oauth_provider: client_store = oauth_provider.client_authenticator.client_store client_store.add_client(client_id, spawner.api_token, url_path_join(self.url, server_name, 'oauth_callback'), ) db.commit() except Exception as e: if isinstance(e, gen.TimeoutError): self.log.warning("{user}'s server failed to start in {s} seconds, giving up".format( user=self.name, s=spawner.start_timeout, )) e.reason = 'timeout' self.settings['statsd'].incr('spawner.failure.timeout') else: self.log.error("Unhandled error starting {user}'s server: {error}".format( user=self.name, error=e, )) self.settings['statsd'].incr('spawner.failure.error') e.reason = 'error' try: yield self.stop() except Exception: self.log.error("Failed to cleanup {user}'s server that failed to start".format( user=self.name, ), exc_info=True) # raise original exception spawner._start_pending = False raise e spawner.start_polling() # store state if self.state is None: self.state = {} spawner.orm_spawner.state = spawner.get_state() self.last_activity = spawner.orm_spawner.last_activity = datetime.utcnow() db.commit() spawner._waiting_for_response = True try: resp = yield server.wait_up(http=True, timeout=spawner.http_timeout) except Exception as e: if isinstance(e, TimeoutError): self.log.warning( "{user}'s server never showed up at {url} " "after {http_timeout} seconds. Giving up".format( user=self.name, url=server.url, http_timeout=spawner.http_timeout, ) ) e.reason = 'timeout' self.settings['statsd'].incr('spawner.failure.http_timeout') else: e.reason = 'error' self.log.error("Unhandled error waiting for {user}'s server to show up at {url}: {error}".format( user=self.name, url=server.url, error=e, )) self.settings['statsd'].incr('spawner.failure.http_error') try: yield self.stop() except Exception: self.log.error("Failed to cleanup {user}'s server that failed to start".format( user=self.name, ), exc_info=True) # raise original TimeoutError raise e else: server_version = resp.headers.get('X-JupyterHub-Version') _check_version(__version__, server_version, self.log) # record the Spawner version for better error messages # if it doesn't work spawner._jupyterhub_version = server_version finally: spawner._waiting_for_response = False spawner._start_pending = False return self @gen.coroutine def stop(self, server_name=''): """Stop the user's spawner and cleanup after it. """ spawner = self.spawners[server_name] spawner._spawn_pending = False spawner._start_pending = False spawner.stop_polling() spawner._stop_pending = True try: api_token = spawner.api_token status = yield spawner.poll() if status is None: yield spawner.stop() spawner.clear_state() spawner.orm_spawner.state = spawner.get_state() self.last_activity = spawner.orm_spawner.last_activity = datetime.utcnow() # remove server entry from db spawner.server = None if not spawner.will_resume: # find and remove the API token if the spawner isn't # going to re-use it next time orm_token = orm.APIToken.find(self.db, api_token) if orm_token: self.db.delete(orm_token) self.db.commit() finally: # trigger post-spawner hook on authenticator auth = spawner.authenticator try: if auth: yield gen.maybe_future( auth.post_spawn_stop(self, spawner) ) except Exception: self.log.exception("Error in Authenticator.post_spawn_stop for %s", self) spawner._stop_pending = False # pop the Spawner object self.spawners.pop(server_name)