Files
jupyterhub/jupyterhub/orm.py
Min RK 239902934a Merge pull request #4988 from manics/ipv6
More IPv6: Use bare IPv6 for configuration, use `[ipv6]` when displaying IPv6 outputs
2025-03-28 10:31:07 +01:00

1601 lines
49 KiB
Python

"""sqlalchemy ORM tools for the state of the constellation of processes"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import enum
import json
import numbers
import secrets
from base64 import decodebytes, encodebytes
from datetime import timedelta
from functools import lru_cache, partial
from itertools import chain
import alembic.command
import alembic.config
import sqlalchemy
from alembic.script import ScriptDirectory
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
MetaData,
Table,
Unicode,
create_engine,
event,
exc,
inspect,
or_,
select,
text,
)
from sqlalchemy.orm import (
Session,
declarative_base,
declared_attr,
interfaces,
joinedload,
object_session,
relationship,
sessionmaker,
)
from sqlalchemy.pool import StaticPool
from sqlalchemy.types import LargeBinary, Text, TypeDecorator
from tornado.log import app_log
from .utils import compare_token, fmt_ip_url, hash_token, new_token, random_port, utcnow
# top-level variable for easier mocking in tests
utcnow = partial(utcnow, with_tz=False)
class JSONDict(TypeDecorator):
"""Represents an immutable structure as a json-encoded string.
Usage::
JSONDict(255)
"""
impl = Text
def _json_default(self, obj):
"""encode non-jsonable objects as JSON
Currently only bytes are supported
"""
if not isinstance(obj, bytes):
app_log.warning(
"Non-jsonable data in user_options: %r; will persist None.", type(obj)
)
return None
return {"__jupyterhub_bytes__": True, "data": encodebytes(obj).decode('ascii')}
def _object_hook(self, dct):
"""decode non-json objects packed by _json_default"""
if dct.get("__jupyterhub_bytes__", False):
return decodebytes(dct['data'].encode('ascii'))
return dct
def process_bind_param(self, value, dialect):
if value is not None:
value = json.dumps(value, default=self._json_default)
return value
def process_result_value(self, value, dialect):
if value is not None:
value = json.loads(value, object_hook=self._object_hook)
return value
class JSONList(JSONDict):
"""Represents an immutable structure as a json-encoded string (to be used for list type columns).
Accepts list, tuple, sets for assignment
Always read as a list
Usage::
JSONList(JSONDict)
"""
def process_bind_param(self, value, dialect):
if isinstance(value, (list, tuple)):
value = json.dumps(value)
if isinstance(value, set):
# serialize sets as ordered lists
value = json.dumps(sorted(value))
return value
def process_result_value(self, value, dialect):
if value is None:
return []
else:
value = json.loads(value)
return value
meta = MetaData(
naming_convention={
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
)
Base = declarative_base(metadata=meta)
Base.log = app_log
class Server(Base):
"""The basic state of a server
connection and cookie info
"""
__tablename__ = 'servers'
id = Column(Integer, primary_key=True)
proto = Column(Unicode(15), default='http')
ip = Column(Unicode(255), default='') # could also be a DNS name
port = Column(Integer, default=random_port)
base_url = Column(Unicode(255), default='/')
cookie_name = Column(Unicode(255), default='cookie')
service = relationship("Service", back_populates="server", uselist=False)
spawner = relationship("Spawner", back_populates="server", uselist=False)
def __repr__(self):
return f"<Server({fmt_ip_url(self.ip)}:{self.port})>"
# lots of things have roles
# mapping tables are the same for all of them
_role_associations = {}
for entity in (
'user',
'group',
'service',
):
table = Table(
f'{entity}_role_map',
Base.metadata,
Column(
f'{entity}_id',
ForeignKey(f'{entity}s.id', ondelete='CASCADE'),
primary_key=True,
),
Column(
'role_id',
ForeignKey('roles.id', ondelete='CASCADE'),
primary_key=True,
),
Column('managed_by_auth', Boolean, default=False, nullable=False),
)
_role_associations[entity] = type(
entity.title() + 'RoleMap', (Base,), {'__table__': table}
)
class Role(Base):
"""User Roles"""
__tablename__ = 'roles'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(255), unique=True)
description = Column(Unicode(1023))
scopes = Column(JSONList, default=[])
users = relationship('User', secondary='user_role_map', back_populates='roles')
services = relationship(
'Service', secondary='service_role_map', back_populates='roles'
)
groups = relationship('Group', secondary='group_role_map', back_populates='roles')
managed_by_auth = Column(Boolean, default=False, nullable=False)
def __repr__(self):
return f"<{self.__class__.__name__} {self.name} ({self.description}) - scopes: {self.scopes}>"
@classmethod
def find(cls, db, name):
"""Find a role by name.
Returns None if not found.
"""
return db.query(cls).filter(cls.name == name).first()
# user:group many:many mapping table
user_group_map = Table(
'user_group_map',
Base.metadata,
Column('user_id', ForeignKey('users.id', ondelete='CASCADE'), primary_key=True),
Column('group_id', ForeignKey('groups.id', ondelete='CASCADE'), primary_key=True),
)
class Group(Base):
"""User Groups"""
__tablename__ = 'groups'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(255), unique=True)
users = relationship('User', secondary='user_group_map', back_populates='groups')
properties = Column(JSONDict, default={})
roles = relationship(
'Role', secondary='group_role_map', back_populates='groups', lazy="selectin"
)
shared_with_me = relationship(
"Share",
back_populates="group",
cascade="all, delete-orphan",
lazy="selectin",
)
# used in some model fields to differentiate 'whoami'
kind = "group"
def __repr__(self):
return f"<{self.__class__.__name__} {self.name}>"
@classmethod
def find(cls, db, name):
"""Find a group by name.
Returns None if not found.
"""
return db.query(cls).filter(cls.name == name).first()
class User(Base):
"""The User table
Each user can have one or more single user notebook servers.
Each single user notebook server will have a unique token for authorization.
Therefore, a user with multiple notebook servers will have multiple tokens.
API tokens grant access to the Hub's REST API.
These are used by single-user servers to authenticate requests,
and external services to manipulate the Hub.
Cookies are set with a single ID.
Resetting the Cookie ID invalidates all cookies, forcing user to login again.
A `state` column contains a JSON dict,
used for restoring state of a Spawner.
`servers` is a list that contains a reference for each of the user's single user notebook servers.
The method `server` returns the first entry in the user's `servers` list.
"""
__tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(255), unique=True)
roles = relationship(
'Role',
secondary='user_role_map',
back_populates='users',
lazy="selectin",
)
_orm_spawners = relationship(
"Spawner", back_populates="user", cascade="all, delete-orphan"
)
@property
def orm_spawners(self):
return {s.name: s for s in self._orm_spawners}
admin = Column(Boolean(create_constraint=False), default=False)
created = Column(DateTime, default=utcnow)
last_activity = Column(DateTime, nullable=True)
api_tokens = relationship(
"APIToken", back_populates="user", cascade="all, delete-orphan"
)
groups = relationship(
"Group",
secondary='user_group_map',
back_populates="users",
lazy="selectin",
)
oauth_codes = relationship(
"OAuthCode", back_populates="user", cascade="all, delete-orphan"
)
# sharing relationships
shares = relationship(
"Share",
back_populates="owner",
cascade="all, delete-orphan",
foreign_keys="Share.owner_id",
)
share_codes = relationship(
"ShareCode",
back_populates="owner",
cascade="all, delete-orphan",
foreign_keys="ShareCode.owner_id",
)
shared_with_me = relationship(
"Share",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="Share.user_id",
lazy="selectin",
)
@property
def all_shared_with_me(self):
"""return all shares shared with me,
including via group
"""
return list(
chain(
self.shared_with_me,
*[group.shared_with_me for group in self.groups],
)
)
cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True)
# User.state is actually Spawner state
# We will need to figure something else out if/when we have multiple spawners per user
state = Column(JSONDict)
# Authenticators can store their state here:
# Encryption is handled elsewhere
encrypted_auth_state = Column(LargeBinary)
# used in some model fields to differentiate whether an owner or actor
# is a user or service
kind = "user"
def __repr__(self):
return f"<{self.__class__.__name__}({self.name} {sum(bool(s.server) for s in self._orm_spawners)}/{len(self._orm_spawners)} running)>"
def new_api_token(self, token=None, **kwargs):
"""Create a new API token
If `token` is given, load that token.
"""
return APIToken.new(token=token, user=self, **kwargs)
@classmethod
def find(cls, db, name):
"""Find a user by name.
Returns None if not found.
"""
return db.query(cls).filter(cls.name == name).first()
class Spawner(Base):
""" "State about a Spawner"""
__tablename__ = 'spawners'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
user = relationship("User", back_populates="_orm_spawners")
server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL'))
server = relationship(
Server,
back_populates="spawner",
lazy="joined",
single_parent=True,
cascade="all, delete-orphan",
)
shares = relationship(
"Share", back_populates="spawner", cascade="all, delete-orphan"
)
share_codes = relationship(
"ShareCode", back_populates="spawner", cascade="all, delete-orphan"
)
state = Column(JSONDict)
name = Column(Unicode(255))
started = Column(DateTime)
last_activity = Column(DateTime, nullable=True)
user_options = Column(JSONDict)
# added in 2.0
oauth_client_id = Column(
Unicode(255),
ForeignKey(
'oauth_clients.identifier',
ondelete='SET NULL',
),
)
oauth_client = relationship(
'OAuthClient',
back_populates="spawner",
cascade="all, delete-orphan",
single_parent=True,
)
# properties on the spawner wrapper
# some APIs get these low-level objects
# when the spawner isn't running,
# for which these should all be False
active = running = ready = False
pending = None
@property
def orm_spawner(self):
return self
class Service(Base):
"""A service run with JupyterHub
A service is similar to a User without a Spawner.
A service can have API tokens for accessing the Hub's API
It has:
- name
- admin
- api tokens
- server (if proxied http endpoint)
In addition to what it has in common with users, a Service has extra info:
- pid: the process id (if managed)
"""
__tablename__ = 'services'
id = Column(Integer, primary_key=True, autoincrement=True)
# common user interface:
name = Column(Unicode(255), unique=True)
admin = Column(Boolean(create_constraint=False), default=False)
roles = relationship(
'Role', secondary='service_role_map', back_populates='services', lazy="selectin"
)
url = Column(Unicode(2047), nullable=True)
oauth_client_allowed_scopes = Column(JSONList, nullable=True)
info = Column(JSONDict, nullable=True)
display = Column(Boolean, nullable=True)
oauth_no_confirm = Column(Boolean, nullable=True)
command = Column(JSONList, nullable=True)
cwd = Column(Unicode(4095), nullable=True)
environment = Column(JSONDict, nullable=True)
user = Column(Unicode(255), nullable=True)
from_config = Column(Boolean, default=True)
api_tokens = relationship(
"APIToken", back_populates="service", cascade="all, delete-orphan"
)
# service-specific interface
_server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL'))
server = relationship(
Server,
back_populates="service",
single_parent=True,
cascade="all, delete-orphan",
)
pid = Column(Integer)
# added in 2.0
oauth_client_id = Column(
Unicode(255),
ForeignKey(
'oauth_clients.identifier',
ondelete='SET NULL',
),
)
oauth_client = relationship(
'OAuthClient',
back_populates="service",
cascade="all, delete-orphan",
single_parent=True,
)
# used in some model fields to differentiate 'whoami'
kind = "service"
def new_api_token(self, token=None, **kwargs):
"""Create a new API token
If `token` is given, load that token.
"""
return APIToken.new(token=token, service=self, **kwargs)
@classmethod
def find(cls, db, name):
"""Find a service by name.
Returns None if not found.
"""
return db.query(cls).filter(cls.name == name).first()
class Expiring:
"""Mixin for expiring entries
Subclass must define at least expires_at property,
which should be unix timestamp or datetime object
"""
now = utcnow # function, must return float timestamp or datetime
expires_at = None # must be defined
@property
def expires_in(self):
"""Property returning expiration in seconds from now
or None
"""
if self.expires_at:
delta = self.expires_at - self.now()
if isinstance(delta, timedelta):
delta = delta.total_seconds()
return delta
else:
return None
@property
def expired(self):
"""Is this object expired?"""
if not self.expires_at:
return False
else:
return self.expires_in <= 0
@classmethod
def purge_expired(cls, db):
"""Purge expired API Tokens from the database"""
now = cls.now()
deleted = False
for obj in (
db.query(cls).filter(cls.expires_at != None).filter(cls.expires_at < now)
):
app_log.debug("Purging expired %s", obj)
deleted = True
db.delete(obj)
if deleted:
db.commit()
class Hashed(Expiring):
"""Mixin for tables with hashed tokens"""
prefix_length = 4
algorithm = "sha512"
rounds = 16384
salt_bytes = 8
min_length = 8
# values to use for internally generated tokens,
# which have good entropy as UUIDs
generated = True
generated_salt_bytes = 8
generated_rounds = 1
@property
def token(self):
raise AttributeError(f"{self.__class__.__name__}.token is write-only")
@token.setter
def token(self, token):
"""Store the hashed value and prefix for a token"""
self.prefix = token[: self.prefix_length]
if self.generated:
# Generated tokens are UUIDs, which have sufficient entropy on their own
# and don't need salt & hash rounds.
# ref: https://security.stackexchange.com/a/151262/155114
rounds = self.generated_rounds
salt_bytes = self.generated_salt_bytes
else:
rounds = self.rounds
salt_bytes = self.salt_bytes
self.hashed = hash_token(
token, rounds=rounds, salt=salt_bytes, algorithm=self.algorithm
)
def match(self, token):
"""Is this my token?"""
return compare_token(self.hashed, token)
@classmethod
def check_token(cls, db, token):
"""Check if a token is acceptable"""
if len(token) < cls.min_length:
raise ValueError(
f"{cls.__name__}.token must be at least {cls.min_length} characters, got {len(token)}: {token[: cls.prefix_length]}..."
)
found = cls.find(db, token)
if found:
raise ValueError(
f"Collision on {cls.__name__}: {token[: cls.prefix_length]}..."
)
@classmethod
def find_prefix(cls, db, token):
"""Start the query for matching token.
Returns an SQLAlchemy query already filtered by prefix-matches.
.. versionchanged:: 1.2
Excludes expired matches.
"""
prefix = token[: cls.prefix_length]
# since we can't filter on hashed values, filter on prefix
# so we aren't comparing with all tokens
prefix_match = db.query(cls).filter_by(prefix=prefix)
prefix_match = prefix_match.filter(
or_(cls.expires_at == None, cls.expires_at >= cls.now())
)
return prefix_match
@classmethod
def find(cls, db, token):
"""Find a token object by value.
Returns None if not found.
`kind='user'` only returns API tokens for users
`kind='service'` only returns API tokens for services
"""
prefix_match = cls.find_prefix(db, token).options(
joinedload(cls.user), joinedload(cls.service)
)
for orm_token in prefix_match:
if orm_token.match(token):
return orm_token
class _Share:
"""Common columns for Share and ShareCode"""
id = Column(Integer, primary_key=True, autoincrement=True)
created_at = Column(DateTime, nullable=False, default=utcnow)
# TODO: owner_id and spawner_id columns don't need `@declared_attr` when we can require sqlalchemy 2
# the owner of the shared server
# this is redundant with spawner.user, but saves a join
@declared_attr
def owner_id(self):
return Column(Integer, ForeignKey('users.id', ondelete="CASCADE"))
@declared_attr
def owner(self):
# table name happens to be appropriate 'shares', 'share_codes'
# could be another, more explicit attribute, but the values would be the same
return relationship(
"User",
back_populates=self.__tablename__,
foreign_keys=[self.owner_id],
lazy="selectin",
)
# the spawner the share is for
@declared_attr
def spawner_id(self):
return Column(Integer, ForeignKey('spawners.id', ondelete="CASCADE"))
@declared_attr
def spawner(self):
return relationship(
"Spawner",
back_populates=self.__tablename__,
lazy="selectin",
)
# the permissions granted (!server filter will always be applied)
scopes = Column(JSONList)
expires_at = Column(DateTime, nullable=True)
@classmethod
def apply_filter(cls, scopes, spawner):
"""Apply our filter, ensures all scopes have appropriate !server filter
Any other filters will raise ValueError.
"""
return cls._apply_filter(frozenset(scopes), spawner.user.name, spawner.name)
@staticmethod
@lru_cache
def _apply_filter(scopes, owner_name, server_name):
"""
implementation of Share.apply_filter
Static method so @lru_cache is persisted across instances
"""
filtered_scopes = []
server_filter = f"server={owner_name}/{server_name}"
for scope in scopes:
base_scope, _, filter = scope.partition("!")
if filter and filter != server_filter:
raise ValueError(
f"!{filter} not allowed on sharing {scope}, only !{server_filter}"
)
filtered_scopes.append(f"{base_scope}!{server_filter}")
return frozenset(filtered_scopes)
class Share(_Share, Expiring, Base):
"""A single record of a sharing permission
granted by one user to another user (or group)
Restricted to a single server.
"""
__tablename__ = "shares"
# who the share is granted to (user or group)
user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
user = relationship(
"User", back_populates="shared_with_me", foreign_keys=[user_id], lazy="selectin"
)
group_id = Column(
Integer, ForeignKey('groups.id', ondelete="CASCADE"), nullable=True
)
group = relationship("Group", back_populates="shared_with_me", lazy="selectin")
def __repr__(self):
if self.user:
kind = "user"
name = self.user.name
elif self.group:
kind = "group"
name = self.group.name
else: # pragma: no cover
kind = "deleted"
name = "unknown"
if self.owner and self.spawner:
server_name = f"{self.owner.name}/{self.spawner.name}"
else: # pragma: n cover
server_name = "unknown/deleted"
return f"<{self.__class__.__name__}(server={server_name}, scopes={self.scopes}, {kind}={name})>"
@staticmethod
def _share_with_key(share_with):
"""Get the field name for share with
either group_id or user_id, depending on type of share_with
raises TypeError if neither User nor Group
"""
if isinstance(share_with, User):
return "user_id"
elif isinstance(share_with, Group):
return "group_id"
else:
raise TypeError(
f"Can only share with orm.User or orm.Group, not {share_with!r}"
)
@classmethod
def find(cls, db, spawner, share_with):
"""Find an existing
Shares are unique for a given (spawner, user)
"""
filter_by = {
cls._share_with_key(share_with): share_with.id,
"spawner_id": spawner.id,
"owner_id": spawner.user.id,
}
return db.query(Share).filter_by(**filter_by).one_or_none()
@staticmethod
def _get_log_name(spawner, share_with):
"""construct log snippet to refer to the share"""
return (
f"{share_with.kind}:{share_with.name} on {spawner.user.name}/{spawner.name}"
)
@property
def _log_name(self):
return self._get_log_name(self.spawner, self.user or self.group)
@classmethod
def grant(cls, db, spawner, share_with, scopes=None):
"""Grant shared permissions for a server
Updates existing Share if there is one,
otherwise creates a new Share
"""
if scopes is None:
scopes = frozenset(
[f"access:servers!server={spawner.user.name}/{spawner.name}"]
)
scopes = cls._apply_filter(frozenset(scopes), spawner.user.name, spawner.name)
if not scopes:
raise ValueError("Must specify scopes to grant.")
# 1. lookup existing share and update
share = cls.find(db, spawner, share_with)
share_with_log = cls._get_log_name(spawner, share_with)
if share is not None:
# update existing permissions in-place
# extend permissions
existing_scopes = set(share.scopes)
added_scopes = set(scopes).difference(existing_scopes)
if not added_scopes:
app_log.info(f"No new scopes for {share_with_log}")
return share
new_scopes = sorted(existing_scopes | added_scopes)
app_log.info(f"Granting scopes {sorted(added_scopes)} for {share_with_log}")
share.scopes = new_scopes
db.commit()
else:
# no share for (spawner, share_with), create a new one
app_log.info(f"Sharing scopes {sorted(scopes)} for {share_with_log}")
share = cls(
created_at=cls.now(),
# copy shared fields
owner=spawner.user,
spawner=spawner,
scopes=sorted(scopes),
)
if share_with.kind == "user":
share.user = share_with
elif share_with.kind == "group":
share.group = share_with
else:
raise TypeError(f"Expected user or group, got {share_with!r}")
db.add(share)
db.commit()
return share
@classmethod
def revoke(cls, db, spawner, share_with, scopes=None):
"""Revoke permissions for share_with on `spawner`
If scopes are not specified, all scopes are revoked
"""
share = cls.find(db, spawner, share_with)
if share is None:
_log_name = cls._get_log_name(spawner, share_with)
app_log.info(f"No permissions to revoke from {_log_name}")
return
else:
_log_name = share._log_name
if scopes is None:
app_log.info(f"Revoked all permissions from {_log_name}")
db.delete(share)
db.commit()
return None
# update scopes
new_scopes = [scope for scope in share.scopes if scope not in scopes]
revoked_scopes = [scope for scope in scopes if scope in set(share.scopes)]
if new_scopes == share.scopes:
app_log.info(f"No change in scopes for {_log_name}")
return share
elif not new_scopes:
# revoked all scopes, delete the Share
app_log.info(f"Revoked all permissions from {_log_name}")
db.delete(share)
db.commit()
else:
app_log.info(f"Revoked {revoked_scopes} from {_log_name}")
share.scopes = new_scopes
db.commit()
if new_scopes:
return share
else:
return None
class ShareCode(_Share, Hashed, Base):
"""A code that can be exchanged for a Share
Ultimately, the same as a Share, but has a 'code'
instead of a user or group that it is shared with.
The code can be exchanged to create or update an actual Share.
"""
__tablename__ = "share_codes"
hashed = Column(Unicode(255), unique=True)
prefix = Column(Unicode(16), index=True)
exchange_count = Column(Integer, default=0)
last_exchanged_at = Column(DateTime, nullable=True, default=None)
_code_bytes = 32
default_expires_in = 86400
def __repr__(self):
if self.owner and self.spawner:
server_name = f"{self.owner.name}/{self.spawner.name}"
else:
server_name = "unknown/deleted"
return f"<{self.__class__.__name__}(id={self.id}, server={server_name}, scopes={self.scopes}, expires_at={self.expires_at})>"
@classmethod
def new(
cls,
db,
spawner,
*,
scopes,
expires_in=None,
**kwargs,
):
"""Create a new ShareCode"""
app_log.info(f"Creating share code for {spawner.user.name}/{spawner.name}")
# verify scopes have the necessary filter
kwargs["scopes"] = sorted(cls.apply_filter(scopes, spawner))
if not expires_in:
expires_in = cls.default_expires_in
kwargs["expires_at"] = utcnow() + timedelta(seconds=expires_in)
kwargs["spawner"] = spawner
kwargs["owner"] = spawner.user
code = secrets.token_urlsafe(cls._code_bytes)
# create the ShareCode
share_code = cls(**kwargs)
# setting Hashed.token property sets the `hashed` column in the db
share_code.token = code
# actually put it in the db
db.add(share_code)
db.commit()
return (share_code, code)
@classmethod
def find(cls, db, code, *, spawner=None):
"""Lookup a single ShareCode by code"""
prefix_match = cls.find_prefix(db, code)
if spawner:
prefix_match = prefix_match.filter_by(spawner_id=spawner.id)
for share_code in prefix_match:
if share_code.match(code):
return share_code
def exchange(self, share_with):
"""exchange a ShareCode for a Share
share_with can be a User or a Group.
"""
db = inspect(self).session
share_code_log = f"Share code {self.prefix}..."
if self.expired:
db.delete(self)
db.commit()
raise ValueError(f"{share_code_log} expired")
share_with_log = f"{share_with.kind}:{share_with.name} on {self.owner.name}/{self.spawner.name}"
app_log.info(f"Exchanging {share_code_log} for {share_with_log}")
share = Share.grant(db, self.spawner, share_with, self.scopes)
# note: we count exchanges, even if they don't modify the permissions
# (e.g. one user exchanging the same code twice)
self.exchange_count += 1
self.last_exchanged_at = self.now()
db.commit()
return share
# ------------------------------------
# OAuth tables
# ------------------------------------
class GrantType(enum.Enum):
# we only use authorization_code for now
authorization_code = 'authorization_code'
implicit = 'implicit'
password = 'password'
client_credentials = 'client_credentials'
refresh_token = 'refresh_token'
class APIToken(Hashed, Base):
"""An API token"""
__tablename__ = 'api_tokens'
user_id = Column(
Integer,
ForeignKey('users.id', ondelete="CASCADE"),
nullable=True,
)
service_id = Column(
Integer,
ForeignKey('services.id', ondelete="CASCADE"),
nullable=True,
)
user = relationship("User", back_populates="api_tokens")
service = relationship("Service", back_populates="api_tokens")
oauth_client = relationship("OAuthClient", back_populates="access_tokens")
id = Column(Integer, primary_key=True)
hashed = Column(Unicode(255), unique=True)
prefix = Column(Unicode(16), index=True)
@property
def api_id(self):
return f"a{self.id}"
@property
def owner(self):
return self.user or self.service
# added in 2.0
client_id = Column(
Unicode(255),
ForeignKey(
'oauth_clients.identifier',
ondelete='CASCADE',
),
)
# FIXME: refresh_tokens not implemented
# should be a relation to another token table
# refresh_token = Column(
# Integer,
# ForeignKey('refresh_tokens.id', ondelete="CASCADE"),
# nullable=True,
# )
# the browser session id associated with a given token,
# if issued during oauth to be stored in a cookie
session_id = Column(Unicode(255), nullable=True)
# token metadata for bookkeeping
now = utcnow # for expiry
created = Column(DateTime, default=utcnow)
expires_at = Column(DateTime, default=None, nullable=True)
last_activity = Column(DateTime)
note = Column(Unicode(1023))
scopes = Column(JSONList, default=[])
def __repr__(self):
if self.user is not None:
kind = 'user'
name = self.user.name
elif self.service is not None:
kind = 'service'
name = self.service.name
else:
# this shouldn't happen
kind = 'owner'
name = 'unknown'
return f"<{self.__class__.__name__}('{self.prefix}...', {kind}='{name}', client_id={self.client_id!r})>"
@classmethod
def find(cls, db, token, *, kind=None):
"""Find a token object by value.
Returns None if not found.
`kind='user'` only returns API tokens for users
`kind='service'` only returns API tokens for services
"""
prefix_match = cls.find_prefix(db, token)
if kind == 'user':
prefix_match = prefix_match.filter(cls.user_id != None)
elif kind == 'service':
prefix_match = prefix_match.filter(cls.service_id != None)
elif kind is not None:
raise ValueError(f"kind must be 'user', 'service', or None, not {kind!r}")
for orm_token in prefix_match:
if orm_token.match(token):
if not orm_token.client_id:
app_log.warning(
"Deleting stale oauth token for %s with no client",
orm_token.user and orm_token.user.name,
)
db.delete(orm_token)
db.commit()
return
return orm_token
@classmethod
def new(
cls,
token=None,
*,
user=None,
service=None,
roles=None,
scopes=None,
note='',
generated=True,
session_id=None,
expires_in=None,
client_id=None,
oauth_client=None,
):
"""Generate a new API token for a user or service"""
assert user or service
assert not (user and service)
db = inspect(user or service).session
if token is None:
token = new_token()
# Don't need hash + salt rounds on generated tokens,
# which already have good entropy
generated = True
else:
cls.check_token(db, token)
# avoid circular import
from .roles import roles_to_scopes
if scopes is not None and roles is not None:
raise ValueError(
"Can only assign one of scopes or roles when creating tokens."
)
elif scopes is None and roles is None:
# this is the default branch
# use the default 'token' role to specify default permissions for API tokens
default_token_role = Role.find(db, 'token')
if not default_token_role:
scopes = ["inherit"]
else:
scopes = roles_to_scopes([default_token_role])
elif roles is not None:
# evaluate roles to scopes immediately
# TODO: should this be deprecated, or not?
# warnings.warn(
# "Setting roles on tokens is deprecated in JupyterHub 3.0. Use scopes.",
# DeprecationWarning,
# stacklevel=3,
# )
orm_roles = []
for rolename in roles:
role = Role.find(db, name=rolename)
if role is None:
raise ValueError(f"No such role: {rolename}")
orm_roles.append(role)
scopes = roles_to_scopes(orm_roles)
if oauth_client is None:
# lookup oauth client by identifier
if client_id is None:
# default: global 'jupyterhub' client
client_id = "jupyterhub"
oauth_client = db.query(OAuthClient).filter_by(identifier=client_id).one()
if client_id is None:
client_id = oauth_client.identifier
# avoid circular import
from .scopes import _check_scopes_exist, _check_token_scopes
_check_scopes_exist(scopes, who_for="token")
_check_token_scopes(scopes, owner=user or service, oauth_client=oauth_client)
# two stages to ensure orm_token.generated has been set
# before token setter is called
orm_token = cls(
generated=generated,
note=note or '',
client_id=client_id,
session_id=session_id,
scopes=list(scopes),
)
db.add(orm_token)
orm_token.token = token
if user:
assert user.id is not None
orm_token.user = user
else:
assert service.id is not None
orm_token.service = service
if expires_in:
if not isinstance(expires_in, numbers.Real):
raise TypeError(
f"expires_in must be a positive integer or null, not {expires_in!r}"
)
expires_in = int(expires_in)
# tokens must always expire in the future
if expires_in < 1:
raise ValueError(
f"expires_in must be a positive integer or null, not {expires_in!r}"
)
orm_token.expires_at = cls.now() + timedelta(seconds=expires_in)
db.commit()
return token
def update_scopes(self, new_scopes):
"""Set new scopes, checking that they are allowed"""
from .scopes import _check_scopes_exist, _check_token_scopes
_check_scopes_exist(new_scopes, who_for="token")
_check_token_scopes(
new_scopes, owner=self.owner, oauth_client=self.oauth_client
)
self.scopes = new_scopes
class OAuthCode(Expiring, Base):
__tablename__ = 'oauth_codes'
id = Column(Integer, primary_key=True, autoincrement=True)
client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
)
client = relationship(
"OAuthClient",
back_populates="codes",
)
code = Column(Unicode(36))
expires_at = Column(Integer)
redirect_uri = Column(Unicode(1023))
session_id = Column(Unicode(255))
# state = Column(Unicode(1023))
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
user = relationship(
"User",
back_populates="oauth_codes",
)
scopes = Column(JSONList, default=[])
@staticmethod
def now():
return utcnow(with_tz=True).timestamp()
@classmethod
def find(cls, db, code):
return (
db.query(cls)
.filter(cls.code == code)
.filter(or_(cls.expires_at == None, cls.expires_at >= cls.now()))
.options(
# load user with the code
joinedload(cls.user, innerjoin=True),
)
.first()
)
def __repr__(self):
return (
f"<{self.__class__.__name__}(id={self.id}, client_id={self.client_id!r})>"
)
class OAuthClient(Base):
__tablename__ = 'oauth_clients'
id = Column(Integer, primary_key=True, autoincrement=True)
identifier = Column(Unicode(255), unique=True)
description = Column(Unicode(1023))
secret = Column(Unicode(255))
redirect_uri = Column(Unicode(1023))
@property
def client_id(self):
return self.identifier
spawner = relationship(
"Spawner",
back_populates="oauth_client",
uselist=False,
)
service = relationship(
"Service",
back_populates="oauth_client",
uselist=False,
)
access_tokens = relationship(
APIToken, back_populates='oauth_client', cascade='all, delete-orphan'
)
codes = relationship(
OAuthCode, back_populates='client', cascade='all, delete-orphan'
)
# these are the scopes an oauth client is allowed to request
# *not* the scopes of the client itself
allowed_scopes = Column(JSONList, default=[])
def __repr__(self):
return f"<{self.__class__.__name__}(identifier={self.identifier!r})>"
# General database utilities
class DatabaseSchemaMismatch(Exception):
"""Exception raised when the database schema version does not match
the current version of JupyterHub.
"""
def register_foreign_keys(engine):
"""register PRAGMA foreign_keys=on on connection"""
@event.listens_for(engine, "connect")
def connect(dbapi_con, con_record):
cursor = dbapi_con.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
def _expire_relationship(target, relationship_prop):
"""Expire relationship backrefs
used when an object with relationships is deleted
"""
session = object_session(target)
# get peer objects to be expired
peers = getattr(target, relationship_prop.key)
if peers is None:
# no peer to clear
return
# many-to-many and one-to-many have a list of peers
# many-to-one has only one
if (
relationship_prop.direction is interfaces.MANYTOONE
or not relationship_prop.uselist
):
peers = [peers]
for obj in peers:
if inspect(obj).persistent:
session.expire(obj, [relationship_prop.back_populates])
@event.listens_for(Session, "persistent_to_deleted")
def _notify_deleted_relationships(session, obj):
"""Expire relationships when an object becomes deleted
Needed to keep relationships up to date.
"""
mapper = inspect(obj).mapper
for prop in mapper.relationships:
if prop.back_populates:
_expire_relationship(obj, prop)
def register_ping_connection(engine):
"""Check connections before using them.
Avoids database errors when using stale connections.
From SQLAlchemy docs on pessimistic disconnect handling:
https://docs.sqlalchemy.org/en/rel_1_1/core/pooling.html#disconnect-handling-pessimistic
"""
# listeners are normally registered as a decorator,
# but we need two different signatures to avoid SAWarning:
# The argument signature for the "ConnectionEvents.engine_connect" event listener has changed
# while we support sqla 1.4 and 2.0.
# @event.listens_for(engine, "engine_connect")
def ping_connection(connection):
# turn off "close with result". This flag is only used with
# "connectionless" execution, otherwise will be False in any case
save_should_close_with_result = connection.should_close_with_result
connection.should_close_with_result = False
try:
# run a SELECT 1. use a core select() so that
# the SELECT of a scalar value without a table is
# appropriately formatted for the backend
with connection.begin() as transaction:
connection.scalar(select(1))
except exc.DBAPIError as err:
# catch SQLAlchemy's DBAPIError, which is a wrapper
# for the DBAPI's exception. It includes a .connection_invalidated
# attribute which specifies if this connection is a "disconnect"
# condition, which is based on inspection of the original exception
# by the dialect in use.
if err.connection_invalidated:
app_log.error(
"Database connection error, attempting to reconnect: %s", err
)
# run the same SELECT again - the connection will re-validate
# itself and establish a new connection. The disconnect detection
# here also causes the whole connection pool to be invalidated
# so that all stale connections are discarded.
with connection.begin() as transaction:
connection.scalar(select(1))
else:
raise
finally:
# restore "close with result"
connection.should_close_with_result = save_should_close_with_result
# sqla v1/v2 compatible invocation of @event.listens_for:
def ping_connection_v1(connection, branch=None):
"""sqlalchemy < 2.0 compatibility"""
return ping_connection(connection)
if int(sqlalchemy.__version__.split(".", 1)[0]) >= 2:
listener = ping_connection
else:
listener = ping_connection_v1
event.listens_for(engine, "engine_connect")(listener)
def check_db_revision(engine):
"""Check the JupyterHub database revision
After calling this function, an alembic tag is guaranteed to be stored in the db.
- Checks the alembic tag and raises a ValueError if it's not the current revision
- If no tag is stored (Bug in Hub prior to 0.8),
guess revision based on db contents and tag the revision.
- Empty databases are tagged with the current revision
"""
# Check database schema version
current_table_names = set(inspect(engine).get_table_names())
my_table_names = set(Base.metadata.tables.keys())
from .dbutil import _temp_alembic_ini
# alembic needs the password if it's in the URL
engine_url = engine.url.render_as_string(hide_password=False)
with _temp_alembic_ini(engine_url) as ini:
cfg = alembic.config.Config(ini)
scripts = ScriptDirectory.from_config(cfg)
head = scripts.get_heads()[0]
base = scripts.get_base()
if not my_table_names.intersection(current_table_names):
# no tables have been created, stamp with current revision
app_log.debug("Stamping empty database with alembic revision %s", head)
alembic.command.stamp(cfg, head)
return
if 'alembic_version' not in current_table_names:
# Has not been tagged or upgraded before.
# we didn't start tagging revisions correctly except during `upgrade-db`
# until 0.8
# This should only occur for databases created prior to JupyterHub 0.8
msg_t = "Database schema version not found, guessing that JupyterHub %s created this database."
if 'spawners' in current_table_names:
# 0.8
app_log.warning(msg_t, '0.8.dev')
rev = head
elif 'services' in current_table_names:
# services is present, tag for 0.7
app_log.warning(msg_t, '0.7.x')
rev = 'af4cbdb2d13c'
else:
# it's old, mark as first revision
app_log.warning(msg_t, '0.6 or earlier')
rev = base
app_log.debug("Stamping database schema version %s", rev)
alembic.command.stamp(cfg, rev)
# check database schema version
# it should always be defined at this point
with engine.begin() as connection:
alembic_revision = connection.execute(
text('SELECT version_num FROM alembic_version')
).first()[0]
if alembic_revision == head:
app_log.debug("database schema version found: %s", alembic_revision)
else:
raise DatabaseSchemaMismatch(
f"Found database schema version {alembic_revision} != {head}. "
"Backup your database and run `jupyterhub upgrade-db`"
" to upgrade to the latest schema."
)
def mysql_large_prefix_check(engine):
"""Check mysql has innodb_large_prefix set"""
if not str(engine.url).startswith('mysql'):
return False
with engine.begin() as connection:
variables = dict(
connection.execute(
text(
'show variables where variable_name like '
'"innodb_large_prefix" or '
'variable_name like "innodb_file_format";'
)
).fetchall()
)
if (
variables.get('innodb_file_format', 'Barracuda') == 'Barracuda'
and variables.get('innodb_large_prefix', 'ON') == 'ON'
):
return True
else:
return False
def add_row_format(base):
for t in base.metadata.tables.values():
t.dialect_kwargs['mysql_ROW_FORMAT'] = 'DYNAMIC'
def new_session_factory(
url="sqlite:///:memory:", reset=False, expire_on_commit=False, **kwargs
):
"""Create a new session at url"""
if url.startswith('sqlite'):
kwargs.setdefault('connect_args', {'check_same_thread': False})
elif url.startswith('mysql'):
kwargs.setdefault('pool_recycle', 60)
kwargs.setdefault("future", True)
if url.endswith(':memory:'):
# If we're using an in-memory database, ensure that only one connection
# is ever created.
kwargs.setdefault('poolclass', StaticPool)
engine = create_engine(url, **kwargs)
if url.startswith('sqlite'):
register_foreign_keys(engine)
# enable pessimistic disconnect handling
register_ping_connection(engine)
if reset:
Base.metadata.drop_all(engine)
if mysql_large_prefix_check(engine): # if mysql is allows large indexes
add_row_format(Base) # set format on the tables
# check the db revision (will raise, pointing to `upgrade-db` if version doesn't match)
check_db_revision(engine)
Base.metadata.create_all(engine)
# We set expire_on_commit=False, since we don't actually need
# SQLAlchemy to expire objects after committing - we don't expect
# concurrent runs of the hub talking to the same db. Turning
# this off gives us a major performance boost
session_factory = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
return session_factory
def get_class(resource_name):
"""Translates resource string names to ORM classes"""
class_dict = {
'users': User,
'services': Service,
'tokens': APIToken,
'groups': Group,
}
if resource_name not in class_dict:
raise ValueError(
f'Kind must be one of {", ".join(class_dict)}, not {resource_name}'
)
return class_dict[resource_name]