diff --git a/dev-requirements.txt b/dev-requirements.txt index b5b9d73c..cc9bd8b7 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,7 +1,7 @@ -r requirements.txt mock codecov -cryptography +privy pytest-cov pytest-tornado pytest>=2.8 diff --git a/jupyterhub/crypto.py b/jupyterhub/crypto.py new file mode 100644 index 00000000..7c9f8dc4 --- /dev/null +++ b/jupyterhub/crypto.py @@ -0,0 +1,123 @@ + +from concurrent.futures import ThreadPoolExecutor +import json +import os + +from traitlets.config import SingletonConfigurable, Config +from traitlets import Any, Dict, Integer, List, default, validate + +try: + import privy +except ImportError: + privy = None + + +KEY_ENV = 'JUPYTERHUB_CRYPT_KEY' + +class EncryptionUnavailable(Exception): + pass + +class PrivyUnavailable(EncryptionUnavailable): + def __str__(self): + return "privy library is required for encryption" + +class NoEncryptionKeys(EncryptionUnavailable): + def __str__(self): + return "Encryption keys must be specified in %s env" % KEY_ENV + +class CryptKeeper(SingletonConfigurable): + """Encapsulate encryption configuration + + Use via the encryption_config singleton below. + """ + + privy_kwargs = Dict({'server': True}, + help="""Keyword arguments to pass to privy.hide. + + For example, to + """ + ) + + n_threads = Integer(max(os.cpu_count(), 1), + help="The number of threads to allocate for encryption", + config=True, + ) + + @default('config') + def _config_default(self): + # load application config by default + from .app import JupyterHub + if JupyterHub.initialized(): + return JupyterHub.instance().config + else: + return Config() + + executor = Any() + def _executor_default(self): + return ThreadPoolExecutor(self.n_threads) + + keys = List(config=True) + def _keys_default(self): + if KEY_ENV not in os.environ: + return [] + # key can be a ;-separated sequence for key rotation. + # First item in the list is used for encryption. + return [ k.encode('ascii') for k in os.environ[KEY_ENV].split(';') if k.strip() ] + + @validate('keys') + def _ensure_bytes(self, proposal): + # cast str to bytes + return [ (k.encode('ascii') if isinstance(k, str) else k) for k in proposal.value ] + + def check_available(self): + if privy is None: + raise PrivyUnavailable() + if not self.keys: + raise NoEncryptionKeys() + + def _encrypt(self, data): + """Actually do the encryption. Runs in a background thread. + + data is serialized to bytes with pickle. + bytes are returned. + """ + return privy.hide(json.dumps(data).encode('utf8'), self.keys[0], **self.privy_kwargs).encode('ascii') + + def encrypt(self, data): + """Encrypt an object with privy""" + self.check_available() + return self.executor.submit(self._encrypt, data) + + def _decrypt(self, encrypted): + for key in self.keys: + try: + decrypted = privy.peek(encrypted, key) + except ValueError as e: + continue + else: + break + else: + raise ValueError("Failed to decrypt %r" % encrypted) + return json.loads(decrypted.decode('utf8')) + + def decrypt(self, encrypted): + """Decrypt an object with privy""" + self.check_available() + return self.executor.submit(self._decrypt, encrypted) + + +def encrypt(data): + """encrypt some data with the crypt keeper. + + data will be serialized with pickle. + Returns a Future whose result will be bytes. + """ + return CryptKeeper.instance().encrypt(data) + +def decrypt(data): + """decrypt some data with the crypt keeper + + Returns a Future whose result will be the decrypted, deserialized data. + """ + return CryptKeeper.instance().decrypt(data) + \ No newline at end of file diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index 3655a6be..75fb161e 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -334,7 +334,7 @@ class BaseHandler(RequestHandler): # always set auth_state and commit, # because there could be key-rotation or clearing of previous values # going on. - user.auth_state = auth_state + yield user.save_auth_state(auth_state) self.db.commit() self.set_login_cookie(user) self.statsd.incr('login.success') diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 37b03db9..66c98330 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -3,20 +3,13 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -import base64 from datetime import datetime import enum -import os import json -try: - import cryptography -except ImportError: - cryptography = None - from tornado.log import app_log -from sqlalchemy.types import TypeDecorator, TEXT +from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary from sqlalchemy import ( inspect, Column, Integer, ForeignKey, Unicode, Boolean, @@ -26,11 +19,8 @@ from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.pool import StaticPool from sqlalchemy.sql.expression import bindparam -from sqlalchemy_utils.types.encrypted import EncryptedType, FernetEngine from sqlalchemy import create_engine, Table -from traitlets import HasTraits, List - from .utils import ( random_port, new_token, hash_token, compare_token, @@ -42,7 +32,7 @@ class JSONDict(TypeDecorator): Usage:: - JSONDict(255) + JSONEncodedDict(255) """ @@ -60,104 +50,6 @@ class JSONDict(TypeDecorator): return value -def _fernet_key(key): - """Generate a Fernet key from a secret - - Fernet keys are 32 bytes encoded in url-safe base64 (44 characters). - - If a given key is not already a fernet key, - it will be passed through HKDF to generate the 32 bytes. - """ - from cryptography.hazmat.primitives import hashes - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives.kdf.hkdf import HKDF - if isinstance(key, str): - key = key.encode() - if len(key) == 44: - # already a fernet key, pass it along - try: - base64.urlsafe_b64decode(key) - except Exception: - pass - else: - return key - elif len(key) != 32: - # not the right size, pass through HKDF - kdf = HKDF( - algorithm=hashes.SHA256(), - length=32, - salt=None, - info=b'jupyterhub auth state', - backend=default_backend(), - ) - key = kdf.derive(key) - return base64.urlsafe_b64encode(key) - - -class MultiFernetEngine(FernetEngine): - """Extend SQLAlchemy-Utils FernetEngine to use MultiFernet, - - which supports key rotation. - """ - key_list = None - - def _update_key(self, key): - if key == self.key_list: - return - return self._initialize_engine(key) - - def _initialize_engine(self, parent_class_key): - from cryptography.fernet import MultiFernet, Fernet - # key will be a *list* of keys - self.key_list = parent_class_key - self.fernet = MultiFernet([Fernet(_fernet_key(key)) for key in self.key_list]) - -class EncryptionUnavailable(Exception): - pass - -class EncryptionConfig(HasTraits): - """Encapsulate encryption configuration - - Use via the encryption_config singleton below. - """ - key_list = List() - def _key_list_default(self): - if 'AUTH_STATE_KEY' not in os.environ: - return [] - # key can be a ;-separated sequence for key rotation. - # First item in the list is used for encryption. - return os.environ['AUTH_STATE_KEY'].split(';') - - @property - def available(self): - if not self.key_list: - return False - return cryptography is not None - -encryption_config = EncryptionConfig() - -class Encrypted(EncryptedType): - def __init__(self, type_in=None, key=None, **kwargs): - super().__init__(type_in, key=lambda : encryption_config.key_list, engine=MultiFernetEngine, **kwargs) - - -class CantEncrypt(TypeDecorator): - """Use in place of Encrypted when Encrypted types can't even be instantiated (crypto unavailable)""" - def process_bind_param(self, value, dialect): - if value is None: - return value - raise EncryptionUnavailable("cryptography library is unavailable") - - def process_result_value(self, value, dialect): - if value is None: - return value - raise EncryptionUnavailable("cryptography library is unavailable") - - -# if cryptography library is unavailable, use CantEncrypt -if cryptography is None: - Encrypted = CantEncrypt - Base = declarative_base() Base.log = app_log @@ -250,38 +142,8 @@ class User(Base): # We will need to figure something else out if/when we have multiple spawners per user state = Column(JSONDict) # Authenticators can store their state here: - _auth_state = Column('auth_state', Encrypted(JSONDict)) - - # check for availability of encryption on a property - # to get better errors than raising in the TypeDecorator methods, - # which won't raise until `db.commit()` - - @property - def auth_state(self): - # TODO: handle decryption failure - try: - value = self._auth_state - except Exception as e: - if encryption_config.available: - why = str(e) - else: - why = "encryption is unavailable" - app_log.warning("Failed to retrieve encrypted auth_state for %s because %s", - self.name, why) - return None - if value is not None and not encryption_config.available: - raise EncryptionUnavailable("auth_state requires cryptography library and AUTH_STATE_KEY") - return value - - @auth_state.setter - def auth_state(self, value): - if value is None: - self._auth_state = value - return - if value is not None and not encryption_config.available: - raise EncryptionUnavailable("auth_state requires cryptography library and AUTH_STATE_KEY") - self._auth_state = value - + # Encryption is handled elsewhere + encrypted_auth_state = Column(LargeBinary) # group mapping groups = relationship('Group', secondary='user_group_map', back_populates='users') diff --git a/jupyterhub/tests/test_crypto.py b/jupyterhub/tests/test_crypto.py new file mode 100644 index 00000000..2c62b317 --- /dev/null +++ b/jupyterhub/tests/test_crypto.py @@ -0,0 +1,55 @@ +import os + +import pytest +from unittest.mock import patch + +from .. import crypto +from ..crypto import encrypt, decrypt + +@pytest.mark.parametrize("key_env, keys", [ + ("secret", [b'secret']), + ("secret1;secret2", [b'secret1', b'secret2']), + ("secret1;secret2;", [b'secret1', b'secret2']), + ("", []), +]) +def test_env_constructor(key_env, keys): + with patch.dict(os.environ, {crypto.KEY_ENV: key_env}): + ck = crypto.CryptKeeper() + assert ck.keys == keys + +@pytest.fixture +def crypt_keeper(): + """Fixture configuring and returning the global CryptKeeper instance""" + ck = crypto.CryptKeeper.instance() + save_keys = ck.keys + ck.keys = [os.urandom(32), os.urandom(32)] + try: + yield ck + finally: + ck.keys = save_keys + +@pytest.mark.gen_test +def test_roundtrip(crypt_keeper): + data = {'key': 'value'} + encrypted = yield encrypt(data) + decrypted = yield decrypt(encrypted) + assert decrypted == data + +@pytest.mark.gen_test +def test_missing_privy(crypt_keeper): + with patch.object(crypto, 'privy', None): + with pytest.raises(crypto.PrivyUnavailable): + yield encrypt({}) + + with pytest.raises(crypto.PrivyUnavailable): + yield decrypt(b'whatever') + +@pytest.mark.gen_test +def test_missing_keys(crypt_keeper): + crypt_keeper.keys = [] + with pytest.raises(crypto.NoEncryptionKeys): + yield encrypt({}) + + with pytest.raises(crypto.NoEncryptionKeys): + yield decrypt(b'whatever') + diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index de9172c3..d0a558ff 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -3,9 +3,6 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -import base64 -import cryptography -import os import socket import pytest @@ -13,6 +10,7 @@ from tornado import gen from .. import orm from .. import objects +from .. import crypto from ..user import User from .mocking import MockSpawner @@ -180,53 +178,65 @@ def test_groups(db): assert user.groups == [group] -def test_auth_state(db): - user = orm.User(name='eve') - db.add(user) +@pytest.mark.gen_test +def test_auth_state_crypto(db): + user = User(orm.User(name='eve')) + db.add(user.orm_user) db.commit() + + ck = crypto.CryptKeeper.instance() + # starts empty - assert user.auth_state is None + assert user.encrypted_auth_state is None # can't set auth_state without keys state = {'key': 'value'} - orm.encryption_config.key_list = [] - with pytest.raises(orm.EncryptionUnavailable): - user.auth_state = state - db.commit() - assert user.auth_state is None - - # + ck.keys = [] + with pytest.raises(crypto.EncryptionUnavailable): + yield user.save_auth_state(state) + + assert user.encrypted_auth_state is None + # saving/loading None doesn't require keys + yield user.save_auth_state(None) + current = yield user.get_auth_state() + assert current is None + first_key = 'first-key' second_key = 'second-key' - orm.encryption_config.key_list = [first_key] - user.auth_state = state - db.commit() - assert user.auth_state == state + ck.keys = [first_key] + yield user.save_auth_state(state) + assert user.encrypted_auth_state is not None + decrypted_state = yield user.get_auth_state() + assert decrypted_state == state # can't read auth_state without keys - orm.encryption_config.key_list = [] - with pytest.raises(orm.EncryptionUnavailable): - print(user.auth_state) + ck.keys = [] + auth_state = yield user.get_auth_state() + assert auth_state is None # key rotation works db.rollback() - orm.encryption_config.key_list = [second_key, first_key] - assert user.auth_state == state + ck.keys = [second_key, first_key] + decrypted_state = yield user.get_auth_state() + assert decrypted_state == state - user.auth_state = new_state = {'key': 'newvalue'} + new_state = {'key': 'newvalue'} + yield user.save_auth_state(new_state) db.commit() - orm.encryption_config.key_list = [first_key] + ck.keys = [first_key] db.rollback() # can't read anymore with new-key after encrypting with second-key - assert user.auth_state is None + decrypted_state = yield user.get_auth_state() + assert decrypted_state is None - user.auth_state = new_state - db.commit() - assert user.auth_state == new_state + yield user.save_auth_state(new_state) + decrypted_state = yield user.get_auth_state() + assert decrypted_state == new_state - orm.encryption_config.key_list = [] + ck.keys = [] db.rollback() - assert user.auth_state is None + decrypted_state = yield user.get_auth_state() + assert decrypted_state is None diff --git a/jupyterhub/user.py b/jupyterhub/user.py index 7d04185d..f9b7f655 100644 --- a/jupyterhub/user.py +++ b/jupyterhub/user.py @@ -15,6 +15,7 @@ from . import orm from ._version import _check_version, __version__ from traitlets import HasTraits, Any, Dict, observe, default from .spawner import LocalProcessSpawner +from .crypto import encrypt, decrypt, CryptKeeper, EncryptionUnavailable class UserDict(dict): """Like defaultdict, but for users @@ -82,6 +83,7 @@ class _SpawnerDict(dict): self[key] = self.spawner_factory(key) return super().__getitem__(key) + class User(HasTraits): @default('log') @@ -100,6 +102,35 @@ class User(HasTraits): orm_user = Any(allow_none=True) + @gen.coroutine + def save_auth_state(self, auth_state): + """Encrypt and store auth_state""" + if auth_state is None: + self.encrypted_auth_state = None + else: + self.encrypted_auth_state = yield encrypt(auth_state) + self.db.commit() + + @gen.coroutine + def get_auth_state(self): + """Retrieve and decrypt auth_state for the user""" + encrypted = self.encrypted_auth_state + if encrypted is None: + return None + try: + auth_state = yield decrypt(encrypted) + except (ValueError, EncryptionUnavailable) as e: + self.log.warning("Failed to retrieve encrypted auth_state for %s because %s", + self.name, e, + ) + return + # loading auth_state + if auth_state: + # Crypt has multiple keys, store again with new key for rotation. + if len(CryptKeeper.instance().keys) > 1: + yield self.save_auth_state(auth_state) + return auth_state + @property def authenticator(self): return self.settings.get('authenticator', None) diff --git a/requirements.txt b/requirements.txt index 8171f706..c11ca810 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,3 @@ pamela python-oauth2>=1.0 sqlalchemy>=1.0 requests -SQLAlchemy-Utils