mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-16 14:33:00 +00:00
move oauth tables to top-level orm
This commit is contained in:
@@ -1,59 +0,0 @@
|
|||||||
"""SQLAlchemy declarations for OAuth2 data stores"""
|
|
||||||
import enum
|
|
||||||
|
|
||||||
from sqlalchemy.types import TypeDecorator, TEXT
|
|
||||||
from sqlalchemy import (
|
|
||||||
inspect,
|
|
||||||
Column, Integer, ForeignKey, Unicode, Boolean,
|
|
||||||
DateTime, Enum,
|
|
||||||
)
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
|
||||||
from sqlalchemy.orm import sessionmaker, relationship, backref
|
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
from sqlalchemy.schema import Index, UniqueConstraint
|
|
||||||
from sqlalchemy.ext.associationproxy import association_proxy
|
|
||||||
from sqlalchemy.sql.expression import bindparam
|
|
||||||
from sqlalchemy import create_engine, Table
|
|
||||||
from ..orm import Base, APIToken, User
|
|
||||||
|
|
||||||
|
|
||||||
class GrantType(enum.Enum):
|
|
||||||
authorization_code = 'authorization_code'
|
|
||||||
implicit = 'implicit'
|
|
||||||
password = 'password'
|
|
||||||
client_credentials = 'client_credentials'
|
|
||||||
refresh_token = 'refresh_token'
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthAccessToken(Base):
|
|
||||||
__tablename__ = 'oauth_access_tokens'
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
|
|
||||||
client_id = Column(Unicode(1023))
|
|
||||||
grant_type = Column(Enum(GrantType), nullable=False)
|
|
||||||
expires_at = Column(Integer)
|
|
||||||
refresh_token = Column(Unicode(36))
|
|
||||||
refresh_expires_at = Column(Integer)
|
|
||||||
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
|
||||||
user = relationship(User)
|
|
||||||
api_token_id = Column(Integer, ForeignKey('api_tokens.id', ondelete='CASCADE'))
|
|
||||||
api_token = relationship(APIToken, backref='oauth_token')
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthCode(Base):
|
|
||||||
__tablename__ = 'oauth_codes'
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
client_id = Column(Unicode(1023))
|
|
||||||
code = Column(Unicode(36))
|
|
||||||
expires_at = Column(Integer)
|
|
||||||
redirect_uri = Column(Unicode(1023))
|
|
||||||
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthClient(Base):
|
|
||||||
__tablename__ = 'oauth_clients'
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
identifier = Column(Unicode(1023), unique=True)
|
|
||||||
secret = Column(Unicode(1023))
|
|
||||||
redirect_uri = Column(Unicode(1023))
|
|
||||||
|
|
@@ -16,8 +16,7 @@ from oauth2.tokengenerator import Uuid4 as UUID4
|
|||||||
from sqlalchemy.orm import scoped_session
|
from sqlalchemy.orm import scoped_session
|
||||||
from tornado.escape import url_escape
|
from tornado.escape import url_escape
|
||||||
|
|
||||||
from ..orm import User
|
from .. import orm
|
||||||
from . import orm
|
|
||||||
from jupyterhub.orm import APIToken
|
from jupyterhub.orm import APIToken
|
||||||
from ..utils import url_path_join, hash_token, compare_token
|
from ..utils import url_path_join, hash_token, compare_token
|
||||||
|
|
||||||
@@ -86,7 +85,7 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user = self.db.query(User).filter(User.id == access_token.user_id).first()
|
user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first()
|
||||||
token = user.new_api_token(access_token.token)
|
token = user.new_api_token(access_token.token)
|
||||||
orm_api_token = APIToken.find(self.db, token, kind='user')
|
orm_api_token = APIToken.find(self.db, token, kind='user')
|
||||||
|
|
||||||
|
@@ -4,6 +4,7 @@
|
|||||||
# Distributed under the terms of the Modified BSD License.
|
# Distributed under the terms of the Modified BSD License.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import enum
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from tornado import gen
|
from tornado import gen
|
||||||
@@ -14,7 +15,7 @@ from sqlalchemy.types import TypeDecorator, TEXT
|
|||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
inspect,
|
inspect,
|
||||||
Column, Integer, ForeignKey, Unicode, Boolean,
|
Column, Integer, ForeignKey, Unicode, Boolean,
|
||||||
DateTime,
|
DateTime, Enum
|
||||||
)
|
)
|
||||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||||
from sqlalchemy.orm import sessionmaker, relationship, backref
|
from sqlalchemy.orm import sessionmaker, relationship, backref
|
||||||
@@ -609,6 +610,53 @@ class APIToken(Base):
|
|||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
#------------------------------------
|
||||||
|
# OAuth tables
|
||||||
|
#------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class GrantType(enum.Enum):
|
||||||
|
authorization_code = 'authorization_code'
|
||||||
|
implicit = 'implicit'
|
||||||
|
password = 'password'
|
||||||
|
client_credentials = 'client_credentials'
|
||||||
|
refresh_token = 'refresh_token'
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccessToken(Base):
|
||||||
|
__tablename__ = 'oauth_access_tokens'
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
|
||||||
|
client_id = Column(Unicode(1023))
|
||||||
|
grant_type = Column(Enum(GrantType), nullable=False)
|
||||||
|
expires_at = Column(Integer)
|
||||||
|
refresh_token = Column(Unicode(36))
|
||||||
|
refresh_expires_at = Column(Integer)
|
||||||
|
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
||||||
|
user = relationship(User)
|
||||||
|
api_token_id = Column(Integer, ForeignKey('api_tokens.id', ondelete='CASCADE'))
|
||||||
|
api_token = relationship(APIToken, backref='oauth_token')
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCode(Base):
|
||||||
|
__tablename__ = 'oauth_codes'
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
client_id = Column(Unicode(1023))
|
||||||
|
code = Column(Unicode(36))
|
||||||
|
expires_at = Column(Integer)
|
||||||
|
redirect_uri = Column(Unicode(1023))
|
||||||
|
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClient(Base):
|
||||||
|
__tablename__ = 'oauth_clients'
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
identifier = Column(Unicode(1023), unique=True)
|
||||||
|
secret = Column(Unicode(1023))
|
||||||
|
redirect_uri = Column(Unicode(1023))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
|
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
|
||||||
"""Create a new session at url"""
|
"""Create a new session at url"""
|
||||||
if url.startswith('sqlite'):
|
if url.startswith('sqlite'):
|
||||||
|
Reference in New Issue
Block a user