remove separate oauth tokens

- merge oauth token fields into APITokens
- create oauth client 'jupyterhub' which owns current API tokens
- db upgrade is currently to drop both token tables, and force recreation on next start
This commit is contained in:
Min RK
2021-03-10 16:41:37 +01:00
parent 2fdf820fe5
commit 0b56fd9e62
15 changed files with 264 additions and 202 deletions

View File

@@ -277,9 +277,6 @@ class User(Base):
last_activity = Column(DateTime, nullable=True)
api_tokens = relationship("APIToken", backref="user", cascade="all, delete-orphan")
oauth_tokens = relationship(
"OAuthAccessToken", backref="user", cascade="all, delete-orphan"
)
oauth_codes = relationship(
"OAuthCode", backref="user", cascade="all, delete-orphan"
)
@@ -485,7 +482,9 @@ class Hashed(Expiring):
@classmethod
def check_token(cls, db, token):
"""Check if a token is acceptable"""
print("checking", cls, token, len(token), cls.min_length)
if len(token) < cls.min_length:
print("raising")
raise ValueError(
"Tokens must be at least %i characters, got %r"
% (cls.min_length, token)
@@ -530,6 +529,20 @@ class Hashed(Expiring):
return orm_token
# ------------------------------------
# OAuth tables
# ------------------------------------
class GrantType(enum.Enum):
# we only use authorization_code for now
authorization_code = 'authorization_code'
implicit = 'implicit'
password = 'password'
client_credentials = 'client_credentials'
refresh_token = 'refresh_token'
class APIToken(Hashed, Base):
"""An API token"""
@@ -548,6 +561,15 @@ class APIToken(Hashed, Base):
def api_id(self):
return 'a%i' % self.id
# added in 2.0
client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
)
grant_type = Column(Enum(GrantType), nullable=False)
refresh_token = Column(Unicode(255))
# the browser session id associated with a given token
session_id = Column(Unicode(255))
# token metadata for bookkeeping
now = datetime.utcnow # for expiry
created = Column(DateTime, default=datetime.utcnow)
@@ -566,8 +588,12 @@ class APIToken(Hashed, Base):
# this shouldn't happen
kind = 'owner'
name = 'unknown'
return "<{cls}('{pre}...', {kind}='{name}')>".format(
cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name
return "<{cls}('{pre}...', {kind}='{name}', client_id={client_id!r})>".format(
cls=self.__class__.__name__,
pre=self.prefix,
kind=kind,
name=name,
client_id=self.client_id,
)
@classmethod
@@ -588,6 +614,14 @@ class APIToken(Hashed, Base):
raise ValueError("kind must be 'user', 'service', or None, not %r" % kind)
for orm_token in prefix_match:
if orm_token.match(token):
if not orm_token.client_id:
app_log.warning(
"Deleting stale oauth token for %s with no client",
orm_token.user and orm_token.user.name,
)
db.delete(orm_token)
db.commit()
return
return orm_token
@classmethod
@@ -600,6 +634,7 @@ class APIToken(Hashed, Base):
note='',
generated=True,
expires_in=None,
client_id='jupyterhub',
):
"""Generate a new API token for a user or service"""
assert user or service
@@ -614,7 +649,12 @@ class APIToken(Hashed, Base):
cls.check_token(db, token)
# two stages to ensure orm_token.generated has been set
# before token setter is called
orm_token = cls(generated=generated, note=note or '')
orm_token = cls(
generated=generated,
note=note or '',
grant_type=GrantType.authorization_code,
client_id=client_id,
)
orm_token.token = token
if user:
assert user.id is not None
@@ -641,76 +681,6 @@ class APIToken(Hashed, Base):
return token
# ------------------------------------
# OAuth tables
# ------------------------------------
class GrantType(enum.Enum):
# we only use authorization_code for now
authorization_code = 'authorization_code'
implicit = 'implicit'
password = 'password'
client_credentials = 'client_credentials'
refresh_token = 'refresh_token'
class OAuthAccessToken(Hashed, Base):
__tablename__ = 'oauth_access_tokens'
id = Column(Integer, primary_key=True, autoincrement=True)
@staticmethod
def now():
return datetime.utcnow().timestamp()
@property
def api_id(self):
return 'o%i' % self.id
client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
)
grant_type = Column(Enum(GrantType), nullable=False)
expires_at = Column(Integer)
refresh_token = Column(Unicode(255))
# TODO: drop refresh_expires_at. Refresh tokens shouldn't expire
refresh_expires_at = Column(Integer)
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
service = None # for API-equivalence with APIToken
# the browser session id associated with a given token
session_id = Column(Unicode(255))
# from Hashed
hashed = Column(Unicode(255), unique=True)
prefix = Column(Unicode(16), index=True)
created = Column(DateTime, default=datetime.utcnow)
last_activity = Column(DateTime, nullable=True)
def __repr__(self):
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}, expires_in={expires_in}>".format(
cls=self.__class__.__name__,
client_id=self.client_id,
user=self.user and self.user.name,
prefix=self.prefix,
expires_in=self.expires_in,
)
@classmethod
def find(cls, db, token):
orm_token = super().find(db, token)
if orm_token and not orm_token.client_id:
app_log.warning(
"Deleting stale oauth token for %s with no client",
orm_token.user and orm_token.user.name,
)
db.delete(orm_token)
db.commit()
return
return orm_token
class OAuthCode(Expiring, Base):
__tablename__ = 'oauth_codes'
@@ -752,7 +722,7 @@ class OAuthClient(Base):
return self.identifier
access_tokens = relationship(
OAuthAccessToken, backref='client', cascade='all, delete-orphan'
APIToken, backref='client', cascade='all, delete-orphan'
)
codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan')