Escape named servers when used in URL paths

This commit is contained in:
Simon Li
2022-05-22 23:27:31 +01:00
parent b18a05c2c8
commit 84cb9761e8
3 changed files with 27 additions and 9 deletions

View File

@@ -17,7 +17,13 @@ from .crypto import CryptKeeper, EncryptionUnavailable, InvalidToken, decrypt, e
from .metrics import RUNNING_SERVERS, TOTAL_USERS
from .objects import Server
from .spawner import LocalProcessSpawner
from .utils import AnyTimeoutError, make_ssl_context, maybe_future, url_path_join
from .utils import (
AnyTimeoutError,
make_ssl_context,
maybe_future,
url_escape_path,
url_path_join,
)
# detailed messages about the most common failure-to-start errors,
# which manifest timeouts during start
@@ -410,7 +416,9 @@ class User:
hub=self.settings.get('hub'),
authenticator=self.authenticator,
config=self.settings.get('config'),
proxy_spec=url_path_join(self.proxy_spec, server_name, '/'),
proxy_spec=url_path_join(
self.proxy_spec, url_escape_path(server_name), '/'
),
_deprecated_db_session=self.db,
oauth_client_id=client_id,
cookie_options=self.settings.get('cookie_options', {}),
@@ -494,7 +502,7 @@ class User:
@property
def escaped_name(self):
"""My name, escaped for use in URLs, cookies, etc."""
return quote(self.name, safe='@~')
return url_escape_path(self.name)
@property
def json_escaped_name(self):
@@ -543,13 +551,13 @@ class User:
if not server_name:
return self.url
else:
return url_path_join(self.url, server_name)
return url_path_join(self.url, url_escape_path(server_name))
def progress_url(self, server_name=''):
"""API URL for progress endpoint for a server with a given name"""
url_parts = [self.settings['hub'].base_url, 'api/users', self.escaped_name]
if server_name:
url_parts.extend(['servers', server_name, 'progress'])
url_parts.extend(['servers', url_escape_path(server_name), 'progress'])
else:
url_parts.extend(['server/progress'])
return url_path_join(*url_parts)
@@ -623,7 +631,7 @@ class User:
if handler:
await self.refresh_auth(handler)
base_url = url_path_join(self.base_url, server_name) + '/'
base_url = url_path_join(self.base_url, url_escape_path(server_name)) + '/'
orm_server = orm.Server(base_url=base_url)
db.add(orm_server)
@@ -678,7 +686,8 @@ class User:
oauth_client = oauth_provider.add_client(
client_id,
api_token,
url_path_join(self.url, server_name, 'oauth_callback'),
url_path_join(self.url, url_escape_path(server_name), 'oauth_callback'),
# url_path_join(self.url, server_name, 'oauth_callback'),
allowed_roles=allowed_roles,
description="Server at %s"
% (url_path_join(self.base_url, server_name) + '/'),
@@ -785,7 +794,9 @@ class User:
oauth_provider.add_client(
client_id,
spawner.api_token,
url_path_join(self.url, server_name, 'oauth_callback'),
url_path_join(
self.url, url_escape_path(server_name), 'oauth_callback'
),
)
db.commit()