fix errors, remove pep8 corrections

This commit is contained in:
Christian Barra
2017-01-03 14:10:46 +01:00
parent dbe8bf5428
commit 02090c953b
5 changed files with 118 additions and 145 deletions

View File

@@ -146,18 +146,18 @@ class NewToken(Application):
class UpgradeDB(Application): class UpgradeDB(Application):
"""Upgrade the JupyterHub database schema.""" """Upgrade the JupyterHub database schema."""
name = 'jupyterhub-upgrade-db' name = 'jupyterhub-upgrade-db'
version = jupyterhub.__version__ version = jupyterhub.__version__
description = """Upgrade the JupyterHub database to the current schema. description = """Upgrade the JupyterHub database to the current schema.
Usage: Usage:
jupyterhub upgrade-db jupyterhub upgrade-db
""" """
aliases = common_aliases aliases = common_aliases
classes = [] classes = []
def _backup_db_file(self, db_file): def _backup_db_file(self, db_file):
"""Backup a database file""" """Backup a database file"""
if not os.path.exists(db_file): if not os.path.exists(db_file):
@@ -171,7 +171,7 @@ class UpgradeDB(Application):
backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i) backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i)
if os.path.exists(backup_db_file): if os.path.exists(backup_db_file):
self.exit("backup db file already exists: %s" % backup_db_file) self.exit("backup db file already exists: %s" % backup_db_file)
self.log.info("Backing up %s => %s", db_file, backup_db_file) self.log.info("Backing up %s => %s", db_file, backup_db_file)
shutil.copy(db_file, backup_db_file) shutil.copy(db_file, backup_db_file)
@@ -222,12 +222,12 @@ class JupyterHub(Application):
Authenticator, Authenticator,
PAMAuthenticator, PAMAuthenticator,
]) ])
load_groups = Dict(List(Unicode()), load_groups = Dict(List(Unicode()),
help="""Dict of 'group': ['usernames'] to load at startup. help="""Dict of 'group': ['usernames'] to load at startup.
This strictly *adds* groups and users to groups. This strictly *adds* groups and users to groups.
Loading one set of groups, then starting JupyterHub again with a different Loading one set of groups, then starting JupyterHub again with a different
set will not remove users or groups from previous launches. set will not remove users or groups from previous launches.
That must be done through the API. That must be done through the API.
@@ -414,7 +414,7 @@ class JupyterHub(Application):
api_tokens = Dict(Unicode(), api_tokens = Dict(Unicode(),
help="""PENDING DEPRECATION: consider using service_tokens help="""PENDING DEPRECATION: consider using service_tokens
Dict of token:username to be loaded into the database. Dict of token:username to be loaded into the database.
Allows ahead-of-time generation of API tokens for use by externally managed services, Allows ahead-of-time generation of API tokens for use by externally managed services,
@@ -437,14 +437,14 @@ class JupyterHub(Application):
Allows ahead-of-time generation of API tokens for use by externally managed services. Allows ahead-of-time generation of API tokens for use by externally managed services.
""" """
).tag(config=True) ).tag(config=True)
services = List(Dict(), services = List(Dict(),
help="""List of service specification dictionaries. help="""List of service specification dictionaries.
A service A service
For instance:: For instance::
services = [ services = [
{ {
'name': 'cull_idle', 'name': 'cull_idle',
@@ -454,7 +454,7 @@ class JupyterHub(Application):
'name': 'formgrader', 'name': 'formgrader',
'url': 'http://127.0.0.1:1234', 'url': 'http://127.0.0.1:1234',
'token': 'super-secret', 'token': 'super-secret',
'env': 'env':
} }
] ]
""" """
@@ -608,7 +608,7 @@ class JupyterHub(Application):
Instance(logging.Handler), Instance(logging.Handler),
help="Extra log handlers to set on JupyterHub logger", help="Extra log handlers to set on JupyterHub logger",
).tag(config=True) ).tag(config=True)
statsd = Any(allow_none=False, help="The statsd client, if any. A mock will be used if we aren't using statsd") statsd = Any(allow_none=False, help="The statsd client, if any. A mock will be used if we aren't using statsd")
@default('statsd') @default('statsd')
def _statsd(self): def _statsd(self):
@@ -919,7 +919,7 @@ class JupyterHub(Application):
# The whitelist set and the users in the db are now the same. # The whitelist set and the users in the db are now the same.
# From this point on, any user changes should be done simultaneously # From this point on, any user changes should be done simultaneously
# to the whitelist set and user db, unless the whitelist is empty (all users allowed). # to the whitelist set and user db, unless the whitelist is empty (all users allowed).
@gen.coroutine @gen.coroutine
def init_groups(self): def init_groups(self):
"""Load predefined groups into the database""" """Load predefined groups into the database"""
@@ -941,7 +941,7 @@ class JupyterHub(Application):
db.add(user) db.add(user)
group.users.append(user) group.users.append(user)
db.commit() db.commit()
@gen.coroutine @gen.coroutine
def _add_tokens(self, token_dict, kind): def _add_tokens(self, token_dict, kind):
"""Add tokens for users or services to the database""" """Add tokens for users or services to the database"""
@@ -982,13 +982,13 @@ class JupyterHub(Application):
else: else:
self.log.debug("Not duplicating token %s", orm_token) self.log.debug("Not duplicating token %s", orm_token)
db.commit() db.commit()
@gen.coroutine @gen.coroutine
def init_api_tokens(self): def init_api_tokens(self):
"""Load predefined API tokens (for services) into database""" """Load predefined API tokens (for services) into database"""
yield self._add_tokens(self.service_tokens, kind='service') yield self._add_tokens(self.service_tokens, kind='service')
yield self._add_tokens(self.api_tokens, kind='user') yield self._add_tokens(self.api_tokens, kind='user')
def init_services(self): def init_services(self):
self._service_map.clear() self._service_map.clear()
if self.domain: if self.domain:
@@ -1095,7 +1095,9 @@ class JupyterHub(Application):
# if user.server is defined. # if user.server is defined.
log = self.log.warning if user.server else self.log.debug log = self.log.warning if user.server else self.log.debug
log("%s not running.", user.name) log("%s not running.", user.name)
user.server = None for server in user.servers:
db.delete(server)
db.commit()
user_summaries.append(_user_summary(user)) user_summaries.append(_user_summary(user))
@@ -1458,7 +1460,7 @@ class JupyterHub(Application):
except Exception as e: except Exception as e:
self.log.critical("Failed to start proxy", exc_info=True) self.log.critical("Failed to start proxy", exc_info=True)
self.exit(1) self.exit(1)
for service_name, service in self._service_map.items(): for service_name, service in self._service_map.items():
if not service.managed: if not service.managed:
continue continue

View File

@@ -19,7 +19,7 @@ from sqlalchemy import (
from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import sessionmaker, relationship, backref from sqlalchemy.orm import sessionmaker, relationship, backref
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from sqlalchemy.schema import Index from sqlalchemy.schema import Index, UniqueConstraint
from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.sql.expression import bindparam from sqlalchemy.sql.expression import bindparam
from sqlalchemy import create_engine, Table from sqlalchemy import create_engine, Table
@@ -32,11 +32,8 @@ from .utils import (
class JSONDict(TypeDecorator): class JSONDict(TypeDecorator):
"""Represents an immutable structure as a json-encoded string. """Represents an immutable structure as a json-encoded string.
Usage:: Usage::
JSONEncodedDict(255) JSONEncodedDict(255)
""" """
impl = TEXT impl = TEXT
@@ -59,7 +56,6 @@ Base.log = app_log
class Server(Base): class Server(Base):
"""The basic state of a server """The basic state of a server
connection and cookie info connection and cookie info
""" """
__tablename__ = 'servers' __tablename__ = 'servers'
@@ -69,6 +65,7 @@ class Server(Base):
port = Column(Integer, default=random_port) port = Column(Integer, default=random_port)
base_url = Column(Unicode(255), default='/') base_url = Column(Unicode(255), default='/')
cookie_name = Column(Unicode(255), default='cookie') cookie_name = Column(Unicode(255), default='cookie')
# added to handle multi-server feature # added to handle multi-server feature
last_activity = Column(DateTime, default=datetime.utcnow) last_activity = Column(DateTime, default=datetime.utcnow)
@@ -101,10 +98,8 @@ class Server(Base):
@property @property
def bind_url(self): def bind_url(self):
"""representation of URL used for binding """representation of URL used for binding
Never used in APIs, only logging, Never used in APIs, only logging,
since it can be non-connectable value, such as '', since it can be non-connectable value, such as '', meaning all interfaces.
meaning all interfaces.
""" """
if self.ip in {'', '0.0.0.0'}: if self.ip in {'', '0.0.0.0'}:
return self.url.replace('127.0.0.1', self.ip or '*', 1) return self.url.replace('127.0.0.1', self.ip or '*', 1)
@@ -116,8 +111,7 @@ class Server(Base):
if http: if http:
yield wait_for_http_server(self.url, timeout=timeout) yield wait_for_http_server(self.url, timeout=timeout)
else: else:
yield wait_for_server(self.ip or '127.0.0.1', self.port, yield wait_for_server(self.ip or '127.0.0.1', self.port, timeout=timeout)
timeout=timeout)
def is_up(self): def is_up(self):
"""Is the server accepting connections?""" """Is the server accepting connections?"""
@@ -126,7 +120,6 @@ class Server(Base):
class Proxy(Base): class Proxy(Base):
"""A configurable-http-proxy instance. """A configurable-http-proxy instance.
A proxy consists of the API server info and the public-facing server info, A proxy consists of the API server info and the public-facing server info,
plus an auth token for configuring the proxy table. plus an auth token for configuring the proxy table.
""" """
@@ -155,10 +148,10 @@ class Proxy(Base):
body = json.dumps(body) body = json.dumps(body)
self.log.debug("Fetching %s %s", method, url) self.log.debug("Fetching %s %s", method, url)
req = HTTPRequest(url, req = HTTPRequest(url,
method=method, method=method,
headers={'Authorization': 'token {}'.format(self.auth_token)}, headers={'Authorization': 'token {}'.format(self.auth_token)},
body=body, body=body,
) )
return client.fetch(req) return client.fetch(req)
@@ -170,48 +163,48 @@ class Proxy(Base):
"Service %s does not have an http endpoint to add to the proxy.", service.name) "Service %s does not have an http endpoint to add to the proxy.", service.name)
self.log.info("Adding service %s to proxy %s => %s", self.log.info("Adding service %s to proxy %s => %s",
service.name, service.proxy_path, service.server.host, service.name, service.proxy_path, service.server.host,
) )
yield self.api_request(service.proxy_path, yield self.api_request(service.proxy_path,
method='POST', method='POST',
body=dict( body=dict(
target=service.server.host, target=service.server.host,
service=service.name, service=service.name,
), ),
client=client, client=client,
) )
@gen.coroutine @gen.coroutine
def delete_service(self, service, client=None): def delete_service(self, service, client=None):
"""Remove a service's server from the proxy table.""" """Remove a service's server from the proxy table."""
self.log.info("Removing service %s from proxy", service.name) self.log.info("Removing service %s from proxy", service.name)
yield self.api_request(service.proxy_path, yield self.api_request(service.proxy_path,
method='DELETE', method='DELETE',
client=client, client=client,
) )
@gen.coroutine
# FIX-ME # FIX-ME
# we need to add a reference to a specific server # we need to add a reference to a specific server
@gen.coroutine
def add_user(self, user, client=None): def add_user(self, user, client=None):
"""Add a user's server to the proxy table.""" """Add a user's server to the proxy table."""
self.log.info("Adding user %s to proxy %s => %s", self.log.info("Adding user %s to proxy %s => %s",
user.name, user.proxy_path, user.server.host, user.name, user.proxy_path, user.server.host,
) )
if user.spawn_pending: if user.spawn_pending:
raise RuntimeError( raise RuntimeError(
"User %s's spawn is pending, shouldn't be added to the proxy yet!", user.name) "User %s's spawn is pending, shouldn't be added to the proxy yet!", user.name)
yield self.api_request(user.proxy_path, yield self.api_request(user.proxy_path,
method='POST', method='POST',
body=dict( body=dict(
target=user.server.host, target=user.server.host,
user=user.name, user=user.name,
), ),
client=client, client=client,
) )
@gen.coroutine @gen.coroutine
def delete_user(self, user, client=None): def delete_user(self, user, client=None):
@@ -225,7 +218,6 @@ class Proxy(Base):
@gen.coroutine @gen.coroutine
def add_all_services(self, service_dict): def add_all_services(self, service_dict):
"""Update the proxy table from the database. """Update the proxy table from the database.
Used when loading up a new proxy. Used when loading up a new proxy.
""" """
db = inspect(self).session db = inspect(self).session
@@ -241,7 +233,6 @@ class Proxy(Base):
@gen.coroutine @gen.coroutine
def add_all_users(self, user_dict): def add_all_users(self, user_dict):
"""Update the proxy table from the database. """Update the proxy table from the database.
Used when loading up a new proxy. Used when loading up a new proxy.
""" """
db = inspect(self).session db = inspect(self).session
@@ -299,11 +290,10 @@ class Proxy(Base):
yield f yield f
class Hub(Base): class Hub(Base):
"""Bring it all together at the hub. """Bring it all together at the hub.
The Hub is a server, plus its API path suffix The Hub is a server, plus its API path suffix
the api_url is the full URL plus the api_path suffix on the end the api_url is the full URL plus the api_path suffix on the end
of the server base_url. of the server base_url.
""" """
@@ -329,10 +319,9 @@ class Hub(Base):
# user:group many:many mapping table # user:group many:many mapping table
user_group_map = Table('user_group_map', Base.metadata, user_group_map = Table('user_group_map', Base.metadata,
Column('user_id', ForeignKey('users.id'), primary_key=True), Column('user_id', ForeignKey('users.id'), primary_key=True),
Column('group_id', ForeignKey('groups.id'), primary_key=True), Column('group_id', ForeignKey('groups.id'), primary_key=True),
) )
class Group(Base): class Group(Base):
"""User Groups""" """User Groups"""
@@ -345,47 +334,41 @@ class Group(Base):
return "<%s %s (%i users)>" % ( return "<%s %s (%i users)>" % (
self.__class__.__name__, self.name, len(self.users) self.__class__.__name__, self.name, len(self.users)
) )
@classmethod @classmethod
def find(cls, db, name): def find(cls, db, name):
"""Find a group by name. """Find a group by name.
Returns None if not found. Returns None if not found.
""" """
return db.query(cls).filter(cls.name == name).first() return db.query(cls).filter(cls.name==name).first()
class User(Base): class User(Base):
"""The User table """The User table
Each user can have more than a single server,
Each user can have more than one server,
and multiple tokens used for authorization. and multiple tokens used for authorization.
API tokens grant access to the Hub's REST API. API tokens grant access to the Hub's REST API.
These are used by single-user servers to authenticate requests, These are used by single-user servers to authenticate requests,
and external services to manipulate the Hub. and external services to manipulate the Hub.
Cookies are set with a single ID. Cookies are set with a single ID.
Resetting the Cookie ID invalidates all cookies, forcing user to login again. Resetting the Cookie ID invalidates all cookies, forcing user to login again.
`server` returns the first entry for the users' servers.
A `state` column contains a JSON dict, A `state` column contains a JSON dict,
used for restoring state of a Spawner. used for restoring state of a Spawner.
`server` returns the first entry for the users' servers.
'servers' is a list that contains a reference to the ser's Servers.
""" """
__tablename__ = 'users' __tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(1023), unique=True) name = Column(Unicode(1023), unique=True)
admin = Column(Boolean, default=False)
last_activity = Column(DateTime, default=datetime.utcnow)
servers = association_proxy("user_to_servers", "server", creator=lambda server: UserServer(server=server)) servers = association_proxy("user_to_servers", "server", creator=lambda server: UserServer(server=server))
admin = Column(Boolean, default=False)
last_activity = Column(DateTime, default=datetime.utcnow)
api_tokens = relationship("APIToken", backref="user") api_tokens = relationship("APIToken", backref="user")
cookie_id = Column(Unicode(1023), default=new_token) cookie_id = Column(Unicode(1023), default=new_token)
# User.state is actually Spawner state # User.state is actually Spawner state
# We will need to figure something else # We will need to figure something else out if/when we have multiple spawners per user
# out if/when we have multiple spawners per user
state = Column(JSONDict) state = Column(JSONDict)
# Authenticators can store their state here: # Authenticators can store their state here:
auth_state = Column(JSONDict) auth_state = Column(JSONDict)
@@ -397,7 +380,6 @@ class User(Base):
@property @property
def server(self): def server(self):
"""Returns the first element of servers. """Returns the first element of servers.
Returns None if the list is empty. Returns None if the list is empty.
""" """
if len(self.servers) == 0: if len(self.servers) == 0:
@@ -429,24 +411,19 @@ class User(Base):
@classmethod @classmethod
def find(cls, db, name): def find(cls, db, name):
"""Find a user by name. """Find a user by name.
Returns None if not found. Returns None if not found.
""" """
return db.query(cls).filter(cls.name == name).first() return db.query(cls).filter(cls.name==name).first()
class UserServer(Base): class UserServer(Base):
"""The UserServer table """The UserServer table
Each user can have have more than one server, Each user can have have more than one server,
we use this table to mantain the Many-To-Many we use this table to mantain the Many-To-One
relationship between Users and Servers. relationship between Users and Servers tables.
Cookies are set with a single ID. Servers can have only 1 user, this condition is mantained
Resetting the Cookie ID invalidates all cookies, forcing user to login again. by UniqueConstraint
A `state` column contains a JSON dict,
used for restoring state of a Spawner.
""" """
__tablename__ = 'users_servers' __tablename__ = 'users_servers'
@@ -454,9 +431,12 @@ class UserServer(Base):
_server_id = Column(Integer, ForeignKey('servers.id'), primary_key=True) _server_id = Column(Integer, ForeignKey('servers.id'), primary_key=True)
user = relationship(User, backref=backref('user_to_servers', cascade='all, delete-orphan')) user = relationship(User, backref=backref('user_to_servers', cascade='all, delete-orphan'))
server = relationship(Server, backref=backref('server_to_users', cascade='all, delete-orphan')) server = relationship(Server, backref=backref('server_to_users', cascade='all, delete-orphan')
)
__table_args__ = (Index('server_user_index', '_server_id', '_user_id'),) __table_args__ = (
UniqueConstraint('_server_id'),
Index('server_user_index', '_server_id', '_user_id'),)
def __repr__(self): def __repr__(self):
return "<{cls}({name}@{ip}:{port})>".format( return "<{cls}({name}@{ip}:{port})>".format(
@@ -469,21 +449,15 @@ class UserServer(Base):
class Service(Base): class Service(Base):
"""A service run with JupyterHub """A service run with JupyterHub
A service is similar to a User without a Spawner. A service is similar to a User without a Spawner.
A service can have API tokens for accessing the Hub's API A service can have API tokens for accessing the Hub's API
It has: It has:
- name - name
- admin - admin
- api tokens - api tokens
- server (if proxied http endpoint) - server (if proxied http endpoint)
In addition to what it has in common with users, a Service has extra info: In addition to what it has in common with users, a Service has extra info:
- pid: the process id (if managed) - pid: the process id (if managed)
""" """
__tablename__ = 'services' __tablename__ = 'services'
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
@@ -501,7 +475,6 @@ class Service(Base):
def new_api_token(self, token=None): def new_api_token(self, token=None):
"""Create a new API token """Create a new API token
If `token` is given, load that token. If `token` is given, load that token.
""" """
return APIToken.new(token=token, service=self) return APIToken.new(token=token, service=self)
@@ -509,7 +482,6 @@ class Service(Base):
@classmethod @classmethod
def find(cls, db, name): def find(cls, db, name):
"""Find a service by name. """Find a service by name.
Returns None if not found. Returns None if not found.
""" """
return db.query(cls).filter(cls.name==name).first() return db.query(cls).filter(cls.name==name).first()
@@ -567,7 +539,6 @@ class APIToken(Base):
@classmethod @classmethod
def find(cls, db, token, *, kind=None): def find(cls, db, token, *, kind=None):
"""Find a token object by value. """Find a token object by value.
Returns None if not found. Returns None if not found.
`kind='user'` only returns API tokens for users `kind='user'` only returns API tokens for users

View File

@@ -30,8 +30,8 @@ def db():
_db = orm.new_session_factory('sqlite:///:memory:', echo=True)() _db = orm.new_session_factory('sqlite:///:memory:', echo=True)()
user = orm.User( user = orm.User(
name=getuser(), name=getuser(),
server=orm.Server(),
) )
user.servers.append(orm.Server())
hub = orm.Hub( hub = orm.Hub(
server=orm.Server(), server=orm.Server(),
) )

View File

@@ -54,7 +54,7 @@ def test_hub(db):
port = 1234, port = 1234,
base_url='/hubtest/', base_url='/hubtest/',
), ),
) )
db.add(hub) db.add(hub)
db.commit() db.commit()
@@ -129,14 +129,14 @@ def test_service_server(db):
service = orm.Service(name='has_servers') service = orm.Service(name='has_servers')
db.add(service) db.add(service)
db.commit() db.commit()
assert service.server is None assert service.server is None
server = service.server = orm.Server() server = service.server = orm.Server()
assert service assert service
assert server.id is None assert server.id is None
db.commit() db.commit()
assert isinstance(server.id, int) assert isinstance(server.id, int)
def test_token_find(db): def test_token_find(db):
service = db.query(orm.Service).first() service = db.query(orm.Service).first()
@@ -172,17 +172,17 @@ def test_spawn_fails(db, io_loop):
orm_user = orm.User(name='aeofel') orm_user = orm.User(name='aeofel')
db.add(orm_user) db.add(orm_user)
db.commit() db.commit()
class BadSpawner(MockSpawner): class BadSpawner(MockSpawner):
@gen.coroutine @gen.coroutine
def start(self): def start(self):
raise RuntimeError("Split the party") raise RuntimeError("Split the party")
user = User(orm_user, { user = User(orm_user, {
'spawner_class': BadSpawner, 'spawner_class': BadSpawner,
'config': None, 'config': None,
}) })
with pytest.raises(Exception) as exc: with pytest.raises(Exception) as exc:
io_loop.run_sync(user.spawn) io_loop.run_sync(user.spawn)
assert user.server is None assert user.server is None
@@ -192,7 +192,7 @@ def test_spawn_fails(db, io_loop):
def test_groups(db): def test_groups(db):
user = orm.User.find(db, name='aeofel') user = orm.User.find(db, name='aeofel')
db.add(user) db.add(user)
group = orm.Group(name='lives') group = orm.Group(name='lives')
db.add(group) db.add(group)
db.commit() db.commit()

View File

@@ -18,23 +18,23 @@ from .spawner import LocalProcessSpawner
class UserDict(dict): class UserDict(dict):
"""Like defaultdict, but for users """Like defaultdict, but for users
Getting by a user id OR an orm.User instance returns a User wrapper around the orm user. Getting by a user id OR an orm.User instance returns a User wrapper around the orm user.
""" """
def __init__(self, db_factory, settings): def __init__(self, db_factory, settings):
self.db_factory = db_factory self.db_factory = db_factory
self.settings = settings self.settings = settings
super().__init__() super().__init__()
@property @property
def db(self): def db(self):
return self.db_factory() return self.db_factory()
def __contains__(self, key): def __contains__(self, key):
if isinstance(key, (User, orm.User)): if isinstance(key, (User, orm.User)):
key = key.id key = key.id
return dict.__contains__(self, key) return dict.__contains__(self, key)
def __getitem__(self, key): def __getitem__(self, key):
if isinstance(key, User): if isinstance(key, User):
key = key.id key = key.id
@@ -63,7 +63,7 @@ class UserDict(dict):
return dict.__getitem__(self, id) return dict.__getitem__(self, id)
else: else:
raise KeyError(repr(key)) raise KeyError(repr(key))
def __delitem__(self, key): def __delitem__(self, key):
user = self[key] user = self[key]
user_id = user.id user_id = user.id
@@ -74,13 +74,13 @@ class UserDict(dict):
class User(HasTraits): class User(HasTraits):
@default('log') @default('log')
def _log_default(self): def _log_default(self):
return app_log return app_log
settings = Dict() settings = Dict()
db = Any(allow_none=True) db = Any(allow_none=True)
@default('db') @default('db')
def _db_default(self): def _db_default(self):
@@ -94,32 +94,32 @@ class User(HasTraits):
id = self.orm_user.id id = self.orm_user.id
self.orm_user = change['new'].query(orm.User).filter(orm.User.id==id).first() self.orm_user = change['new'].query(orm.User).filter(orm.User.id==id).first()
self.spawner.db = self.db self.spawner.db = self.db
orm_user = None orm_user = None
spawner = None spawner = None
spawn_pending = False spawn_pending = False
stop_pending = False stop_pending = False
waiting_for_response = False waiting_for_response = False
@property @property
def authenticator(self): def authenticator(self):
return self.settings.get('authenticator', None) return self.settings.get('authenticator', None)
@property @property
def spawner_class(self): def spawner_class(self):
return self.settings.get('spawner_class', LocalProcessSpawner) return self.settings.get('spawner_class', LocalProcessSpawner)
def __init__(self, orm_user, settings, **kwargs): def __init__(self, orm_user, settings, **kwargs):
self.orm_user = orm_user self.orm_user = orm_user
self.settings = settings self.settings = settings
super().__init__(**kwargs) super().__init__(**kwargs)
hub = self.db.query(orm.Hub).first() hub = self.db.query(orm.Hub).first()
self.cookie_name = '%s-%s' % (hub.server.cookie_name, quote(self.name, safe='')) self.cookie_name = '%s-%s' % (hub.server.cookie_name, quote(self.name, safe=''))
self.base_url = url_path_join( self.base_url = url_path_join(
self.settings.get('base_url', '/'), 'user', self.escaped_name) self.settings.get('base_url', '/'), 'user', self.escaped_name)
self.spawner = self.spawner_class( self.spawner = self.spawner_class(
user=self, user=self,
db=self.db, db=self.db,
@@ -127,24 +127,24 @@ class User(HasTraits):
authenticator=self.authenticator, authenticator=self.authenticator,
config=self.settings.get('config'), config=self.settings.get('config'),
) )
# pass get/setattr to ORM user # pass get/setattr to ORM user
def __getattr__(self, attr): def __getattr__(self, attr):
if hasattr(self.orm_user, attr): if hasattr(self.orm_user, attr):
return getattr(self.orm_user, attr) return getattr(self.orm_user, attr)
else: else:
raise AttributeError(attr) raise AttributeError(attr)
def __setattr__(self, attr, value): def __setattr__(self, attr, value):
if self.orm_user and hasattr(self.orm_user, attr): if self.orm_user and hasattr(self.orm_user, attr):
setattr(self.orm_user, attr, value) setattr(self.orm_user, attr, value)
else: else:
super().__setattr__(attr, value) super().__setattr__(attr, value)
def __repr__(self): def __repr__(self):
return repr(self.orm_user) return repr(self.orm_user)
@property @property
def running(self): def running(self):
"""property for whether a user has a running server""" """property for whether a user has a running server"""
@@ -153,25 +153,25 @@ class User(HasTraits):
if self.server is None: if self.server is None:
return False return False
return True return True
@property @property
def escaped_name(self): def escaped_name(self):
"""My name, escaped for use in URLs, cookies, etc.""" """My name, escaped for use in URLs, cookies, etc."""
return quote(self.name, safe='@') return quote(self.name, safe='@')
@property @property
def proxy_path(self): def proxy_path(self):
if self.settings.get('subdomain_host'): if self.settings.get('subdomain_host'):
return url_path_join('/' + self.domain, self.base_url) return url_path_join('/' + self.domain, self.base_url)
else: else:
return self.base_url return self.base_url
@property @property
def domain(self): def domain(self):
"""Get the domain for my server.""" """Get the domain for my server."""
# FIXME: escaped_name probably isn't escaped enough in general for a domain fragment # FIXME: escaped_name probably isn't escaped enough in general for a domain fragment
return self.escaped_name + '.' + self.settings['domain'] return self.escaped_name + '.' + self.settings['domain']
@property @property
def host(self): def host(self):
"""Get the *host* for my server (proto://domain[:port])""" """Get the *host* for my server (proto://domain[:port])"""
@@ -179,11 +179,11 @@ class User(HasTraits):
parsed = urlparse(self.settings['subdomain_host']) parsed = urlparse(self.settings['subdomain_host'])
h = '%s://%s.%s' % (parsed.scheme, self.escaped_name, parsed.netloc) h = '%s://%s.%s' % (parsed.scheme, self.escaped_name, parsed.netloc)
return h return h
@property @property
def url(self): def url(self):
"""My URL """My URL
Full name.domain/path if using subdomains, otherwise just my /base/url Full name.domain/path if using subdomains, otherwise just my /base/url
""" """
if self.settings.get('subdomain_host'): if self.settings.get('subdomain_host'):
@@ -193,22 +193,22 @@ class User(HasTraits):
) )
else: else:
return self.base_url return self.base_url
@gen.coroutine @gen.coroutine
def spawn(self, options=None): def spawn(self, options=None):
"""Start the user's spawner""" """Start the user's spawner"""
db = self.db db = self.db
server = orm.Server(
self.server = orm.Server(
cookie_name=self.cookie_name, cookie_name=self.cookie_name,
base_url=self.base_url, base_url=self.base_url,
) )
db.add(self.server) self.servers.append(server)
db.add(self)
db.commit() db.commit()
api_token = self.new_api_token() api_token = self.new_api_token()
db.commit() db.commit()
spawner = self.spawner spawner = self.spawner
spawner.user_options = options or {} spawner.user_options = options or {}
# we are starting a new server, make sure it doesn't restore state # we are starting a new server, make sure it doesn't restore state
@@ -294,7 +294,7 @@ class User(HasTraits):
@gen.coroutine @gen.coroutine
def stop(self): def stop(self):
"""Stop the user's spawner """Stop the user's spawner
and cleanup after it. and cleanup after it.
""" """
self.spawn_pending = False self.spawn_pending = False
@@ -316,7 +316,8 @@ class User(HasTraits):
orm_token = orm.APIToken.find(self.db, api_token) orm_token = orm.APIToken.find(self.db, api_token)
if orm_token: if orm_token:
self.db.delete(orm_token) self.db.delete(orm_token)
self.server = None for server in self.servers:
self.db.delete(server)
self.db.commit() self.db.commit()
finally: finally:
self.stop_pending = False self.stop_pending = False
@@ -326,4 +327,3 @@ class User(HasTraits):
yield gen.maybe_future( yield gen.maybe_future(
auth.post_spawn_stop(self, spawner) auth.post_spawn_stop(self, spawner)
) )