mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-13 21:13:01 +00:00

allows setup/cleanup to be performed by the authenticator use this to open PAM sessions at spawn and close them at stop, rather than open at login and never close.
526 lines
16 KiB
Python
526 lines
16 KiB
Python
"""sqlalchemy ORM tools for the state of the constellation of processes"""
|
|
|
|
# Copyright (c) Jupyter Development Team.
|
|
# Distributed under the terms of the Modified BSD License.
|
|
|
|
from datetime import datetime, timedelta
|
|
import errno
|
|
import json
|
|
import socket
|
|
from urllib.parse import quote
|
|
|
|
from tornado import gen
|
|
from tornado.log import app_log
|
|
from tornado.httpclient import HTTPRequest, AsyncHTTPClient, HTTPError
|
|
|
|
from sqlalchemy.types import TypeDecorator, VARCHAR
|
|
from sqlalchemy import (
|
|
inspect,
|
|
Column, Integer, ForeignKey, Unicode, Boolean,
|
|
DateTime,
|
|
)
|
|
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
|
from sqlalchemy.orm import sessionmaker, relationship
|
|
from sqlalchemy.pool import StaticPool
|
|
from sqlalchemy.sql.expression import bindparam
|
|
from sqlalchemy import create_engine
|
|
|
|
from .utils import (
|
|
random_port, url_path_join, wait_for_server, wait_for_http_server,
|
|
new_token, hash_token, compare_token,
|
|
)
|
|
|
|
|
|
class JSONDict(TypeDecorator):
|
|
"""Represents an immutable structure as a json-encoded string.
|
|
|
|
Usage::
|
|
|
|
JSONEncodedDict(255)
|
|
|
|
"""
|
|
|
|
impl = VARCHAR
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
if value is not None:
|
|
value = json.dumps(value)
|
|
|
|
return value
|
|
|
|
def process_result_value(self, value, dialect):
|
|
if value is not None:
|
|
value = json.loads(value)
|
|
return value
|
|
|
|
|
|
Base = declarative_base()
|
|
Base.log = app_log
|
|
|
|
|
|
class Server(Base):
|
|
"""The basic state of a server
|
|
|
|
connection and cookie info
|
|
"""
|
|
__tablename__ = 'servers'
|
|
id = Column(Integer, primary_key=True)
|
|
proto = Column(Unicode, default='http')
|
|
ip = Column(Unicode, default='')
|
|
port = Column(Integer, default=random_port)
|
|
base_url = Column(Unicode, default='/')
|
|
cookie_name = Column(Unicode, default='cookie')
|
|
|
|
def __repr__(self):
|
|
return "<Server(%s:%s)>" % (self.ip, self.port)
|
|
|
|
@property
|
|
def host(self):
|
|
ip = self.ip
|
|
if ip in {'', '0.0.0.0'}:
|
|
# when listening on all interfaces, connect to localhost
|
|
ip = 'localhost'
|
|
return "{proto}://{ip}:{port}".format(
|
|
proto=self.proto,
|
|
ip=ip,
|
|
port=self.port,
|
|
)
|
|
|
|
@property
|
|
def url(self):
|
|
return "{host}{uri}".format(
|
|
host=self.host,
|
|
uri=self.base_url,
|
|
)
|
|
|
|
@property
|
|
def bind_url(self):
|
|
"""representation of URL used for binding
|
|
|
|
Never used in APIs, only logging,
|
|
since it can be non-connectable value, such as '', meaning all interfaces.
|
|
"""
|
|
if self.ip in {'', '0.0.0.0'}:
|
|
return self.url.replace('localhost', self.ip or '*', 1)
|
|
return self.url
|
|
|
|
@gen.coroutine
|
|
def wait_up(self, timeout=10, http=False):
|
|
"""Wait for this server to come up"""
|
|
if http:
|
|
yield wait_for_http_server(self.url, timeout=timeout)
|
|
else:
|
|
yield wait_for_server(self.ip or 'localhost', self.port, timeout=timeout)
|
|
|
|
def is_up(self):
|
|
"""Is the server accepting connections?"""
|
|
try:
|
|
socket.create_connection((self.ip or 'localhost', self.port))
|
|
except socket.error as e:
|
|
if e.errno == errno.ECONNREFUSED:
|
|
return False
|
|
else:
|
|
raise
|
|
else:
|
|
return True
|
|
|
|
|
|
class Proxy(Base):
|
|
"""A configurable-http-proxy instance.
|
|
|
|
A proxy consists of the API server info and the public-facing server info,
|
|
plus an auth token for configuring the proxy table.
|
|
"""
|
|
__tablename__ = 'proxies'
|
|
id = Column(Integer, primary_key=True)
|
|
auth_token = None
|
|
_public_server_id = Column(Integer, ForeignKey('servers.id'))
|
|
public_server = relationship(Server, primaryjoin=_public_server_id == Server.id)
|
|
_api_server_id = Column(Integer, ForeignKey('servers.id'))
|
|
api_server = relationship(Server, primaryjoin=_api_server_id == Server.id)
|
|
|
|
def __repr__(self):
|
|
if self.public_server:
|
|
return "<%s %s:%s>" % (
|
|
self.__class__.__name__, self.public_server.ip, self.public_server.port,
|
|
)
|
|
else:
|
|
return "<%s [unconfigured]>" % self.__class__.__name__
|
|
|
|
def api_request(self, path, method='GET', body=None, client=None):
|
|
"""Make an authenticated API request of the proxy"""
|
|
client = client or AsyncHTTPClient()
|
|
url = url_path_join(self.api_server.url, path)
|
|
|
|
if isinstance(body, dict):
|
|
body = json.dumps(body)
|
|
self.log.debug("Fetching %s %s", method, url)
|
|
req = HTTPRequest(url,
|
|
method=method,
|
|
headers={'Authorization': 'token {}'.format(self.auth_token)},
|
|
body=body,
|
|
)
|
|
|
|
return client.fetch(req)
|
|
|
|
@gen.coroutine
|
|
def add_user(self, user, client=None):
|
|
"""Add a user's server to the proxy table."""
|
|
self.log.info("Adding user %s to proxy %s => %s",
|
|
user.name, user.server.base_url, user.server.host,
|
|
)
|
|
|
|
yield self.api_request(user.server.base_url,
|
|
method='POST',
|
|
body=dict(
|
|
target=user.server.host,
|
|
user=user.name,
|
|
),
|
|
client=client,
|
|
)
|
|
|
|
@gen.coroutine
|
|
def delete_user(self, user, client=None):
|
|
"""Remove a user's server to the proxy table."""
|
|
self.log.info("Removing user %s from proxy", user.name)
|
|
yield self.api_request(user.server.base_url,
|
|
method='DELETE',
|
|
client=client,
|
|
)
|
|
|
|
@gen.coroutine
|
|
def add_all_users(self):
|
|
"""Update the proxy table from the database.
|
|
|
|
Used when loading up a new proxy.
|
|
"""
|
|
db = inspect(self).session
|
|
futures = []
|
|
for user in db.query(User):
|
|
if (user.server):
|
|
futures.append(self.add_user(user))
|
|
# wait after submitting them all
|
|
for f in futures:
|
|
yield f
|
|
|
|
@gen.coroutine
|
|
def get_routes(self, client=None):
|
|
"""Fetch the proxy's routes"""
|
|
resp = yield self.api_request('', client=client)
|
|
return json.loads(resp.body.decode('utf8', 'replace'))
|
|
|
|
@gen.coroutine
|
|
def check_routes(self, routes=None):
|
|
"""Check that all users are properly"""
|
|
if not routes:
|
|
routes = yield self.get_routes()
|
|
|
|
have_routes = { r['user'] for r in routes.values() if 'user' in r }
|
|
futures = []
|
|
db = inspect(self).session
|
|
for user in db.query(User).filter(User.server != None):
|
|
if user.name not in have_routes:
|
|
self.log.warn("Adding missing route for %s", user.name)
|
|
futures.append(self.add_user(user))
|
|
for f in futures:
|
|
yield f
|
|
|
|
|
|
|
|
class Hub(Base):
|
|
"""Bring it all together at the hub.
|
|
|
|
The Hub is a server, plus its API path suffix
|
|
|
|
the api_url is the full URL plus the api_path suffix on the end
|
|
of the server base_url.
|
|
"""
|
|
__tablename__ = 'hubs'
|
|
id = Column(Integer, primary_key=True)
|
|
_server_id = Column(Integer, ForeignKey('servers.id'))
|
|
server = relationship(Server, primaryjoin=_server_id == Server.id)
|
|
|
|
@property
|
|
def api_url(self):
|
|
"""return the full API url (with proto://host...)"""
|
|
return url_path_join(self.server.url, 'api')
|
|
|
|
def __repr__(self):
|
|
if self.server:
|
|
return "<%s %s:%s>" % (
|
|
self.__class__.__name__, self.server.ip, self.server.port,
|
|
)
|
|
else:
|
|
return "<%s [unconfigured]>" % self.__class__.__name__
|
|
|
|
|
|
class User(Base):
|
|
"""The User table
|
|
|
|
Each user has a single server,
|
|
and multiple tokens used for authorization.
|
|
|
|
API tokens grant access to the Hub's REST API.
|
|
These are used by single-user servers to authenticate requests,
|
|
and external services to manipulate the Hub.
|
|
|
|
Cookies are set with a single ID.
|
|
Resetting the Cookie ID invalidates all cookies, forcing user to login again.
|
|
|
|
A `state` column contains a JSON dict,
|
|
used for restoring state of a Spawner.
|
|
"""
|
|
__tablename__ = 'users'
|
|
id = Column(Integer, primary_key=True)
|
|
name = Column(Unicode)
|
|
# should we allow multiple servers per user?
|
|
_server_id = Column(Integer, ForeignKey('servers.id'))
|
|
server = relationship(Server, primaryjoin=_server_id == Server.id)
|
|
admin = Column(Boolean, default=False)
|
|
last_activity = Column(DateTime, default=datetime.utcnow)
|
|
|
|
api_tokens = relationship("APIToken", backref="user")
|
|
cookie_id = Column(Unicode, default=new_token)
|
|
state = Column(JSONDict)
|
|
spawner = None
|
|
spawn_pending = False
|
|
stop_pending = False
|
|
|
|
other_user_cookies = set([])
|
|
|
|
def __repr__(self):
|
|
if self.server:
|
|
return "<{cls}({name}@{ip}:{port})>".format(
|
|
cls=self.__class__.__name__,
|
|
name=self.name,
|
|
ip=self.server.ip,
|
|
port=self.server.port,
|
|
)
|
|
else:
|
|
return "<{cls}({name} [unconfigured])>".format(
|
|
cls=self.__class__.__name__,
|
|
name=self.name,
|
|
)
|
|
|
|
@property
|
|
def escaped_name(self):
|
|
"""My name, escaped for use in URLs, cookies, etc."""
|
|
return quote(self.name, safe='@')
|
|
|
|
@property
|
|
def running(self):
|
|
"""property for whether a user has a running server"""
|
|
if self.spawner is None:
|
|
return False
|
|
if self.server is None:
|
|
return False
|
|
return True
|
|
|
|
def new_api_token(self):
|
|
"""Create a new API token"""
|
|
assert self.id is not None
|
|
db = inspect(self).session
|
|
token = new_token()
|
|
orm_token = APIToken(user_id=self.id)
|
|
orm_token.token = token
|
|
db.add(orm_token)
|
|
db.commit()
|
|
return token
|
|
|
|
@classmethod
|
|
def find(cls, db, name):
|
|
"""Find a user by name.
|
|
|
|
Returns None if not found.
|
|
"""
|
|
return db.query(cls).filter(cls.name==name).first()
|
|
|
|
@gen.coroutine
|
|
def spawn(self, spawner_class, base_url='/', hub=None, authenticator=None, config=None):
|
|
"""Start the user's spawner"""
|
|
db = inspect(self).session
|
|
if hub is None:
|
|
hub = db.query(Hub).first()
|
|
|
|
self.server = Server(
|
|
cookie_name='%s-%s' % (hub.server.cookie_name, quote(self.name, safe='')),
|
|
base_url=url_path_join(base_url, 'user', self.escaped_name),
|
|
)
|
|
db.add(self.server)
|
|
db.commit()
|
|
|
|
api_token = self.new_api_token()
|
|
db.commit()
|
|
|
|
spawner = self.spawner = spawner_class(
|
|
config=config,
|
|
user=self,
|
|
hub=hub,
|
|
db=db,
|
|
authenticator=authenticator,
|
|
)
|
|
# we are starting a new server, make sure it doesn't restore state
|
|
spawner.clear_state()
|
|
spawner.api_token = api_token
|
|
|
|
# trigger pre-spawn hook on authenticator
|
|
if (authenticator):
|
|
yield gen.maybe_future(authenticator.pre_spawn_start(self, spawner))
|
|
self.spawn_pending = True
|
|
# wait for spawner.start to return
|
|
try:
|
|
f = spawner.start()
|
|
yield gen.with_timeout(timedelta(seconds=spawner.start_timeout), f)
|
|
except Exception as e:
|
|
if isinstance(e, gen.TimeoutError):
|
|
self.log.warn("{user}'s server failed to start in {s} seconds, giving up".format(
|
|
user=self.name, s=spawner.start_timeout,
|
|
))
|
|
e.reason = 'timeout'
|
|
else:
|
|
self.log.error("Unhandled error starting {user}'s server: {error}".format(
|
|
user=self.name, error=e,
|
|
))
|
|
e.reason = 'error'
|
|
try:
|
|
yield self.stop()
|
|
except Exception:
|
|
self.log.error("Failed to cleanup {user}'s server that failed to start".format(
|
|
user=self.name,
|
|
), exc_info=True)
|
|
# raise original exception
|
|
raise e
|
|
spawner.start_polling()
|
|
|
|
# store state
|
|
self.state = spawner.get_state()
|
|
self.last_activity = datetime.utcnow()
|
|
db.commit()
|
|
try:
|
|
yield self.server.wait_up(http=True, timeout=spawner.http_timeout)
|
|
except Exception as e:
|
|
if isinstance(e, TimeoutError):
|
|
self.log.warn(
|
|
"{user}'s server never showed up at {url} "
|
|
"after {http_timeout} seconds. Giving up".format(
|
|
user=self.name,
|
|
url=self.server.url,
|
|
http_timeout=spawner.http_timeout,
|
|
)
|
|
)
|
|
e.reason = 'timeout'
|
|
else:
|
|
e.reason = 'error'
|
|
self.log.error("Unhandled error waiting for {user}'s server to show up at {url}: {error}".format(
|
|
user=self.name, url=self.server.url, error=e,
|
|
))
|
|
try:
|
|
yield self.stop()
|
|
except Exception:
|
|
self.log.error("Failed to cleanup {user}'s server that failed to start".format(
|
|
user=self.name,
|
|
), exc_info=True)
|
|
# raise original TimeoutError
|
|
raise e
|
|
self.spawn_pending = False
|
|
return self
|
|
|
|
@gen.coroutine
|
|
def stop(self):
|
|
"""Stop the user's spawner
|
|
|
|
and cleanup after it.
|
|
"""
|
|
self.spawn_pending = False
|
|
spawner = self.spawner
|
|
if spawner is None:
|
|
return
|
|
spawner.stop_polling()
|
|
self.stop_pending = True
|
|
try:
|
|
status = yield spawner.poll()
|
|
if status is None:
|
|
yield self.spawner.stop()
|
|
spawner.clear_state()
|
|
self.state = spawner.get_state()
|
|
self.server = None
|
|
inspect(self).session.commit()
|
|
finally:
|
|
self.stop_pending = False
|
|
# trigger post-spawner hook on authenticator
|
|
auth = spawner.authenticator
|
|
if auth:
|
|
yield gen.maybe_future(
|
|
auth.post_spawn_stop(self, spawner)
|
|
)
|
|
|
|
class APIToken(Base):
|
|
"""An API token"""
|
|
__tablename__ = 'api_tokens'
|
|
|
|
@declared_attr
|
|
def user_id(cls):
|
|
return Column(Integer, ForeignKey('users.id'))
|
|
|
|
id = Column(Integer, primary_key=True)
|
|
hashed = Column(Unicode)
|
|
prefix = Column(Unicode)
|
|
prefix_length = 4
|
|
algorithm = "sha512"
|
|
rounds = 16384
|
|
salt_bytes = 8
|
|
|
|
@property
|
|
def token(self):
|
|
raise AttributeError("token is write-only")
|
|
|
|
@token.setter
|
|
def token(self, token):
|
|
"""Store the hashed value and prefix for a token"""
|
|
self.prefix = token[:self.prefix_length]
|
|
self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm)
|
|
|
|
def __repr__(self):
|
|
return "<{cls}('{pre}...', user='{u}')>".format(
|
|
cls=self.__class__.__name__,
|
|
pre=self.prefix,
|
|
u=self.user.name,
|
|
)
|
|
|
|
@classmethod
|
|
def find(cls, db, token):
|
|
"""Find a token object by value.
|
|
|
|
Returns None if not found.
|
|
"""
|
|
prefix = token[:cls.prefix_length]
|
|
# since we can't filter on hashed values, filter on prefix
|
|
# so we aren't comparing with all tokens
|
|
prefix_match = db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))
|
|
for orm_token in prefix_match:
|
|
if orm_token.match(token):
|
|
return orm_token
|
|
|
|
def match(self, token):
|
|
"""Is this my token?"""
|
|
return compare_token(self.hashed, token)
|
|
|
|
|
|
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
|
|
"""Create a new session at url"""
|
|
if url.startswith('sqlite'):
|
|
kwargs.setdefault('connect_args', {'check_same_thread': False})
|
|
|
|
if url.endswith(':memory:'):
|
|
# If we're using an in-memory database, ensure that only one connection
|
|
# is ever created.
|
|
kwargs.setdefault('poolclass', StaticPool)
|
|
|
|
engine = create_engine(url, **kwargs)
|
|
if reset:
|
|
Base.metadata.drop_all(engine)
|
|
Base.metadata.create_all(engine)
|
|
|
|
session_factory = sessionmaker(bind=engine)
|
|
return session_factory
|