mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-11 20:13:02 +00:00
support oauth in services
fix bugs caught by tests
This commit is contained in:
@@ -1041,6 +1041,7 @@ class JupyterHub(Application):
|
|||||||
host = '%s://services.%s' % (parsed.scheme, parsed.netloc)
|
host = '%s://services.%s' % (parsed.scheme, parsed.netloc)
|
||||||
else:
|
else:
|
||||||
domain = host = ''
|
domain = host = ''
|
||||||
|
client_store = self.oauth_provider.client_authenticator.client_store
|
||||||
for spec in self.services:
|
for spec in self.services:
|
||||||
if 'name' not in spec:
|
if 'name' not in spec:
|
||||||
raise ValueError('service spec must have a name: %r' % spec)
|
raise ValueError('service spec must have a name: %r' % spec)
|
||||||
@@ -1082,6 +1083,12 @@ class JupyterHub(Application):
|
|||||||
base_url=service.prefix,
|
base_url=service.prefix,
|
||||||
)
|
)
|
||||||
self.db.add(server)
|
self.db.add(server)
|
||||||
|
|
||||||
|
client_store.add_client(
|
||||||
|
client_id=service.oauth_client_id,
|
||||||
|
client_secret=service.oauth_client_secret,
|
||||||
|
redirect_uri=host + url_path_join(service.prefix, 'oauth_callback'),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
service.orm.server = None
|
service.orm.server = None
|
||||||
|
|
||||||
@@ -1378,11 +1385,11 @@ class JupyterHub(Application):
|
|||||||
self.init_db()
|
self.init_db()
|
||||||
self.init_hub()
|
self.init_hub()
|
||||||
self.init_proxy()
|
self.init_proxy()
|
||||||
|
self.init_oauth()
|
||||||
yield self.init_users()
|
yield self.init_users()
|
||||||
yield self.init_groups()
|
yield self.init_groups()
|
||||||
self.init_services()
|
self.init_services()
|
||||||
yield self.init_api_tokens()
|
yield self.init_api_tokens()
|
||||||
self.init_oauth()
|
|
||||||
self.init_tornado_settings()
|
self.init_tornado_settings()
|
||||||
yield self.init_spawners()
|
yield self.init_spawners()
|
||||||
self.init_handlers()
|
self.init_handlers()
|
||||||
|
@@ -616,6 +616,7 @@ class APIToken(Base):
|
|||||||
|
|
||||||
|
|
||||||
class GrantType(enum.Enum):
|
class GrantType(enum.Enum):
|
||||||
|
# we only use authorization_code for now
|
||||||
authorization_code = 'authorization_code'
|
authorization_code = 'authorization_code'
|
||||||
implicit = 'implicit'
|
implicit = 'implicit'
|
||||||
password = 'password'
|
password = 'password'
|
||||||
@@ -656,7 +657,6 @@ class OAuthClient(Base):
|
|||||||
redirect_uri = Column(Unicode(1023))
|
redirect_uri = Column(Unicode(1023))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
|
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
|
||||||
"""Create a new session at url"""
|
"""Create a new session at url"""
|
||||||
if url.startswith('sqlite'):
|
if url.startswith('sqlite'):
|
||||||
|
@@ -643,11 +643,11 @@ class JupyterHubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
|
|||||||
if not code:
|
if not code:
|
||||||
raise HTTPError(400, "oauth callback made without a token")
|
raise HTTPError(400, "oauth callback made without a token")
|
||||||
# TODO: make async (in a Thread?)
|
# TODO: make async (in a Thread?)
|
||||||
token_reply = self.hub_auth.token_for_code(code)
|
token = self.hub_auth.token_for_code(code)
|
||||||
user_model = self.hub_auth.user_for_token(token)
|
user_model = self.hub_auth.user_for_token(token)
|
||||||
self.log.info("Logged-in user %s", user_model)
|
app_log.info("Logged-in user %s", user_model)
|
||||||
self.hub_auth.set_cookie(self, user_model)
|
self.hub_auth.set_cookie(self, user_model)
|
||||||
next_url = self.get_argument('next', '') or self.base_url
|
next_url = self.get_argument('next', '') or self.hub_auth.base_url
|
||||||
self.redirect(next_url)
|
self.redirect(next_url)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -57,7 +57,7 @@ from traitlets.config import LoggingConfigurable
|
|||||||
from .. import orm
|
from .. import orm
|
||||||
from ..traitlets import Command
|
from ..traitlets import Command
|
||||||
from ..spawner import LocalProcessSpawner, set_user_setuid
|
from ..spawner import LocalProcessSpawner, set_user_setuid
|
||||||
from ..utils import url_path_join
|
from ..utils import url_path_join, new_token
|
||||||
|
|
||||||
class _MockUser(HasTraits):
|
class _MockUser(HasTraits):
|
||||||
name = Unicode()
|
name = Unicode()
|
||||||
@@ -198,6 +198,30 @@ class Service(LoggingConfigurable):
|
|||||||
db = Any()
|
db = Any()
|
||||||
orm = Any()
|
orm = Any()
|
||||||
|
|
||||||
|
oauth_provider = Any()
|
||||||
|
|
||||||
|
oauth_client_id = Unicode(
|
||||||
|
help="""OAuth client ID for this service.
|
||||||
|
|
||||||
|
You shouldn't generally need to change this.
|
||||||
|
Default: `service-<name>`
|
||||||
|
"""
|
||||||
|
).tag(input=True)
|
||||||
|
@default('oauth_client_id')
|
||||||
|
def _default_client_id(self):
|
||||||
|
return 'service-%s' % self.name
|
||||||
|
|
||||||
|
oauth_client_secret = Unicode(
|
||||||
|
help="""OAuth client secret for this service.
|
||||||
|
|
||||||
|
Default: Generated on each launch.
|
||||||
|
"""
|
||||||
|
).tag(input=True)
|
||||||
|
@default('oauth_client_secret')
|
||||||
|
def _default_client_secret(self):
|
||||||
|
self.log.debug("Generating new OAuth secret for service %s", self.name)
|
||||||
|
return new_token()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def server(self):
|
def server(self):
|
||||||
return self.orm.server
|
return self.orm.server
|
||||||
@@ -242,6 +266,8 @@ class Service(LoggingConfigurable):
|
|||||||
cmd=self.command,
|
cmd=self.command,
|
||||||
environment=env,
|
environment=env,
|
||||||
api_token=self.api_token,
|
api_token=self.api_token,
|
||||||
|
oauth_client_id=self.oauth_client_id,
|
||||||
|
oauth_client_secret=self.oauth_client_secret,
|
||||||
cwd=self.cwd,
|
cwd=self.cwd,
|
||||||
user=_MockUser(
|
user=_MockUser(
|
||||||
name=self.user,
|
name=self.user,
|
||||||
|
@@ -11,13 +11,15 @@ Handlers allow:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import pprint
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from tornado import web, httpserver, ioloop
|
from tornado import web, httpserver, ioloop
|
||||||
|
|
||||||
from jupyterhub.services.auth import HubAuthenticated
|
from jupyterhub.services.auth import HubAuthenticated, HubOAuthenticated, JupyterHubOAuthCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
class EchoHandler(web.RequestHandler):
|
class EchoHandler(web.RequestHandler):
|
||||||
@@ -47,21 +49,37 @@ class APIHandler(web.RequestHandler):
|
|||||||
|
|
||||||
|
|
||||||
class WhoAmIHandler(HubAuthenticated, web.RequestHandler):
|
class WhoAmIHandler(HubAuthenticated, web.RequestHandler):
|
||||||
"""Reply with the name of the user who made the request."""
|
"""Reply with the name of the user who made the request.
|
||||||
|
|
||||||
|
Uses deprecated cookie login
|
||||||
|
"""
|
||||||
|
@web.authenticated
|
||||||
|
def get(self):
|
||||||
|
self.write(self.get_current_user())
|
||||||
|
|
||||||
|
class OWhoAmIHandler(HubOAuthenticated, web.RequestHandler):
|
||||||
|
"""Reply with the name of the user who made the request.
|
||||||
|
|
||||||
|
Uses OAuth login flow
|
||||||
|
"""
|
||||||
@web.authenticated
|
@web.authenticated
|
||||||
def get(self):
|
def get(self):
|
||||||
self.write(self.get_current_user())
|
self.write(self.get_current_user())
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
pprint.pprint(dict(os.environ), stream=sys.stderr)
|
||||||
|
|
||||||
if os.environ['JUPYTERHUB_SERVICE_URL']:
|
if os.environ['JUPYTERHUB_SERVICE_URL']:
|
||||||
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
|
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
|
||||||
app = web.Application([
|
app = web.Application([
|
||||||
(r'.*/env', EnvHandler),
|
(r'.*/env', EnvHandler),
|
||||||
(r'.*/api/(.*)', APIHandler),
|
(r'.*/api/(.*)', APIHandler),
|
||||||
(r'.*/whoami/?', WhoAmIHandler),
|
(r'.*/whoami/?', WhoAmIHandler),
|
||||||
|
(r'.*/owhoami/?', OWhoAmIHandler),
|
||||||
|
(r'.*/oauth_callback', JupyterHubOAuthCallbackHandler),
|
||||||
(r'.*', EchoHandler),
|
(r'.*', EchoHandler),
|
||||||
])
|
], cookie_secret=os.urandom(32))
|
||||||
|
|
||||||
server = httpserver.HTTPServer(app)
|
server = httpserver.HTTPServer(app)
|
||||||
server.listen(url.port, url.hostname)
|
server.listen(url.port, url.hostname)
|
||||||
@@ -70,6 +88,7 @@ def main():
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print('\nInterrupted')
|
print('\nInterrupted')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from tornado.options import parse_command_line
|
from tornado.options import parse_command_line
|
||||||
parse_command_line()
|
parse_command_line()
|
||||||
|
@@ -18,7 +18,7 @@ from tornado.web import RequestHandler, Application, authenticated, HTTPError
|
|||||||
|
|
||||||
from ..services.auth import _ExpiringDict, HubAuth, HubAuthenticated
|
from ..services.auth import _ExpiringDict, HubAuth, HubAuthenticated
|
||||||
from ..utils import url_path_join
|
from ..utils import url_path_join
|
||||||
from .mocking import public_url
|
from .mocking import public_url, public_host
|
||||||
from .test_api import add_user
|
from .test_api import add_user
|
||||||
|
|
||||||
# mock for sending monotonic counter way into the future
|
# mock for sending monotonic counter way into the future
|
||||||
@@ -244,7 +244,6 @@ def test_hubauth_token(app, mockservice_url):
|
|||||||
headers={
|
headers={
|
||||||
'Authorization': 'token %s' % token,
|
'Authorization': 'token %s' % token,
|
||||||
})
|
})
|
||||||
r.raise_for_status()
|
|
||||||
reply = r.json()
|
reply = r.json()
|
||||||
sub_reply = { key: reply.get(key, 'missing') for key in ['name', 'admin']}
|
sub_reply = { key: reply.get(key, 'missing') for key in ['name', 'admin']}
|
||||||
assert sub_reply == {
|
assert sub_reply == {
|
||||||
@@ -312,3 +311,23 @@ def test_hubauth_service_token(app, mockservice_url, io_loop):
|
|||||||
path = urlparse(location).path
|
path = urlparse(location).path
|
||||||
assert path.endswith('/hub/login')
|
assert path.endswith('/hub/login')
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_service(app, mockservice_url):
|
||||||
|
url = url_path_join(public_url(app, mockservice_url) + 'owhoami/')
|
||||||
|
# first request is only going to set login cookie
|
||||||
|
# FIXME: redirect to originating URL (OAuth loses this info)
|
||||||
|
s = requests.Session()
|
||||||
|
s.cookies = app.login_user('link')
|
||||||
|
r = s.get(url)
|
||||||
|
r.raise_for_status()
|
||||||
|
# second request should be authenticated
|
||||||
|
r = s.get(url, allow_redirects=False)
|
||||||
|
r.raise_for_status()
|
||||||
|
assert r.status_code == 200
|
||||||
|
reply = r.json()
|
||||||
|
sub_reply = { key:reply.get(key, 'missing') for key in ('kind', 'name') }
|
||||||
|
assert sub_reply == {
|
||||||
|
'name': 'link',
|
||||||
|
'kind': 'user',
|
||||||
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user