move auth_state encryption outside the ORM

privy is used for encryption

- db only has blob column, no knowledge of encryption
- add CryptKeeper for handling encryption
- use privy for encryption, so we have fewer choices to make
- storing/loading encrypted auth_state runs in a ThreadPool
This commit is contained in:
Min RK
2017-07-28 13:44:37 +02:00
parent 32a9b38d26
commit 90e8e1a8aa
8 changed files with 256 additions and 176 deletions

View File

@@ -1,7 +1,7 @@
-r requirements.txt -r requirements.txt
mock mock
codecov codecov
cryptography privy
pytest-cov pytest-cov
pytest-tornado pytest-tornado
pytest>=2.8 pytest>=2.8

123
jupyterhub/crypto.py Normal file
View File

@@ -0,0 +1,123 @@
from concurrent.futures import ThreadPoolExecutor
import json
import os
from traitlets.config import SingletonConfigurable, Config
from traitlets import Any, Dict, Integer, List, default, validate
try:
import privy
except ImportError:
privy = None
KEY_ENV = 'JUPYTERHUB_CRYPT_KEY'
class EncryptionUnavailable(Exception):
pass
class PrivyUnavailable(EncryptionUnavailable):
def __str__(self):
return "privy library is required for encryption"
class NoEncryptionKeys(EncryptionUnavailable):
def __str__(self):
return "Encryption keys must be specified in %s env" % KEY_ENV
class CryptKeeper(SingletonConfigurable):
"""Encapsulate encryption configuration
Use via the encryption_config singleton below.
"""
privy_kwargs = Dict({'server': True},
help="""Keyword arguments to pass to privy.hide.
For example, to
"""
)
n_threads = Integer(max(os.cpu_count(), 1),
help="The number of threads to allocate for encryption",
config=True,
)
@default('config')
def _config_default(self):
# load application config by default
from .app import JupyterHub
if JupyterHub.initialized():
return JupyterHub.instance().config
else:
return Config()
executor = Any()
def _executor_default(self):
return ThreadPoolExecutor(self.n_threads)
keys = List(config=True)
def _keys_default(self):
if KEY_ENV not in os.environ:
return []
# key can be a ;-separated sequence for key rotation.
# First item in the list is used for encryption.
return [ k.encode('ascii') for k in os.environ[KEY_ENV].split(';') if k.strip() ]
@validate('keys')
def _ensure_bytes(self, proposal):
# cast str to bytes
return [ (k.encode('ascii') if isinstance(k, str) else k) for k in proposal.value ]
def check_available(self):
if privy is None:
raise PrivyUnavailable()
if not self.keys:
raise NoEncryptionKeys()
def _encrypt(self, data):
"""Actually do the encryption. Runs in a background thread.
data is serialized to bytes with pickle.
bytes are returned.
"""
return privy.hide(json.dumps(data).encode('utf8'), self.keys[0], **self.privy_kwargs).encode('ascii')
def encrypt(self, data):
"""Encrypt an object with privy"""
self.check_available()
return self.executor.submit(self._encrypt, data)
def _decrypt(self, encrypted):
for key in self.keys:
try:
decrypted = privy.peek(encrypted, key)
except ValueError as e:
continue
else:
break
else:
raise ValueError("Failed to decrypt %r" % encrypted)
return json.loads(decrypted.decode('utf8'))
def decrypt(self, encrypted):
"""Decrypt an object with privy"""
self.check_available()
return self.executor.submit(self._decrypt, encrypted)
def encrypt(data):
"""encrypt some data with the crypt keeper.
data will be serialized with pickle.
Returns a Future whose result will be bytes.
"""
return CryptKeeper.instance().encrypt(data)
def decrypt(data):
"""decrypt some data with the crypt keeper
Returns a Future whose result will be the decrypted, deserialized data.
"""
return CryptKeeper.instance().decrypt(data)

View File

@@ -334,7 +334,7 @@ class BaseHandler(RequestHandler):
# always set auth_state and commit, # always set auth_state and commit,
# because there could be key-rotation or clearing of previous values # because there could be key-rotation or clearing of previous values
# going on. # going on.
user.auth_state = auth_state yield user.save_auth_state(auth_state)
self.db.commit() self.db.commit()
self.set_login_cookie(user) self.set_login_cookie(user)
self.statsd.incr('login.success') self.statsd.incr('login.success')

View File

@@ -3,20 +3,13 @@
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import base64
from datetime import datetime from datetime import datetime
import enum import enum
import os
import json import json
try:
import cryptography
except ImportError:
cryptography = None
from tornado.log import app_log from tornado.log import app_log
from sqlalchemy.types import TypeDecorator, TEXT from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary
from sqlalchemy import ( from sqlalchemy import (
inspect, inspect,
Column, Integer, ForeignKey, Unicode, Boolean, Column, Integer, ForeignKey, Unicode, Boolean,
@@ -26,11 +19,8 @@ from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from sqlalchemy.sql.expression import bindparam from sqlalchemy.sql.expression import bindparam
from sqlalchemy_utils.types.encrypted import EncryptedType, FernetEngine
from sqlalchemy import create_engine, Table from sqlalchemy import create_engine, Table
from traitlets import HasTraits, List
from .utils import ( from .utils import (
random_port, random_port,
new_token, hash_token, compare_token, new_token, hash_token, compare_token,
@@ -42,7 +32,7 @@ class JSONDict(TypeDecorator):
Usage:: Usage::
JSONDict(255) JSONEncodedDict(255)
""" """
@@ -60,104 +50,6 @@ class JSONDict(TypeDecorator):
return value return value
def _fernet_key(key):
"""Generate a Fernet key from a secret
Fernet keys are 32 bytes encoded in url-safe base64 (44 characters).
If a given key is not already a fernet key,
it will be passed through HKDF to generate the 32 bytes.
"""
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
if isinstance(key, str):
key = key.encode()
if len(key) == 44:
# already a fernet key, pass it along
try:
base64.urlsafe_b64decode(key)
except Exception:
pass
else:
return key
elif len(key) != 32:
# not the right size, pass through HKDF
kdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=b'jupyterhub auth state',
backend=default_backend(),
)
key = kdf.derive(key)
return base64.urlsafe_b64encode(key)
class MultiFernetEngine(FernetEngine):
"""Extend SQLAlchemy-Utils FernetEngine to use MultiFernet,
which supports key rotation.
"""
key_list = None
def _update_key(self, key):
if key == self.key_list:
return
return self._initialize_engine(key)
def _initialize_engine(self, parent_class_key):
from cryptography.fernet import MultiFernet, Fernet
# key will be a *list* of keys
self.key_list = parent_class_key
self.fernet = MultiFernet([Fernet(_fernet_key(key)) for key in self.key_list])
class EncryptionUnavailable(Exception):
pass
class EncryptionConfig(HasTraits):
"""Encapsulate encryption configuration
Use via the encryption_config singleton below.
"""
key_list = List()
def _key_list_default(self):
if 'AUTH_STATE_KEY' not in os.environ:
return []
# key can be a ;-separated sequence for key rotation.
# First item in the list is used for encryption.
return os.environ['AUTH_STATE_KEY'].split(';')
@property
def available(self):
if not self.key_list:
return False
return cryptography is not None
encryption_config = EncryptionConfig()
class Encrypted(EncryptedType):
def __init__(self, type_in=None, key=None, **kwargs):
super().__init__(type_in, key=lambda : encryption_config.key_list, engine=MultiFernetEngine, **kwargs)
class CantEncrypt(TypeDecorator):
"""Use in place of Encrypted when Encrypted types can't even be instantiated (crypto unavailable)"""
def process_bind_param(self, value, dialect):
if value is None:
return value
raise EncryptionUnavailable("cryptography library is unavailable")
def process_result_value(self, value, dialect):
if value is None:
return value
raise EncryptionUnavailable("cryptography library is unavailable")
# if cryptography library is unavailable, use CantEncrypt
if cryptography is None:
Encrypted = CantEncrypt
Base = declarative_base() Base = declarative_base()
Base.log = app_log Base.log = app_log
@@ -250,38 +142,8 @@ class User(Base):
# We will need to figure something else out if/when we have multiple spawners per user # We will need to figure something else out if/when we have multiple spawners per user
state = Column(JSONDict) state = Column(JSONDict)
# Authenticators can store their state here: # Authenticators can store their state here:
_auth_state = Column('auth_state', Encrypted(JSONDict)) # Encryption is handled elsewhere
encrypted_auth_state = Column(LargeBinary)
# check for availability of encryption on a property
# to get better errors than raising in the TypeDecorator methods,
# which won't raise until `db.commit()`
@property
def auth_state(self):
# TODO: handle decryption failure
try:
value = self._auth_state
except Exception as e:
if encryption_config.available:
why = str(e)
else:
why = "encryption is unavailable"
app_log.warning("Failed to retrieve encrypted auth_state for %s because %s",
self.name, why)
return None
if value is not None and not encryption_config.available:
raise EncryptionUnavailable("auth_state requires cryptography library and AUTH_STATE_KEY")
return value
@auth_state.setter
def auth_state(self, value):
if value is None:
self._auth_state = value
return
if value is not None and not encryption_config.available:
raise EncryptionUnavailable("auth_state requires cryptography library and AUTH_STATE_KEY")
self._auth_state = value
# group mapping # group mapping
groups = relationship('Group', secondary='user_group_map', back_populates='users') groups = relationship('Group', secondary='user_group_map', back_populates='users')

View File

@@ -0,0 +1,55 @@
import os
import pytest
from unittest.mock import patch
from .. import crypto
from ..crypto import encrypt, decrypt
@pytest.mark.parametrize("key_env, keys", [
("secret", [b'secret']),
("secret1;secret2", [b'secret1', b'secret2']),
("secret1;secret2;", [b'secret1', b'secret2']),
("", []),
])
def test_env_constructor(key_env, keys):
with patch.dict(os.environ, {crypto.KEY_ENV: key_env}):
ck = crypto.CryptKeeper()
assert ck.keys == keys
@pytest.fixture
def crypt_keeper():
"""Fixture configuring and returning the global CryptKeeper instance"""
ck = crypto.CryptKeeper.instance()
save_keys = ck.keys
ck.keys = [os.urandom(32), os.urandom(32)]
try:
yield ck
finally:
ck.keys = save_keys
@pytest.mark.gen_test
def test_roundtrip(crypt_keeper):
data = {'key': 'value'}
encrypted = yield encrypt(data)
decrypted = yield decrypt(encrypted)
assert decrypted == data
@pytest.mark.gen_test
def test_missing_privy(crypt_keeper):
with patch.object(crypto, 'privy', None):
with pytest.raises(crypto.PrivyUnavailable):
yield encrypt({})
with pytest.raises(crypto.PrivyUnavailable):
yield decrypt(b'whatever')
@pytest.mark.gen_test
def test_missing_keys(crypt_keeper):
crypt_keeper.keys = []
with pytest.raises(crypto.NoEncryptionKeys):
yield encrypt({})
with pytest.raises(crypto.NoEncryptionKeys):
yield decrypt(b'whatever')

View File

@@ -3,9 +3,6 @@
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import base64
import cryptography
import os
import socket import socket
import pytest import pytest
@@ -13,6 +10,7 @@ from tornado import gen
from .. import orm from .. import orm
from .. import objects from .. import objects
from .. import crypto
from ..user import User from ..user import User
from .mocking import MockSpawner from .mocking import MockSpawner
@@ -180,53 +178,65 @@ def test_groups(db):
assert user.groups == [group] assert user.groups == [group]
def test_auth_state(db): @pytest.mark.gen_test
user = orm.User(name='eve') def test_auth_state_crypto(db):
db.add(user) user = User(orm.User(name='eve'))
db.add(user.orm_user)
db.commit() db.commit()
ck = crypto.CryptKeeper.instance()
# starts empty # starts empty
assert user.auth_state is None assert user.encrypted_auth_state is None
# can't set auth_state without keys # can't set auth_state without keys
state = {'key': 'value'} state = {'key': 'value'}
orm.encryption_config.key_list = [] ck.keys = []
with pytest.raises(orm.EncryptionUnavailable): with pytest.raises(crypto.EncryptionUnavailable):
user.auth_state = state yield user.save_auth_state(state)
db.commit()
assert user.auth_state is None assert user.encrypted_auth_state is None
# saving/loading None doesn't require keys
# yield user.save_auth_state(None)
current = yield user.get_auth_state()
assert current is None
first_key = 'first-key' first_key = 'first-key'
second_key = 'second-key' second_key = 'second-key'
orm.encryption_config.key_list = [first_key] ck.keys = [first_key]
user.auth_state = state yield user.save_auth_state(state)
db.commit() assert user.encrypted_auth_state is not None
assert user.auth_state == state decrypted_state = yield user.get_auth_state()
assert decrypted_state == state
# can't read auth_state without keys # can't read auth_state without keys
orm.encryption_config.key_list = [] ck.keys = []
with pytest.raises(orm.EncryptionUnavailable): auth_state = yield user.get_auth_state()
print(user.auth_state) assert auth_state is None
# key rotation works # key rotation works
db.rollback() db.rollback()
orm.encryption_config.key_list = [second_key, first_key] ck.keys = [second_key, first_key]
assert user.auth_state == state decrypted_state = yield user.get_auth_state()
assert decrypted_state == state
user.auth_state = new_state = {'key': 'newvalue'} new_state = {'key': 'newvalue'}
yield user.save_auth_state(new_state)
db.commit() db.commit()
orm.encryption_config.key_list = [first_key] ck.keys = [first_key]
db.rollback() db.rollback()
# can't read anymore with new-key after encrypting with second-key # can't read anymore with new-key after encrypting with second-key
assert user.auth_state is None decrypted_state = yield user.get_auth_state()
assert decrypted_state is None
user.auth_state = new_state yield user.save_auth_state(new_state)
db.commit() decrypted_state = yield user.get_auth_state()
assert user.auth_state == new_state assert decrypted_state == new_state
orm.encryption_config.key_list = [] ck.keys = []
db.rollback() db.rollback()
assert user.auth_state is None decrypted_state = yield user.get_auth_state()
assert decrypted_state is None

View File

@@ -15,6 +15,7 @@ from . import orm
from ._version import _check_version, __version__ from ._version import _check_version, __version__
from traitlets import HasTraits, Any, Dict, observe, default from traitlets import HasTraits, Any, Dict, observe, default
from .spawner import LocalProcessSpawner from .spawner import LocalProcessSpawner
from .crypto import encrypt, decrypt, CryptKeeper, EncryptionUnavailable
class UserDict(dict): class UserDict(dict):
"""Like defaultdict, but for users """Like defaultdict, but for users
@@ -82,6 +83,7 @@ class _SpawnerDict(dict):
self[key] = self.spawner_factory(key) self[key] = self.spawner_factory(key)
return super().__getitem__(key) return super().__getitem__(key)
class User(HasTraits): class User(HasTraits):
@default('log') @default('log')
@@ -100,6 +102,35 @@ class User(HasTraits):
orm_user = Any(allow_none=True) orm_user = Any(allow_none=True)
@gen.coroutine
def save_auth_state(self, auth_state):
"""Encrypt and store auth_state"""
if auth_state is None:
self.encrypted_auth_state = None
else:
self.encrypted_auth_state = yield encrypt(auth_state)
self.db.commit()
@gen.coroutine
def get_auth_state(self):
"""Retrieve and decrypt auth_state for the user"""
encrypted = self.encrypted_auth_state
if encrypted is None:
return None
try:
auth_state = yield decrypt(encrypted)
except (ValueError, EncryptionUnavailable) as e:
self.log.warning("Failed to retrieve encrypted auth_state for %s because %s",
self.name, e,
)
return
# loading auth_state
if auth_state:
# Crypt has multiple keys, store again with new key for rotation.
if len(CryptKeeper.instance().keys) > 1:
yield self.save_auth_state(auth_state)
return auth_state
@property @property
def authenticator(self): def authenticator(self):
return self.settings.get('authenticator', None) return self.settings.get('authenticator', None)

View File

@@ -6,4 +6,3 @@ pamela
python-oauth2>=1.0 python-oauth2>=1.0
sqlalchemy>=1.0 sqlalchemy>=1.0
requests requests
SQLAlchemy-Utils