move oauth tables to top-level orm

This commit is contained in:
Min RK
2017-02-06 10:45:55 +01:00
parent d0eb4e0946
commit 453d1daf8b
3 changed files with 51 additions and 63 deletions

View File

@@ -1,59 +0,0 @@
"""SQLAlchemy declarations for OAuth2 data stores"""
import enum
from sqlalchemy.types import TypeDecorator, TEXT
from sqlalchemy import (
inspect,
Column, Integer, ForeignKey, Unicode, Boolean,
DateTime, Enum,
)
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import sessionmaker, relationship, backref
from sqlalchemy.pool import StaticPool
from sqlalchemy.schema import Index, UniqueConstraint
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.sql.expression import bindparam
from sqlalchemy import create_engine, Table
from ..orm import Base, APIToken, User
class GrantType(enum.Enum):
authorization_code = 'authorization_code'
implicit = 'implicit'
password = 'password'
client_credentials = 'client_credentials'
refresh_token = 'refresh_token'
class OAuthAccessToken(Base):
__tablename__ = 'oauth_access_tokens'
id = Column(Integer, primary_key=True, autoincrement=True)
client_id = Column(Unicode(1023))
grant_type = Column(Enum(GrantType), nullable=False)
expires_at = Column(Integer)
refresh_token = Column(Unicode(36))
refresh_expires_at = Column(Integer)
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
user = relationship(User)
api_token_id = Column(Integer, ForeignKey('api_tokens.id', ondelete='CASCADE'))
api_token = relationship(APIToken, backref='oauth_token')
class OAuthCode(Base):
__tablename__ = 'oauth_codes'
id = Column(Integer, primary_key=True, autoincrement=True)
client_id = Column(Unicode(1023))
code = Column(Unicode(36))
expires_at = Column(Integer)
redirect_uri = Column(Unicode(1023))
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
class OAuthClient(Base):
__tablename__ = 'oauth_clients'
id = Column(Integer, primary_key=True, autoincrement=True)
identifier = Column(Unicode(1023), unique=True)
secret = Column(Unicode(1023))
redirect_uri = Column(Unicode(1023))

View File

@@ -16,8 +16,7 @@ from oauth2.tokengenerator import Uuid4 as UUID4
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import scoped_session
from tornado.escape import url_escape from tornado.escape import url_escape
from ..orm import User from .. import orm
from . import orm
from jupyterhub.orm import APIToken from jupyterhub.orm import APIToken
from ..utils import url_path_join, hash_token, compare_token from ..utils import url_path_join, hash_token, compare_token
@@ -86,7 +85,7 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
""" """
user = self.db.query(User).filter(User.id == access_token.user_id).first() user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first()
token = user.new_api_token(access_token.token) token = user.new_api_token(access_token.token)
orm_api_token = APIToken.find(self.db, token, kind='user') orm_api_token = APIToken.find(self.db, token, kind='user')

View File

@@ -4,6 +4,7 @@
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from datetime import datetime from datetime import datetime
import enum
import json import json
from tornado import gen from tornado import gen
@@ -14,7 +15,7 @@ from sqlalchemy.types import TypeDecorator, TEXT
from sqlalchemy import ( from sqlalchemy import (
inspect, inspect,
Column, Integer, ForeignKey, Unicode, Boolean, Column, Integer, ForeignKey, Unicode, Boolean,
DateTime, DateTime, Enum
) )
from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import sessionmaker, relationship, backref from sqlalchemy.orm import sessionmaker, relationship, backref
@@ -609,6 +610,53 @@ class APIToken(Base):
return token return token
#------------------------------------
# OAuth tables
#------------------------------------
class GrantType(enum.Enum):
authorization_code = 'authorization_code'
implicit = 'implicit'
password = 'password'
client_credentials = 'client_credentials'
refresh_token = 'refresh_token'
class OAuthAccessToken(Base):
__tablename__ = 'oauth_access_tokens'
id = Column(Integer, primary_key=True, autoincrement=True)
client_id = Column(Unicode(1023))
grant_type = Column(Enum(GrantType), nullable=False)
expires_at = Column(Integer)
refresh_token = Column(Unicode(36))
refresh_expires_at = Column(Integer)
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
user = relationship(User)
api_token_id = Column(Integer, ForeignKey('api_tokens.id', ondelete='CASCADE'))
api_token = relationship(APIToken, backref='oauth_token')
class OAuthCode(Base):
__tablename__ = 'oauth_codes'
id = Column(Integer, primary_key=True, autoincrement=True)
client_id = Column(Unicode(1023))
code = Column(Unicode(36))
expires_at = Column(Integer)
redirect_uri = Column(Unicode(1023))
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
class OAuthClient(Base):
__tablename__ = 'oauth_clients'
id = Column(Integer, primary_key=True, autoincrement=True)
identifier = Column(Unicode(1023), unique=True)
secret = Column(Unicode(1023))
redirect_uri = Column(Unicode(1023))
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs): def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
"""Create a new session at url""" """Create a new session at url"""
if url.startswith('sqlite'): if url.startswith('sqlite'):