Test to ensure fernet key rotation results in new key sets

This amends the existing rotation test to ensure that after each key
rotation, a completely new set of keys exists on disk. This is just a
paranoid test to ensure that the key rotation strategy isn't something
terribly dumb like "shuffle all key files on disk without creating any
new ones."

Change-Id: Ie897bd9d41ff96ed6819cbee0053853701afc9d1
(cherry picked from commit 3183a801c7)
This commit is contained in:
Dolph Mathews 2015-06-17 18:17:39 +00:00
parent 2f580e4adb
commit 0ed33e3d25
1 changed files with 64 additions and 7 deletions

View File

@ -11,6 +11,7 @@
# under the License.
import datetime
import hashlib
import os
import uuid
@ -337,6 +338,13 @@ class TestPayloads(tests.TestCase):
class TestFernetKeyRotation(tests.TestCase):
def setUp(self):
super(TestFernetKeyRotation, self).setUp()
# A collection of all previously-seen signatures of the key
# repository's contents.
self.key_repo_signatures = set()
@property
def keys(self):
"""Key files converted to numbers."""
@ -348,6 +356,49 @@ class TestFernetKeyRotation(tests.TestCase):
"""The number of keys in the key repository."""
return len(self.keys)
@property
def key_repository_signature(self):
"""Create a "thumbprint" of the current key repository.
Because key files are renamed, this produces a hash of the contents of
the key files, ignoring their filenames.
The resulting signature can be used, for example, to ensure that you
have a unique set of keys after you perform a key rotation (taking a
static set of keys, and simply shuffling them, would fail such a test).
"""
# Load the keys into a list.
keys = fernet_utils.load_keys()
# Sort the list of keys by the keys themselves (they were previously
# sorted by filename).
keys.sort()
# Create the thumbprint using all keys in the repository.
signature = hashlib.sha1()
for key in keys:
signature.update(key)
return signature.hexdigest()
def assertRepositoryState(self, expected_size):
"""Validate the state of the key repository."""
self.assertEqual(expected_size, self.key_repository_size)
self.assertUniqueRepositoryState()
def assertUniqueRepositoryState(self):
"""Ensures that the current key repo state has not been seen before."""
# This is assigned to a variable because it takes some work to
# calculate.
signature = self.key_repository_signature
# Ensure the signature is not in the set of previously seen signatures.
self.assertNotIn(signature, self.key_repo_signatures)
# Add the signature to the set of repository signatures to validate
# that we don't see it again later.
self.key_repo_signatures.add(signature)
def test_rotation(self):
# Initializing a key repository results in this many keys. We don't
# support max_active_keys being set any lower.
@ -363,31 +414,37 @@ class TestFernetKeyRotation(tests.TestCase):
# active keys.
self.useFixture(ksfixtures.KeyRepository(self.config_fixture))
# Validate the initial repository state.
self.assertRepositoryState(expected_size=min_active_keys)
# The repository should be initialized with a staged key (0) and a
# primary key (1). The next key is just auto-incremented.
exp_keys = [0, 1]
key_no = 2 # keep track of next key
next_key_number = exp_keys[-1] + 1 # keep track of next key
self.assertEqual(exp_keys, self.keys)
# Rotate the keys just enough times to fully populate the key
# repository.
for rotation in range(max_active_keys - min_active_keys):
fernet_utils.rotate_keys()
self.assertEqual(rotation + 3, self.key_repository_size)
self.assertRepositoryState(expected_size=rotation + 3)
exp_keys.append(key_no)
key_no += 1
exp_keys.append(next_key_number)
next_key_number += 1
self.assertEqual(exp_keys, self.keys)
# We should have a fully populated key repository now.
self.assertEqual(max_active_keys, self.key_repository_size)
# Rotate an additional number of times to ensure that we maintain
# the desired number of active keys.
for rotation in range(10):
fernet_utils.rotate_keys()
self.assertEqual(max_active_keys, self.key_repository_size)
self.assertRepositoryState(expected_size=max_active_keys)
exp_keys.pop(1)
exp_keys.append(key_no)
key_no += 1
exp_keys.append(next_key_number)
next_key_number += 1
self.assertEqual(exp_keys, self.keys)
def test_non_numeric_files(self):