diff --git a/jupyterhub/oauth/orm.py b/jupyterhub/oauth/orm.py deleted file mode 100644 index f55ed726..00000000 --- a/jupyterhub/oauth/orm.py +++ /dev/null @@ -1,59 +0,0 @@ -"""SQLAlchemy declarations for OAuth2 data stores""" -import enum - -from sqlalchemy.types import TypeDecorator, TEXT -from sqlalchemy import ( - inspect, - Column, Integer, ForeignKey, Unicode, Boolean, - DateTime, Enum, -) -from sqlalchemy.ext.declarative import declarative_base, declared_attr -from sqlalchemy.orm import sessionmaker, relationship, backref -from sqlalchemy.pool import StaticPool -from sqlalchemy.schema import Index, UniqueConstraint -from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.sql.expression import bindparam -from sqlalchemy import create_engine, Table -from ..orm import Base, APIToken, User - - -class GrantType(enum.Enum): - authorization_code = 'authorization_code' - implicit = 'implicit' - password = 'password' - client_credentials = 'client_credentials' - refresh_token = 'refresh_token' - - -class OAuthAccessToken(Base): - __tablename__ = 'oauth_access_tokens' - id = Column(Integer, primary_key=True, autoincrement=True) - - client_id = Column(Unicode(1023)) - grant_type = Column(Enum(GrantType), nullable=False) - expires_at = Column(Integer) - refresh_token = Column(Unicode(36)) - refresh_expires_at = Column(Integer) - user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) - user = relationship(User) - api_token_id = Column(Integer, ForeignKey('api_tokens.id', ondelete='CASCADE')) - api_token = relationship(APIToken, backref='oauth_token') - - -class OAuthCode(Base): - __tablename__ = 'oauth_codes' - id = Column(Integer, primary_key=True, autoincrement=True) - client_id = Column(Unicode(1023)) - code = Column(Unicode(36)) - expires_at = Column(Integer) - redirect_uri = Column(Unicode(1023)) - user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) - - -class OAuthClient(Base): - __tablename__ = 'oauth_clients' - id = Column(Integer, primary_key=True, autoincrement=True) - identifier = Column(Unicode(1023), unique=True) - secret = Column(Unicode(1023)) - redirect_uri = Column(Unicode(1023)) - diff --git a/jupyterhub/oauth/store.py b/jupyterhub/oauth/store.py index c2c4f6f9..d7221ee4 100644 --- a/jupyterhub/oauth/store.py +++ b/jupyterhub/oauth/store.py @@ -16,8 +16,7 @@ from oauth2.tokengenerator import Uuid4 as UUID4 from sqlalchemy.orm import scoped_session from tornado.escape import url_escape -from ..orm import User -from . import orm +from .. import orm from jupyterhub.orm import APIToken from ..utils import url_path_join, hash_token, compare_token @@ -86,7 +85,7 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore): """ - user = self.db.query(User).filter(User.id == access_token.user_id).first() + user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first() token = user.new_api_token(access_token.token) orm_api_token = APIToken.find(self.db, token, kind='user') diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 3576ac19..9c02e830 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -4,6 +4,7 @@ # Distributed under the terms of the Modified BSD License. from datetime import datetime +import enum import json from tornado import gen @@ -14,7 +15,7 @@ from sqlalchemy.types import TypeDecorator, TEXT from sqlalchemy import ( inspect, Column, Integer, ForeignKey, Unicode, Boolean, - DateTime, + DateTime, Enum ) from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.orm import sessionmaker, relationship, backref @@ -609,6 +610,53 @@ class APIToken(Base): return token +#------------------------------------ +# OAuth tables +#------------------------------------ + + +class GrantType(enum.Enum): + authorization_code = 'authorization_code' + implicit = 'implicit' + password = 'password' + client_credentials = 'client_credentials' + refresh_token = 'refresh_token' + + +class OAuthAccessToken(Base): + __tablename__ = 'oauth_access_tokens' + id = Column(Integer, primary_key=True, autoincrement=True) + + client_id = Column(Unicode(1023)) + grant_type = Column(Enum(GrantType), nullable=False) + expires_at = Column(Integer) + refresh_token = Column(Unicode(36)) + refresh_expires_at = Column(Integer) + user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) + user = relationship(User) + api_token_id = Column(Integer, ForeignKey('api_tokens.id', ondelete='CASCADE')) + api_token = relationship(APIToken, backref='oauth_token') + + +class OAuthCode(Base): + __tablename__ = 'oauth_codes' + id = Column(Integer, primary_key=True, autoincrement=True) + client_id = Column(Unicode(1023)) + code = Column(Unicode(36)) + expires_at = Column(Integer) + redirect_uri = Column(Unicode(1023)) + user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) + + +class OAuthClient(Base): + __tablename__ = 'oauth_clients' + id = Column(Integer, primary_key=True, autoincrement=True) + identifier = Column(Unicode(1023), unique=True) + secret = Column(Unicode(1023)) + redirect_uri = Column(Unicode(1023)) + + + def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs): """Create a new session at url""" if url.startswith('sqlite'):