Merge: Liberty/Mitaka support

This commit is contained in:
Bilal Baqar 2016-04-25 11:21:09 +02:00
commit 2212d1eae1
63 changed files with 7301 additions and 587 deletions

View File

@ -4,7 +4,7 @@ PYTHON := /usr/bin/env python
virtualenv:
virtualenv .venv
.venv/bin/pip install flake8 nose coverage mock pyyaml netifaces \
netaddr jinja2
netaddr jinja2 pyflakes pep8 six pbr funcsigs psutil
lint: virtualenv
.venv/bin/flake8 --exclude hooks/charmhelpers hooks unit_tests tests --ignore E402

253
bin/charm_helpers_sync.py Normal file
View File

@ -0,0 +1,253 @@
#!/usr/bin/python
# Copyright 2014-2015 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
# Authors:
# Adam Gandelman <adamg@ubuntu.com>
import logging
import optparse
import os
import subprocess
import shutil
import sys
import tempfile
import yaml
from fnmatch import fnmatch
import six
CHARM_HELPERS_BRANCH = 'lp:charm-helpers'
def parse_config(conf_file):
if not os.path.isfile(conf_file):
logging.error('Invalid config file: %s.' % conf_file)
return False
return yaml.load(open(conf_file).read())
def clone_helpers(work_dir, branch):
dest = os.path.join(work_dir, 'charm-helpers')
logging.info('Checking out %s to %s.' % (branch, dest))
cmd = ['bzr', 'checkout', '--lightweight', branch, dest]
subprocess.check_call(cmd)
return dest
def _module_path(module):
return os.path.join(*module.split('.'))
def _src_path(src, module):
return os.path.join(src, 'charmhelpers', _module_path(module))
def _dest_path(dest, module):
return os.path.join(dest, _module_path(module))
def _is_pyfile(path):
return os.path.isfile(path + '.py')
def ensure_init(path):
'''
ensure directories leading up to path are importable, omitting
parent directory, eg path='/hooks/helpers/foo'/:
hooks/
hooks/helpers/__init__.py
hooks/helpers/foo/__init__.py
'''
for d, dirs, files in os.walk(os.path.join(*path.split('/')[:2])):
_i = os.path.join(d, '__init__.py')
if not os.path.exists(_i):
logging.info('Adding missing __init__.py: %s' % _i)
open(_i, 'wb').close()
def sync_pyfile(src, dest):
src = src + '.py'
src_dir = os.path.dirname(src)
logging.info('Syncing pyfile: %s -> %s.' % (src, dest))
if not os.path.exists(dest):
os.makedirs(dest)
shutil.copy(src, dest)
if os.path.isfile(os.path.join(src_dir, '__init__.py')):
shutil.copy(os.path.join(src_dir, '__init__.py'),
dest)
ensure_init(dest)
def get_filter(opts=None):
opts = opts or []
if 'inc=*' in opts:
# do not filter any files, include everything
return None
def _filter(dir, ls):
incs = [opt.split('=').pop() for opt in opts if 'inc=' in opt]
_filter = []
for f in ls:
_f = os.path.join(dir, f)
if not os.path.isdir(_f) and not _f.endswith('.py') and incs:
if True not in [fnmatch(_f, inc) for inc in incs]:
logging.debug('Not syncing %s, does not match include '
'filters (%s)' % (_f, incs))
_filter.append(f)
else:
logging.debug('Including file, which matches include '
'filters (%s): %s' % (incs, _f))
elif (os.path.isfile(_f) and not _f.endswith('.py')):
logging.debug('Not syncing file: %s' % f)
_filter.append(f)
elif (os.path.isdir(_f) and not
os.path.isfile(os.path.join(_f, '__init__.py'))):
logging.debug('Not syncing directory: %s' % f)
_filter.append(f)
return _filter
return _filter
def sync_directory(src, dest, opts=None):
if os.path.exists(dest):
logging.debug('Removing existing directory: %s' % dest)
shutil.rmtree(dest)
logging.info('Syncing directory: %s -> %s.' % (src, dest))
shutil.copytree(src, dest, ignore=get_filter(opts))
ensure_init(dest)
def sync(src, dest, module, opts=None):
# Sync charmhelpers/__init__.py for bootstrap code.
sync_pyfile(_src_path(src, '__init__'), dest)
# Sync other __init__.py files in the path leading to module.
m = []
steps = module.split('.')[:-1]
while steps:
m.append(steps.pop(0))
init = '.'.join(m + ['__init__'])
sync_pyfile(_src_path(src, init),
os.path.dirname(_dest_path(dest, init)))
# Sync the module, or maybe a .py file.
if os.path.isdir(_src_path(src, module)):
sync_directory(_src_path(src, module), _dest_path(dest, module), opts)
elif _is_pyfile(_src_path(src, module)):
sync_pyfile(_src_path(src, module),
os.path.dirname(_dest_path(dest, module)))
else:
logging.warn('Could not sync: %s. Neither a pyfile or directory, '
'does it even exist?' % module)
def parse_sync_options(options):
if not options:
return []
return options.split(',')
def extract_options(inc, global_options=None):
global_options = global_options or []
if global_options and isinstance(global_options, six.string_types):
global_options = [global_options]
if '|' not in inc:
return (inc, global_options)
inc, opts = inc.split('|')
return (inc, parse_sync_options(opts) + global_options)
def sync_helpers(include, src, dest, options=None):
if not os.path.isdir(dest):
os.makedirs(dest)
global_options = parse_sync_options(options)
for inc in include:
if isinstance(inc, str):
inc, opts = extract_options(inc, global_options)
sync(src, dest, inc, opts)
elif isinstance(inc, dict):
# could also do nested dicts here.
for k, v in six.iteritems(inc):
if isinstance(v, list):
for m in v:
inc, opts = extract_options(m, global_options)
sync(src, dest, '%s.%s' % (k, inc), opts)
if __name__ == '__main__':
parser = optparse.OptionParser()
parser.add_option('-c', '--config', action='store', dest='config',
default=None, help='helper config file')
parser.add_option('-D', '--debug', action='store_true', dest='debug',
default=False, help='debug')
parser.add_option('-b', '--branch', action='store', dest='branch',
help='charm-helpers bzr branch (overrides config)')
parser.add_option('-d', '--destination', action='store', dest='dest_dir',
help='sync destination dir (overrides config)')
(opts, args) = parser.parse_args()
if opts.debug:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
if opts.config:
logging.info('Loading charm helper config from %s.' % opts.config)
config = parse_config(opts.config)
if not config:
logging.error('Could not parse config from %s.' % opts.config)
sys.exit(1)
else:
config = {}
if 'branch' not in config:
config['branch'] = CHARM_HELPERS_BRANCH
if opts.branch:
config['branch'] = opts.branch
if opts.dest_dir:
config['destination'] = opts.dest_dir
if 'destination' not in config:
logging.error('No destination dir. specified as option or config.')
sys.exit(1)
if 'include' not in config:
if not args:
logging.error('No modules to sync specified as option or config.')
sys.exit(1)
config['include'] = []
[config['include'].append(a) for a in args]
sync_options = None
if 'options' in config:
sync_options = config['options']
tmpd = tempfile.mkdtemp()
try:
checkout = clone_helpers(tmpd, config['branch'])
sync_helpers(config['include'], checkout, config['destination'],
options=sync_options)
except Exception as e:
logging.error("Could not sync: %s" % e)
raise e
finally:
logging.debug('Cleaning up %s' % tmpd)
shutil.rmtree(tmpd)

View File

@ -51,7 +51,8 @@ class AmuletDeployment(object):
if 'units' not in this_service:
this_service['units'] = 1
self.d.add(this_service['name'], units=this_service['units'])
self.d.add(this_service['name'], units=this_service['units'],
constraints=this_service.get('constraints'))
for svc in other_services:
if 'location' in svc:
@ -64,7 +65,8 @@ class AmuletDeployment(object):
if 'units' not in svc:
svc['units'] = 1
self.d.add(svc['name'], charm=branch_location, units=svc['units'])
self.d.add(svc['name'], charm=branch_location, units=svc['units'],
constraints=svc.get('constraints'))
def _add_relations(self, relations):
"""Add all of the relations for the services."""

View File

@ -14,17 +14,25 @@
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import amulet
import ConfigParser
import distro_info
import io
import json
import logging
import os
import re
import six
import socket
import subprocess
import sys
import time
import urlparse
import uuid
import amulet
import distro_info
import six
from six.moves import configparser
if six.PY3:
from urllib import parse as urlparse
else:
import urlparse
class AmuletUtils(object):
@ -108,7 +116,7 @@ class AmuletUtils(object):
# /!\ DEPRECATION WARNING (beisner):
# New and existing tests should be rewritten to use
# validate_services_by_name() as it is aware of init systems.
self.log.warn('/!\\ DEPRECATION WARNING: use '
self.log.warn('DEPRECATION WARNING: use '
'validate_services_by_name instead of validate_services '
'due to init system differences.')
@ -142,19 +150,23 @@ class AmuletUtils(object):
for service_name in services_list:
if (self.ubuntu_releases.index(release) >= systemd_switch or
service_name == "rabbitmq-server"):
# init is systemd
service_name in ['rabbitmq-server', 'apache2']):
# init is systemd (or regular sysv)
cmd = 'sudo service {} status'.format(service_name)
output, code = sentry_unit.run(cmd)
service_running = code == 0
elif self.ubuntu_releases.index(release) < systemd_switch:
# init is upstart
cmd = 'sudo status {}'.format(service_name)
output, code = sentry_unit.run(cmd)
service_running = code == 0 and "start/running" in output
output, code = sentry_unit.run(cmd)
self.log.debug('{} `{}` returned '
'{}'.format(sentry_unit.info['unit_name'],
cmd, code))
if code != 0:
return "command `{}` returned {}".format(cmd, str(code))
if not service_running:
return u"command `{}` returned {} {}".format(
cmd, output, str(code))
return None
def _get_config(self, unit, filename):
@ -164,7 +176,7 @@ class AmuletUtils(object):
# NOTE(beisner): by default, ConfigParser does not handle options
# with no value, such as the flags used in the mysql my.cnf file.
# https://bugs.python.org/issue7005
config = ConfigParser.ConfigParser(allow_no_value=True)
config = configparser.ConfigParser(allow_no_value=True)
config.readfp(io.StringIO(file_contents))
return config
@ -259,33 +271,52 @@ class AmuletUtils(object):
"""Get last modification time of directory."""
return sentry_unit.directory_stat(directory)['mtime']
def _get_proc_start_time(self, sentry_unit, service, pgrep_full=False):
"""Get process' start time.
def _get_proc_start_time(self, sentry_unit, service, pgrep_full=None):
"""Get start time of a process based on the last modification time
of the /proc/pid directory.
Determine start time of the process based on the last modification
time of the /proc/pid directory. If pgrep_full is True, the process
name is matched against the full command line.
"""
if pgrep_full:
cmd = 'pgrep -o -f {}'.format(service)
else:
cmd = 'pgrep -o {}'.format(service)
cmd = cmd + ' | grep -v pgrep || exit 0'
cmd_out = sentry_unit.run(cmd)
self.log.debug('CMDout: ' + str(cmd_out))
if cmd_out[0]:
self.log.debug('Pid for %s %s' % (service, str(cmd_out[0])))
proc_dir = '/proc/{}'.format(cmd_out[0].strip())
return self._get_dir_mtime(sentry_unit, proc_dir)
:sentry_unit: The sentry unit to check for the service on
:service: service name to look for in process table
:pgrep_full: [Deprecated] Use full command line search mode with pgrep
:returns: epoch time of service process start
:param commands: list of bash commands
:param sentry_units: list of sentry unit pointers
:returns: None if successful; Failure message otherwise
"""
if pgrep_full is not None:
# /!\ DEPRECATION WARNING (beisner):
# No longer implemented, as pidof is now used instead of pgrep.
# https://bugs.launchpad.net/charm-helpers/+bug/1474030
self.log.warn('DEPRECATION WARNING: pgrep_full bool is no '
'longer implemented re: lp 1474030.')
pid_list = self.get_process_id_list(sentry_unit, service)
pid = pid_list[0]
proc_dir = '/proc/{}'.format(pid)
self.log.debug('Pid for {} on {}: {}'.format(
service, sentry_unit.info['unit_name'], pid))
return self._get_dir_mtime(sentry_unit, proc_dir)
def service_restarted(self, sentry_unit, service, filename,
pgrep_full=False, sleep_time=20):
pgrep_full=None, sleep_time=20):
"""Check if service was restarted.
Compare a service's start time vs a file's last modification time
(such as a config file for that service) to determine if the service
has been restarted.
"""
# /!\ DEPRECATION WARNING (beisner):
# This method is prone to races in that no before-time is known.
# Use validate_service_config_changed instead.
# NOTE(beisner) pgrep_full is no longer implemented, as pidof is now
# used instead of pgrep. pgrep_full is still passed through to ensure
# deprecation WARNS. lp1474030
self.log.warn('DEPRECATION WARNING: use '
'validate_service_config_changed instead of '
'service_restarted due to known races.')
time.sleep(sleep_time)
if (self._get_proc_start_time(sentry_unit, service, pgrep_full) >=
self._get_file_mtime(sentry_unit, filename)):
@ -294,78 +325,122 @@ class AmuletUtils(object):
return False
def service_restarted_since(self, sentry_unit, mtime, service,
pgrep_full=False, sleep_time=20,
retry_count=2):
pgrep_full=None, sleep_time=20,
retry_count=30, retry_sleep_time=10):
"""Check if service was been started after a given time.
Args:
sentry_unit (sentry): The sentry unit to check for the service on
mtime (float): The epoch time to check against
service (string): service name to look for in process table
pgrep_full (boolean): Use full command line search mode with pgrep
sleep_time (int): Seconds to sleep before looking for process
retry_count (int): If service is not found, how many times to retry
pgrep_full: [Deprecated] Use full command line search mode with pgrep
sleep_time (int): Initial sleep time (s) before looking for file
retry_sleep_time (int): Time (s) to sleep between retries
retry_count (int): If file is not found, how many times to retry
Returns:
bool: True if service found and its start time it newer than mtime,
False if service is older than mtime or if service was
not found.
"""
self.log.debug('Checking %s restarted since %s' % (service, mtime))
# NOTE(beisner) pgrep_full is no longer implemented, as pidof is now
# used instead of pgrep. pgrep_full is still passed through to ensure
# deprecation WARNS. lp1474030
unit_name = sentry_unit.info['unit_name']
self.log.debug('Checking that %s service restarted since %s on '
'%s' % (service, mtime, unit_name))
time.sleep(sleep_time)
proc_start_time = self._get_proc_start_time(sentry_unit, service,
pgrep_full)
while retry_count > 0 and not proc_start_time:
self.log.debug('No pid file found for service %s, will retry %i '
'more times' % (service, retry_count))
time.sleep(30)
proc_start_time = self._get_proc_start_time(sentry_unit, service,
pgrep_full)
retry_count = retry_count - 1
proc_start_time = None
tries = 0
while tries <= retry_count and not proc_start_time:
try:
proc_start_time = self._get_proc_start_time(sentry_unit,
service,
pgrep_full)
self.log.debug('Attempt {} to get {} proc start time on {} '
'OK'.format(tries, service, unit_name))
except IOError as e:
# NOTE(beisner) - race avoidance, proc may not exist yet.
# https://bugs.launchpad.net/charm-helpers/+bug/1474030
self.log.debug('Attempt {} to get {} proc start time on {} '
'failed\n{}'.format(tries, service,
unit_name, e))
time.sleep(retry_sleep_time)
tries += 1
if not proc_start_time:
self.log.warn('No proc start time found, assuming service did '
'not start')
return False
if proc_start_time >= mtime:
self.log.debug('proc start time is newer than provided mtime'
'(%s >= %s)' % (proc_start_time, mtime))
self.log.debug('Proc start time is newer than provided mtime'
'(%s >= %s) on %s (OK)' % (proc_start_time,
mtime, unit_name))
return True
else:
self.log.warn('proc start time (%s) is older than provided mtime '
'(%s), service did not restart' % (proc_start_time,
mtime))
self.log.warn('Proc start time (%s) is older than provided mtime '
'(%s) on %s, service did not '
'restart' % (proc_start_time, mtime, unit_name))
return False
def config_updated_since(self, sentry_unit, filename, mtime,
sleep_time=20):
sleep_time=20, retry_count=30,
retry_sleep_time=10):
"""Check if file was modified after a given time.
Args:
sentry_unit (sentry): The sentry unit to check the file mtime on
filename (string): The file to check mtime of
mtime (float): The epoch time to check against
sleep_time (int): Seconds to sleep before looking for process
sleep_time (int): Initial sleep time (s) before looking for file
retry_sleep_time (int): Time (s) to sleep between retries
retry_count (int): If file is not found, how many times to retry
Returns:
bool: True if file was modified more recently than mtime, False if
file was modified before mtime,
file was modified before mtime, or if file not found.
"""
self.log.debug('Checking %s updated since %s' % (filename, mtime))
unit_name = sentry_unit.info['unit_name']
self.log.debug('Checking that %s updated since %s on '
'%s' % (filename, mtime, unit_name))
time.sleep(sleep_time)
file_mtime = self._get_file_mtime(sentry_unit, filename)
file_mtime = None
tries = 0
while tries <= retry_count and not file_mtime:
try:
file_mtime = self._get_file_mtime(sentry_unit, filename)
self.log.debug('Attempt {} to get {} file mtime on {} '
'OK'.format(tries, filename, unit_name))
except IOError as e:
# NOTE(beisner) - race avoidance, file may not exist yet.
# https://bugs.launchpad.net/charm-helpers/+bug/1474030
self.log.debug('Attempt {} to get {} file mtime on {} '
'failed\n{}'.format(tries, filename,
unit_name, e))
time.sleep(retry_sleep_time)
tries += 1
if not file_mtime:
self.log.warn('Could not determine file mtime, assuming '
'file does not exist')
return False
if file_mtime >= mtime:
self.log.debug('File mtime is newer than provided mtime '
'(%s >= %s)' % (file_mtime, mtime))
'(%s >= %s) on %s (OK)' % (file_mtime,
mtime, unit_name))
return True
else:
self.log.warn('File mtime %s is older than provided mtime %s'
% (file_mtime, mtime))
self.log.warn('File mtime is older than provided mtime'
'(%s < on %s) on %s' % (file_mtime,
mtime, unit_name))
return False
def validate_service_config_changed(self, sentry_unit, mtime, service,
filename, pgrep_full=False,
sleep_time=20, retry_count=2):
filename, pgrep_full=None,
sleep_time=20, retry_count=30,
retry_sleep_time=10):
"""Check service and file were updated after mtime
Args:
@ -373,9 +448,10 @@ class AmuletUtils(object):
mtime (float): The epoch time to check against
service (string): service name to look for in process table
filename (string): The file to check mtime of
pgrep_full (boolean): Use full command line search mode with pgrep
sleep_time (int): Seconds to sleep before looking for process
pgrep_full: [Deprecated] Use full command line search mode with pgrep
sleep_time (int): Initial sleep in seconds to pass to test helpers
retry_count (int): If service is not found, how many times to retry
retry_sleep_time (int): Time in seconds to wait between retries
Typical Usage:
u = OpenStackAmuletUtils(ERROR)
@ -392,15 +468,27 @@ class AmuletUtils(object):
mtime, False if service is older than mtime or if service was
not found or if filename was modified before mtime.
"""
self.log.debug('Checking %s restarted since %s' % (service, mtime))
time.sleep(sleep_time)
service_restart = self.service_restarted_since(sentry_unit, mtime,
service,
pgrep_full=pgrep_full,
sleep_time=0,
retry_count=retry_count)
config_update = self.config_updated_since(sentry_unit, filename, mtime,
sleep_time=0)
# NOTE(beisner) pgrep_full is no longer implemented, as pidof is now
# used instead of pgrep. pgrep_full is still passed through to ensure
# deprecation WARNS. lp1474030
service_restart = self.service_restarted_since(
sentry_unit, mtime,
service,
pgrep_full=pgrep_full,
sleep_time=sleep_time,
retry_count=retry_count,
retry_sleep_time=retry_sleep_time)
config_update = self.config_updated_since(
sentry_unit,
filename,
mtime,
sleep_time=sleep_time,
retry_count=retry_count,
retry_sleep_time=retry_sleep_time)
return service_restart and config_update
def get_sentry_time(self, sentry_unit):
@ -418,7 +506,6 @@ class AmuletUtils(object):
"""Return a list of all Ubuntu releases in order of release."""
_d = distro_info.UbuntuDistroInfo()
_release_list = _d.all
self.log.debug('Ubuntu release list: {}'.format(_release_list))
return _release_list
def file_to_url(self, file_rel_path):
@ -450,15 +537,20 @@ class AmuletUtils(object):
cmd, code, output))
return None
def get_process_id_list(self, sentry_unit, process_name):
def get_process_id_list(self, sentry_unit, process_name,
expect_success=True):
"""Get a list of process ID(s) from a single sentry juju unit
for a single process name.
:param sentry_unit: Pointer to amulet sentry instance (juju unit)
:param sentry_unit: Amulet sentry instance (juju unit)
:param process_name: Process name
:param expect_success: If False, expect the PID to be missing,
raise if it is present.
:returns: List of process IDs
"""
cmd = 'pidof {}'.format(process_name)
cmd = 'pidof -x {}'.format(process_name)
if not expect_success:
cmd += " || exit 0 && exit 1"
output, code = sentry_unit.run(cmd)
if code != 0:
msg = ('{} `{}` returned {} '
@ -467,14 +559,23 @@ class AmuletUtils(object):
amulet.raise_status(amulet.FAIL, msg=msg)
return str(output).split()
def get_unit_process_ids(self, unit_processes):
def get_unit_process_ids(self, unit_processes, expect_success=True):
"""Construct a dict containing unit sentries, process names, and
process IDs."""
process IDs.
:param unit_processes: A dictionary of Amulet sentry instance
to list of process names.
:param expect_success: if False expect the processes to not be
running, raise if they are.
:returns: Dictionary of Amulet sentry instance to dictionary
of process names to PIDs.
"""
pid_dict = {}
for sentry_unit, process_list in unit_processes.iteritems():
for sentry_unit, process_list in six.iteritems(unit_processes):
pid_dict[sentry_unit] = {}
for process in process_list:
pids = self.get_process_id_list(sentry_unit, process)
pids = self.get_process_id_list(
sentry_unit, process, expect_success=expect_success)
pid_dict[sentry_unit].update({process: pids})
return pid_dict
@ -488,7 +589,7 @@ class AmuletUtils(object):
return ('Unit count mismatch. expected, actual: {}, '
'{} '.format(len(expected), len(actual)))
for (e_sentry, e_proc_names) in expected.iteritems():
for (e_sentry, e_proc_names) in six.iteritems(expected):
e_sentry_name = e_sentry.info['unit_name']
if e_sentry in actual.keys():
a_proc_names = actual[e_sentry]
@ -500,22 +601,40 @@ class AmuletUtils(object):
return ('Process name count mismatch. expected, actual: {}, '
'{}'.format(len(expected), len(actual)))
for (e_proc_name, e_pids_length), (a_proc_name, a_pids) in \
for (e_proc_name, e_pids), (a_proc_name, a_pids) in \
zip(e_proc_names.items(), a_proc_names.items()):
if e_proc_name != a_proc_name:
return ('Process name mismatch. expected, actual: {}, '
'{}'.format(e_proc_name, a_proc_name))
a_pids_length = len(a_pids)
if e_pids_length != a_pids_length:
return ('PID count mismatch. {} ({}) expected, actual: '
fail_msg = ('PID count mismatch. {} ({}) expected, actual: '
'{}, {} ({})'.format(e_sentry_name, e_proc_name,
e_pids_length, a_pids_length,
e_pids, a_pids_length,
a_pids))
# If expected is a list, ensure at least one PID quantity match
if isinstance(e_pids, list) and \
a_pids_length not in e_pids:
return fail_msg
# If expected is not bool and not list,
# ensure PID quantities match
elif not isinstance(e_pids, bool) and \
not isinstance(e_pids, list) and \
a_pids_length != e_pids:
return fail_msg
# If expected is bool True, ensure 1 or more PIDs exist
elif isinstance(e_pids, bool) and \
e_pids is True and a_pids_length < 1:
return fail_msg
# If expected is bool False, ensure 0 PIDs exist
elif isinstance(e_pids, bool) and \
e_pids is False and a_pids_length != 0:
return fail_msg
else:
self.log.debug('PID check OK: {} {} {}: '
'{}'.format(e_sentry_name, e_proc_name,
e_pids_length, a_pids))
e_pids, a_pids))
return None
def validate_list_of_identical_dicts(self, list_of_dicts):
@ -531,3 +650,180 @@ class AmuletUtils(object):
return 'Dicts within list are not identical'
return None
def validate_sectionless_conf(self, file_contents, expected):
"""A crude conf parser. Useful to inspect configuration files which
do not have section headers (as would be necessary in order to use
the configparser). Such as openstack-dashboard or rabbitmq confs."""
for line in file_contents.split('\n'):
if '=' in line:
args = line.split('=')
if len(args) <= 1:
continue
key = args[0].strip()
value = args[1].strip()
if key in expected.keys():
if expected[key] != value:
msg = ('Config mismatch. Expected, actual: {}, '
'{}'.format(expected[key], value))
amulet.raise_status(amulet.FAIL, msg=msg)
def get_unit_hostnames(self, units):
"""Return a dict of juju unit names to hostnames."""
host_names = {}
for unit in units:
host_names[unit.info['unit_name']] = \
str(unit.file_contents('/etc/hostname').strip())
self.log.debug('Unit host names: {}'.format(host_names))
return host_names
def run_cmd_unit(self, sentry_unit, cmd):
"""Run a command on a unit, return the output and exit code."""
output, code = sentry_unit.run(cmd)
if code == 0:
self.log.debug('{} `{}` command returned {} '
'(OK)'.format(sentry_unit.info['unit_name'],
cmd, code))
else:
msg = ('{} `{}` command returned {} '
'{}'.format(sentry_unit.info['unit_name'],
cmd, code, output))
amulet.raise_status(amulet.FAIL, msg=msg)
return str(output), code
def file_exists_on_unit(self, sentry_unit, file_name):
"""Check if a file exists on a unit."""
try:
sentry_unit.file_stat(file_name)
return True
except IOError:
return False
except Exception as e:
msg = 'Error checking file {}: {}'.format(file_name, e)
amulet.raise_status(amulet.FAIL, msg=msg)
def file_contents_safe(self, sentry_unit, file_name,
max_wait=60, fatal=False):
"""Get file contents from a sentry unit. Wrap amulet file_contents
with retry logic to address races where a file checks as existing,
but no longer exists by the time file_contents is called.
Return None if file not found. Optionally raise if fatal is True."""
unit_name = sentry_unit.info['unit_name']
file_contents = False
tries = 0
while not file_contents and tries < (max_wait / 4):
try:
file_contents = sentry_unit.file_contents(file_name)
except IOError:
self.log.debug('Attempt {} to open file {} from {} '
'failed'.format(tries, file_name,
unit_name))
time.sleep(4)
tries += 1
if file_contents:
return file_contents
elif not fatal:
return None
elif fatal:
msg = 'Failed to get file contents from unit.'
amulet.raise_status(amulet.FAIL, msg)
def port_knock_tcp(self, host="localhost", port=22, timeout=15):
"""Open a TCP socket to check for a listening sevice on a host.
:param host: host name or IP address, default to localhost
:param port: TCP port number, default to 22
:param timeout: Connect timeout, default to 15 seconds
:returns: True if successful, False if connect failed
"""
# Resolve host name if possible
try:
connect_host = socket.gethostbyname(host)
host_human = "{} ({})".format(connect_host, host)
except socket.error as e:
self.log.warn('Unable to resolve address: '
'{} ({}) Trying anyway!'.format(host, e))
connect_host = host
host_human = connect_host
# Attempt socket connection
try:
knock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
knock.settimeout(timeout)
knock.connect((connect_host, port))
knock.close()
self.log.debug('Socket connect OK for host '
'{} on port {}.'.format(host_human, port))
return True
except socket.error as e:
self.log.debug('Socket connect FAIL for'
' {} port {} ({})'.format(host_human, port, e))
return False
def port_knock_units(self, sentry_units, port=22,
timeout=15, expect_success=True):
"""Open a TCP socket to check for a listening sevice on each
listed juju unit.
:param sentry_units: list of sentry unit pointers
:param port: TCP port number, default to 22
:param timeout: Connect timeout, default to 15 seconds
:expect_success: True by default, set False to invert logic
:returns: None if successful, Failure message otherwise
"""
for unit in sentry_units:
host = unit.info['public-address']
connected = self.port_knock_tcp(host, port, timeout)
if not connected and expect_success:
return 'Socket connect failed.'
elif connected and not expect_success:
return 'Socket connected unexpectedly.'
def get_uuid_epoch_stamp(self):
"""Returns a stamp string based on uuid4 and epoch time. Useful in
generating test messages which need to be unique-ish."""
return '[{}-{}]'.format(uuid.uuid4(), time.time())
# amulet juju action helpers:
def run_action(self, unit_sentry, action,
_check_output=subprocess.check_output,
params=None):
"""Run the named action on a given unit sentry.
params a dict of parameters to use
_check_output parameter is used for dependency injection.
@return action_id.
"""
unit_id = unit_sentry.info["unit_name"]
command = ["juju", "action", "do", "--format=json", unit_id, action]
if params is not None:
for key, value in params.iteritems():
command.append("{}={}".format(key, value))
self.log.info("Running command: %s\n" % " ".join(command))
output = _check_output(command, universal_newlines=True)
data = json.loads(output)
action_id = data[u'Action queued with id']
return action_id
def wait_on_action(self, action_id, _check_output=subprocess.check_output):
"""Wait for a given action, returning if it completed or not.
_check_output parameter is used for dependency injection.
"""
command = ["juju", "action", "fetch", "--format=json", "--wait=0",
action_id]
output = _check_output(command, universal_newlines=True)
data = json.loads(output)
return data.get(u"status") == "completed"
def status_get(self, unit):
"""Return the current service status of this unit."""
raw_status, return_code = unit.run(
"status-get --format=json --include-data")
if return_code != 0:
return ("unknown", "")
status = json.loads(raw_status)
return (status["status"], status["message"])

View File

@ -148,6 +148,13 @@ define service {{
self.description = description
self.check_cmd = self._locate_cmd(check_cmd)
def _get_check_filename(self):
return os.path.join(NRPE.nrpe_confdir, '{}.cfg'.format(self.command))
def _get_service_filename(self, hostname):
return os.path.join(NRPE.nagios_exportdir,
'service__{}_{}.cfg'.format(hostname, self.command))
def _locate_cmd(self, check_cmd):
search_path = (
'/usr/lib/nagios/plugins',
@ -163,9 +170,21 @@ define service {{
log('Check command not found: {}'.format(parts[0]))
return ''
def _remove_service_files(self):
if not os.path.exists(NRPE.nagios_exportdir):
return
for f in os.listdir(NRPE.nagios_exportdir):
if f.endswith('_{}.cfg'.format(self.command)):
os.remove(os.path.join(NRPE.nagios_exportdir, f))
def remove(self, hostname):
nrpe_check_file = self._get_check_filename()
if os.path.exists(nrpe_check_file):
os.remove(nrpe_check_file)
self._remove_service_files()
def write(self, nagios_context, hostname, nagios_servicegroups):
nrpe_check_file = '/etc/nagios/nrpe.d/{}.cfg'.format(
self.command)
nrpe_check_file = self._get_check_filename()
with open(nrpe_check_file, 'w') as nrpe_check_config:
nrpe_check_config.write("# check {}\n".format(self.shortname))
nrpe_check_config.write("command[{}]={}\n".format(
@ -180,9 +199,7 @@ define service {{
def write_service_config(self, nagios_context, hostname,
nagios_servicegroups):
for f in os.listdir(NRPE.nagios_exportdir):
if re.search('.*{}.cfg'.format(self.command), f):
os.remove(os.path.join(NRPE.nagios_exportdir, f))
self._remove_service_files()
templ_vars = {
'nagios_hostname': hostname,
@ -192,8 +209,7 @@ define service {{
'command': self.command,
}
nrpe_service_text = Check.service_template.format(**templ_vars)
nrpe_service_file = '{}/service__{}_{}.cfg'.format(
NRPE.nagios_exportdir, hostname, self.command)
nrpe_service_file = self._get_service_filename(hostname)
with open(nrpe_service_file, 'w') as nrpe_service_config:
nrpe_service_config.write(str(nrpe_service_text))
@ -218,12 +234,32 @@ class NRPE(object):
if hostname:
self.hostname = hostname
else:
self.hostname = "{}-{}".format(self.nagios_context, self.unit_name)
nagios_hostname = get_nagios_hostname()
if nagios_hostname:
self.hostname = nagios_hostname
else:
self.hostname = "{}-{}".format(self.nagios_context, self.unit_name)
self.checks = []
def add_check(self, *args, **kwargs):
self.checks.append(Check(*args, **kwargs))
def remove_check(self, *args, **kwargs):
if kwargs.get('shortname') is None:
raise ValueError('shortname of check must be specified')
# Use sensible defaults if they're not specified - these are not
# actually used during removal, but they're required for constructing
# the Check object; check_disk is chosen because it's part of the
# nagios-plugins-basic package.
if kwargs.get('check_cmd') is None:
kwargs['check_cmd'] = 'check_disk'
if kwargs.get('description') is None:
kwargs['description'] = ''
check = Check(*args, **kwargs)
check.remove(self.hostname)
def write(self):
try:
nagios_uid = pwd.getpwnam('nagios').pw_uid
@ -260,7 +296,7 @@ def get_nagios_hostcontext(relation_name='nrpe-external-master'):
:param str relation_name: Name of relation nrpe sub joined to
"""
for rel in relations_of_type(relation_name):
if 'nagios_hostname' in rel:
if 'nagios_host_context' in rel:
return rel['nagios_host_context']
@ -301,11 +337,13 @@ def add_init_service_checks(nrpe, services, unit_name):
upstart_init = '/etc/init/%s.conf' % svc
sysv_init = '/etc/init.d/%s' % svc
if os.path.exists(upstart_init):
nrpe.add_check(
shortname=svc,
description='process check {%s}' % unit_name,
check_cmd='check_upstart_job %s' % svc
)
# Don't add a check for these services from neutron-gateway
if svc not in ['ext-port', 'os-charm-phy-nic-mtu']:
nrpe.add_check(
shortname=svc,
description='process check {%s}' % unit_name,
check_cmd='check_upstart_job %s' % svc
)
elif os.path.exists(sysv_init):
cronpath = '/etc/cron.d/nagios-service-check-%s' % svc
cron_file = ('*/5 * * * * root '

View File

@ -0,0 +1,15 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.

View File

@ -0,0 +1,19 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from os import path
TEMPLATES_DIR = path.join(path.dirname(__file__), 'templates')

View File

@ -0,0 +1,31 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from charmhelpers.core.hookenv import (
log,
DEBUG,
)
from charmhelpers.contrib.hardening.apache.checks import config
def run_apache_checks():
log("Starting Apache hardening checks.", level=DEBUG)
checks = config.get_audits()
for check in checks:
log("Running '%s' check" % (check.__class__.__name__), level=DEBUG)
check.ensure_compliance()
log("Apache hardening checks complete.", level=DEBUG)

View File

@ -0,0 +1,100 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import os
import re
import subprocess
from charmhelpers.core.hookenv import (
log,
INFO,
)
from charmhelpers.contrib.hardening.audits.file import (
FilePermissionAudit,
DirectoryPermissionAudit,
NoReadWriteForOther,
TemplatedFile,
)
from charmhelpers.contrib.hardening.audits.apache import DisabledModuleAudit
from charmhelpers.contrib.hardening.apache import TEMPLATES_DIR
from charmhelpers.contrib.hardening import utils
def get_audits():
"""Get Apache hardening config audits.
:returns: dictionary of audits
"""
if subprocess.call(['which', 'apache2'], stdout=subprocess.PIPE) != 0:
log("Apache server does not appear to be installed on this node - "
"skipping apache hardening", level=INFO)
return []
context = ApacheConfContext()
settings = utils.get_settings('apache')
audits = [
FilePermissionAudit(paths='/etc/apache2/apache2.conf', user='root',
group='root', mode=0o0640),
TemplatedFile(os.path.join(settings['common']['apache_dir'],
'mods-available/alias.conf'),
context,
TEMPLATES_DIR,
mode=0o0755,
user='root',
service_actions=[{'service': 'apache2',
'actions': ['restart']}]),
TemplatedFile(os.path.join(settings['common']['apache_dir'],
'conf-enabled/hardening.conf'),
context,
TEMPLATES_DIR,
mode=0o0640,
user='root',
service_actions=[{'service': 'apache2',
'actions': ['restart']}]),
DirectoryPermissionAudit(settings['common']['apache_dir'],
user='root',
group='root',
mode=0o640),
DisabledModuleAudit(settings['hardening']['modules_to_disable']),
NoReadWriteForOther(settings['common']['apache_dir']),
]
return audits
class ApacheConfContext(object):
"""Defines the set of key/value pairs to set in a apache config file.
This context, when called, will return a dictionary containing the
key/value pairs of setting to specify in the
/etc/apache/conf-enabled/hardening.conf file.
"""
def __call__(self):
settings = utils.get_settings('apache')
ctxt = settings['hardening']
out = subprocess.check_output(['apache2', '-v'])
ctxt['apache_version'] = re.search(r'.+version: Apache/(.+?)\s.+',
out).group(1)
ctxt['apache_icondir'] = '/usr/share/apache2/icons/'
ctxt['traceenable'] = settings['hardening']['traceenable']
return ctxt

View File

@ -0,0 +1,63 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
class BaseAudit(object): # NO-QA
"""Base class for hardening checks.
The lifecycle of a hardening check is to first check to see if the system
is in compliance for the specified check. If it is not in compliance, the
check method will return a value which will be supplied to the.
"""
def __init__(self, *args, **kwargs):
self.unless = kwargs.get('unless', None)
super(BaseAudit, self).__init__()
def ensure_compliance(self):
"""Checks to see if the current hardening check is in compliance or
not.
If the check that is performed is not in compliance, then an exception
should be raised.
"""
pass
def _take_action(self):
"""Determines whether to perform the action or not.
Checks whether or not an action should be taken. This is determined by
the truthy value for the unless parameter. If unless is a callback
method, it will be invoked with no parameters in order to determine
whether or not the action should be taken. Otherwise, the truthy value
of the unless attribute will determine if the action should be
performed.
"""
# Do the action if there isn't an unless override.
if self.unless is None:
return True
# Invoke the callback if there is one.
if hasattr(self.unless, '__call__'):
results = self.unless()
if results:
return False
else:
return True
if self.unless:
return False
else:
return True

View File

@ -0,0 +1,100 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import re
import subprocess
from six import string_types
from charmhelpers.core.hookenv import (
log,
INFO,
ERROR,
)
from charmhelpers.contrib.hardening.audits import BaseAudit
class DisabledModuleAudit(BaseAudit):
"""Audits Apache2 modules.
Determines if the apache2 modules are enabled. If the modules are enabled
then they are removed in the ensure_compliance.
"""
def __init__(self, modules):
if modules is None:
self.modules = []
elif isinstance(modules, string_types):
self.modules = [modules]
else:
self.modules = modules
def ensure_compliance(self):
"""Ensures that the modules are not loaded."""
if not self.modules:
return
try:
loaded_modules = self._get_loaded_modules()
non_compliant_modules = []
for module in self.modules:
if module in loaded_modules:
log("Module '%s' is enabled but should not be." %
(module), level=INFO)
non_compliant_modules.append(module)
if len(non_compliant_modules) == 0:
return
for module in non_compliant_modules:
self._disable_module(module)
self._restart_apache()
except subprocess.CalledProcessError as e:
log('Error occurred auditing apache module compliance. '
'This may have been already reported. '
'Output is: %s' % e.output, level=ERROR)
@staticmethod
def _get_loaded_modules():
"""Returns the modules which are enabled in Apache."""
output = subprocess.check_output(['apache2ctl', '-M'])
modules = []
for line in output.strip().split():
# Each line of the enabled module output looks like:
# module_name (static|shared)
# Plus a header line at the top of the output which is stripped
# out by the regex.
matcher = re.search(r'^ (\S*)', line)
if matcher:
modules.append(matcher.group(1))
return modules
@staticmethod
def _disable_module(module):
"""Disables the specified module in Apache."""
try:
subprocess.check_call(['a2dismod', module])
except subprocess.CalledProcessError as e:
# Note: catch error here to allow the attempt of disabling
# multiple modules in one go rather than failing after the
# first module fails.
log('Error occurred disabling module %s. '
'Output is: %s' % (module, e.output), level=ERROR)
@staticmethod
def _restart_apache():
"""Restarts the apache process"""
subprocess.check_output(['service', 'apache2', 'restart'])

View File

@ -0,0 +1,105 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import # required for external apt import
from apt import apt_pkg
from six import string_types
from charmhelpers.fetch import (
apt_cache,
apt_purge
)
from charmhelpers.core.hookenv import (
log,
DEBUG,
WARNING,
)
from charmhelpers.contrib.hardening.audits import BaseAudit
class AptConfig(BaseAudit):
def __init__(self, config, **kwargs):
self.config = config
def verify_config(self):
apt_pkg.init()
for cfg in self.config:
value = apt_pkg.config.get(cfg['key'], cfg.get('default', ''))
if value and value != cfg['expected']:
log("APT config '%s' has unexpected value '%s' "
"(expected='%s')" %
(cfg['key'], value, cfg['expected']), level=WARNING)
def ensure_compliance(self):
self.verify_config()
class RestrictedPackages(BaseAudit):
"""Class used to audit restricted packages on the system."""
def __init__(self, pkgs, **kwargs):
super(RestrictedPackages, self).__init__(**kwargs)
if isinstance(pkgs, string_types) or not hasattr(pkgs, '__iter__'):
self.pkgs = [pkgs]
else:
self.pkgs = pkgs
def ensure_compliance(self):
cache = apt_cache()
for p in self.pkgs:
if p not in cache:
continue
pkg = cache[p]
if not self.is_virtual_package(pkg):
if not pkg.current_ver:
log("Package '%s' is not installed." % pkg.name,
level=DEBUG)
continue
else:
log("Restricted package '%s' is installed" % pkg.name,
level=WARNING)
self.delete_package(cache, pkg)
else:
log("Checking restricted virtual package '%s' provides" %
pkg.name, level=DEBUG)
self.delete_package(cache, pkg)
def delete_package(self, cache, pkg):
"""Deletes the package from the system.
Deletes the package form the system, properly handling virtual
packages.
:param cache: the apt cache
:param pkg: the package to remove
"""
if self.is_virtual_package(pkg):
log("Package '%s' appears to be virtual - purging provides" %
pkg.name, level=DEBUG)
for _p in pkg.provides_list:
self.delete_package(cache, _p[2].parent_pkg)
elif not pkg.current_ver:
log("Package '%s' not installed" % pkg.name, level=DEBUG)
return
else:
log("Purging package '%s'" % pkg.name, level=DEBUG)
apt_purge(pkg.name)
def is_virtual_package(self, pkg):
return pkg.has_provides and not pkg.has_versions

View File

@ -0,0 +1,552 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import grp
import os
import pwd
import re
from subprocess import (
CalledProcessError,
check_output,
check_call,
)
from traceback import format_exc
from six import string_types
from stat import (
S_ISGID,
S_ISUID
)
from charmhelpers.core.hookenv import (
log,
DEBUG,
INFO,
WARNING,
ERROR,
)
from charmhelpers.core import unitdata
from charmhelpers.core.host import file_hash
from charmhelpers.contrib.hardening.audits import BaseAudit
from charmhelpers.contrib.hardening.templating import (
get_template_path,
render_and_write,
)
from charmhelpers.contrib.hardening import utils
class BaseFileAudit(BaseAudit):
"""Base class for file audits.
Provides api stubs for compliance check flow that must be used by any class
that implemented this one.
"""
def __init__(self, paths, always_comply=False, *args, **kwargs):
"""
:param paths: string path of list of paths of files we want to apply
compliance checks are criteria to.
:param always_comply: if true compliance criteria is always applied
else compliance is skipped for non-existent
paths.
"""
super(BaseFileAudit, self).__init__(*args, **kwargs)
self.always_comply = always_comply
if isinstance(paths, string_types) or not hasattr(paths, '__iter__'):
self.paths = [paths]
else:
self.paths = paths
def ensure_compliance(self):
"""Ensure that the all registered files comply to registered criteria.
"""
for p in self.paths:
if os.path.exists(p):
if self.is_compliant(p):
continue
log('File %s is not in compliance.' % p, level=INFO)
else:
if not self.always_comply:
log("Non-existent path '%s' - skipping compliance check"
% (p), level=INFO)
continue
if self._take_action():
log("Applying compliance criteria to '%s'" % (p), level=INFO)
self.comply(p)
def is_compliant(self, path):
"""Audits the path to see if it is compliance.
:param path: the path to the file that should be checked.
"""
raise NotImplementedError
def comply(self, path):
"""Enforces the compliance of a path.
:param path: the path to the file that should be enforced.
"""
raise NotImplementedError
@classmethod
def _get_stat(cls, path):
"""Returns the Posix st_stat information for the specified file path.
:param path: the path to get the st_stat information for.
:returns: an st_stat object for the path or None if the path doesn't
exist.
"""
return os.stat(path)
class FilePermissionAudit(BaseFileAudit):
"""Implements an audit for file permissions and ownership for a user.
This class implements functionality that ensures that a specific user/group
will own the file(s) specified and that the permissions specified are
applied properly to the file.
"""
def __init__(self, paths, user, group=None, mode=0o600, **kwargs):
self.user = user
self.group = group
self.mode = mode
super(FilePermissionAudit, self).__init__(paths, user, group, mode,
**kwargs)
@property
def user(self):
return self._user
@user.setter
def user(self, name):
try:
user = pwd.getpwnam(name)
except KeyError:
log('Unknown user %s' % name, level=ERROR)
user = None
self._user = user
@property
def group(self):
return self._group
@group.setter
def group(self, name):
try:
group = None
if name:
group = grp.getgrnam(name)
else:
group = grp.getgrgid(self.user.pw_gid)
except KeyError:
log('Unknown group %s' % name, level=ERROR)
self._group = group
def is_compliant(self, path):
"""Checks if the path is in compliance.
Used to determine if the path specified meets the necessary
requirements to be in compliance with the check itself.
:param path: the file path to check
:returns: True if the path is compliant, False otherwise.
"""
stat = self._get_stat(path)
user = self.user
group = self.group
compliant = True
if stat.st_uid != user.pw_uid or stat.st_gid != group.gr_gid:
log('File %s is not owned by %s:%s.' % (path, user.pw_name,
group.gr_name),
level=INFO)
compliant = False
# POSIX refers to the st_mode bits as corresponding to both the
# file type and file permission bits, where the least significant 12
# bits (o7777) are the suid (11), sgid (10), sticky bits (9), and the
# file permission bits (8-0)
perms = stat.st_mode & 0o7777
if perms != self.mode:
log('File %s has incorrect permissions, currently set to %s' %
(path, oct(stat.st_mode & 0o7777)), level=INFO)
compliant = False
return compliant
def comply(self, path):
"""Issues a chown and chmod to the file paths specified."""
utils.ensure_permissions(path, self.user.pw_name, self.group.gr_name,
self.mode)
class DirectoryPermissionAudit(FilePermissionAudit):
"""Performs a permission check for the specified directory path."""
def __init__(self, paths, user, group=None, mode=0o600,
recursive=True, **kwargs):
super(DirectoryPermissionAudit, self).__init__(paths, user, group,
mode, **kwargs)
self.recursive = recursive
def is_compliant(self, path):
"""Checks if the directory is compliant.
Used to determine if the path specified and all of its children
directories are in compliance with the check itself.
:param path: the directory path to check
:returns: True if the directory tree is compliant, otherwise False.
"""
if not os.path.isdir(path):
log('Path specified %s is not a directory.' % path, level=ERROR)
raise ValueError("%s is not a directory." % path)
if not self.recursive:
return super(DirectoryPermissionAudit, self).is_compliant(path)
compliant = True
for root, dirs, _ in os.walk(path):
if len(dirs) > 0:
continue
if not super(DirectoryPermissionAudit, self).is_compliant(root):
compliant = False
continue
return compliant
def comply(self, path):
for root, dirs, _ in os.walk(path):
if len(dirs) > 0:
super(DirectoryPermissionAudit, self).comply(root)
class ReadOnly(BaseFileAudit):
"""Audits that files and folders are read only."""
def __init__(self, paths, *args, **kwargs):
super(ReadOnly, self).__init__(paths=paths, *args, **kwargs)
def is_compliant(self, path):
try:
output = check_output(['find', path, '-perm', '-go+w',
'-type', 'f']).strip()
# The find above will find any files which have permission sets
# which allow too broad of write access. As such, the path is
# compliant if there is no output.
if output:
return False
return True
except CalledProcessError as e:
log('Error occurred checking finding writable files for %s. '
'Error information is: command %s failed with returncode '
'%d and output %s.\n%s' % (path, e.cmd, e.returncode, e.output,
format_exc(e)), level=ERROR)
return False
def comply(self, path):
try:
check_output(['chmod', 'go-w', '-R', path])
except CalledProcessError as e:
log('Error occurred removing writeable permissions for %s. '
'Error information is: command %s failed with returncode '
'%d and output %s.\n%s' % (path, e.cmd, e.returncode, e.output,
format_exc(e)), level=ERROR)
class NoReadWriteForOther(BaseFileAudit):
"""Ensures that the files found under the base path are readable or
writable by anyone other than the owner or the group.
"""
def __init__(self, paths):
super(NoReadWriteForOther, self).__init__(paths)
def is_compliant(self, path):
try:
cmd = ['find', path, '-perm', '-o+r', '-type', 'f', '-o',
'-perm', '-o+w', '-type', 'f']
output = check_output(cmd).strip()
# The find above here will find any files which have read or
# write permissions for other, meaning there is too broad of access
# to read/write the file. As such, the path is compliant if there's
# no output.
if output:
return False
return True
except CalledProcessError as e:
log('Error occurred while finding files which are readable or '
'writable to the world in %s. '
'Command output is: %s.' % (path, e.output), level=ERROR)
def comply(self, path):
try:
check_output(['chmod', '-R', 'o-rw', path])
except CalledProcessError as e:
log('Error occurred attempting to change modes of files under '
'path %s. Output of command is: %s' % (path, e.output))
class NoSUIDSGIDAudit(BaseFileAudit):
"""Audits that specified files do not have SUID/SGID bits set."""
def __init__(self, paths, *args, **kwargs):
super(NoSUIDSGIDAudit, self).__init__(paths=paths, *args, **kwargs)
def is_compliant(self, path):
stat = self._get_stat(path)
if (stat.st_mode & (S_ISGID | S_ISUID)) != 0:
return False
return True
def comply(self, path):
try:
log('Removing suid/sgid from %s.' % path, level=DEBUG)
check_output(['chmod', '-s', path])
except CalledProcessError as e:
log('Error occurred removing suid/sgid from %s.'
'Error information is: command %s failed with returncode '
'%d and output %s.\n%s' % (path, e.cmd, e.returncode, e.output,
format_exc(e)), level=ERROR)
class TemplatedFile(BaseFileAudit):
"""The TemplatedFileAudit audits the contents of a templated file.
This audit renders a file from a template, sets the appropriate file
permissions, then generates a hashsum with which to check the content
changed.
"""
def __init__(self, path, context, template_dir, mode, user='root',
group='root', service_actions=None, **kwargs):
self.context = context
self.user = user
self.group = group
self.mode = mode
self.template_dir = template_dir
self.service_actions = service_actions
super(TemplatedFile, self).__init__(paths=path, always_comply=True,
**kwargs)
def is_compliant(self, path):
"""Determines if the templated file is compliant.
A templated file is only compliant if it has not changed (as
determined by its sha256 hashsum) AND its file permissions are set
appropriately.
:param path: the path to check compliance.
"""
same_templates = self.templates_match(path)
same_content = self.contents_match(path)
same_permissions = self.permissions_match(path)
if same_content and same_permissions and same_templates:
return True
return False
def run_service_actions(self):
"""Run any actions on services requested."""
if not self.service_actions:
return
for svc_action in self.service_actions:
name = svc_action['service']
actions = svc_action['actions']
log("Running service '%s' actions '%s'" % (name, actions),
level=DEBUG)
for action in actions:
cmd = ['service', name, action]
try:
check_call(cmd)
except CalledProcessError as exc:
log("Service name='%s' action='%s' failed - %s" %
(name, action, exc), level=WARNING)
def comply(self, path):
"""Ensures the contents and the permissions of the file.
:param path: the path to correct
"""
dirname = os.path.dirname(path)
if not os.path.exists(dirname):
os.makedirs(dirname)
self.pre_write()
render_and_write(self.template_dir, path, self.context())
utils.ensure_permissions(path, self.user, self.group, self.mode)
self.run_service_actions()
self.save_checksum(path)
self.post_write()
def pre_write(self):
"""Invoked prior to writing the template."""
pass
def post_write(self):
"""Invoked after writing the template."""
pass
def templates_match(self, path):
"""Determines if the template files are the same.
The template file equality is determined by the hashsum of the
template files themselves. If there is no hashsum, then the content
cannot be sure to be the same so treat it as if they changed.
Otherwise, return whether or not the hashsums are the same.
:param path: the path to check
:returns: boolean
"""
template_path = get_template_path(self.template_dir, path)
key = 'hardening:template:%s' % template_path
template_checksum = file_hash(template_path)
kv = unitdata.kv()
stored_tmplt_checksum = kv.get(key)
if not stored_tmplt_checksum:
kv.set(key, template_checksum)
kv.flush()
log('Saved template checksum for %s.' % template_path,
level=DEBUG)
# Since we don't have a template checksum, then assume it doesn't
# match and return that the template is different.
return False
elif stored_tmplt_checksum != template_checksum:
kv.set(key, template_checksum)
kv.flush()
log('Updated template checksum for %s.' % template_path,
level=DEBUG)
return False
# Here the template hasn't changed based upon the calculated
# checksum of the template and what was previously stored.
return True
def contents_match(self, path):
"""Determines if the file content is the same.
This is determined by comparing hashsum of the file contents and
the saved hashsum. If there is no hashsum, then the content cannot
be sure to be the same so treat them as if they are not the same.
Otherwise, return True if the hashsums are the same, False if they
are not the same.
:param path: the file to check.
"""
checksum = file_hash(path)
kv = unitdata.kv()
stored_checksum = kv.get('hardening:%s' % path)
if not stored_checksum:
# If the checksum hasn't been generated, return False to ensure
# the file is written and the checksum stored.
log('Checksum for %s has not been calculated.' % path, level=DEBUG)
return False
elif stored_checksum != checksum:
log('Checksum mismatch for %s.' % path, level=DEBUG)
return False
return True
def permissions_match(self, path):
"""Determines if the file owner and permissions match.
:param path: the path to check.
"""
audit = FilePermissionAudit(path, self.user, self.group, self.mode)
return audit.is_compliant(path)
def save_checksum(self, path):
"""Calculates and saves the checksum for the path specified.
:param path: the path of the file to save the checksum.
"""
checksum = file_hash(path)
kv = unitdata.kv()
kv.set('hardening:%s' % path, checksum)
kv.flush()
class DeletedFile(BaseFileAudit):
"""Audit to ensure that a file is deleted."""
def __init__(self, paths):
super(DeletedFile, self).__init__(paths)
def is_compliant(self, path):
return not os.path.exists(path)
def comply(self, path):
os.remove(path)
class FileContentAudit(BaseFileAudit):
"""Audit the contents of a file."""
def __init__(self, paths, cases, **kwargs):
# Cases we expect to pass
self.pass_cases = cases.get('pass', [])
# Cases we expect to fail
self.fail_cases = cases.get('fail', [])
super(FileContentAudit, self).__init__(paths, **kwargs)
def is_compliant(self, path):
"""
Given a set of content matching cases i.e. tuple(regex, bool) where
bool value denotes whether or not regex is expected to match, check that
all cases match as expected with the contents of the file. Cases can be
expected to pass of fail.
:param path: Path of file to check.
:returns: Boolean value representing whether or not all cases are
found to be compliant.
"""
log("Auditing contents of file '%s'" % (path), level=DEBUG)
with open(path, 'r') as fd:
contents = fd.read()
matches = 0
for pattern in self.pass_cases:
key = re.compile(pattern, flags=re.MULTILINE)
results = re.search(key, contents)
if results:
matches += 1
else:
log("Pattern '%s' was expected to pass but instead it failed"
% (pattern), level=WARNING)
for pattern in self.fail_cases:
key = re.compile(pattern, flags=re.MULTILINE)
results = re.search(key, contents)
if not results:
matches += 1
else:
log("Pattern '%s' was expected to fail but instead it passed"
% (pattern), level=WARNING)
total = len(self.pass_cases) + len(self.fail_cases)
log("Checked %s cases and %s passed" % (total, matches), level=DEBUG)
return matches == total
def comply(self, *args, **kwargs):
"""NOOP since we just issue warnings. This is to avoid the
NotImplememtedError.
"""
log("Not applying any compliance criteria, only checks.", level=INFO)

View File

@ -0,0 +1,84 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import six
from collections import OrderedDict
from charmhelpers.core.hookenv import (
config,
log,
DEBUG,
WARNING,
)
from charmhelpers.contrib.hardening.host.checks import run_os_checks
from charmhelpers.contrib.hardening.ssh.checks import run_ssh_checks
from charmhelpers.contrib.hardening.mysql.checks import run_mysql_checks
from charmhelpers.contrib.hardening.apache.checks import run_apache_checks
def harden(overrides=None):
"""Hardening decorator.
This is the main entry point for running the hardening stack. In order to
run modules of the stack you must add this decorator to charm hook(s) and
ensure that your charm config.yaml contains the 'harden' option set to
one or more of the supported modules. Setting these will cause the
corresponding hardening code to be run when the hook fires.
This decorator can and should be applied to more than one hook or function
such that hardening modules are called multiple times. This is because
subsequent calls will perform auditing checks that will report any changes
to resources hardened by the first run (and possibly perform compliance
actions as a result of any detected infractions).
:param overrides: Optional list of stack modules used to override those
provided with 'harden' config.
:returns: Returns value returned by decorated function once executed.
"""
def _harden_inner1(f):
log("Hardening function '%s'" % (f.__name__), level=DEBUG)
def _harden_inner2(*args, **kwargs):
RUN_CATALOG = OrderedDict([('os', run_os_checks),
('ssh', run_ssh_checks),
('mysql', run_mysql_checks),
('apache', run_apache_checks)])
enabled = overrides or (config("harden") or "").split()
if enabled:
modules_to_run = []
# modules will always be performed in the following order
for module, func in six.iteritems(RUN_CATALOG):
if module in enabled:
enabled.remove(module)
modules_to_run.append(func)
if enabled:
log("Unknown hardening modules '%s' - ignoring" %
(', '.join(enabled)), level=WARNING)
for hardener in modules_to_run:
log("Executing hardening module '%s'" %
(hardener.__name__), level=DEBUG)
hardener()
else:
log("No hardening applied to '%s'" % (f.__name__), level=DEBUG)
return f(*args, **kwargs)
return _harden_inner2
return _harden_inner1

View File

@ -0,0 +1,19 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from os import path
TEMPLATES_DIR = path.join(path.dirname(__file__), 'templates')

View File

@ -0,0 +1,50 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from charmhelpers.core.hookenv import (
log,
DEBUG,
)
from charmhelpers.contrib.hardening.host.checks import (
apt,
limits,
login,
minimize_access,
pam,
profile,
securetty,
suid_sgid,
sysctl
)
def run_os_checks():
log("Starting OS hardening checks.", level=DEBUG)
checks = apt.get_audits()
checks.extend(limits.get_audits())
checks.extend(login.get_audits())
checks.extend(minimize_access.get_audits())
checks.extend(pam.get_audits())
checks.extend(profile.get_audits())
checks.extend(securetty.get_audits())
checks.extend(suid_sgid.get_audits())
checks.extend(sysctl.get_audits())
for check in checks:
log("Running '%s' check" % (check.__class__.__name__), level=DEBUG)
check.ensure_compliance()
log("OS hardening checks complete.", level=DEBUG)

View File

@ -0,0 +1,39 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from charmhelpers.contrib.hardening.utils import get_settings
from charmhelpers.contrib.hardening.audits.apt import (
AptConfig,
RestrictedPackages,
)
def get_audits():
"""Get OS hardening apt audits.
:returns: dictionary of audits
"""
audits = [AptConfig([{'key': 'APT::Get::AllowUnauthenticated',
'expected': 'false'}])]
settings = get_settings('os')
clean_packages = settings['security']['packages_clean']
if clean_packages:
security_packages = settings['security']['packages_list']
if security_packages:
audits.append(RestrictedPackages(security_packages))
return audits

View File

@ -0,0 +1,55 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from charmhelpers.contrib.hardening.audits.file import (
DirectoryPermissionAudit,
TemplatedFile,
)
from charmhelpers.contrib.hardening.host import TEMPLATES_DIR
from charmhelpers.contrib.hardening import utils
def get_audits():
"""Get OS hardening security limits audits.
:returns: dictionary of audits
"""
audits = []
settings = utils.get_settings('os')
# Ensure that the /etc/security/limits.d directory is only writable
# by the root user, but others can execute and read.
audits.append(DirectoryPermissionAudit('/etc/security/limits.d',
user='root', group='root',
mode=0o755))
# If core dumps are not enabled, then don't allow core dumps to be
# created as they may contain sensitive information.
if not settings['security']['kernel_enable_core_dump']:
audits.append(TemplatedFile('/etc/security/limits.d/10.hardcore.conf',
SecurityLimitsContext(),
template_dir=TEMPLATES_DIR,
user='root', group='root', mode=0o0440))
return audits
class SecurityLimitsContext(object):
def __call__(self):
settings = utils.get_settings('os')
ctxt = {'disable_core_dump':
not settings['security']['kernel_enable_core_dump']}
return ctxt

View File

@ -0,0 +1,67 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from six import string_types
from charmhelpers.contrib.hardening.audits.file import TemplatedFile
from charmhelpers.contrib.hardening.host import TEMPLATES_DIR
from charmhelpers.contrib.hardening import utils
def get_audits():
"""Get OS hardening login.defs audits.
:returns: dictionary of audits
"""
audits = [TemplatedFile('/etc/login.defs', LoginContext(),
template_dir=TEMPLATES_DIR,
user='root', group='root', mode=0o0444)]
return audits
class LoginContext(object):
def __call__(self):
settings = utils.get_settings('os')
# Octal numbers in yaml end up being turned into decimal,
# so check if the umask is entered as a string (e.g. '027')
# or as an octal umask as we know it (e.g. 002). If its not
# a string assume it to be octal and turn it into an octal
# string.
umask = settings['environment']['umask']
if not isinstance(umask, string_types):
umask = '%s' % oct(umask)
ctxt = {
'additional_user_paths':
settings['environment']['extra_user_paths'],
'umask': umask,
'pwd_max_age': settings['auth']['pw_max_age'],
'pwd_min_age': settings['auth']['pw_min_age'],
'uid_min': settings['auth']['uid_min'],
'sys_uid_min': settings['auth']['sys_uid_min'],
'sys_uid_max': settings['auth']['sys_uid_max'],
'gid_min': settings['auth']['gid_min'],
'sys_gid_min': settings['auth']['sys_gid_min'],
'sys_gid_max': settings['auth']['sys_gid_max'],
'login_retries': settings['auth']['retries'],
'login_timeout': settings['auth']['timeout'],
'chfn_restrict': settings['auth']['chfn_restrict'],
'allow_login_without_home': settings['auth']['allow_homeless']
}
return ctxt

View File

@ -0,0 +1,52 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from charmhelpers.contrib.hardening.audits.file import (
FilePermissionAudit,
ReadOnly,
)
from charmhelpers.contrib.hardening import utils
def get_audits():
"""Get OS hardening access audits.
:returns: dictionary of audits
"""
audits = []
settings = utils.get_settings('os')
# Remove write permissions from $PATH folders for all regular users.
# This prevents changing system-wide commands from normal users.
path_folders = {'/usr/local/sbin',
'/usr/local/bin',
'/usr/sbin',
'/usr/bin',
'/bin'}
extra_user_paths = settings['environment']['extra_user_paths']
path_folders.update(extra_user_paths)
audits.append(ReadOnly(path_folders))
# Only allow the root user to have access to the shadow file.
audits.append(FilePermissionAudit('/etc/shadow', 'root', 'root', 0o0600))
if 'change_user' not in settings['security']['users_allow']:
# su should only be accessible to user and group root, unless it is
# expressly defined to allow users to change to root via the
# security_users_allow config option.
audits.append(FilePermissionAudit('/bin/su', 'root', 'root', 0o750))
return audits

View File

@ -0,0 +1,134 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from subprocess import (
check_output,
CalledProcessError,
)
from charmhelpers.core.hookenv import (
log,
DEBUG,
ERROR,
)
from charmhelpers.fetch import (
apt_install,
apt_purge,
apt_update,
)
from charmhelpers.contrib.hardening.audits.file import (
TemplatedFile,
DeletedFile,
)
from charmhelpers.contrib.hardening import utils
from charmhelpers.contrib.hardening.host import TEMPLATES_DIR
def get_audits():
"""Get OS hardening PAM authentication audits.
:returns: dictionary of audits
"""
audits = []
settings = utils.get_settings('os')
if settings['auth']['pam_passwdqc_enable']:
audits.append(PasswdqcPAM('/etc/passwdqc.conf'))
if settings['auth']['retries']:
audits.append(Tally2PAM('/usr/share/pam-configs/tally2'))
else:
audits.append(DeletedFile('/usr/share/pam-configs/tally2'))
return audits
class PasswdqcPAMContext(object):
def __call__(self):
ctxt = {}
settings = utils.get_settings('os')
ctxt['auth_pam_passwdqc_options'] = \
settings['auth']['pam_passwdqc_options']
return ctxt
class PasswdqcPAM(TemplatedFile):
"""The PAM Audit verifies the linux PAM settings."""
def __init__(self, path):
super(PasswdqcPAM, self).__init__(path=path,
template_dir=TEMPLATES_DIR,
context=PasswdqcPAMContext(),
user='root',
group='root',
mode=0o0640)
def pre_write(self):
# Always remove?
for pkg in ['libpam-ccreds', 'libpam-cracklib']:
log("Purging package '%s'" % pkg, level=DEBUG),
apt_purge(pkg)
apt_update(fatal=True)
for pkg in ['libpam-passwdqc']:
log("Installing package '%s'" % pkg, level=DEBUG),
apt_install(pkg)
def post_write(self):
"""Updates the PAM configuration after the file has been written"""
try:
check_output(['pam-auth-update', '--package'])
except CalledProcessError as e:
log('Error calling pam-auth-update: %s' % e, level=ERROR)
class Tally2PAMContext(object):
def __call__(self):
ctxt = {}
settings = utils.get_settings('os')
ctxt['auth_lockout_time'] = settings['auth']['lockout_time']
ctxt['auth_retries'] = settings['auth']['retries']
return ctxt
class Tally2PAM(TemplatedFile):
"""The PAM Audit verifies the linux PAM settings."""
def __init__(self, path):
super(Tally2PAM, self).__init__(path=path,
template_dir=TEMPLATES_DIR,
context=Tally2PAMContext(),
user='root',
group='root',
mode=0o0640)
def pre_write(self):
# Always remove?
apt_purge('libpam-ccreds')
apt_update(fatal=True)
apt_install('libpam-modules')
def post_write(self):
"""Updates the PAM configuration after the file has been written"""
try:
check_output(['pam-auth-update', '--package'])
except CalledProcessError as e:
log('Error calling pam-auth-update: %s' % e, level=ERROR)

View File

@ -0,0 +1,45 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from charmhelpers.contrib.hardening.audits.file import TemplatedFile
from charmhelpers.contrib.hardening.host import TEMPLATES_DIR
from charmhelpers.contrib.hardening import utils
def get_audits():
"""Get OS hardening profile audits.
:returns: dictionary of audits
"""
audits = []
settings = utils.get_settings('os')
# If core dumps are not enabled, then don't allow core dumps to be
# created as they may contain sensitive information.
if not settings['security']['kernel_enable_core_dump']:
audits.append(TemplatedFile('/etc/profile.d/pinerolo_profile.sh',
ProfileContext(),
template_dir=TEMPLATES_DIR,
mode=0o0755, user='root', group='root'))
return audits
class ProfileContext(object):
def __call__(self):
ctxt = {}
return ctxt

View File

@ -0,0 +1,39 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from charmhelpers.contrib.hardening.audits.file import TemplatedFile
from charmhelpers.contrib.hardening.host import TEMPLATES_DIR
from charmhelpers.contrib.hardening import utils
def get_audits():
"""Get OS hardening Secure TTY audits.
:returns: dictionary of audits
"""
audits = []
audits.append(TemplatedFile('/etc/securetty', SecureTTYContext(),
template_dir=TEMPLATES_DIR,
mode=0o0400, user='root', group='root'))
return audits
class SecureTTYContext(object):
def __call__(self):
settings = utils.get_settings('os')
ctxt = {'ttys': settings['auth']['root_ttys']}
return ctxt

View File

@ -0,0 +1,131 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import subprocess
from charmhelpers.core.hookenv import (
log,
INFO,
)
from charmhelpers.contrib.hardening.audits.file import NoSUIDSGIDAudit
from charmhelpers.contrib.hardening import utils
BLACKLIST = ['/usr/bin/rcp', '/usr/bin/rlogin', '/usr/bin/rsh',
'/usr/libexec/openssh/ssh-keysign',
'/usr/lib/openssh/ssh-keysign',
'/sbin/netreport',
'/usr/sbin/usernetctl',
'/usr/sbin/userisdnctl',
'/usr/sbin/pppd',
'/usr/bin/lockfile',
'/usr/bin/mail-lock',
'/usr/bin/mail-unlock',
'/usr/bin/mail-touchlock',
'/usr/bin/dotlockfile',
'/usr/bin/arping',
'/usr/sbin/uuidd',
'/usr/bin/mtr',
'/usr/lib/evolution/camel-lock-helper-1.2',
'/usr/lib/pt_chown',
'/usr/lib/eject/dmcrypt-get-device',
'/usr/lib/mc/cons.saver']
WHITELIST = ['/bin/mount', '/bin/ping', '/bin/su', '/bin/umount',
'/sbin/pam_timestamp_check', '/sbin/unix_chkpwd', '/usr/bin/at',
'/usr/bin/gpasswd', '/usr/bin/locate', '/usr/bin/newgrp',
'/usr/bin/passwd', '/usr/bin/ssh-agent',
'/usr/libexec/utempter/utempter', '/usr/sbin/lockdev',
'/usr/sbin/sendmail.sendmail', '/usr/bin/expiry',
'/bin/ping6', '/usr/bin/traceroute6.iputils',
'/sbin/mount.nfs', '/sbin/umount.nfs',
'/sbin/mount.nfs4', '/sbin/umount.nfs4',
'/usr/bin/crontab',
'/usr/bin/wall', '/usr/bin/write',
'/usr/bin/screen',
'/usr/bin/mlocate',
'/usr/bin/chage', '/usr/bin/chfn', '/usr/bin/chsh',
'/bin/fusermount',
'/usr/bin/pkexec',
'/usr/bin/sudo', '/usr/bin/sudoedit',
'/usr/sbin/postdrop', '/usr/sbin/postqueue',
'/usr/sbin/suexec',
'/usr/lib/squid/ncsa_auth', '/usr/lib/squid/pam_auth',
'/usr/kerberos/bin/ksu',
'/usr/sbin/ccreds_validate',
'/usr/bin/Xorg',
'/usr/bin/X',
'/usr/lib/dbus-1.0/dbus-daemon-launch-helper',
'/usr/lib/vte/gnome-pty-helper',
'/usr/lib/libvte9/gnome-pty-helper',
'/usr/lib/libvte-2.90-9/gnome-pty-helper']
def get_audits():
"""Get OS hardening suid/sgid audits.
:returns: dictionary of audits
"""
checks = []
settings = utils.get_settings('os')
if not settings['security']['suid_sgid_enforce']:
log("Skipping suid/sgid hardening", level=INFO)
return checks
# Build the blacklist and whitelist of files for suid/sgid checks.
# There are a total of 4 lists:
# 1. the system blacklist
# 2. the system whitelist
# 3. the user blacklist
# 4. the user whitelist
#
# The blacklist is the set of paths which should NOT have the suid/sgid bit
# set and the whitelist is the set of paths which MAY have the suid/sgid
# bit setl. The user whitelist/blacklist effectively override the system
# whitelist/blacklist.
u_b = settings['security']['suid_sgid_blacklist']
u_w = settings['security']['suid_sgid_whitelist']
blacklist = set(BLACKLIST) - set(u_w + u_b)
whitelist = set(WHITELIST) - set(u_b + u_w)
checks.append(NoSUIDSGIDAudit(blacklist))
dry_run = settings['security']['suid_sgid_dry_run_on_unknown']
if settings['security']['suid_sgid_remove_from_unknown'] or dry_run:
# If the policy is a dry_run (e.g. complain only) or remove unknown
# suid/sgid bits then find all of the paths which have the suid/sgid
# bit set and then remove the whitelisted paths.
root_path = settings['environment']['root_path']
unknown_paths = find_paths_with_suid_sgid(root_path) - set(whitelist)
checks.append(NoSUIDSGIDAudit(unknown_paths, unless=dry_run))
return checks
def find_paths_with_suid_sgid(root_path):
"""Finds all paths/files which have an suid/sgid bit enabled.
Starting with the root_path, this will recursively find all paths which
have an suid or sgid bit set.
"""
cmd = ['find', root_path, '-perm', '-4000', '-o', '-perm', '-2000',
'-type', 'f', '!', '-path', '/proc/*', '-print']
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, _ = p.communicate()
return set(out.split('\n'))

View File

@ -0,0 +1,211 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import os
import platform
import re
import six
import subprocess
from charmhelpers.core.hookenv import (
log,
INFO,
WARNING,
)
from charmhelpers.contrib.hardening import utils
from charmhelpers.contrib.hardening.audits.file import (
FilePermissionAudit,
TemplatedFile,
)
from charmhelpers.contrib.hardening.host import TEMPLATES_DIR
SYSCTL_DEFAULTS = """net.ipv4.ip_forward=%(net_ipv4_ip_forward)s
net.ipv6.conf.all.forwarding=%(net_ipv6_conf_all_forwarding)s
net.ipv4.conf.all.rp_filter=1
net.ipv4.conf.default.rp_filter=1
net.ipv4.icmp_echo_ignore_broadcasts=1
net.ipv4.icmp_ignore_bogus_error_responses=1
net.ipv4.icmp_ratelimit=100
net.ipv4.icmp_ratemask=88089
net.ipv6.conf.all.disable_ipv6=%(net_ipv6_conf_all_disable_ipv6)s
net.ipv4.tcp_timestamps=%(net_ipv4_tcp_timestamps)s
net.ipv4.conf.all.arp_ignore=%(net_ipv4_conf_all_arp_ignore)s
net.ipv4.conf.all.arp_announce=%(net_ipv4_conf_all_arp_announce)s
net.ipv4.tcp_rfc1337=1
net.ipv4.tcp_syncookies=1
net.ipv4.conf.all.shared_media=1
net.ipv4.conf.default.shared_media=1
net.ipv4.conf.all.accept_source_route=0
net.ipv4.conf.default.accept_source_route=0
net.ipv4.conf.all.accept_redirects=0
net.ipv4.conf.default.accept_redirects=0
net.ipv6.conf.all.accept_redirects=0
net.ipv6.conf.default.accept_redirects=0
net.ipv4.conf.all.secure_redirects=0
net.ipv4.conf.default.secure_redirects=0
net.ipv4.conf.all.send_redirects=0
net.ipv4.conf.default.send_redirects=0
net.ipv4.conf.all.log_martians=0
net.ipv6.conf.default.router_solicitations=0
net.ipv6.conf.default.accept_ra_rtr_pref=0
net.ipv6.conf.default.accept_ra_pinfo=0
net.ipv6.conf.default.accept_ra_defrtr=0
net.ipv6.conf.default.autoconf=0
net.ipv6.conf.default.dad_transmits=0
net.ipv6.conf.default.max_addresses=1
net.ipv6.conf.all.accept_ra=0
net.ipv6.conf.default.accept_ra=0
kernel.modules_disabled=%(kernel_modules_disabled)s
kernel.sysrq=%(kernel_sysrq)s
fs.suid_dumpable=%(fs_suid_dumpable)s
kernel.randomize_va_space=2
"""
def get_audits():
"""Get OS hardening sysctl audits.
:returns: dictionary of audits
"""
audits = []
settings = utils.get_settings('os')
# Apply the sysctl settings which are configured to be applied.
audits.append(SysctlConf())
# Make sure that only root has access to the sysctl.conf file, and
# that it is read-only.
audits.append(FilePermissionAudit('/etc/sysctl.conf',
user='root',
group='root', mode=0o0440))
# If module loading is not enabled, then ensure that the modules
# file has the appropriate permissions and rebuild the initramfs
if not settings['security']['kernel_enable_module_loading']:
audits.append(ModulesTemplate())
return audits
class ModulesContext(object):
def __call__(self):
settings = utils.get_settings('os')
with open('/proc/cpuinfo', 'r') as fd:
cpuinfo = fd.readlines()
for line in cpuinfo:
match = re.search(r"^vendor_id\s+:\s+(.+)", line)
if match:
vendor = match.group(1)
if vendor == "GenuineIntel":
vendor = "intel"
elif vendor == "AuthenticAMD":
vendor = "amd"
ctxt = {'arch': platform.processor(),
'cpuVendor': vendor,
'desktop_enable': settings['general']['desktop_enable']}
return ctxt
class ModulesTemplate(object):
def __init__(self):
super(ModulesTemplate, self).__init__('/etc/initramfs-tools/modules',
ModulesContext(),
templates_dir=TEMPLATES_DIR,
user='root', group='root',
mode=0o0440)
def post_write(self):
subprocess.check_call(['update-initramfs', '-u'])
class SysCtlHardeningContext(object):
def __call__(self):
settings = utils.get_settings('os')
ctxt = {'sysctl': {}}
log("Applying sysctl settings", level=INFO)
extras = {'net_ipv4_ip_forward': 0,
'net_ipv6_conf_all_forwarding': 0,
'net_ipv6_conf_all_disable_ipv6': 1,
'net_ipv4_tcp_timestamps': 0,
'net_ipv4_conf_all_arp_ignore': 0,
'net_ipv4_conf_all_arp_announce': 0,
'kernel_sysrq': 0,
'fs_suid_dumpable': 0,
'kernel_modules_disabled': 1}
if settings['sysctl']['ipv6_enable']:
extras['net_ipv6_conf_all_disable_ipv6'] = 0
if settings['sysctl']['forwarding']:
extras['net_ipv4_ip_forward'] = 1
extras['net_ipv6_conf_all_forwarding'] = 1
if settings['sysctl']['arp_restricted']:
extras['net_ipv4_conf_all_arp_ignore'] = 1
extras['net_ipv4_conf_all_arp_announce'] = 2
if settings['security']['kernel_enable_module_loading']:
extras['kernel_modules_disabled'] = 0
if settings['sysctl']['kernel_enable_sysrq']:
sysrq_val = settings['sysctl']['kernel_secure_sysrq']
extras['kernel_sysrq'] = sysrq_val
if settings['security']['kernel_enable_core_dump']:
extras['fs_suid_dumpable'] = 1
settings.update(extras)
for d in (SYSCTL_DEFAULTS % settings).split():
d = d.strip().partition('=')
key = d[0].strip()
path = os.path.join('/proc/sys', key.replace('.', '/'))
if not os.path.exists(path):
log("Skipping '%s' since '%s' does not exist" % (key, path),
level=WARNING)
continue
ctxt['sysctl'][key] = d[2] or None
# Translate for python3
return {'sysctl_settings':
[(k, v) for k, v in six.iteritems(ctxt['sysctl'])]}
class SysctlConf(TemplatedFile):
"""An audit check for sysctl settings."""
def __init__(self):
self.conffile = '/etc/sysctl.d/99-juju-hardening.conf'
super(SysctlConf, self).__init__(self.conffile,
SysCtlHardeningContext(),
template_dir=TEMPLATES_DIR,
user='root', group='root',
mode=0o0440)
def post_write(self):
try:
subprocess.check_call(['sysctl', '-p', self.conffile])
except subprocess.CalledProcessError as e:
# NOTE: on some systems if sysctl cannot apply all settings it
# will return non-zero as well.
log("sysctl command returned an error (maybe some "
"keys could not be set) - %s" % (e),
level=WARNING)

View File

@ -0,0 +1,19 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from os import path
TEMPLATES_DIR = path.join(path.dirname(__file__), 'templates')

View File

@ -0,0 +1,31 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from charmhelpers.core.hookenv import (
log,
DEBUG,
)
from charmhelpers.contrib.hardening.mysql.checks import config
def run_mysql_checks():
log("Starting MySQL hardening checks.", level=DEBUG)
checks = config.get_audits()
for check in checks:
log("Running '%s' check" % (check.__class__.__name__), level=DEBUG)
check.ensure_compliance()
log("MySQL hardening checks complete.", level=DEBUG)

View File

@ -0,0 +1,89 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import six
import subprocess
from charmhelpers.core.hookenv import (
log,
WARNING,
)
from charmhelpers.contrib.hardening.audits.file import (
FilePermissionAudit,
DirectoryPermissionAudit,
TemplatedFile,
)
from charmhelpers.contrib.hardening.mysql import TEMPLATES_DIR
from charmhelpers.contrib.hardening import utils
def get_audits():
"""Get MySQL hardening config audits.
:returns: dictionary of audits
"""
if subprocess.call(['which', 'mysql'], stdout=subprocess.PIPE) != 0:
log("MySQL does not appear to be installed on this node - "
"skipping mysql hardening", level=WARNING)
return []
settings = utils.get_settings('mysql')
hardening_settings = settings['hardening']
my_cnf = hardening_settings['mysql-conf']
audits = [
FilePermissionAudit(paths=[my_cnf], user='root',
group='root', mode=0o0600),
TemplatedFile(hardening_settings['hardening-conf'],
MySQLConfContext(),
TEMPLATES_DIR,
mode=0o0750,
user='mysql',
group='root',
service_actions=[{'service': 'mysql',
'actions': ['restart']}]),
# MySQL and Percona charms do not allow configuration of the
# data directory, so use the default.
DirectoryPermissionAudit('/var/lib/mysql',
user='mysql',
group='mysql',
recursive=False,
mode=0o755),
DirectoryPermissionAudit('/etc/mysql',
user='root',
group='root',
recursive=False,
mode=0o700),
]
return audits
class MySQLConfContext(object):
"""Defines the set of key/value pairs to set in a mysql config file.
This context, when called, will return a dictionary containing the
key/value pairs of setting to specify in the
/etc/mysql/conf.d/hardening.cnf file.
"""
def __call__(self):
settings = utils.get_settings('mysql')
# Translate for python3
return {'mysql_settings':
[(k, v) for k, v in six.iteritems(settings['security'])]}

View File

@ -0,0 +1,19 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from os import path
TEMPLATES_DIR = path.join(path.dirname(__file__), 'templates')

View File

@ -0,0 +1,31 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from charmhelpers.core.hookenv import (
log,
DEBUG,
)
from charmhelpers.contrib.hardening.ssh.checks import config
def run_ssh_checks():
log("Starting SSH hardening checks.", level=DEBUG)
checks = config.get_audits()
for check in checks:
log("Running '%s' check" % (check.__class__.__name__), level=DEBUG)
check.ensure_compliance()
log("SSH hardening checks complete.", level=DEBUG)

View File

@ -0,0 +1,394 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import os
from charmhelpers.core.hookenv import (
log,
DEBUG,
)
from charmhelpers.fetch import (
apt_install,
apt_update,
)
from charmhelpers.core.host import lsb_release
from charmhelpers.contrib.hardening.audits.file import (
TemplatedFile,
FileContentAudit,
)
from charmhelpers.contrib.hardening.ssh import TEMPLATES_DIR
from charmhelpers.contrib.hardening import utils
def get_audits():
"""Get SSH hardening config audits.
:returns: dictionary of audits
"""
audits = [SSHConfig(), SSHDConfig(), SSHConfigFileContentAudit(),
SSHDConfigFileContentAudit()]
return audits
class SSHConfigContext(object):
type = 'client'
def get_macs(self, allow_weak_mac):
if allow_weak_mac:
weak_macs = 'weak'
else:
weak_macs = 'default'
default = 'hmac-sha2-512,hmac-sha2-256,hmac-ripemd160'
macs = {'default': default,
'weak': default + ',hmac-sha1'}
default = ('hmac-sha2-512-etm@openssh.com,'
'hmac-sha2-256-etm@openssh.com,'
'hmac-ripemd160-etm@openssh.com,umac-128-etm@openssh.com,'
'hmac-sha2-512,hmac-sha2-256,hmac-ripemd160')
macs_66 = {'default': default,
'weak': default + ',hmac-sha1'}
# Use newer ciphers on Ubuntu Trusty and above
if lsb_release()['DISTRIB_CODENAME'].lower() >= 'trusty':
log("Detected Ubuntu 14.04 or newer, using new macs", level=DEBUG)
macs = macs_66
return macs[weak_macs]
def get_kexs(self, allow_weak_kex):
if allow_weak_kex:
weak_kex = 'weak'
else:
weak_kex = 'default'
default = 'diffie-hellman-group-exchange-sha256'
weak = (default + ',diffie-hellman-group14-sha1,'
'diffie-hellman-group-exchange-sha1,'
'diffie-hellman-group1-sha1')
kex = {'default': default,
'weak': weak}
default = ('curve25519-sha256@libssh.org,'
'diffie-hellman-group-exchange-sha256')
weak = (default + ',diffie-hellman-group14-sha1,'
'diffie-hellman-group-exchange-sha1,'
'diffie-hellman-group1-sha1')
kex_66 = {'default': default,
'weak': weak}
# Use newer kex on Ubuntu Trusty and above
if lsb_release()['DISTRIB_CODENAME'].lower() >= 'trusty':
log('Detected Ubuntu 14.04 or newer, using new key exchange '
'algorithms', level=DEBUG)
kex = kex_66
return kex[weak_kex]
def get_ciphers(self, cbc_required):
if cbc_required:
weak_ciphers = 'weak'
else:
weak_ciphers = 'default'
default = 'aes256-ctr,aes192-ctr,aes128-ctr'
cipher = {'default': default,
'weak': default + 'aes256-cbc,aes192-cbc,aes128-cbc'}
default = ('chacha20-poly1305@openssh.com,aes256-gcm@openssh.com,'
'aes128-gcm@openssh.com,aes256-ctr,aes192-ctr,aes128-ctr')
ciphers_66 = {'default': default,
'weak': default + ',aes256-cbc,aes192-cbc,aes128-cbc'}
# Use newer ciphers on ubuntu Trusty and above
if lsb_release()['DISTRIB_CODENAME'].lower() >= 'trusty':
log('Detected Ubuntu 14.04 or newer, using new ciphers',
level=DEBUG)
cipher = ciphers_66
return cipher[weak_ciphers]
def __call__(self):
settings = utils.get_settings('ssh')
if settings['common']['network_ipv6_enable']:
addr_family = 'any'
else:
addr_family = 'inet'
ctxt = {
'addr_family': addr_family,
'remote_hosts': settings['common']['remote_hosts'],
'password_auth_allowed':
settings['client']['password_authentication'],
'ports': settings['common']['ports'],
'ciphers': self.get_ciphers(settings['client']['cbc_required']),
'macs': self.get_macs(settings['client']['weak_hmac']),
'kexs': self.get_kexs(settings['client']['weak_kex']),
'roaming': settings['client']['roaming'],
}
return ctxt
class SSHConfig(TemplatedFile):
def __init__(self):
path = '/etc/ssh/ssh_config'
super(SSHConfig, self).__init__(path=path,
template_dir=TEMPLATES_DIR,
context=SSHConfigContext(),
user='root',
group='root',
mode=0o0644)
def pre_write(self):
settings = utils.get_settings('ssh')
apt_update(fatal=True)
apt_install(settings['client']['package'])
if not os.path.exists('/etc/ssh'):
os.makedir('/etc/ssh')
# NOTE: don't recurse
utils.ensure_permissions('/etc/ssh', 'root', 'root', 0o0755,
maxdepth=0)
def post_write(self):
# NOTE: don't recurse
utils.ensure_permissions('/etc/ssh', 'root', 'root', 0o0755,
maxdepth=0)
class SSHDConfigContext(SSHConfigContext):
type = 'server'
def __call__(self):
settings = utils.get_settings('ssh')
if settings['common']['network_ipv6_enable']:
addr_family = 'any'
else:
addr_family = 'inet'
ctxt = {
'ssh_ip': settings['server']['listen_to'],
'password_auth_allowed':
settings['server']['password_authentication'],
'ports': settings['common']['ports'],
'addr_family': addr_family,
'ciphers': self.get_ciphers(settings['server']['cbc_required']),
'macs': self.get_macs(settings['server']['weak_hmac']),
'kexs': self.get_kexs(settings['server']['weak_kex']),
'host_key_files': settings['server']['host_key_files'],
'allow_root_with_key': settings['server']['allow_root_with_key'],
'password_authentication':
settings['server']['password_authentication'],
'use_priv_sep': settings['server']['use_privilege_separation'],
'use_pam': settings['server']['use_pam'],
'allow_x11_forwarding': settings['server']['allow_x11_forwarding'],
'print_motd': settings['server']['print_motd'],
'print_last_log': settings['server']['print_last_log'],
'client_alive_interval':
settings['server']['alive_interval'],
'client_alive_count': settings['server']['alive_count'],
'allow_tcp_forwarding': settings['server']['allow_tcp_forwarding'],
'allow_agent_forwarding':
settings['server']['allow_agent_forwarding'],
'deny_users': settings['server']['deny_users'],
'allow_users': settings['server']['allow_users'],
'deny_groups': settings['server']['deny_groups'],
'allow_groups': settings['server']['allow_groups'],
'use_dns': settings['server']['use_dns'],
'sftp_enable': settings['server']['sftp_enable'],
'sftp_group': settings['server']['sftp_group'],
'sftp_chroot': settings['server']['sftp_chroot'],
'max_auth_tries': settings['server']['max_auth_tries'],
'max_sessions': settings['server']['max_sessions'],
}
return ctxt
class SSHDConfig(TemplatedFile):
def __init__(self):
path = '/etc/ssh/sshd_config'
super(SSHDConfig, self).__init__(path=path,
template_dir=TEMPLATES_DIR,
context=SSHDConfigContext(),
user='root',
group='root',
mode=0o0600,
service_actions=[{'service': 'ssh',
'actions':
['restart']}])
def pre_write(self):
settings = utils.get_settings('ssh')
apt_update(fatal=True)
apt_install(settings['server']['package'])
if not os.path.exists('/etc/ssh'):
os.makedir('/etc/ssh')
# NOTE: don't recurse
utils.ensure_permissions('/etc/ssh', 'root', 'root', 0o0755,
maxdepth=0)
def post_write(self):
# NOTE: don't recurse
utils.ensure_permissions('/etc/ssh', 'root', 'root', 0o0755,
maxdepth=0)
class SSHConfigFileContentAudit(FileContentAudit):
def __init__(self):
self.path = '/etc/ssh/ssh_config'
super(SSHConfigFileContentAudit, self).__init__(self.path, {})
def is_compliant(self, *args, **kwargs):
self.pass_cases = []
self.fail_cases = []
settings = utils.get_settings('ssh')
if lsb_release()['DISTRIB_CODENAME'].lower() >= 'trusty':
if not settings['server']['weak_hmac']:
self.pass_cases.append(r'^MACs.+,hmac-ripemd160$')
else:
self.pass_cases.append(r'^MACs.+,hmac-sha1$')
if settings['server']['weak_kex']:
self.fail_cases.append(r'^KexAlgorithms\sdiffie-hellman-group-exchange-sha256[,\s]?') # noqa
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group14-sha1[,\s]?') # noqa
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group-exchange-sha1[,\s]?') # noqa
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group1-sha1[,\s]?') # noqa
else:
self.pass_cases.append(r'^KexAlgorithms.+,diffie-hellman-group-exchange-sha256$') # noqa
self.fail_cases.append(r'^KexAlgorithms.*diffie-hellman-group14-sha1[,\s]?') # noqa
if settings['server']['cbc_required']:
self.pass_cases.append(r'^Ciphers\s.*-cbc[,\s]?')
self.fail_cases.append(r'^Ciphers\s.*aes128-ctr[,\s]?')
self.fail_cases.append(r'^Ciphers\s.*aes192-ctr[,\s]?')
self.fail_cases.append(r'^Ciphers\s.*aes256-ctr[,\s]?')
else:
self.fail_cases.append(r'^Ciphers\s.*-cbc[,\s]?')
self.pass_cases.append(r'^Ciphers\schacha20-poly1305@openssh.com,.+') # noqa
self.pass_cases.append(r'^Ciphers\s.*aes128-ctr$')
self.pass_cases.append(r'^Ciphers\s.*aes192-ctr[,\s]?')
self.pass_cases.append(r'^Ciphers\s.*aes256-ctr[,\s]?')
else:
if not settings['client']['weak_hmac']:
self.fail_cases.append(r'^MACs.+,hmac-sha1$')
else:
self.pass_cases.append(r'^MACs.+,hmac-sha1$')
if settings['client']['weak_kex']:
self.fail_cases.append(r'^KexAlgorithms\sdiffie-hellman-group-exchange-sha256[,\s]?') # noqa
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group14-sha1[,\s]?') # noqa
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group-exchange-sha1[,\s]?') # noqa
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group1-sha1[,\s]?') # noqa
else:
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group-exchange-sha256$') # noqa
self.fail_cases.append(r'^KexAlgorithms\sdiffie-hellman-group14-sha1[,\s]?') # noqa
self.fail_cases.append(r'^KexAlgorithms\sdiffie-hellman-group-exchange-sha1[,\s]?') # noqa
self.fail_cases.append(r'^KexAlgorithms\sdiffie-hellman-group1-sha1[,\s]?') # noqa
if settings['client']['cbc_required']:
self.pass_cases.append(r'^Ciphers\s.*-cbc[,\s]?')
self.fail_cases.append(r'^Ciphers\s.*aes128-ctr[,\s]?')
self.fail_cases.append(r'^Ciphers\s.*aes192-ctr[,\s]?')
self.fail_cases.append(r'^Ciphers\s.*aes256-ctr[,\s]?')
else:
self.fail_cases.append(r'^Ciphers\s.*-cbc[,\s]?')
self.pass_cases.append(r'^Ciphers\s.*aes128-ctr[,\s]?')
self.pass_cases.append(r'^Ciphers\s.*aes192-ctr[,\s]?')
self.pass_cases.append(r'^Ciphers\s.*aes256-ctr[,\s]?')
if settings['client']['roaming']:
self.pass_cases.append(r'^UseRoaming yes$')
else:
self.fail_cases.append(r'^UseRoaming yes$')
return super(SSHConfigFileContentAudit, self).is_compliant(*args,
**kwargs)
class SSHDConfigFileContentAudit(FileContentAudit):
def __init__(self):
self.path = '/etc/ssh/sshd_config'
super(SSHDConfigFileContentAudit, self).__init__(self.path, {})
def is_compliant(self, *args, **kwargs):
self.pass_cases = []
self.fail_cases = []
settings = utils.get_settings('ssh')
if lsb_release()['DISTRIB_CODENAME'].lower() >= 'trusty':
if not settings['server']['weak_hmac']:
self.pass_cases.append(r'^MACs.+,hmac-ripemd160$')
else:
self.pass_cases.append(r'^MACs.+,hmac-sha1$')
if settings['server']['weak_kex']:
self.fail_cases.append(r'^KexAlgorithms\sdiffie-hellman-group-exchange-sha256[,\s]?') # noqa
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group14-sha1[,\s]?') # noqa
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group-exchange-sha1[,\s]?') # noqa
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group1-sha1[,\s]?') # noqa
else:
self.pass_cases.append(r'^KexAlgorithms.+,diffie-hellman-group-exchange-sha256$') # noqa
self.fail_cases.append(r'^KexAlgorithms.*diffie-hellman-group14-sha1[,\s]?') # noqa
if settings['server']['cbc_required']:
self.pass_cases.append(r'^Ciphers\s.*-cbc[,\s]?')
self.fail_cases.append(r'^Ciphers\s.*aes128-ctr[,\s]?')
self.fail_cases.append(r'^Ciphers\s.*aes192-ctr[,\s]?')
self.fail_cases.append(r'^Ciphers\s.*aes256-ctr[,\s]?')
else:
self.fail_cases.append(r'^Ciphers\s.*-cbc[,\s]?')
self.pass_cases.append(r'^Ciphers\schacha20-poly1305@openssh.com,.+') # noqa
self.pass_cases.append(r'^Ciphers\s.*aes128-ctr$')
self.pass_cases.append(r'^Ciphers\s.*aes192-ctr[,\s]?')
self.pass_cases.append(r'^Ciphers\s.*aes256-ctr[,\s]?')
else:
if not settings['server']['weak_hmac']:
self.pass_cases.append(r'^MACs.+,hmac-ripemd160$')
else:
self.pass_cases.append(r'^MACs.+,hmac-sha1$')
if settings['server']['weak_kex']:
self.fail_cases.append(r'^KexAlgorithms\sdiffie-hellman-group-exchange-sha256[,\s]?') # noqa
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group14-sha1[,\s]?') # noqa
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group-exchange-sha1[,\s]?') # noqa
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group1-sha1[,\s]?') # noqa
else:
self.pass_cases.append(r'^KexAlgorithms\sdiffie-hellman-group-exchange-sha256$') # noqa
self.fail_cases.append(r'^KexAlgorithms\sdiffie-hellman-group14-sha1[,\s]?') # noqa
self.fail_cases.append(r'^KexAlgorithms\sdiffie-hellman-group-exchange-sha1[,\s]?') # noqa
self.fail_cases.append(r'^KexAlgorithms\sdiffie-hellman-group1-sha1[,\s]?') # noqa
if settings['server']['cbc_required']:
self.pass_cases.append(r'^Ciphers\s.*-cbc[,\s]?')
self.fail_cases.append(r'^Ciphers\s.*aes128-ctr[,\s]?')
self.fail_cases.append(r'^Ciphers\s.*aes192-ctr[,\s]?')
self.fail_cases.append(r'^Ciphers\s.*aes256-ctr[,\s]?')
else:
self.fail_cases.append(r'^Ciphers\s.*-cbc[,\s]?')
self.pass_cases.append(r'^Ciphers\s.*aes128-ctr[,\s]?')
self.pass_cases.append(r'^Ciphers\s.*aes192-ctr[,\s]?')
self.pass_cases.append(r'^Ciphers\s.*aes256-ctr[,\s]?')
if settings['server']['sftp_enable']:
self.pass_cases.append(r'^Subsystem\ssftp')
else:
self.fail_cases.append(r'^Subsystem\ssftp')
return super(SSHDConfigFileContentAudit, self).is_compliant(*args,
**kwargs)

View File

@ -0,0 +1,71 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import os
from charmhelpers.core.hookenv import (
log,
DEBUG,
WARNING,
)
try:
from jinja2 import FileSystemLoader, Environment
except ImportError:
from charmhelpers.fetch import apt_install
from charmhelpers.fetch import apt_update
apt_update(fatal=True)
apt_install('python-jinja2', fatal=True)
from jinja2 import FileSystemLoader, Environment
# NOTE: function separated from main rendering code to facilitate easier
# mocking in unit tests.
def write(path, data):
with open(path, 'wb') as out:
out.write(data)
def get_template_path(template_dir, path):
"""Returns the template file which would be used to render the path.
The path to the template file is returned.
:param template_dir: the directory the templates are located in
:param path: the file path to be written to.
:returns: path to the template file
"""
return os.path.join(template_dir, os.path.basename(path))
def render_and_write(template_dir, path, context):
"""Renders the specified template into the file.
:param template_dir: the directory to load the template from
:param path: the path to write the templated contents to
:param context: the parameters to pass to the rendering engine
"""
env = Environment(loader=FileSystemLoader(template_dir))
template_file = os.path.basename(path)
template = env.get_template(template_file)
log('Rendering from template: %s' % template.name, level=DEBUG)
rendered_content = template.render(context)
if not rendered_content:
log("Render returned None - skipping '%s'" % path,
level=WARNING)
return
write(path, rendered_content.encode('utf-8').strip())
log('Wrote template %s' % path, level=DEBUG)

View File

@ -0,0 +1,157 @@
# Copyright 2016 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import glob
import grp
import os
import pwd
import six
import yaml
from charmhelpers.core.hookenv import (
log,
DEBUG,
INFO,
WARNING,
ERROR,
)
# Global settings cache. Since each hook fire entails a fresh module import it
# is safe to hold this in memory and not risk missing config changes (since
# they will result in a new hook fire and thus re-import).
__SETTINGS__ = {}
def _get_defaults(modules):
"""Load the default config for the provided modules.
:param modules: stack modules config defaults to lookup.
:returns: modules default config dictionary.
"""
default = os.path.join(os.path.dirname(__file__),
'defaults/%s.yaml' % (modules))
return yaml.safe_load(open(default))
def _get_schema(modules):
"""Load the config schema for the provided modules.
NOTE: this schema is intended to have 1-1 relationship with they keys in
the default config and is used a means to verify valid overrides provided
by the user.
:param modules: stack modules config schema to lookup.
:returns: modules default schema dictionary.
"""
schema = os.path.join(os.path.dirname(__file__),
'defaults/%s.yaml.schema' % (modules))
return yaml.safe_load(open(schema))
def _get_user_provided_overrides(modules):
"""Load user-provided config overrides.
:param modules: stack modules to lookup in user overrides yaml file.
:returns: overrides dictionary.
"""
overrides = os.path.join(os.environ['JUJU_CHARM_DIR'],
'hardening.yaml')
if os.path.exists(overrides):
log("Found user-provided config overrides file '%s'" %
(overrides), level=DEBUG)
settings = yaml.safe_load(open(overrides))
if settings and settings.get(modules):
log("Applying '%s' overrides" % (modules), level=DEBUG)
return settings.get(modules)
log("No overrides found for '%s'" % (modules), level=DEBUG)
else:
log("No hardening config overrides file '%s' found in charm "
"root dir" % (overrides), level=DEBUG)
return {}
def _apply_overrides(settings, overrides, schema):
"""Get overrides config overlayed onto modules defaults.
:param modules: require stack modules config.
:returns: dictionary of modules config with user overrides applied.
"""
if overrides:
for k, v in six.iteritems(overrides):
if k in schema:
if schema[k] is None:
settings[k] = v
elif type(schema[k]) is dict:
settings[k] = _apply_overrides(settings[k], overrides[k],
schema[k])
else:
raise Exception("Unexpected type found in schema '%s'" %
type(schema[k]), level=ERROR)
else:
log("Unknown override key '%s' - ignoring" % (k), level=INFO)
return settings
def get_settings(modules):
global __SETTINGS__
if modules in __SETTINGS__:
return __SETTINGS__[modules]
schema = _get_schema(modules)
settings = _get_defaults(modules)
overrides = _get_user_provided_overrides(modules)
__SETTINGS__[modules] = _apply_overrides(settings, overrides, schema)
return __SETTINGS__[modules]
def ensure_permissions(path, user, group, permissions, maxdepth=-1):
"""Ensure permissions for path.
If path is a file, apply to file and return. If path is a directory,
apply recursively (if required) to directory contents and return.
:param user: user name
:param group: group name
:param permissions: octal permissions
:param maxdepth: maximum recursion depth. A negative maxdepth allows
infinite recursion and maxdepth=0 means no recursion.
:returns: None
"""
if not os.path.exists(path):
log("File '%s' does not exist - cannot set permissions" % (path),
level=WARNING)
return
_user = pwd.getpwnam(user)
os.chown(path, _user.pw_uid, grp.getgrnam(group).gr_gid)
os.chmod(path, permissions)
if maxdepth == 0:
log("Max recursion depth reached - skipping further recursion",
level=DEBUG)
return
elif maxdepth > 0:
maxdepth -= 1
if os.path.isdir(path):
contents = glob.glob("%s/*" % (path))
for c in contents:
ensure_permissions(c, user=user, group=group,
permissions=permissions, maxdepth=maxdepth)

View File

@ -0,0 +1,151 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2014-2015 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
__author__ = "Jorge Niedbalski <jorge.niedbalski@canonical.com>"
from charmhelpers.fetch import (
apt_install,
apt_update,
)
from charmhelpers.core.hookenv import (
log,
INFO,
)
try:
from netifaces import interfaces as network_interfaces
except ImportError:
apt_install('python-netifaces')
from netifaces import interfaces as network_interfaces
import os
import re
import subprocess
from charmhelpers.core.kernel import modprobe
REQUIRED_MODULES = (
"mlx4_ib",
"mlx4_en",
"mlx4_core",
"ib_ipath",
"ib_mthca",
"ib_srpt",
"ib_srp",
"ib_ucm",
"ib_isert",
"ib_iser",
"ib_ipoib",
"ib_cm",
"ib_uverbs"
"ib_umad",
"ib_sa",
"ib_mad",
"ib_core",
"ib_addr",
"rdma_ucm",
)
REQUIRED_PACKAGES = (
"ibutils",
"infiniband-diags",
"ibverbs-utils",
)
IPOIB_DRIVERS = (
"ib_ipoib",
)
ABI_VERSION_FILE = "/sys/class/infiniband_mad/abi_version"
class DeviceInfo(object):
pass
def install_packages():
apt_update()
apt_install(REQUIRED_PACKAGES, fatal=True)
def load_modules():
for module in REQUIRED_MODULES:
modprobe(module, persist=True)
def is_enabled():
"""Check if infiniband is loaded on the system"""
return os.path.exists(ABI_VERSION_FILE)
def stat():
"""Return full output of ibstat"""
return subprocess.check_output(["ibstat"])
def devices():
"""Returns a list of IB enabled devices"""
return subprocess.check_output(['ibstat', '-l']).splitlines()
def device_info(device):
"""Returns a DeviceInfo object with the current device settings"""
status = subprocess.check_output([
'ibstat', device, '-s']).splitlines()
regexes = {
"CA type: (.*)": "device_type",
"Number of ports: (.*)": "num_ports",
"Firmware version: (.*)": "fw_ver",
"Hardware version: (.*)": "hw_ver",
"Node GUID: (.*)": "node_guid",
"System image GUID: (.*)": "sys_guid",
}
device = DeviceInfo()
for line in status:
for expression, key in regexes.items():
matches = re.search(expression, line)
if matches:
setattr(device, key, matches.group(1))
return device
def ipoib_interfaces():
"""Return a list of IPOIB capable ethernet interfaces"""
interfaces = []
for interface in network_interfaces():
try:
driver = re.search('^driver: (.+)$', subprocess.check_output([
'ethtool', '-i',
interface]), re.M).group(1)
if driver in IPOIB_DRIVERS:
interfaces.append(interface)
except:
log("Skipping interface %s" % interface, level=INFO)
continue
return interfaces

View File

@ -23,7 +23,7 @@ import socket
from functools import partial
from charmhelpers.core.hookenv import unit_get
from charmhelpers.fetch import apt_install
from charmhelpers.fetch import apt_install, apt_update
from charmhelpers.core.hookenv import (
log,
WARNING,
@ -32,13 +32,15 @@ from charmhelpers.core.hookenv import (
try:
import netifaces
except ImportError:
apt_install('python-netifaces')
apt_update(fatal=True)
apt_install('python-netifaces', fatal=True)
import netifaces
try:
import netaddr
except ImportError:
apt_install('python-netaddr')
apt_update(fatal=True)
apt_install('python-netaddr', fatal=True)
import netaddr
@ -51,7 +53,7 @@ def _validate_cidr(network):
def no_ip_found_error_out(network):
errmsg = ("No IP address found in network: %s" % network)
errmsg = ("No IP address found in network(s): %s" % network)
raise ValueError(errmsg)
@ -59,7 +61,7 @@ def get_address_in_network(network, fallback=None, fatal=False):
"""Get an IPv4 or IPv6 address within the network from the host.
:param network (str): CIDR presentation format. For example,
'192.168.1.0/24'.
'192.168.1.0/24'. Supports multiple networks as a space-delimited list.
:param fallback (str): If no address is found, return fallback.
:param fatal (boolean): If no address is found, fallback is not
set and fatal is True then exit(1).
@ -73,24 +75,26 @@ def get_address_in_network(network, fallback=None, fatal=False):
else:
return None
_validate_cidr(network)
network = netaddr.IPNetwork(network)
for iface in netifaces.interfaces():
addresses = netifaces.ifaddresses(iface)
if network.version == 4 and netifaces.AF_INET in addresses:
addr = addresses[netifaces.AF_INET][0]['addr']
netmask = addresses[netifaces.AF_INET][0]['netmask']
cidr = netaddr.IPNetwork("%s/%s" % (addr, netmask))
if cidr in network:
return str(cidr.ip)
networks = network.split() or [network]
for network in networks:
_validate_cidr(network)
network = netaddr.IPNetwork(network)
for iface in netifaces.interfaces():
addresses = netifaces.ifaddresses(iface)
if network.version == 4 and netifaces.AF_INET in addresses:
addr = addresses[netifaces.AF_INET][0]['addr']
netmask = addresses[netifaces.AF_INET][0]['netmask']
cidr = netaddr.IPNetwork("%s/%s" % (addr, netmask))
if cidr in network:
return str(cidr.ip)
if network.version == 6 and netifaces.AF_INET6 in addresses:
for addr in addresses[netifaces.AF_INET6]:
if not addr['addr'].startswith('fe80'):
cidr = netaddr.IPNetwork("%s/%s" % (addr['addr'],
addr['netmask']))
if cidr in network:
return str(cidr.ip)
if network.version == 6 and netifaces.AF_INET6 in addresses:
for addr in addresses[netifaces.AF_INET6]:
if not addr['addr'].startswith('fe80'):
cidr = netaddr.IPNetwork("%s/%s" % (addr['addr'],
addr['netmask']))
if cidr in network:
return str(cidr.ip)
if fallback is not None:
return fallback
@ -187,6 +191,15 @@ get_iface_for_address = partial(_get_for_address, key='iface')
get_netmask_for_address = partial(_get_for_address, key='netmask')
def resolve_network_cidr(ip_address):
'''
Resolves the full address cidr of an ip_address based on
configured network interfaces
'''
netmask = get_netmask_for_address(ip_address)
return str(netaddr.IPNetwork("%s/%s" % (ip_address, netmask)).cidr)
def format_ipv6_addr(address):
"""If address is IPv6, wrap it in '[]' otherwise return None.
@ -435,8 +448,12 @@ def get_hostname(address, fqdn=True):
rev = dns.reversename.from_address(address)
result = ns_query(rev)
if not result:
return None
try:
result = socket.gethostbyaddr(address)[0]
except:
return None
else:
result = address
@ -448,3 +465,18 @@ def get_hostname(address, fqdn=True):
return result
else:
return result.split('.')[0]
def port_has_listener(address, port):
"""
Returns True if the address:port is open and being listened to,
else False.
@param address: an IP address or hostname
@param port: integer port
Note calls 'zc' via a subprocess shell
"""
cmd = ['nc', '-z', address, str(port)]
result = subprocess.call(cmd)
return not(bool(result))

View File

@ -25,10 +25,14 @@ from charmhelpers.core.host import (
)
def add_bridge(name):
def add_bridge(name, datapath_type=None):
''' Add the named bridge to openvswitch '''
log('Creating bridge {}'.format(name))
subprocess.check_call(["ovs-vsctl", "--", "--may-exist", "add-br", name])
cmd = ["ovs-vsctl", "--", "--may-exist", "add-br", name]
if datapath_type is not None:
cmd += ['--', 'set', 'bridge', name,
'datapath_type={}'.format(datapath_type)]
subprocess.check_call(cmd)
def del_bridge(name):

View File

@ -40,7 +40,9 @@ Examples:
import re
import os
import subprocess
from charmhelpers.core import hookenv
from charmhelpers.core.kernel import modprobe, is_module_loaded
__author__ = "Felipe Reyes <felipe.reyes@canonical.com>"
@ -82,14 +84,11 @@ def is_ipv6_ok(soft_fail=False):
# do we have IPv6 in the machine?
if os.path.isdir('/proc/sys/net/ipv6'):
# is ip6tables kernel module loaded?
lsmod = subprocess.check_output(['lsmod'], universal_newlines=True)
matches = re.findall('^ip6_tables[ ]+', lsmod, re.M)
if len(matches) == 0:
if not is_module_loaded('ip6_tables'):
# ip6tables support isn't complete, let's try to load it
try:
subprocess.check_output(['modprobe', 'ip6_tables'],
universal_newlines=True)
# great, we could load the module
modprobe('ip6_tables')
# great, we can load the module
return True
except subprocess.CalledProcessError as ex:
hookenv.log("Couldn't load ip6_tables module: %s" % ex.output,

View File

@ -14,12 +14,18 @@
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import logging
import re
import sys
import six
from collections import OrderedDict
from charmhelpers.contrib.amulet.deployment import (
AmuletDeployment
)
DEBUG = logging.DEBUG
ERROR = logging.ERROR
class OpenStackAmuletDeployment(AmuletDeployment):
"""OpenStack amulet deployment.
@ -28,9 +34,12 @@ class OpenStackAmuletDeployment(AmuletDeployment):
that is specifically for use by OpenStack charms.
"""
def __init__(self, series=None, openstack=None, source=None, stable=True):
def __init__(self, series=None, openstack=None, source=None,
stable=True, log_level=DEBUG):
"""Initialize the deployment environment."""
super(OpenStackAmuletDeployment, self).__init__(series)
self.log = self.get_logger(level=log_level)
self.log.info('OpenStackAmuletDeployment: init')
self.openstack = openstack
self.source = source
self.stable = stable
@ -38,26 +47,55 @@ class OpenStackAmuletDeployment(AmuletDeployment):
# out.
self.current_next = "trusty"
def get_logger(self, name="deployment-logger", level=logging.DEBUG):
"""Get a logger object that will log to stdout."""
log = logging
logger = log.getLogger(name)
fmt = log.Formatter("%(asctime)s %(funcName)s "
"%(levelname)s: %(message)s")
handler = log.StreamHandler(stream=sys.stdout)
handler.setLevel(level)
handler.setFormatter(fmt)
logger.addHandler(handler)
logger.setLevel(level)
return logger
def _determine_branch_locations(self, other_services):
"""Determine the branch locations for the other services.
Determine if the local branch being tested is derived from its
stable or next (dev) branch, and based on this, use the corresonding
stable or next branches for the other_services."""
base_charms = ['mysql', 'mongodb']
self.log.info('OpenStackAmuletDeployment: determine branch locations')
# Charms outside the lp:~openstack-charmers namespace
base_charms = ['mysql', 'mongodb', 'nrpe']
# Force these charms to current series even when using an older series.
# ie. Use trusty/nrpe even when series is precise, as the P charm
# does not possess the necessary external master config and hooks.
force_series_current = ['nrpe']
if self.series in ['precise', 'trusty']:
base_series = self.series
else:
base_series = self.current_next
if self.stable:
for svc in other_services:
for svc in other_services:
if svc['name'] in force_series_current:
base_series = self.current_next
# If a location has been explicitly set, use it
if svc.get('location'):
continue
if self.stable:
temp = 'lp:charms/{}/{}'
svc['location'] = temp.format(base_series,
svc['name'])
else:
for svc in other_services:
else:
if svc['name'] in base_charms:
temp = 'lp:charms/{}/{}'
svc['location'] = temp.format(base_series,
@ -66,10 +104,13 @@ class OpenStackAmuletDeployment(AmuletDeployment):
temp = 'lp:~openstack-charmers/charms/{}/{}/next'
svc['location'] = temp.format(self.current_next,
svc['name'])
return other_services
def _add_services(self, this_service, other_services):
"""Add services to the deployment and set openstack-origin/source."""
self.log.info('OpenStackAmuletDeployment: adding services')
other_services = self._determine_branch_locations(other_services)
super(OpenStackAmuletDeployment, self)._add_services(this_service,
@ -77,29 +118,105 @@ class OpenStackAmuletDeployment(AmuletDeployment):
services = other_services
services.append(this_service)
# Charms which should use the source config option
use_source = ['mysql', 'mongodb', 'rabbitmq-server', 'ceph',
'ceph-osd', 'ceph-radosgw']
# Most OpenStack subordinate charms do not expose an origin option
# as that is controlled by the principle.
ignore = ['cinder-ceph', 'hacluster', 'neutron-openvswitch']
'ceph-osd', 'ceph-radosgw', 'ceph-mon']
# Charms which can not use openstack-origin, ie. many subordinates
no_origin = ['cinder-ceph', 'hacluster', 'neutron-openvswitch', 'nrpe',
'openvswitch-odl', 'neutron-api-odl', 'odl-controller',
'cinder-backup', 'nexentaedge-data',
'nexentaedge-iscsi-gw', 'nexentaedge-swift-gw',
'cinder-nexentaedge', 'nexentaedge-mgmt']
if self.openstack:
for svc in services:
if svc['name'] not in use_source + ignore:
if svc['name'] not in use_source + no_origin:
config = {'openstack-origin': self.openstack}
self.d.configure(svc['name'], config)
if self.source:
for svc in services:
if svc['name'] in use_source and svc['name'] not in ignore:
if svc['name'] in use_source and svc['name'] not in no_origin:
config = {'source': self.source}
self.d.configure(svc['name'], config)
def _configure_services(self, configs):
"""Configure all of the services."""
self.log.info('OpenStackAmuletDeployment: configure services')
for service, config in six.iteritems(configs):
self.d.configure(service, config)
def _auto_wait_for_status(self, message=None, exclude_services=None,
include_only=None, timeout=1800):
"""Wait for all units to have a specific extended status, except
for any defined as excluded. Unless specified via message, any
status containing any case of 'ready' will be considered a match.
Examples of message usage:
Wait for all unit status to CONTAIN any case of 'ready' or 'ok':
message = re.compile('.*ready.*|.*ok.*', re.IGNORECASE)
Wait for all units to reach this status (exact match):
message = re.compile('^Unit is ready and clustered$')
Wait for all units to reach any one of these (exact match):
message = re.compile('Unit is ready|OK|Ready')
Wait for at least one unit to reach this status (exact match):
message = {'ready'}
See Amulet's sentry.wait_for_messages() for message usage detail.
https://github.com/juju/amulet/blob/master/amulet/sentry.py
:param message: Expected status match
:param exclude_services: List of juju service names to ignore,
not to be used in conjuction with include_only.
:param include_only: List of juju service names to exclusively check,
not to be used in conjuction with exclude_services.
:param timeout: Maximum time in seconds to wait for status match
:returns: None. Raises if timeout is hit.
"""
self.log.info('Waiting for extended status on units...')
all_services = self.d.services.keys()
if exclude_services and include_only:
raise ValueError('exclude_services can not be used '
'with include_only')
if message:
if isinstance(message, re._pattern_type):
match = message.pattern
else:
match = message
self.log.debug('Custom extended status wait match: '
'{}'.format(match))
else:
self.log.debug('Default extended status wait match: contains '
'READY (case-insensitive)')
message = re.compile('.*ready.*', re.IGNORECASE)
if exclude_services:
self.log.debug('Excluding services from extended status match: '
'{}'.format(exclude_services))
else:
exclude_services = []
if include_only:
services = include_only
else:
services = list(set(all_services) - set(exclude_services))
self.log.debug('Waiting up to {}s for extended status on services: '
'{}'.format(timeout, services))
service_messages = {service: message for service in services}
self.d.sentry.wait_for_messages(service_messages, timeout=timeout)
self.log.info('OK')
def _get_openstack_release(self):
"""Get openstack release.
@ -111,7 +228,8 @@ class OpenStackAmuletDeployment(AmuletDeployment):
self.precise_havana, self.precise_icehouse,
self.trusty_icehouse, self.trusty_juno, self.utopic_juno,
self.trusty_kilo, self.vivid_kilo, self.trusty_liberty,
self.wily_liberty) = range(12)
self.wily_liberty, self.trusty_mitaka,
self.xenial_mitaka) = range(14)
releases = {
('precise', None): self.precise_essex,
@ -123,9 +241,11 @@ class OpenStackAmuletDeployment(AmuletDeployment):
('trusty', 'cloud:trusty-juno'): self.trusty_juno,
('trusty', 'cloud:trusty-kilo'): self.trusty_kilo,
('trusty', 'cloud:trusty-liberty'): self.trusty_liberty,
('trusty', 'cloud:trusty-mitaka'): self.trusty_mitaka,
('utopic', None): self.utopic_juno,
('vivid', None): self.vivid_kilo,
('wily', None): self.wily_liberty}
('wily', None): self.wily_liberty,
('xenial', None): self.xenial_mitaka}
return releases[(self.series, self.openstack)]
def _get_openstack_release_string(self):
@ -142,6 +262,7 @@ class OpenStackAmuletDeployment(AmuletDeployment):
('utopic', 'juno'),
('vivid', 'kilo'),
('wily', 'liberty'),
('xenial', 'mitaka'),
])
if self.openstack:
os_origin = self.openstack.split(':')[1]

View File

@ -18,6 +18,7 @@ import amulet
import json
import logging
import os
import re
import six
import time
import urllib
@ -26,7 +27,12 @@ import cinderclient.v1.client as cinder_client
import glanceclient.v1.client as glance_client
import heatclient.v1.client as heat_client
import keystoneclient.v2_0 as keystone_client
import novaclient.v1_1.client as nova_client
from keystoneclient.auth.identity import v3 as keystone_id_v3
from keystoneclient import session as keystone_session
from keystoneclient.v3 import client as keystone_client_v3
import novaclient.client as nova_client
import pika
import swiftclient
from charmhelpers.contrib.amulet.utils import (
@ -36,6 +42,8 @@ from charmhelpers.contrib.amulet.utils import (
DEBUG = logging.DEBUG
ERROR = logging.ERROR
NOVA_CLIENT_VERSION = "2"
class OpenStackAmuletUtils(AmuletUtils):
"""OpenStack amulet utilities.
@ -137,7 +145,7 @@ class OpenStackAmuletUtils(AmuletUtils):
return "role {} does not exist".format(e['name'])
return ret
def validate_user_data(self, expected, actual):
def validate_user_data(self, expected, actual, api_version=None):
"""Validate user data.
Validate a list of actual user data vs a list of expected user
@ -148,10 +156,15 @@ class OpenStackAmuletUtils(AmuletUtils):
for e in expected:
found = False
for act in actual:
a = {'enabled': act.enabled, 'name': act.name,
'email': act.email, 'tenantId': act.tenantId,
'id': act.id}
if e['name'] == a['name']:
if e['name'] == act.name:
a = {'enabled': act.enabled, 'name': act.name,
'email': act.email, 'id': act.id}
if api_version == 3:
a['default_project_id'] = getattr(act,
'default_project_id',
'none')
else:
a['tenantId'] = act.tenantId
found = True
ret = self._validate_dict_data(e, a)
if ret:
@ -186,15 +199,30 @@ class OpenStackAmuletUtils(AmuletUtils):
return cinder_client.Client(username, password, tenant, ept)
def authenticate_keystone_admin(self, keystone_sentry, user, password,
tenant):
tenant=None, api_version=None,
keystone_ip=None):
"""Authenticates admin user with the keystone admin endpoint."""
self.log.debug('Authenticating keystone admin...')
unit = keystone_sentry
service_ip = unit.relation('shared-db',
'mysql:shared-db')['private-address']
ep = "http://{}:35357/v2.0".format(service_ip.strip().decode('utf-8'))
return keystone_client.Client(username=user, password=password,
tenant_name=tenant, auth_url=ep)
if not keystone_ip:
keystone_ip = unit.relation('shared-db',
'mysql:shared-db')['private-address']
base_ep = "http://{}:35357".format(keystone_ip.strip().decode('utf-8'))
if not api_version or api_version == 2:
ep = base_ep + "/v2.0"
return keystone_client.Client(username=user, password=password,
tenant_name=tenant, auth_url=ep)
else:
ep = base_ep + "/v3"
auth = keystone_id_v3.Password(
user_domain_name='admin_domain',
username=user,
password=password,
domain_name='admin_domain',
auth_url=ep,
)
sess = keystone_session.Session(auth=auth)
return keystone_client_v3.Client(session=sess)
def authenticate_keystone_user(self, keystone, user, password, tenant):
"""Authenticates a regular user with the keystone public endpoint."""
@ -223,7 +251,8 @@ class OpenStackAmuletUtils(AmuletUtils):
self.log.debug('Authenticating nova user ({})...'.format(user))
ep = keystone.service_catalog.url_for(service_type='identity',
endpoint_type='publicURL')
return nova_client.Client(username=user, api_key=password,
return nova_client.Client(NOVA_CLIENT_VERSION,
username=user, api_key=password,
project_id=tenant, auth_url=ep)
def authenticate_swift_user(self, keystone, user, password, tenant):
@ -602,3 +631,382 @@ class OpenStackAmuletUtils(AmuletUtils):
self.log.debug('Ceph {} samples (OK): '
'{}'.format(sample_type, samples))
return None
# rabbitmq/amqp specific helpers:
def rmq_wait_for_cluster(self, deployment, init_sleep=15, timeout=1200):
"""Wait for rmq units extended status to show cluster readiness,
after an optional initial sleep period. Initial sleep is likely
necessary to be effective following a config change, as status
message may not instantly update to non-ready."""
if init_sleep:
time.sleep(init_sleep)
message = re.compile('^Unit is ready and clustered$')
deployment._auto_wait_for_status(message=message,
timeout=timeout,
include_only=['rabbitmq-server'])
def add_rmq_test_user(self, sentry_units,
username="testuser1", password="changeme"):
"""Add a test user via the first rmq juju unit, check connection as
the new user against all sentry units.
:param sentry_units: list of sentry unit pointers
:param username: amqp user name, default to testuser1
:param password: amqp user password
:returns: None if successful. Raise on error.
"""
self.log.debug('Adding rmq user ({})...'.format(username))
# Check that user does not already exist
cmd_user_list = 'rabbitmqctl list_users'
output, _ = self.run_cmd_unit(sentry_units[0], cmd_user_list)
if username in output:
self.log.warning('User ({}) already exists, returning '
'gracefully.'.format(username))
return
perms = '".*" ".*" ".*"'
cmds = ['rabbitmqctl add_user {} {}'.format(username, password),
'rabbitmqctl set_permissions {} {}'.format(username, perms)]
# Add user via first unit
for cmd in cmds:
output, _ = self.run_cmd_unit(sentry_units[0], cmd)
# Check connection against the other sentry_units
self.log.debug('Checking user connect against units...')
for sentry_unit in sentry_units:
connection = self.connect_amqp_by_unit(sentry_unit, ssl=False,
username=username,
password=password)
connection.close()
def delete_rmq_test_user(self, sentry_units, username="testuser1"):
"""Delete a rabbitmq user via the first rmq juju unit.
:param sentry_units: list of sentry unit pointers
:param username: amqp user name, default to testuser1
:param password: amqp user password
:returns: None if successful or no such user.
"""
self.log.debug('Deleting rmq user ({})...'.format(username))
# Check that the user exists
cmd_user_list = 'rabbitmqctl list_users'
output, _ = self.run_cmd_unit(sentry_units[0], cmd_user_list)
if username not in output:
self.log.warning('User ({}) does not exist, returning '
'gracefully.'.format(username))
return
# Delete the user
cmd_user_del = 'rabbitmqctl delete_user {}'.format(username)
output, _ = self.run_cmd_unit(sentry_units[0], cmd_user_del)
def get_rmq_cluster_status(self, sentry_unit):
"""Execute rabbitmq cluster status command on a unit and return
the full output.
:param unit: sentry unit
:returns: String containing console output of cluster status command
"""
cmd = 'rabbitmqctl cluster_status'
output, _ = self.run_cmd_unit(sentry_unit, cmd)
self.log.debug('{} cluster_status:\n{}'.format(
sentry_unit.info['unit_name'], output))
return str(output)
def get_rmq_cluster_running_nodes(self, sentry_unit):
"""Parse rabbitmqctl cluster_status output string, return list of
running rabbitmq cluster nodes.
:param unit: sentry unit
:returns: List containing node names of running nodes
"""
# NOTE(beisner): rabbitmqctl cluster_status output is not
# json-parsable, do string chop foo, then json.loads that.
str_stat = self.get_rmq_cluster_status(sentry_unit)
if 'running_nodes' in str_stat:
pos_start = str_stat.find("{running_nodes,") + 15
pos_end = str_stat.find("]},", pos_start) + 1
str_run_nodes = str_stat[pos_start:pos_end].replace("'", '"')
run_nodes = json.loads(str_run_nodes)
return run_nodes
else:
return []
def validate_rmq_cluster_running_nodes(self, sentry_units):
"""Check that all rmq unit hostnames are represented in the
cluster_status output of all units.
:param host_names: dict of juju unit names to host names
:param units: list of sentry unit pointers (all rmq units)
:returns: None if successful, otherwise return error message
"""
host_names = self.get_unit_hostnames(sentry_units)
errors = []
# Query every unit for cluster_status running nodes
for query_unit in sentry_units:
query_unit_name = query_unit.info['unit_name']
running_nodes = self.get_rmq_cluster_running_nodes(query_unit)
# Confirm that every unit is represented in the queried unit's
# cluster_status running nodes output.
for validate_unit in sentry_units:
val_host_name = host_names[validate_unit.info['unit_name']]
val_node_name = 'rabbit@{}'.format(val_host_name)
if val_node_name not in running_nodes:
errors.append('Cluster member check failed on {}: {} not '
'in {}\n'.format(query_unit_name,
val_node_name,
running_nodes))
if errors:
return ''.join(errors)
def rmq_ssl_is_enabled_on_unit(self, sentry_unit, port=None):
"""Check a single juju rmq unit for ssl and port in the config file."""
host = sentry_unit.info['public-address']
unit_name = sentry_unit.info['unit_name']
conf_file = '/etc/rabbitmq/rabbitmq.config'
conf_contents = str(self.file_contents_safe(sentry_unit,
conf_file, max_wait=16))
# Checks
conf_ssl = 'ssl' in conf_contents
conf_port = str(port) in conf_contents
# Port explicitly checked in config
if port and conf_port and conf_ssl:
self.log.debug('SSL is enabled @{}:{} '
'({})'.format(host, port, unit_name))
return True
elif port and not conf_port and conf_ssl:
self.log.debug('SSL is enabled @{} but not on port {} '
'({})'.format(host, port, unit_name))
return False
# Port not checked (useful when checking that ssl is disabled)
elif not port and conf_ssl:
self.log.debug('SSL is enabled @{}:{} '
'({})'.format(host, port, unit_name))
return True
elif not conf_ssl:
self.log.debug('SSL not enabled @{}:{} '
'({})'.format(host, port, unit_name))
return False
else:
msg = ('Unknown condition when checking SSL status @{}:{} '
'({})'.format(host, port, unit_name))
amulet.raise_status(amulet.FAIL, msg)
def validate_rmq_ssl_enabled_units(self, sentry_units, port=None):
"""Check that ssl is enabled on rmq juju sentry units.
:param sentry_units: list of all rmq sentry units
:param port: optional ssl port override to validate
:returns: None if successful, otherwise return error message
"""
for sentry_unit in sentry_units:
if not self.rmq_ssl_is_enabled_on_unit(sentry_unit, port=port):
return ('Unexpected condition: ssl is disabled on unit '
'({})'.format(sentry_unit.info['unit_name']))
return None
def validate_rmq_ssl_disabled_units(self, sentry_units):
"""Check that ssl is enabled on listed rmq juju sentry units.
:param sentry_units: list of all rmq sentry units
:returns: True if successful. Raise on error.
"""
for sentry_unit in sentry_units:
if self.rmq_ssl_is_enabled_on_unit(sentry_unit):
return ('Unexpected condition: ssl is enabled on unit '
'({})'.format(sentry_unit.info['unit_name']))
return None
def configure_rmq_ssl_on(self, sentry_units, deployment,
port=None, max_wait=60):
"""Turn ssl charm config option on, with optional non-default
ssl port specification. Confirm that it is enabled on every
unit.
:param sentry_units: list of sentry units
:param deployment: amulet deployment object pointer
:param port: amqp port, use defaults if None
:param max_wait: maximum time to wait in seconds to confirm
:returns: None if successful. Raise on error.
"""
self.log.debug('Setting ssl charm config option: on')
# Enable RMQ SSL
config = {'ssl': 'on'}
if port:
config['ssl_port'] = port
deployment.d.configure('rabbitmq-server', config)
# Wait for unit status
self.rmq_wait_for_cluster(deployment)
# Confirm
tries = 0
ret = self.validate_rmq_ssl_enabled_units(sentry_units, port=port)
while ret and tries < (max_wait / 4):
time.sleep(4)
self.log.debug('Attempt {}: {}'.format(tries, ret))
ret = self.validate_rmq_ssl_enabled_units(sentry_units, port=port)
tries += 1
if ret:
amulet.raise_status(amulet.FAIL, ret)
def configure_rmq_ssl_off(self, sentry_units, deployment, max_wait=60):
"""Turn ssl charm config option off, confirm that it is disabled
on every unit.
:param sentry_units: list of sentry units
:param deployment: amulet deployment object pointer
:param max_wait: maximum time to wait in seconds to confirm
:returns: None if successful. Raise on error.
"""
self.log.debug('Setting ssl charm config option: off')
# Disable RMQ SSL
config = {'ssl': 'off'}
deployment.d.configure('rabbitmq-server', config)
# Wait for unit status
self.rmq_wait_for_cluster(deployment)
# Confirm
tries = 0
ret = self.validate_rmq_ssl_disabled_units(sentry_units)
while ret and tries < (max_wait / 4):
time.sleep(4)
self.log.debug('Attempt {}: {}'.format(tries, ret))
ret = self.validate_rmq_ssl_disabled_units(sentry_units)
tries += 1
if ret:
amulet.raise_status(amulet.FAIL, ret)
def connect_amqp_by_unit(self, sentry_unit, ssl=False,
port=None, fatal=True,
username="testuser1", password="changeme"):
"""Establish and return a pika amqp connection to the rabbitmq service
running on a rmq juju unit.
:param sentry_unit: sentry unit pointer
:param ssl: boolean, default to False
:param port: amqp port, use defaults if None
:param fatal: boolean, default to True (raises on connect error)
:param username: amqp user name, default to testuser1
:param password: amqp user password
:returns: pika amqp connection pointer or None if failed and non-fatal
"""
host = sentry_unit.info['public-address']
unit_name = sentry_unit.info['unit_name']
# Default port logic if port is not specified
if ssl and not port:
port = 5671
elif not ssl and not port:
port = 5672
self.log.debug('Connecting to amqp on {}:{} ({}) as '
'{}...'.format(host, port, unit_name, username))
try:
credentials = pika.PlainCredentials(username, password)
parameters = pika.ConnectionParameters(host=host, port=port,
credentials=credentials,
ssl=ssl,
connection_attempts=3,
retry_delay=5,
socket_timeout=1)
connection = pika.BlockingConnection(parameters)
assert connection.server_properties['product'] == 'RabbitMQ'
self.log.debug('Connect OK')
return connection
except Exception as e:
msg = ('amqp connection failed to {}:{} as '
'{} ({})'.format(host, port, username, str(e)))
if fatal:
amulet.raise_status(amulet.FAIL, msg)
else:
self.log.warn(msg)
return None
def publish_amqp_message_by_unit(self, sentry_unit, message,
queue="test", ssl=False,
username="testuser1",
password="changeme",
port=None):
"""Publish an amqp message to a rmq juju unit.
:param sentry_unit: sentry unit pointer
:param message: amqp message string
:param queue: message queue, default to test
:param username: amqp user name, default to testuser1
:param password: amqp user password
:param ssl: boolean, default to False
:param port: amqp port, use defaults if None
:returns: None. Raises exception if publish failed.
"""
self.log.debug('Publishing message to {} queue:\n{}'.format(queue,
message))
connection = self.connect_amqp_by_unit(sentry_unit, ssl=ssl,
port=port,
username=username,
password=password)
# NOTE(beisner): extra debug here re: pika hang potential:
# https://github.com/pika/pika/issues/297
# https://groups.google.com/forum/#!topic/rabbitmq-users/Ja0iyfF0Szw
self.log.debug('Defining channel...')
channel = connection.channel()
self.log.debug('Declaring queue...')
channel.queue_declare(queue=queue, auto_delete=False, durable=True)
self.log.debug('Publishing message...')
channel.basic_publish(exchange='', routing_key=queue, body=message)
self.log.debug('Closing channel...')
channel.close()
self.log.debug('Closing connection...')
connection.close()
def get_amqp_message_by_unit(self, sentry_unit, queue="test",
username="testuser1",
password="changeme",
ssl=False, port=None):
"""Get an amqp message from a rmq juju unit.
:param sentry_unit: sentry unit pointer
:param queue: message queue, default to test
:param username: amqp user name, default to testuser1
:param password: amqp user password
:param ssl: boolean, default to False
:param port: amqp port, use defaults if None
:returns: amqp message body as string. Raise if get fails.
"""
connection = self.connect_amqp_by_unit(sentry_unit, ssl=ssl,
port=port,
username=username,
password=password)
channel = connection.channel()
method_frame, _, body = channel.basic_get(queue)
if method_frame:
self.log.debug('Retreived message from {} queue:\n{}'.format(queue,
body))
channel.basic_ack(method_frame.delivery_tag)
channel.close()
connection.close()
return body
else:
msg = 'No message retrieved.'
amulet.raise_status(amulet.FAIL, msg)

View File

@ -14,12 +14,13 @@
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import glob
import json
import os
import re
import time
from base64 import b64decode
from subprocess import check_call
from subprocess import check_call, CalledProcessError
import six
import yaml
@ -44,16 +45,20 @@ from charmhelpers.core.hookenv import (
INFO,
WARNING,
ERROR,
status_set,
)
from charmhelpers.core.sysctl import create as sysctl_create
from charmhelpers.core.strutils import bool_from_string
from charmhelpers.core.host import (
get_bond_master,
is_phy_iface,
list_nics,
get_nic_hwaddr,
mkdir,
write_file,
pwgen,
)
from charmhelpers.contrib.hahelpers.cluster import (
determine_apache_port,
@ -84,6 +89,14 @@ from charmhelpers.contrib.network.ip import (
is_bridge_member,
)
from charmhelpers.contrib.openstack.utils import get_host_ip
from charmhelpers.core.unitdata import kv
try:
import psutil
except ImportError:
apt_install('python-psutil', fatal=True)
import psutil
CA_CERT_PATH = '/usr/local/share/ca-certificates/keystone_juju_ca_cert.crt'
ADDRESS_TYPES = ['admin', 'internal', 'public']
@ -192,10 +205,50 @@ def config_flags_parser(config_flags):
class OSContextGenerator(object):
"""Base class for all context generators."""
interfaces = []
related = False
complete = False
missing_data = []
def __call__(self):
raise NotImplementedError
def context_complete(self, ctxt):
"""Check for missing data for the required context data.
Set self.missing_data if it exists and return False.
Set self.complete if no missing data and return True.
"""
# Fresh start
self.complete = False
self.missing_data = []
for k, v in six.iteritems(ctxt):
if v is None or v == '':
if k not in self.missing_data:
self.missing_data.append(k)
if self.missing_data:
self.complete = False
log('Missing required data: %s' % ' '.join(self.missing_data), level=INFO)
else:
self.complete = True
return self.complete
def get_related(self):
"""Check if any of the context interfaces have relation ids.
Set self.related and return True if one of the interfaces
has relation ids.
"""
# Fresh start
self.related = False
try:
for interface in self.interfaces:
if relation_ids(interface):
self.related = True
return self.related
except AttributeError as e:
log("{} {}"
"".format(self, e), 'INFO')
return self.related
class SharedDBContext(OSContextGenerator):
interfaces = ['shared-db']
@ -211,6 +264,7 @@ class SharedDBContext(OSContextGenerator):
self.database = database
self.user = user
self.ssl_dir = ssl_dir
self.rel_name = self.interfaces[0]
def __call__(self):
self.database = self.database or config('database')
@ -244,6 +298,7 @@ class SharedDBContext(OSContextGenerator):
password_setting = self.relation_prefix + '_password'
for rid in relation_ids(self.interfaces[0]):
self.related = True
for unit in related_units(rid):
rdata = relation_get(rid=rid, unit=unit)
host = rdata.get('db_host')
@ -255,7 +310,7 @@ class SharedDBContext(OSContextGenerator):
'database_password': rdata.get(password_setting),
'database_type': 'mysql'
}
if context_complete(ctxt):
if self.context_complete(ctxt):
db_ssl(rdata, ctxt, self.ssl_dir)
return ctxt
return {}
@ -276,6 +331,7 @@ class PostgresqlDBContext(OSContextGenerator):
ctxt = {}
for rid in relation_ids(self.interfaces[0]):
self.related = True
for unit in related_units(rid):
rel_host = relation_get('host', rid=rid, unit=unit)
rel_user = relation_get('user', rid=rid, unit=unit)
@ -285,7 +341,7 @@ class PostgresqlDBContext(OSContextGenerator):
'database_user': rel_user,
'database_password': rel_passwd,
'database_type': 'postgresql'}
if context_complete(ctxt):
if self.context_complete(ctxt):
return ctxt
return {}
@ -346,6 +402,7 @@ class IdentityServiceContext(OSContextGenerator):
ctxt['signing_dir'] = cachedir
for rid in relation_ids(self.rel_name):
self.related = True
for unit in related_units(rid):
rdata = relation_get(rid=rid, unit=unit)
serv_host = rdata.get('service_host')
@ -354,6 +411,7 @@ class IdentityServiceContext(OSContextGenerator):
auth_host = format_ipv6_addr(auth_host) or auth_host
svc_protocol = rdata.get('service_protocol') or 'http'
auth_protocol = rdata.get('auth_protocol') or 'http'
api_version = rdata.get('api_version') or '2.0'
ctxt.update({'service_port': rdata.get('service_port'),
'service_host': serv_host,
'auth_host': auth_host,
@ -362,9 +420,10 @@ class IdentityServiceContext(OSContextGenerator):
'admin_user': rdata.get('service_username'),
'admin_password': rdata.get('service_password'),
'service_protocol': svc_protocol,
'auth_protocol': auth_protocol})
'auth_protocol': auth_protocol,
'api_version': api_version})
if context_complete(ctxt):
if self.context_complete(ctxt):
# NOTE(jamespage) this is required for >= icehouse
# so a missing value just indicates keystone needs
# upgrading
@ -403,6 +462,7 @@ class AMQPContext(OSContextGenerator):
ctxt = {}
for rid in relation_ids(self.rel_name):
ha_vip_only = False
self.related = True
for unit in related_units(rid):
if relation_get('clustered', rid=rid, unit=unit):
ctxt['clustered'] = True
@ -435,7 +495,7 @@ class AMQPContext(OSContextGenerator):
ha_vip_only = relation_get('ha-vip-only',
rid=rid, unit=unit) is not None
if context_complete(ctxt):
if self.context_complete(ctxt):
if 'rabbit_ssl_ca' in ctxt:
if not self.ssl_dir:
log("Charm not setup for ssl support but ssl ca "
@ -467,7 +527,7 @@ class AMQPContext(OSContextGenerator):
ctxt['oslo_messaging_flags'] = config_flags_parser(
oslo_messaging_flags)
if not context_complete(ctxt):
if not self.complete:
return {}
return ctxt
@ -483,13 +543,15 @@ class CephContext(OSContextGenerator):
log('Generating template context for ceph', level=DEBUG)
mon_hosts = []
auth = None
key = None
use_syslog = str(config('use-syslog')).lower()
ctxt = {
'use_syslog': str(config('use-syslog')).lower()
}
for rid in relation_ids('ceph'):
for unit in related_units(rid):
auth = relation_get('auth', rid=rid, unit=unit)
key = relation_get('key', rid=rid, unit=unit)
if not ctxt.get('auth'):
ctxt['auth'] = relation_get('auth', rid=rid, unit=unit)
if not ctxt.get('key'):
ctxt['key'] = relation_get('key', rid=rid, unit=unit)
ceph_pub_addr = relation_get('ceph-public-address', rid=rid,
unit=unit)
unit_priv_addr = relation_get('private-address', rid=rid,
@ -498,15 +560,12 @@ class CephContext(OSContextGenerator):
ceph_addr = format_ipv6_addr(ceph_addr) or ceph_addr
mon_hosts.append(ceph_addr)
ctxt = {'mon_hosts': ' '.join(sorted(mon_hosts)),
'auth': auth,
'key': key,
'use_syslog': use_syslog}
ctxt['mon_hosts'] = ' '.join(sorted(mon_hosts))
if not os.path.isdir('/etc/ceph'):
os.mkdir('/etc/ceph')
if not context_complete(ctxt):
if not self.context_complete(ctxt):
return {}
ensure_packages(['ceph-common'])
@ -579,15 +638,28 @@ class HAProxyContext(OSContextGenerator):
if config('haproxy-client-timeout'):
ctxt['haproxy_client_timeout'] = config('haproxy-client-timeout')
if config('haproxy-queue-timeout'):
ctxt['haproxy_queue_timeout'] = config('haproxy-queue-timeout')
if config('haproxy-connect-timeout'):
ctxt['haproxy_connect_timeout'] = config('haproxy-connect-timeout')
if config('prefer-ipv6'):
ctxt['ipv6'] = True
ctxt['local_host'] = 'ip6-localhost'
ctxt['haproxy_host'] = '::'
ctxt['stat_port'] = ':::8888'
else:
ctxt['local_host'] = '127.0.0.1'
ctxt['haproxy_host'] = '0.0.0.0'
ctxt['stat_port'] = ':8888'
ctxt['stat_port'] = '8888'
db = kv()
ctxt['stat_password'] = db.get('stat-password')
if not ctxt['stat_password']:
ctxt['stat_password'] = db.set('stat-password',
pwgen(32))
db.flush()
for frontend in cluster_hosts:
if (len(cluster_hosts[frontend]['backends']) > 1 or
@ -878,19 +950,6 @@ class NeutronContext(OSContextGenerator):
return calico_ctxt
def pg_ctxt(self):
driver = neutron_plugin_attribute(self.plugin, 'driver',
self.network_manager)
config = neutron_plugin_attribute(self.plugin, 'config',
self.network_manager)
ovs_ctxt = {'core_plugin': driver,
'neutron_plugin': 'plumgrid',
'neutron_security_groups': self.neutron_security_groups,
'local_ip': unit_private_ip(),
'config': config}
return ovs_ctxt
def neutron_ctxt(self):
if https():
proto = 'https'
@ -906,6 +965,31 @@ class NeutronContext(OSContextGenerator):
'neutron_url': '%s://%s:%s' % (proto, host, '9696')}
return ctxt
def pg_ctxt(self):
driver = neutron_plugin_attribute(self.plugin, 'driver',
self.network_manager)
config = neutron_plugin_attribute(self.plugin, 'config',
self.network_manager)
ovs_ctxt = {'core_plugin': driver,
'neutron_plugin': 'plumgrid',
'neutron_security_groups': self.neutron_security_groups,
'local_ip': unit_private_ip(),
'config': config}
return ovs_ctxt
def midonet_ctxt(self):
driver = neutron_plugin_attribute(self.plugin, 'driver',
self.network_manager)
midonet_config = neutron_plugin_attribute(self.plugin, 'config',
self.network_manager)
mido_ctxt = {'core_plugin': driver,
'neutron_plugin': 'midonet',
'neutron_security_groups': self.neutron_security_groups,
'local_ip': unit_private_ip(),
'config': midonet_config}
return mido_ctxt
def __call__(self):
if self.network_manager not in ['quantum', 'neutron']:
return {}
@ -927,6 +1011,8 @@ class NeutronContext(OSContextGenerator):
ctxt.update(self.nuage_ctxt())
elif self.plugin == 'plumgrid':
ctxt.update(self.pg_ctxt())
elif self.plugin == 'midonet':
ctxt.update(self.midonet_ctxt())
alchemy_flags = config('neutron-alchemy-flags')
if alchemy_flags:
@ -938,7 +1024,6 @@ class NeutronContext(OSContextGenerator):
class NeutronPortContext(OSContextGenerator):
NIC_PREFIXES = ['eth', 'bond']
def resolve_ports(self, ports):
"""Resolve NICs not yet bound to bridge(s)
@ -950,7 +1035,18 @@ class NeutronPortContext(OSContextGenerator):
hwaddr_to_nic = {}
hwaddr_to_ip = {}
for nic in list_nics(self.NIC_PREFIXES):
for nic in list_nics():
# Ignore virtual interfaces (bond masters will be identified from
# their slaves)
if not is_phy_iface(nic):
continue
_nic = get_bond_master(nic)
if _nic:
log("Replacing iface '%s' with bond master '%s'" % (nic, _nic),
level=DEBUG)
nic = _nic
hwaddr = get_nic_hwaddr(nic)
hwaddr_to_nic[hwaddr] = nic
addresses = get_ipv4_addr(nic, fatal=False)
@ -976,7 +1072,8 @@ class NeutronPortContext(OSContextGenerator):
# trust it to be the real external network).
resolved.append(entry)
return resolved
# Ensure no duplicates
return list(set(resolved))
class OSConfigFlagContext(OSContextGenerator):
@ -1016,6 +1113,20 @@ class OSConfigFlagContext(OSContextGenerator):
config_flags_parser(config_flags)}
class LibvirtConfigFlagsContext(OSContextGenerator):
"""
This context provides support for extending
the libvirt section through user-defined flags.
"""
def __call__(self):
ctxt = {}
libvirt_flags = config('libvirt-flags')
if libvirt_flags:
ctxt['libvirt_flags'] = config_flags_parser(
libvirt_flags)
return ctxt
class SubordinateConfigContext(OSContextGenerator):
"""
@ -1048,7 +1159,7 @@ class SubordinateConfigContext(OSContextGenerator):
ctxt = {
... other context ...
'subordinate_config': {
'subordinate_configuration': {
'DEFAULT': {
'key1': 'value1',
},
@ -1066,13 +1177,22 @@ class SubordinateConfigContext(OSContextGenerator):
:param config_file : Service's config file to query sections
:param interface : Subordinate interface to inspect
"""
self.service = service
self.config_file = config_file
self.interface = interface
if isinstance(service, list):
self.services = service
else:
self.services = [service]
if isinstance(interface, list):
self.interfaces = interface
else:
self.interfaces = [interface]
def __call__(self):
ctxt = {'sections': {}}
for rid in relation_ids(self.interface):
rids = []
for interface in self.interfaces:
rids.extend(relation_ids(interface))
for rid in rids:
for unit in related_units(rid):
sub_config = relation_get('subordinate_configuration',
rid=rid, unit=unit)
@ -1080,33 +1200,37 @@ class SubordinateConfigContext(OSContextGenerator):
try:
sub_config = json.loads(sub_config)
except:
log('Could not parse JSON from subordinate_config '
'setting from %s' % rid, level=ERROR)
log('Could not parse JSON from '
'subordinate_configuration setting from %s'
% rid, level=ERROR)
continue
if self.service not in sub_config:
log('Found subordinate_config on %s but it contained'
'nothing for %s service' % (rid, self.service),
level=INFO)
continue
for service in self.services:
if service not in sub_config:
log('Found subordinate_configuration on %s but it '
'contained nothing for %s service'
% (rid, service), level=INFO)
continue
sub_config = sub_config[self.service]
if self.config_file not in sub_config:
log('Found subordinate_config on %s but it contained'
'nothing for %s' % (rid, self.config_file),
level=INFO)
continue
sub_config = sub_config[self.config_file]
for k, v in six.iteritems(sub_config):
if k == 'sections':
for section, config_dict in six.iteritems(v):
log("adding section '%s'" % (section),
level=DEBUG)
ctxt[k][section] = config_dict
else:
ctxt[k] = v
sub_config = sub_config[service]
if self.config_file not in sub_config:
log('Found subordinate_configuration on %s but it '
'contained nothing for %s'
% (rid, self.config_file), level=INFO)
continue
sub_config = sub_config[self.config_file]
for k, v in six.iteritems(sub_config):
if k == 'sections':
for section, config_list in six.iteritems(v):
log("adding section '%s'" % (section),
level=DEBUG)
if ctxt[k].get(section):
ctxt[k][section].extend(config_list)
else:
ctxt[k][section] = config_list
else:
ctxt[k] = v
log("%d section(s) found" % (len(ctxt['sections'])), level=DEBUG)
return ctxt
@ -1143,13 +1267,11 @@ class WorkerConfigContext(OSContextGenerator):
@property
def num_cpus(self):
try:
from psutil import NUM_CPUS
except ImportError:
apt_install('python-psutil', fatal=True)
from psutil import NUM_CPUS
return NUM_CPUS
# NOTE: use cpu_count if present (16.04 support)
if hasattr(psutil, 'cpu_count'):
return psutil.cpu_count()
else:
return psutil.NUM_CPUS
def __call__(self):
multiplier = config('worker-multiplier') or 0
@ -1283,15 +1405,19 @@ class DataPortContext(NeutronPortContext):
def __call__(self):
ports = config('data-port')
if ports:
# Map of {port/mac:bridge}
portmap = parse_data_port_mappings(ports)
ports = portmap.values()
ports = portmap.keys()
# Resolve provided ports or mac addresses and filter out those
# already attached to a bridge.
resolved = self.resolve_ports(ports)
# FIXME: is this necessary?
normalized = {get_nic_hwaddr(port): port for port in resolved
if port not in ports}
normalized.update({port: port for port in resolved
if port in ports})
if resolved:
return {bridge: normalized[port] for bridge, port in
return {normalized[port]: bridge for port, bridge in
six.iteritems(portmap) if port in normalized.keys()}
return None
@ -1302,12 +1428,22 @@ class PhyNICMTUContext(DataPortContext):
def __call__(self):
ctxt = {}
mappings = super(PhyNICMTUContext, self).__call__()
if mappings and mappings.values():
ports = mappings.values()
if mappings and mappings.keys():
ports = sorted(mappings.keys())
napi_settings = NeutronAPIContext()()
mtu = napi_settings.get('network_device_mtu')
all_ports = set()
# If any of ports is a vlan device, its underlying device must have
# mtu applied first.
for port in ports:
for lport in glob.glob("/sys/class/net/%s/lower_*" % port):
lport = os.path.basename(lport)
all_ports.add(lport.split('_')[1])
all_ports = list(all_ports)
all_ports.extend(ports)
if mtu:
ctxt["devs"] = '\\n'.join(ports)
ctxt["devs"] = '\\n'.join(all_ports)
ctxt['mtu'] = mtu
return ctxt
@ -1338,7 +1474,110 @@ class NetworkServiceContext(OSContextGenerator):
rdata.get('service_protocol') or 'http',
'auth_protocol':
rdata.get('auth_protocol') or 'http',
'api_version':
rdata.get('api_version') or '2.0',
}
if context_complete(ctxt):
if self.context_complete(ctxt):
return ctxt
return {}
class InternalEndpointContext(OSContextGenerator):
"""Internal endpoint context.
This context provides the endpoint type used for communication between
services e.g. between Nova and Cinder internally. Openstack uses Public
endpoints by default so this allows admins to optionally use internal
endpoints.
"""
def __call__(self):
return {'use_internal_endpoints': config('use-internal-endpoints')}
class AppArmorContext(OSContextGenerator):
"""Base class for apparmor contexts."""
def __init__(self):
self._ctxt = None
self.aa_profile = None
self.aa_utils_packages = ['apparmor-utils']
@property
def ctxt(self):
if self._ctxt is not None:
return self._ctxt
self._ctxt = self._determine_ctxt()
return self._ctxt
def _determine_ctxt(self):
"""
Validate aa-profile-mode settings is disable, enforce, or complain.
:return ctxt: Dictionary of the apparmor profile or None
"""
if config('aa-profile-mode') in ['disable', 'enforce', 'complain']:
ctxt = {'aa-profile-mode': config('aa-profile-mode')}
else:
ctxt = None
return ctxt
def __call__(self):
return self.ctxt
def install_aa_utils(self):
"""
Install packages required for apparmor configuration.
"""
log("Installing apparmor utils.")
ensure_packages(self.aa_utils_packages)
def manually_disable_aa_profile(self):
"""
Manually disable an apparmor profile.
If aa-profile-mode is set to disabled (default) this is required as the
template has been written but apparmor is yet unaware of the profile
and aa-disable aa-profile fails. Without this the profile would kick
into enforce mode on the next service restart.
"""
profile_path = '/etc/apparmor.d'
disable_path = '/etc/apparmor.d/disable'
if not os.path.lexists(os.path.join(disable_path, self.aa_profile)):
os.symlink(os.path.join(profile_path, self.aa_profile),
os.path.join(disable_path, self.aa_profile))
def setup_aa_profile(self):
"""
Setup an apparmor profile.
The ctxt dictionary will contain the apparmor profile mode and
the apparmor profile name.
Makes calls out to aa-disable, aa-complain, or aa-enforce to setup
the apparmor profile.
"""
self()
if not self.ctxt:
log("Not enabling apparmor Profile")
return
self.install_aa_utils()
cmd = ['aa-{}'.format(self.ctxt['aa-profile-mode'])]
cmd.append(self.ctxt['aa-profile'])
log("Setting up the apparmor profile for {} in {} mode."
"".format(self.ctxt['aa-profile'], self.ctxt['aa-profile-mode']))
try:
check_call(cmd)
except CalledProcessError as e:
# If aa-profile-mode is set to disabled (default) manual
# disabling is required as the template has been written but
# apparmor is yet unaware of the profile and aa-disable aa-profile
# fails. If aa-disable learns to read profile files first this can
# be removed.
if self.ctxt['aa-profile-mode'] == 'disable':
log("Manually disabling the apparmor profile for {}."
"".format(self.ctxt['aa-profile']))
self.manually_disable_aa_profile()
return
status_set('blocked', "Apparmor profile {} failed to be set to {}."
"".format(self.ctxt['aa-profile'],
self.ctxt['aa-profile-mode']))
raise e

View File

@ -14,16 +14,19 @@
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
from charmhelpers.core.hookenv import (
config,
unit_get,
service_name,
network_get_primary_address,
)
from charmhelpers.contrib.network.ip import (
get_address_in_network,
is_address_in_network,
is_ipv6,
get_ipv6_addr,
resolve_network_cidr,
)
from charmhelpers.contrib.hahelpers.cluster import is_clustered
@ -33,16 +36,19 @@ ADMIN = 'admin'
ADDRESS_MAP = {
PUBLIC: {
'binding': 'public',
'config': 'os-public-network',
'fallback': 'public-address',
'override': 'os-public-hostname',
},
INTERNAL: {
'binding': 'internal',
'config': 'os-internal-network',
'fallback': 'private-address',
'override': 'os-internal-hostname',
},
ADMIN: {
'binding': 'admin',
'config': 'os-admin-network',
'fallback': 'private-address',
'override': 'os-admin-hostname',
@ -110,7 +116,7 @@ def resolve_address(endpoint_type=PUBLIC):
correct network. If clustered with no nets defined, return primary vip.
If not clustered, return unit address ensuring address is on configured net
split if one is configured.
split if one is configured, or a Juju 2.0 extra-binding has been used.
:param endpoint_type: Network endpoing type
"""
@ -125,23 +131,45 @@ def resolve_address(endpoint_type=PUBLIC):
net_type = ADDRESS_MAP[endpoint_type]['config']
net_addr = config(net_type)
net_fallback = ADDRESS_MAP[endpoint_type]['fallback']
binding = ADDRESS_MAP[endpoint_type]['binding']
clustered = is_clustered()
if clustered:
if not net_addr:
# If no net-splits defined, we expect a single vip
resolved_address = vips[0]
else:
if clustered and vips:
if net_addr:
for vip in vips:
if is_address_in_network(net_addr, vip):
resolved_address = vip
break
else:
# NOTE: endeavour to check vips against network space
# bindings
try:
bound_cidr = resolve_network_cidr(
network_get_primary_address(binding)
)
for vip in vips:
if is_address_in_network(bound_cidr, vip):
resolved_address = vip
break
except NotImplementedError:
# If no net-splits configured and no support for extra
# bindings/network spaces so we expect a single vip
resolved_address = vips[0]
else:
if config('prefer-ipv6'):
fallback_addr = get_ipv6_addr(exc_list=vips)[0]
else:
fallback_addr = unit_get(net_fallback)
resolved_address = get_address_in_network(net_addr, fallback_addr)
if net_addr:
resolved_address = get_address_in_network(net_addr, fallback_addr)
else:
# NOTE: only try to use extra bindings if legacy network
# configuration is not in use
try:
resolved_address = network_get_primary_address(binding)
except NotImplementedError:
resolved_address = fallback_addr
if resolved_address is None:
raise ValueError("Unable to resolve a suitable IP address based on "

View File

@ -50,7 +50,7 @@ def determine_dkms_package():
if kernel_version() >= (3, 13):
return []
else:
return ['openvswitch-datapath-dkms']
return [headers_package(), 'openvswitch-datapath-dkms']
# legacy
@ -70,7 +70,7 @@ def quantum_plugins():
relation_prefix='neutron',
ssl_dir=QUANTUM_CONF_DIR)],
'services': ['quantum-plugin-openvswitch-agent'],
'packages': [[headers_package()] + determine_dkms_package(),
'packages': [determine_dkms_package(),
['quantum-plugin-openvswitch-agent']],
'server_packages': ['quantum-server',
'quantum-plugin-openvswitch'],
@ -111,7 +111,7 @@ def neutron_plugins():
relation_prefix='neutron',
ssl_dir=NEUTRON_CONF_DIR)],
'services': ['neutron-plugin-openvswitch-agent'],
'packages': [[headers_package()] + determine_dkms_package(),
'packages': [determine_dkms_package(),
['neutron-plugin-openvswitch-agent']],
'server_packages': ['neutron-server',
'neutron-plugin-openvswitch'],
@ -155,7 +155,7 @@ def neutron_plugins():
relation_prefix='neutron',
ssl_dir=NEUTRON_CONF_DIR)],
'services': [],
'packages': [[headers_package()] + determine_dkms_package(),
'packages': [determine_dkms_package(),
['neutron-plugin-cisco']],
'server_packages': ['neutron-server',
'neutron-plugin-cisco'],
@ -174,7 +174,7 @@ def neutron_plugins():
'neutron-dhcp-agent',
'nova-api-metadata',
'etcd'],
'packages': [[headers_package()] + determine_dkms_package(),
'packages': [determine_dkms_package(),
['calico-compute',
'bird',
'neutron-dhcp-agent',
@ -209,6 +209,20 @@ def neutron_plugins():
'server_packages': ['neutron-server',
'neutron-plugin-plumgrid'],
'server_services': ['neutron-server']
},
'midonet': {
'config': '/etc/neutron/plugins/midonet/midonet.ini',
'driver': 'midonet.neutron.plugin.MidonetPluginV2',
'contexts': [
context.SharedDBContext(user=config('neutron-database-user'),
database=config('neutron-database'),
relation_prefix='neutron',
ssl_dir=NEUTRON_CONF_DIR)],
'services': [],
'packages': [determine_dkms_package()],
'server_packages': ['neutron-server',
'python-neutron-plugin-midonet'],
'server_services': ['neutron-server']
}
}
if release >= 'icehouse':
@ -219,6 +233,20 @@ def neutron_plugins():
'neutron-plugin-ml2']
# NOTE: patch in vmware renames nvp->nsx for icehouse onwards
plugins['nvp'] = plugins['nsx']
if release >= 'kilo':
plugins['midonet']['driver'] = (
'neutron.plugins.midonet.plugin.MidonetPluginV2')
if release >= 'liberty':
plugins['midonet']['driver'] = (
'midonet.neutron.plugin_v1.MidonetPluginV2')
plugins['midonet']['server_packages'].remove(
'python-neutron-plugin-midonet')
plugins['midonet']['server_packages'].append(
'python-networking-midonet')
plugins['plumgrid']['driver'] = (
'networking_plumgrid.neutron.plugins.plugin.NeutronPluginPLUMgridV2')
plugins['plumgrid']['server_packages'].remove(
'neutron-plugin-plumgrid')
return plugins
@ -269,17 +297,30 @@ def network_manager():
return 'neutron'
def parse_mappings(mappings):
def parse_mappings(mappings, key_rvalue=False):
"""By default mappings are lvalue keyed.
If key_rvalue is True, the mapping will be reversed to allow multiple
configs for the same lvalue.
"""
parsed = {}
if mappings:
mappings = mappings.split()
for m in mappings:
p = m.partition(':')
key = p[0].strip()
if p[1]:
parsed[key] = p[2].strip()
if key_rvalue:
key_index = 2
val_index = 0
# if there is no rvalue skip to next
if not p[1]:
continue
else:
parsed[key] = ''
key_index = 0
val_index = 2
key = p[key_index].strip()
parsed[key] = p[val_index].strip()
return parsed
@ -297,25 +338,25 @@ def parse_bridge_mappings(mappings):
def parse_data_port_mappings(mappings, default_bridge='br-data'):
"""Parse data port mappings.
Mappings must be a space-delimited list of bridge:port mappings.
Mappings must be a space-delimited list of bridge:port.
Returns dict of the form {bridge:port}.
Returns dict of the form {port:bridge} where ports may be mac addresses or
interface names.
"""
_mappings = parse_mappings(mappings)
# NOTE(dosaboy): we use rvalue for key to allow multiple values to be
# proposed for <port> since it may be a mac address which will differ
# across units this allowing first-known-good to be chosen.
_mappings = parse_mappings(mappings, key_rvalue=True)
if not _mappings or list(_mappings.values()) == ['']:
if not mappings:
return {}
# For backwards-compatibility we need to support port-only provided in
# config.
_mappings = {default_bridge: mappings.split()[0]}
bridges = _mappings.keys()
ports = _mappings.values()
if len(set(bridges)) != len(bridges):
raise Exception("It is not allowed to have more than one port "
"configured on the same bridge")
_mappings = {mappings.split()[0]: default_bridge}
ports = _mappings.keys()
if len(set(ports)) != len(ports):
raise Exception("It is not allowed to have the same port configured "
"on more than one bridge")

View File

@ -18,7 +18,7 @@ import os
import six
from charmhelpers.fetch import apt_install
from charmhelpers.fetch import apt_install, apt_update
from charmhelpers.core.hookenv import (
log,
ERROR,
@ -29,6 +29,7 @@ from charmhelpers.contrib.openstack.utils import OPENSTACK_CODENAMES
try:
from jinja2 import FileSystemLoader, ChoiceLoader, Environment, exceptions
except ImportError:
apt_update(fatal=True)
apt_install('python-jinja2', fatal=True)
from jinja2 import FileSystemLoader, ChoiceLoader, Environment, exceptions
@ -112,7 +113,7 @@ class OSConfigTemplate(object):
def complete_contexts(self):
'''
Return a list of interfaces that have atisfied contexts.
Return a list of interfaces that have satisfied contexts.
'''
if self._complete_contexts:
return self._complete_contexts
@ -293,3 +294,30 @@ class OSConfigRenderer(object):
[interfaces.extend(i.complete_contexts())
for i in six.itervalues(self.templates)]
return interfaces
def get_incomplete_context_data(self, interfaces):
'''
Return dictionary of relation status of interfaces and any missing
required context data. Example:
{'amqp': {'missing_data': ['rabbitmq_password'], 'related': True},
'zeromq-configuration': {'related': False}}
'''
incomplete_context_data = {}
for i in six.itervalues(self.templates):
for context in i.contexts:
for interface in interfaces:
related = False
if interface in context.interfaces:
related = context.get_related()
missing_data = context.missing_data
if missing_data:
incomplete_context_data[interface] = {'missing_data': missing_data}
if related:
if incomplete_context_data.get(interface):
incomplete_context_data[interface].update({'related': True})
else:
incomplete_context_data[interface] = {'related': True}
else:
incomplete_context_data[interface] = {'related': False}
return incomplete_context_data

File diff suppressed because it is too large Load Diff

View File

@ -59,7 +59,7 @@ def some_hook():
"""
def leader_get(attribute=None):
def leader_get(attribute=None, rid=None):
"""Wrapper to ensure that settings are migrated from the peer relation.
This is to support upgrading an environment that does not support
@ -94,7 +94,8 @@ def leader_get(attribute=None):
# If attribute not present in leader db, check if this unit has set
# the attribute in the peer relation
if not leader_settings:
peer_setting = relation_get(attribute=attribute, unit=local_unit())
peer_setting = _relation_get(attribute=attribute, unit=local_unit(),
rid=rid)
if peer_setting:
leader_set(settings={attribute: peer_setting})
leader_settings = peer_setting
@ -103,7 +104,7 @@ def leader_get(attribute=None):
settings_migrated = True
migrated.add(attribute)
else:
r_settings = relation_get(unit=local_unit())
r_settings = _relation_get(unit=local_unit(), rid=rid)
if r_settings:
for key in set(r_settings.keys()).difference(migrated):
# Leader setting wins
@ -151,7 +152,7 @@ def relation_get(attribute=None, unit=None, rid=None):
"""
try:
if rid in relation_ids('cluster'):
return leader_get(attribute)
return leader_get(attribute, rid)
else:
raise NotImplementedError
except NotImplementedError:

View File

@ -19,20 +19,35 @@
import os
import subprocess
import sys
from charmhelpers.fetch import apt_install, apt_update
from charmhelpers.core.hookenv import charm_dir, log
try:
from pip import main as pip_execute
except ImportError:
apt_update()
apt_install('python-pip')
from pip import main as pip_execute
__author__ = "Jorge Niedbalski <jorge.niedbalski@canonical.com>"
def pip_execute(*args, **kwargs):
"""Overriden pip_execute() to stop sys.path being changed.
The act of importing main from the pip module seems to cause add wheels
from the /usr/share/python-wheels which are installed by various tools.
This function ensures that sys.path remains the same after the call is
executed.
"""
try:
_path = sys.path
try:
from pip import main as _pip_execute
except ImportError:
apt_update()
apt_install('python-pip')
from pip import main as _pip_execute
_pip_execute(*args, **kwargs)
finally:
sys.path = _path
def parse_options(given, available):
"""Given a set of options, check if available"""
for key, value in sorted(given.items()):
@ -42,8 +57,12 @@ def parse_options(given, available):
yield "--{0}={1}".format(key, value)
def pip_install_requirements(requirements, **options):
"""Install a requirements file """
def pip_install_requirements(requirements, constraints=None, **options):
"""Install a requirements file.
:param constraints: Path to pip constraints file.
http://pip.readthedocs.org/en/stable/user_guide/#constraints-files
"""
command = ["install"]
available_options = ('proxy', 'src', 'log', )
@ -51,8 +70,13 @@ def pip_install_requirements(requirements, **options):
command.append(option)
command.append("-r {0}".format(requirements))
log("Installing from file: {} with options: {}".format(requirements,
command))
if constraints:
command.append("-c {0}".format(constraints))
log("Installing from file: {} with constraints {} "
"and options: {}".format(requirements, constraints, command))
else:
log("Installing from file: {} with options: {}".format(requirements,
command))
pip_execute(command)

View File

@ -23,11 +23,16 @@
# James Page <james.page@ubuntu.com>
# Adam Gandelman <adamg@ubuntu.com>
#
import bisect
import errno
import hashlib
import six
import os
import shutil
import json
import time
import uuid
from subprocess import (
check_call,
@ -35,8 +40,10 @@ from subprocess import (
CalledProcessError,
)
from charmhelpers.core.hookenv import (
local_unit,
relation_get,
relation_ids,
relation_set,
related_units,
log,
DEBUG,
@ -56,6 +63,8 @@ from charmhelpers.fetch import (
apt_install,
)
from charmhelpers.core.kernel import modprobe
KEYRING = '/etc/ceph/ceph.client.{}.keyring'
KEYFILE = '/etc/ceph/ceph.client.{}.key'
@ -67,6 +76,559 @@ log to syslog = {use_syslog}
err to syslog = {use_syslog}
clog to syslog = {use_syslog}
"""
# For 50 < osds < 240,000 OSDs (Roughly 1 Exabyte at 6T OSDs)
powers_of_two = [8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608]
def validator(value, valid_type, valid_range=None):
"""
Used to validate these: http://docs.ceph.com/docs/master/rados/operations/pools/#set-pool-values
Example input:
validator(value=1,
valid_type=int,
valid_range=[0, 2])
This says I'm testing value=1. It must be an int inclusive in [0,2]
:param value: The value to validate
:param valid_type: The type that value should be.
:param valid_range: A range of values that value can assume.
:return:
"""
assert isinstance(value, valid_type), "{} is not a {}".format(
value,
valid_type)
if valid_range is not None:
assert isinstance(valid_range, list), \
"valid_range must be a list, was given {}".format(valid_range)
# If we're dealing with strings
if valid_type is six.string_types:
assert value in valid_range, \
"{} is not in the list {}".format(value, valid_range)
# Integer, float should have a min and max
else:
if len(valid_range) != 2:
raise ValueError(
"Invalid valid_range list of {} for {}. "
"List must be [min,max]".format(valid_range, value))
assert value >= valid_range[0], \
"{} is less than minimum allowed value of {}".format(
value, valid_range[0])
assert value <= valid_range[1], \
"{} is greater than maximum allowed value of {}".format(
value, valid_range[1])
class PoolCreationError(Exception):
"""
A custom error to inform the caller that a pool creation failed. Provides an error message
"""
def __init__(self, message):
super(PoolCreationError, self).__init__(message)
class Pool(object):
"""
An object oriented approach to Ceph pool creation. This base class is inherited by ReplicatedPool and ErasurePool.
Do not call create() on this base class as it will not do anything. Instantiate a child class and call create().
"""
def __init__(self, service, name):
self.service = service
self.name = name
# Create the pool if it doesn't exist already
# To be implemented by subclasses
def create(self):
pass
def add_cache_tier(self, cache_pool, mode):
"""
Adds a new cache tier to an existing pool.
:param cache_pool: six.string_types. The cache tier pool name to add.
:param mode: six.string_types. The caching mode to use for this pool. valid range = ["readonly", "writeback"]
:return: None
"""
# Check the input types and values
validator(value=cache_pool, valid_type=six.string_types)
validator(value=mode, valid_type=six.string_types, valid_range=["readonly", "writeback"])
check_call(['ceph', '--id', self.service, 'osd', 'tier', 'add', self.name, cache_pool])
check_call(['ceph', '--id', self.service, 'osd', 'tier', 'cache-mode', cache_pool, mode])
check_call(['ceph', '--id', self.service, 'osd', 'tier', 'set-overlay', self.name, cache_pool])
check_call(['ceph', '--id', self.service, 'osd', 'pool', 'set', cache_pool, 'hit_set_type', 'bloom'])
def remove_cache_tier(self, cache_pool):
"""
Removes a cache tier from Ceph. Flushes all dirty objects from writeback pools and waits for that to complete.
:param cache_pool: six.string_types. The cache tier pool name to remove.
:return: None
"""
# read-only is easy, writeback is much harder
mode = get_cache_mode(self.service, cache_pool)
version = ceph_version()
if mode == 'readonly':
check_call(['ceph', '--id', self.service, 'osd', 'tier', 'cache-mode', cache_pool, 'none'])
check_call(['ceph', '--id', self.service, 'osd', 'tier', 'remove', self.name, cache_pool])
elif mode == 'writeback':
pool_forward_cmd = ['ceph', '--id', self.service, 'osd', 'tier',
'cache-mode', cache_pool, 'forward']
if version >= '10.1':
# Jewel added a mandatory flag
pool_forward_cmd.append('--yes-i-really-mean-it')
check_call(pool_forward_cmd)
# Flush the cache and wait for it to return
check_call(['rados', '--id', self.service, '-p', cache_pool, 'cache-flush-evict-all'])
check_call(['ceph', '--id', self.service, 'osd', 'tier', 'remove-overlay', self.name])
check_call(['ceph', '--id', self.service, 'osd', 'tier', 'remove', self.name, cache_pool])
def get_pgs(self, pool_size):
"""
:param pool_size: int. pool_size is either the number of replicas for replicated pools or the K+M sum for
erasure coded pools
:return: int. The number of pgs to use.
"""
validator(value=pool_size, valid_type=int)
osd_list = get_osds(self.service)
if not osd_list:
# NOTE(james-page): Default to 200 for older ceph versions
# which don't support OSD query from cli
return 200
osd_list_length = len(osd_list)
# Calculate based on Ceph best practices
if osd_list_length < 5:
return 128
elif 5 < osd_list_length < 10:
return 512
elif 10 < osd_list_length < 50:
return 4096
else:
estimate = (osd_list_length * 100) / pool_size
# Return the next nearest power of 2
index = bisect.bisect_right(powers_of_two, estimate)
return powers_of_two[index]
class ReplicatedPool(Pool):
def __init__(self, service, name, pg_num=None, replicas=2):
super(ReplicatedPool, self).__init__(service=service, name=name)
self.replicas = replicas
if pg_num is None:
self.pg_num = self.get_pgs(self.replicas)
else:
self.pg_num = pg_num
def create(self):
if not pool_exists(self.service, self.name):
# Create it
cmd = ['ceph', '--id', self.service, 'osd', 'pool', 'create',
self.name, str(self.pg_num)]
try:
check_call(cmd)
# Set the pool replica size
update_pool(client=self.service,
pool=self.name,
settings={'size': str(self.replicas)})
except CalledProcessError:
raise
# Default jerasure erasure coded pool
class ErasurePool(Pool):
def __init__(self, service, name, erasure_code_profile="default"):
super(ErasurePool, self).__init__(service=service, name=name)
self.erasure_code_profile = erasure_code_profile
def create(self):
if not pool_exists(self.service, self.name):
# Try to find the erasure profile information so we can properly size the pgs
erasure_profile = get_erasure_profile(service=self.service, name=self.erasure_code_profile)
# Check for errors
if erasure_profile is None:
log(message='Failed to discover erasure_profile named={}'.format(self.erasure_code_profile),
level=ERROR)
raise PoolCreationError(message='unable to find erasure profile {}'.format(self.erasure_code_profile))
if 'k' not in erasure_profile or 'm' not in erasure_profile:
# Error
log(message='Unable to find k (data chunks) or m (coding chunks) in {}'.format(erasure_profile),
level=ERROR)
raise PoolCreationError(
message='unable to find k (data chunks) or m (coding chunks) in {}'.format(erasure_profile))
pgs = self.get_pgs(int(erasure_profile['k']) + int(erasure_profile['m']))
# Create it
cmd = ['ceph', '--id', self.service, 'osd', 'pool', 'create', self.name, str(pgs), str(pgs),
'erasure', self.erasure_code_profile]
try:
check_call(cmd)
except CalledProcessError:
raise
"""Get an existing erasure code profile if it already exists.
Returns json formatted output"""
def get_mon_map(service):
"""
Returns the current monitor map.
:param service: six.string_types. The Ceph user name to run the command under
:return: json string. :raise: ValueError if the monmap fails to parse.
Also raises CalledProcessError if our ceph command fails
"""
try:
mon_status = check_output(
['ceph', '--id', service,
'mon_status', '--format=json'])
try:
return json.loads(mon_status)
except ValueError as v:
log("Unable to parse mon_status json: {}. Error: {}".format(
mon_status, v.message))
raise
except CalledProcessError as e:
log("mon_status command failed with message: {}".format(
e.message))
raise
def hash_monitor_names(service):
"""
Uses the get_mon_map() function to get information about the monitor
cluster.
Hash the name of each monitor. Return a sorted list of monitor hashes
in an ascending order.
:param service: six.string_types. The Ceph user name to run the command under
:rtype : dict. json dict of monitor name, ip address and rank
example: {
'name': 'ip-172-31-13-165',
'rank': 0,
'addr': '172.31.13.165:6789/0'}
"""
try:
hash_list = []
monitor_list = get_mon_map(service=service)
if monitor_list['monmap']['mons']:
for mon in monitor_list['monmap']['mons']:
hash_list.append(
hashlib.sha224(mon['name'].encode('utf-8')).hexdigest())
return sorted(hash_list)
else:
return None
except (ValueError, CalledProcessError):
raise
def monitor_key_delete(service, key):
"""
Delete a key and value pair from the monitor cluster
:param service: six.string_types. The Ceph user name to run the command under
Deletes a key value pair on the monitor cluster.
:param key: six.string_types. The key to delete.
"""
try:
check_output(
['ceph', '--id', service,
'config-key', 'del', str(key)])
except CalledProcessError as e:
log("Monitor config-key put failed with message: {}".format(
e.output))
raise
def monitor_key_set(service, key, value):
"""
Sets a key value pair on the monitor cluster.
:param service: six.string_types. The Ceph user name to run the command under
:param key: six.string_types. The key to set.
:param value: The value to set. This will be converted to a string
before setting
"""
try:
check_output(
['ceph', '--id', service,
'config-key', 'put', str(key), str(value)])
except CalledProcessError as e:
log("Monitor config-key put failed with message: {}".format(
e.output))
raise
def monitor_key_get(service, key):
"""
Gets the value of an existing key in the monitor cluster.
:param service: six.string_types. The Ceph user name to run the command under
:param key: six.string_types. The key to search for.
:return: Returns the value of that key or None if not found.
"""
try:
output = check_output(
['ceph', '--id', service,
'config-key', 'get', str(key)])
return output
except CalledProcessError as e:
log("Monitor config-key get failed with message: {}".format(
e.output))
return None
def monitor_key_exists(service, key):
"""
Searches for the existence of a key in the monitor cluster.
:param service: six.string_types. The Ceph user name to run the command under
:param key: six.string_types. The key to search for
:return: Returns True if the key exists, False if not and raises an
exception if an unknown error occurs. :raise: CalledProcessError if
an unknown error occurs
"""
try:
check_call(
['ceph', '--id', service,
'config-key', 'exists', str(key)])
# I can return true here regardless because Ceph returns
# ENOENT if the key wasn't found
return True
except CalledProcessError as e:
if e.returncode == errno.ENOENT:
return False
else:
log("Unknown error from ceph config-get exists: {} {}".format(
e.returncode, e.output))
raise
def get_erasure_profile(service, name):
"""
:param service: six.string_types. The Ceph user name to run the command under
:param name:
:return:
"""
try:
out = check_output(['ceph', '--id', service,
'osd', 'erasure-code-profile', 'get',
name, '--format=json'])
return json.loads(out)
except (CalledProcessError, OSError, ValueError):
return None
def pool_set(service, pool_name, key, value):
"""
Sets a value for a RADOS pool in ceph.
:param service: six.string_types. The Ceph user name to run the command under
:param pool_name: six.string_types
:param key: six.string_types
:param value:
:return: None. Can raise CalledProcessError
"""
cmd = ['ceph', '--id', service, 'osd', 'pool', 'set', pool_name, key, value]
try:
check_call(cmd)
except CalledProcessError:
raise
def snapshot_pool(service, pool_name, snapshot_name):
"""
Snapshots a RADOS pool in ceph.
:param service: six.string_types. The Ceph user name to run the command under
:param pool_name: six.string_types
:param snapshot_name: six.string_types
:return: None. Can raise CalledProcessError
"""
cmd = ['ceph', '--id', service, 'osd', 'pool', 'mksnap', pool_name, snapshot_name]
try:
check_call(cmd)
except CalledProcessError:
raise
def remove_pool_snapshot(service, pool_name, snapshot_name):
"""
Remove a snapshot from a RADOS pool in ceph.
:param service: six.string_types. The Ceph user name to run the command under
:param pool_name: six.string_types
:param snapshot_name: six.string_types
:return: None. Can raise CalledProcessError
"""
cmd = ['ceph', '--id', service, 'osd', 'pool', 'rmsnap', pool_name, snapshot_name]
try:
check_call(cmd)
except CalledProcessError:
raise
# max_bytes should be an int or long
def set_pool_quota(service, pool_name, max_bytes):
"""
:param service: six.string_types. The Ceph user name to run the command under
:param pool_name: six.string_types
:param max_bytes: int or long
:return: None. Can raise CalledProcessError
"""
# Set a byte quota on a RADOS pool in ceph.
cmd = ['ceph', '--id', service, 'osd', 'pool', 'set-quota', pool_name,
'max_bytes', str(max_bytes)]
try:
check_call(cmd)
except CalledProcessError:
raise
def remove_pool_quota(service, pool_name):
"""
Set a byte quota on a RADOS pool in ceph.
:param service: six.string_types. The Ceph user name to run the command under
:param pool_name: six.string_types
:return: None. Can raise CalledProcessError
"""
cmd = ['ceph', '--id', service, 'osd', 'pool', 'set-quota', pool_name, 'max_bytes', '0']
try:
check_call(cmd)
except CalledProcessError:
raise
def remove_erasure_profile(service, profile_name):
"""
Create a new erasure code profile if one does not already exist for it. Updates
the profile if it exists. Please see http://docs.ceph.com/docs/master/rados/operations/erasure-code-profile/
for more details
:param service: six.string_types. The Ceph user name to run the command under
:param profile_name: six.string_types
:return: None. Can raise CalledProcessError
"""
cmd = ['ceph', '--id', service, 'osd', 'erasure-code-profile', 'rm',
profile_name]
try:
check_call(cmd)
except CalledProcessError:
raise
def create_erasure_profile(service, profile_name, erasure_plugin_name='jerasure',
failure_domain='host',
data_chunks=2, coding_chunks=1,
locality=None, durability_estimator=None):
"""
Create a new erasure code profile if one does not already exist for it. Updates
the profile if it exists. Please see http://docs.ceph.com/docs/master/rados/operations/erasure-code-profile/
for more details
:param service: six.string_types. The Ceph user name to run the command under
:param profile_name: six.string_types
:param erasure_plugin_name: six.string_types
:param failure_domain: six.string_types. One of ['chassis', 'datacenter', 'host', 'osd', 'pdu', 'pod', 'rack', 'region',
'room', 'root', 'row'])
:param data_chunks: int
:param coding_chunks: int
:param locality: int
:param durability_estimator: int
:return: None. Can raise CalledProcessError
"""
# Ensure this failure_domain is allowed by Ceph
validator(failure_domain, six.string_types,
['chassis', 'datacenter', 'host', 'osd', 'pdu', 'pod', 'rack', 'region', 'room', 'root', 'row'])
cmd = ['ceph', '--id', service, 'osd', 'erasure-code-profile', 'set', profile_name,
'plugin=' + erasure_plugin_name, 'k=' + str(data_chunks), 'm=' + str(coding_chunks),
'ruleset_failure_domain=' + failure_domain]
if locality is not None and durability_estimator is not None:
raise ValueError("create_erasure_profile should be called with k, m and one of l or c but not both.")
# Add plugin specific information
if locality is not None:
# For local erasure codes
cmd.append('l=' + str(locality))
if durability_estimator is not None:
# For Shec erasure codes
cmd.append('c=' + str(durability_estimator))
if erasure_profile_exists(service, profile_name):
cmd.append('--force')
try:
check_call(cmd)
except CalledProcessError:
raise
def rename_pool(service, old_name, new_name):
"""
Rename a Ceph pool from old_name to new_name
:param service: six.string_types. The Ceph user name to run the command under
:param old_name: six.string_types
:param new_name: six.string_types
:return: None
"""
validator(value=old_name, valid_type=six.string_types)
validator(value=new_name, valid_type=six.string_types)
cmd = ['ceph', '--id', service, 'osd', 'pool', 'rename', old_name, new_name]
check_call(cmd)
def erasure_profile_exists(service, name):
"""
Check to see if an Erasure code profile already exists.
:param service: six.string_types. The Ceph user name to run the command under
:param name: six.string_types
:return: int or None
"""
validator(value=name, valid_type=six.string_types)
try:
check_call(['ceph', '--id', service,
'osd', 'erasure-code-profile', 'get',
name])
return True
except CalledProcessError:
return False
def get_cache_mode(service, pool_name):
"""
Find the current caching mode of the pool_name given.
:param service: six.string_types. The Ceph user name to run the command under
:param pool_name: six.string_types
:return: int or None
"""
validator(value=service, valid_type=six.string_types)
validator(value=pool_name, valid_type=six.string_types)
out = check_output(['ceph', '--id', service, 'osd', 'dump', '--format=json'])
try:
osd_json = json.loads(out)
for pool in osd_json['pools']:
if pool['pool_name'] == pool_name:
return pool['cache_mode']
return None
except ValueError:
raise
def pool_exists(service, name):
"""Check to see if a RADOS pool already exists."""
try:
out = check_output(['rados', '--id', service,
'lspools']).decode('UTF-8')
except CalledProcessError:
return False
return name in out.split()
def get_osds(service):
"""Return a list of all Ceph Object Storage Daemons currently in the
cluster.
"""
version = ceph_version()
if version and version >= '0.56':
return json.loads(check_output(['ceph', '--id', service,
'osd', 'ls',
'--format=json']).decode('UTF-8'))
return None
def install():
@ -96,53 +658,37 @@ def create_rbd_image(service, pool, image, sizemb):
check_call(cmd)
def pool_exists(service, name):
"""Check to see if a RADOS pool already exists."""
try:
out = check_output(['rados', '--id', service,
'lspools']).decode('UTF-8')
except CalledProcessError:
return False
def update_pool(client, pool, settings):
cmd = ['ceph', '--id', client, 'osd', 'pool', 'set', pool]
for k, v in six.iteritems(settings):
cmd.append(k)
cmd.append(v)
return name in out
check_call(cmd)
def get_osds(service):
"""Return a list of all Ceph Object Storage Daemons currently in the
cluster.
"""
version = ceph_version()
if version and version >= '0.56':
return json.loads(check_output(['ceph', '--id', service,
'osd', 'ls',
'--format=json']).decode('UTF-8'))
return None
def create_pool(service, name, replicas=3):
def create_pool(service, name, replicas=3, pg_num=None):
"""Create a new RADOS pool."""
if pool_exists(service, name):
log("Ceph pool {} already exists, skipping creation".format(name),
level=WARNING)
return
# Calculate the number of placement groups based
# on upstream recommended best practices.
osds = get_osds(service)
if osds:
pgnum = (len(osds) * 100 // replicas)
else:
# NOTE(james-page): Default to 200 for older ceph versions
# which don't support OSD query from cli
pgnum = 200
if not pg_num:
# Calculate the number of placement groups based
# on upstream recommended best practices.
osds = get_osds(service)
if osds:
pg_num = (len(osds) * 100 // replicas)
else:
# NOTE(james-page): Default to 200 for older ceph versions
# which don't support OSD query from cli
pg_num = 200
cmd = ['ceph', '--id', service, 'osd', 'pool', 'create', name, str(pgnum)]
cmd = ['ceph', '--id', service, 'osd', 'pool', 'create', name, str(pg_num)]
check_call(cmd)
cmd = ['ceph', '--id', service, 'osd', 'pool', 'set', name, 'size',
str(replicas)]
check_call(cmd)
update_pool(service, name, settings={'size': str(replicas)})
def delete_pool(service, name):
@ -197,10 +743,10 @@ def create_key_file(service, key):
log('Created new keyfile at %s.' % keyfile, level=INFO)
def get_ceph_nodes():
"""Query named relation 'ceph' to determine current nodes."""
def get_ceph_nodes(relation='ceph'):
"""Query named relation to determine current nodes."""
hosts = []
for r_id in relation_ids('ceph'):
for r_id in relation_ids(relation):
for unit in related_units(r_id):
hosts.append(relation_get('private-address', unit=unit, rid=r_id))
@ -288,17 +834,6 @@ def place_data_on_block_device(blk_device, data_src_dst):
os.chown(data_src_dst, uid, gid)
# TODO: re-use
def modprobe(module):
"""Load a kernel module and configure for auto-load on reboot."""
log('Loading kernel module', level=INFO)
cmd = ['modprobe', module]
check_call(cmd)
with open('/etc/modules', 'r+') as modules:
if module not in modules.read():
modules.write(module)
def copy_files(src, dst, symlinks=False, ignore=None):
"""Copy files from src to dst."""
for item in os.listdir(src):
@ -363,14 +898,14 @@ def ensure_ceph_storage(service, pool, rbd_img, sizemb, mount_point,
service_start(svc)
def ensure_ceph_keyring(service, user=None, group=None):
def ensure_ceph_keyring(service, user=None, group=None, relation='ceph'):
"""Ensures a ceph keyring is created for a named service and optionally
ensures user and group ownership.
Returns False if no ceph key is available in relation state.
"""
key = None
for rid in relation_ids('ceph'):
for rid in relation_ids(relation):
for unit in related_units(rid):
key = relation_get('key', rid=rid, unit=unit)
if key:
@ -411,17 +946,60 @@ class CephBrokerRq(object):
The API is versioned and defaults to version 1.
"""
def __init__(self, api_version=1):
def __init__(self, api_version=1, request_id=None):
self.api_version = api_version
if request_id:
self.request_id = request_id
else:
self.request_id = str(uuid.uuid1())
self.ops = []
def add_op_create_pool(self, name, replica_count=3):
def add_op_create_pool(self, name, replica_count=3, pg_num=None):
"""Adds an operation to create a pool.
@param pg_num setting: optional setting. If not provided, this value
will be calculated by the broker based on how many OSDs are in the
cluster at the time of creation. Note that, if provided, this value
will be capped at the current available maximum.
"""
self.ops.append({'op': 'create-pool', 'name': name,
'replicas': replica_count})
'replicas': replica_count, 'pg_num': pg_num})
def set_ops(self, ops):
"""Set request ops to provided value.
Useful for injecting ops that come from a previous request
to allow comparisons to ensure validity.
"""
self.ops = ops
@property
def request(self):
return json.dumps({'api-version': self.api_version, 'ops': self.ops})
return json.dumps({'api-version': self.api_version, 'ops': self.ops,
'request-id': self.request_id})
def _ops_equal(self, other):
if len(self.ops) == len(other.ops):
for req_no in range(0, len(self.ops)):
for key in ['replicas', 'name', 'op', 'pg_num']:
if self.ops[req_no].get(key) != other.ops[req_no].get(key):
return False
else:
return False
return True
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
if self.api_version == other.api_version and \
self._ops_equal(other):
return True
else:
return False
def __ne__(self, other):
return not self.__eq__(other)
class CephBrokerRsp(object):
@ -431,10 +1009,15 @@ class CephBrokerRsp(object):
The API is versioned and defaults to version 1.
"""
def __init__(self, encoded_rsp):
self.api_version = None
self.rsp = json.loads(encoded_rsp)
@property
def request_id(self):
return self.rsp.get('request-id')
@property
def exit_code(self):
return self.rsp.get('exit-code')
@ -442,3 +1025,182 @@ class CephBrokerRsp(object):
@property
def exit_msg(self):
return self.rsp.get('stderr')
# Ceph Broker Conversation:
# If a charm needs an action to be taken by ceph it can create a CephBrokerRq
# and send that request to ceph via the ceph relation. The CephBrokerRq has a
# unique id so that the client can identity which CephBrokerRsp is associated
# with the request. Ceph will also respond to each client unit individually
# creating a response key per client unit eg glance/0 will get a CephBrokerRsp
# via key broker-rsp-glance-0
#
# To use this the charm can just do something like:
#
# from charmhelpers.contrib.storage.linux.ceph import (
# send_request_if_needed,
# is_request_complete,
# CephBrokerRq,
# )
#
# @hooks.hook('ceph-relation-changed')
# def ceph_changed():
# rq = CephBrokerRq()
# rq.add_op_create_pool(name='poolname', replica_count=3)
#
# if is_request_complete(rq):
# <Request complete actions>
# else:
# send_request_if_needed(get_ceph_request())
#
# CephBrokerRq and CephBrokerRsp are serialized into JSON. Below is an example
# of glance having sent a request to ceph which ceph has successfully processed
# 'ceph:8': {
# 'ceph/0': {
# 'auth': 'cephx',
# 'broker-rsp-glance-0': '{"request-id": "0bc7dc54", "exit-code": 0}',
# 'broker_rsp': '{"request-id": "0da543b8", "exit-code": 0}',
# 'ceph-public-address': '10.5.44.103',
# 'key': 'AQCLDttVuHXINhAAvI144CB09dYchhHyTUY9BQ==',
# 'private-address': '10.5.44.103',
# },
# 'glance/0': {
# 'broker_req': ('{"api-version": 1, "request-id": "0bc7dc54", '
# '"ops": [{"replicas": 3, "name": "glance", '
# '"op": "create-pool"}]}'),
# 'private-address': '10.5.44.109',
# },
# }
def get_previous_request(rid):
"""Return the last ceph broker request sent on a given relation
@param rid: Relation id to query for request
"""
request = None
broker_req = relation_get(attribute='broker_req', rid=rid,
unit=local_unit())
if broker_req:
request_data = json.loads(broker_req)
request = CephBrokerRq(api_version=request_data['api-version'],
request_id=request_data['request-id'])
request.set_ops(request_data['ops'])
return request
def get_request_states(request, relation='ceph'):
"""Return a dict of requests per relation id with their corresponding
completion state.
This allows a charm, which has a request for ceph, to see whether there is
an equivalent request already being processed and if so what state that
request is in.
@param request: A CephBrokerRq object
"""
complete = []
requests = {}
for rid in relation_ids(relation):
complete = False
previous_request = get_previous_request(rid)
if request == previous_request:
sent = True
complete = is_request_complete_for_rid(previous_request, rid)
else:
sent = False
complete = False
requests[rid] = {
'sent': sent,
'complete': complete,
}
return requests
def is_request_sent(request, relation='ceph'):
"""Check to see if a functionally equivalent request has already been sent
Returns True if a similair request has been sent
@param request: A CephBrokerRq object
"""
states = get_request_states(request, relation=relation)
for rid in states.keys():
if not states[rid]['sent']:
return False
return True
def is_request_complete(request, relation='ceph'):
"""Check to see if a functionally equivalent request has already been
completed
Returns True if a similair request has been completed
@param request: A CephBrokerRq object
"""
states = get_request_states(request, relation=relation)
for rid in states.keys():
if not states[rid]['complete']:
return False
return True
def is_request_complete_for_rid(request, rid):
"""Check if a given request has been completed on the given relation
@param request: A CephBrokerRq object
@param rid: Relation ID
"""
broker_key = get_broker_rsp_key()
for unit in related_units(rid):
rdata = relation_get(rid=rid, unit=unit)
if rdata.get(broker_key):
rsp = CephBrokerRsp(rdata.get(broker_key))
if rsp.request_id == request.request_id:
if not rsp.exit_code:
return True
else:
# The remote unit sent no reply targeted at this unit so either the
# remote ceph cluster does not support unit targeted replies or it
# has not processed our request yet.
if rdata.get('broker_rsp'):
request_data = json.loads(rdata['broker_rsp'])
if request_data.get('request-id'):
log('Ignoring legacy broker_rsp without unit key as remote '
'service supports unit specific replies', level=DEBUG)
else:
log('Using legacy broker_rsp as remote service does not '
'supports unit specific replies', level=DEBUG)
rsp = CephBrokerRsp(rdata['broker_rsp'])
if not rsp.exit_code:
return True
return False
def get_broker_rsp_key():
"""Return broker response key for this unit
This is the key that ceph is going to use to pass request status
information back to this unit
"""
return 'broker-rsp-' + local_unit().replace('/', '-')
def send_request_if_needed(request, relation='ceph'):
"""Send broker request if an equivalent request has not already been sent
@param request: A CephBrokerRq object
"""
if is_request_sent(request, relation=relation):
log('Request already sent but not complete, not sending new request',
level=DEBUG)
else:
for rid in relation_ids(relation):
log('Sending request {}'.format(request.request_id), level=DEBUG)
relation_set(relation_id=rid, broker_req=request.request)

View File

@ -76,3 +76,13 @@ def ensure_loopback_device(path, size):
check_call(cmd)
return create_loopback(path)
def is_mapped_loopback_device(device):
"""
Checks if a given device name is an existing/mapped loopback device.
:param device: str: Full path to the device (eg, /dev/loop1).
:returns: str: Path to the backing file if is a loopback device
empty string otherwise
"""
return loopback_devices().get(device, "")

View File

@ -43,9 +43,10 @@ def zap_disk(block_device):
:param block_device: str: Full path of block device to clean.
'''
# https://github.com/ceph/ceph/commit/fdd7f8d83afa25c4e09aaedd90ab93f3b64a677b
# sometimes sgdisk exits non-zero; this is OK, dd will clean up
call(['sgdisk', '--zap-all', '--mbrtogpt',
'--clear', block_device])
call(['sgdisk', '--zap-all', '--', block_device])
call(['sgdisk', '--clear', '--mbrtogpt', '--', block_device])
dev_end = check_output(['blockdev', '--getsz',
block_device]).decode('UTF-8')
gpt_end = int(dev_end.split()[0]) - 100
@ -63,8 +64,8 @@ def is_device_mounted(device):
:returns: boolean: True if the path represents a mounted device, False if
it doesn't.
'''
is_partition = bool(re.search(r".*[0-9]+\b", device))
out = check_output(['mount']).decode('UTF-8')
if is_partition:
return bool(re.search(device + r"\b", out))
return bool(re.search(device + r"[0-9]*\b", out))
try:
out = check_output(['lsblk', '-P', device]).decode('UTF-8')
except:
return False
return bool(re.search(r'MOUNTPOINT=".+"', out))

View File

@ -18,14 +18,15 @@
Templating using the python-jinja2 package.
"""
import six
from charmhelpers.fetch import apt_install
from charmhelpers.fetch import apt_install, apt_update
try:
import jinja2
except ImportError:
apt_update(fatal=True)
if six.PY3:
apt_install(["python3-jinja2"])
apt_install(["python3-jinja2"], fatal=True)
else:
apt_install(["python-jinja2"])
apt_install(["python-jinja2"], fatal=True)
import jinja2

View File

@ -74,6 +74,7 @@ def cached(func):
res = func(*args, **kwargs)
cache[key] = res
return res
wrapper._wrapped = func
return wrapper
@ -173,9 +174,19 @@ def relation_type():
return os.environ.get('JUJU_RELATION', None)
def relation_id():
"""The relation ID for the current relation hook"""
return os.environ.get('JUJU_RELATION_ID', None)
@cached
def relation_id(relation_name=None, service_or_unit=None):
"""The relation ID for the current or a specified relation"""
if not relation_name and not service_or_unit:
return os.environ.get('JUJU_RELATION_ID', None)
elif relation_name and service_or_unit:
service_name = service_or_unit.split('/')[0]
for relid in relation_ids(relation_name):
remote_service = remote_service_name(relid)
if remote_service == service_name:
return relid
else:
raise ValueError('Must specify neither or both of relation_name and service_or_unit')
def local_unit():
@ -193,9 +204,20 @@ def service_name():
return local_unit().split('/')[0]
@cached
def remote_service_name(relid=None):
"""The remote service name for a given relation-id (or the current relation)"""
if relid is None:
unit = remote_unit()
else:
units = related_units(relid)
unit = units[0] if units else None
return unit.split('/')[0] if unit else None
def hook_name():
"""The name of the currently executing hook"""
return os.path.basename(sys.argv[0])
return os.environ.get('JUJU_HOOK_NAME', os.path.basename(sys.argv[0]))
class Config(dict):
@ -468,6 +490,76 @@ def relation_types():
return rel_types
@cached
def peer_relation_id():
'''Get the peers relation id if a peers relation has been joined, else None.'''
md = metadata()
section = md.get('peers')
if section:
for key in section:
relids = relation_ids(key)
if relids:
return relids[0]
return None
@cached
def relation_to_interface(relation_name):
"""
Given the name of a relation, return the interface that relation uses.
:returns: The interface name, or ``None``.
"""
return relation_to_role_and_interface(relation_name)[1]
@cached
def relation_to_role_and_interface(relation_name):
"""
Given the name of a relation, return the role and the name of the interface
that relation uses (where role is one of ``provides``, ``requires``, or ``peers``).
:returns: A tuple containing ``(role, interface)``, or ``(None, None)``.
"""
_metadata = metadata()
for role in ('provides', 'requires', 'peers'):
interface = _metadata.get(role, {}).get(relation_name, {}).get('interface')
if interface:
return role, interface
return None, None
@cached
def role_and_interface_to_relations(role, interface_name):
"""
Given a role and interface name, return a list of relation names for the
current charm that use that interface under that role (where role is one
of ``provides``, ``requires``, or ``peers``).
:returns: A list of relation names.
"""
_metadata = metadata()
results = []
for relation_name, relation in _metadata.get(role, {}).items():
if relation['interface'] == interface_name:
results.append(relation_name)
return results
@cached
def interface_to_relations(interface_name):
"""
Given an interface, return a list of relation names for the current
charm that use that interface.
:returns: A list of relation names.
"""
results = []
for role in ('provides', 'requires', 'peers'):
results.extend(role_and_interface_to_relations(role, interface_name))
return results
@cached
def charm_name():
"""Get the name of the current charm as is specified on metadata.yaml"""
@ -544,6 +636,38 @@ def unit_private_ip():
return unit_get('private-address')
@cached
def storage_get(attribute=None, storage_id=None):
"""Get storage attributes"""
_args = ['storage-get', '--format=json']
if storage_id:
_args.extend(('-s', storage_id))
if attribute:
_args.append(attribute)
try:
return json.loads(subprocess.check_output(_args).decode('UTF-8'))
except ValueError:
return None
@cached
def storage_list(storage_name=None):
"""List the storage IDs for the unit"""
_args = ['storage-list', '--format=json']
if storage_name:
_args.append(storage_name)
try:
return json.loads(subprocess.check_output(_args).decode('UTF-8'))
except ValueError:
return None
except OSError as e:
import errno
if e.errno == errno.ENOENT:
# storage-list does not exist
return []
raise
class UnregisteredHookError(Exception):
"""Raised when an undefined hook is called"""
pass
@ -644,6 +768,21 @@ def action_fail(message):
subprocess.check_call(['action-fail', message])
def action_name():
"""Get the name of the currently executing action."""
return os.environ.get('JUJU_ACTION_NAME')
def action_uuid():
"""Get the UUID of the currently executing action."""
return os.environ.get('JUJU_ACTION_UUID')
def action_tag():
"""Get the tag for the currently executing action."""
return os.environ.get('JUJU_ACTION_TAG')
def status_set(workload_state, message):
"""Set the workload state with a message
@ -673,25 +812,28 @@ def status_set(workload_state, message):
def status_get():
"""Retrieve the previously set juju workload state
"""Retrieve the previously set juju workload state and message
If the status-get command is not found then assume this is juju < 1.23 and
return 'unknown', ""
If the status-set command is not found then assume this is juju < 1.23 and
return 'unknown'
"""
cmd = ['status-get']
cmd = ['status-get', "--format=json", "--include-data"]
try:
raw_status = subprocess.check_output(cmd, universal_newlines=True)
status = raw_status.rstrip()
return status
raw_status = subprocess.check_output(cmd)
except OSError as e:
if e.errno == errno.ENOENT:
return 'unknown'
return ('unknown', "")
else:
raise
else:
status = json.loads(raw_status.decode("UTF-8"))
return (status["status"], status["message"])
def translate_exc(from_exc, to_exc):
def inner_translate_exc1(f):
@wraps(f)
def inner_translate_exc2(*args, **kwargs):
try:
return f(*args, **kwargs)
@ -736,6 +878,58 @@ def leader_set(settings=None, **kwargs):
subprocess.check_call(cmd)
@translate_exc(from_exc=OSError, to_exc=NotImplementedError)
def payload_register(ptype, klass, pid):
""" is used while a hook is running to let Juju know that a
payload has been started."""
cmd = ['payload-register']
for x in [ptype, klass, pid]:
cmd.append(x)
subprocess.check_call(cmd)
@translate_exc(from_exc=OSError, to_exc=NotImplementedError)
def payload_unregister(klass, pid):
""" is used while a hook is running to let Juju know
that a payload has been manually stopped. The <class> and <id> provided
must match a payload that has been previously registered with juju using
payload-register."""
cmd = ['payload-unregister']
for x in [klass, pid]:
cmd.append(x)
subprocess.check_call(cmd)
@translate_exc(from_exc=OSError, to_exc=NotImplementedError)
def payload_status_set(klass, pid, status):
"""is used to update the current status of a registered payload.
The <class> and <id> provided must match a payload that has been previously
registered with juju using payload-register. The <status> must be one of the
follow: starting, started, stopping, stopped"""
cmd = ['payload-status-set']
for x in [klass, pid, status]:
cmd.append(x)
subprocess.check_call(cmd)
@translate_exc(from_exc=OSError, to_exc=NotImplementedError)
def resource_get(name):
"""used to fetch the resource path of the given name.
<name> must match a name of defined resource in metadata.yaml
returns either a path or False if resource not available
"""
if not name:
return False
cmd = ['resource-get', name]
try:
return subprocess.check_output(cmd).decode('UTF-8')
except subprocess.CalledProcessError:
return False
@cached
def juju_version():
"""Full version string (eg. '1.23.3.1-trusty-amd64')"""
@ -800,3 +994,16 @@ def _run_atexit():
for callback, args, kwargs in reversed(_atexit):
callback(*args, **kwargs)
del _atexit[:]
@translate_exc(from_exc=OSError, to_exc=NotImplementedError)
def network_get_primary_address(binding):
'''
Retrieve the primary network address for a named binding
:param binding: string. The name of a relation of extra-binding
:return: string. The primary IP address for the named binding
:raise: NotImplementedError if run on Juju < 2.0
'''
cmd = ['network-get', '--primary-address', binding]
return subprocess.check_output(cmd).strip()

View File

@ -30,6 +30,8 @@ import random
import string
import subprocess
import hashlib
import functools
import itertools
from contextlib import contextmanager
from collections import OrderedDict
@ -63,54 +65,96 @@ def service_reload(service_name, restart_on_failure=False):
return service_result
def service_pause(service_name, init_dir=None):
def service_pause(service_name, init_dir="/etc/init", initd_dir="/etc/init.d"):
"""Pause a system service.
Stop it, and prevent it from starting again at boot."""
if init_dir is None:
init_dir = "/etc/init"
stopped = service_stop(service_name)
# XXX: Support systemd too
override_path = os.path.join(
init_dir, '{}.conf.override'.format(service_name))
with open(override_path, 'w') as fh:
fh.write("manual\n")
stopped = True
if service_running(service_name):
stopped = service_stop(service_name)
upstart_file = os.path.join(init_dir, "{}.conf".format(service_name))
sysv_file = os.path.join(initd_dir, service_name)
if init_is_systemd():
service('disable', service_name)
elif os.path.exists(upstart_file):
override_path = os.path.join(
init_dir, '{}.override'.format(service_name))
with open(override_path, 'w') as fh:
fh.write("manual\n")
elif os.path.exists(sysv_file):
subprocess.check_call(["update-rc.d", service_name, "disable"])
else:
raise ValueError(
"Unable to detect {0} as SystemD, Upstart {1} or"
" SysV {2}".format(
service_name, upstart_file, sysv_file))
return stopped
def service_resume(service_name, init_dir=None):
def service_resume(service_name, init_dir="/etc/init",
initd_dir="/etc/init.d"):
"""Resume a system service.
Reenable starting again at boot. Start the service"""
# XXX: Support systemd too
if init_dir is None:
init_dir = "/etc/init"
override_path = os.path.join(
init_dir, '{}.conf.override'.format(service_name))
if os.path.exists(override_path):
os.unlink(override_path)
started = service_start(service_name)
upstart_file = os.path.join(init_dir, "{}.conf".format(service_name))
sysv_file = os.path.join(initd_dir, service_name)
if init_is_systemd():
service('enable', service_name)
elif os.path.exists(upstart_file):
override_path = os.path.join(
init_dir, '{}.override'.format(service_name))
if os.path.exists(override_path):
os.unlink(override_path)
elif os.path.exists(sysv_file):
subprocess.check_call(["update-rc.d", service_name, "enable"])
else:
raise ValueError(
"Unable to detect {0} as SystemD, Upstart {1} or"
" SysV {2}".format(
service_name, upstart_file, sysv_file))
started = service_running(service_name)
if not started:
started = service_start(service_name)
return started
def service(action, service_name):
"""Control a system service"""
cmd = ['service', service_name, action]
if init_is_systemd():
cmd = ['systemctl', action, service_name]
else:
cmd = ['service', service_name, action]
return subprocess.call(cmd) == 0
def service_running(service):
def systemv_services_running():
output = subprocess.check_output(
['service', '--status-all'],
stderr=subprocess.STDOUT).decode('UTF-8')
return [row.split()[-1] for row in output.split('\n') if '[ + ]' in row]
def service_running(service_name):
"""Determine whether a system service is running"""
try:
output = subprocess.check_output(
['service', service, 'status'],
stderr=subprocess.STDOUT).decode('UTF-8')
except subprocess.CalledProcessError:
return False
if init_is_systemd():
return service('is-active', service_name)
else:
if ("start/running" in output or "is running" in output):
return True
try:
output = subprocess.check_output(
['service', service_name, 'status'],
stderr=subprocess.STDOUT).decode('UTF-8')
except subprocess.CalledProcessError:
return False
else:
# This works for upstart scripts where the 'service' command
# returns a consistent string to represent running 'start/running'
if ("start/running" in output or "is running" in output or
"up and running" in output):
return True
# Check System V scripts init script return codes
if service_name in systemv_services_running():
return True
return False
@ -126,8 +170,29 @@ def service_available(service_name):
return True
def adduser(username, password=None, shell='/bin/bash', system_user=False):
"""Add a user to the system"""
SYSTEMD_SYSTEM = '/run/systemd/system'
def init_is_systemd():
"""Return True if the host system uses systemd, False otherwise."""
return os.path.isdir(SYSTEMD_SYSTEM)
def adduser(username, password=None, shell='/bin/bash', system_user=False,
primary_group=None, secondary_groups=None):
"""Add a user to the system.
Will log but otherwise succeed if the user already exists.
:param str username: Username to create
:param str password: Password for user; if ``None``, create a system user
:param str shell: The default shell for the user
:param bool system_user: Whether to create a login or system user
:param str primary_group: Primary group for user; defaults to username
:param list secondary_groups: Optional list of additional groups
:returns: The password database entry struct, as returned by `pwd.getpwnam`
"""
try:
user_info = pwd.getpwnam(username)
log('user {0} already exists!'.format(username))
@ -142,12 +207,32 @@ def adduser(username, password=None, shell='/bin/bash', system_user=False):
'--shell', shell,
'--password', password,
])
if not primary_group:
try:
grp.getgrnam(username)
primary_group = username # avoid "group exists" error
except KeyError:
pass
if primary_group:
cmd.extend(['-g', primary_group])
if secondary_groups:
cmd.extend(['-G', ','.join(secondary_groups)])
cmd.append(username)
subprocess.check_call(cmd)
user_info = pwd.getpwnam(username)
return user_info
def user_exists(username):
"""Check if a user exists"""
try:
pwd.getpwnam(username)
user_exists = True
except KeyError:
user_exists = False
return user_exists
def add_group(group_name, system_group=False):
"""Add a group to the system"""
try:
@ -229,14 +314,12 @@ def write_file(path, content, owner='root', group='root', perms=0o444):
def fstab_remove(mp):
"""Remove the given mountpoint entry from /etc/fstab
"""
"""Remove the given mountpoint entry from /etc/fstab"""
return Fstab.remove_by_mountpoint(mp)
def fstab_add(dev, mp, fs, options=None):
"""Adds the given device entry to the /etc/fstab file
"""
"""Adds the given device entry to the /etc/fstab file"""
return Fstab.add(dev, mp, fs, options=options)
@ -280,9 +363,19 @@ def mounts():
return system_mounts
def fstab_mount(mountpoint):
"""Mount filesystem using fstab"""
cmd_args = ['mount', mountpoint]
try:
subprocess.check_output(cmd_args)
except subprocess.CalledProcessError as e:
log('Error unmounting {}\n{}'.format(mountpoint, e.output))
return False
return True
def file_hash(path, hash_type='md5'):
"""
Generate a hash checksum of the contents of 'path' or None if not found.
"""Generate a hash checksum of the contents of 'path' or None if not found.
:param str hash_type: Any hash alrgorithm supported by :mod:`hashlib`,
such as md5, sha1, sha256, sha512, etc.
@ -297,10 +390,9 @@ def file_hash(path, hash_type='md5'):
def path_hash(path):
"""
Generate a hash checksum of all files matching 'path'. Standard wildcards
like '*' and '?' are supported, see documentation for the 'glob' module for
more information.
"""Generate a hash checksum of all files matching 'path'. Standard
wildcards like '*' and '?' are supported, see documentation for the 'glob'
module for more information.
:return: dict: A { filename: hash } dictionary for all matched files.
Empty if none found.
@ -312,8 +404,7 @@ def path_hash(path):
def check_hash(path, checksum, hash_type='md5'):
"""
Validate a file using a cryptographic checksum.
"""Validate a file using a cryptographic checksum.
:param str checksum: Value of the checksum used to validate the file.
:param str hash_type: Hash algorithm used to generate `checksum`.
@ -328,10 +419,11 @@ def check_hash(path, checksum, hash_type='md5'):
class ChecksumError(ValueError):
"""A class derived from Value error to indicate the checksum failed."""
pass
def restart_on_change(restart_map, stopstart=False):
def restart_on_change(restart_map, stopstart=False, restart_functions=None):
"""Restart services based on configuration files changing
This function is used a decorator, for example::
@ -349,27 +441,58 @@ def restart_on_change(restart_map, stopstart=False):
restarted if any file matching the pattern got changed, created
or removed. Standard wildcards are supported, see documentation
for the 'glob' module for more information.
@param restart_map: {path_file_name: [service_name, ...]
@param stopstart: DEFAULT false; whether to stop, start OR restart
@param restart_functions: nonstandard functions to use to restart services
{svc: func, ...}
@returns result from decorated function
"""
def wrap(f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
checksums = {path: path_hash(path) for path in restart_map}
f(*args, **kwargs)
restarts = []
for path in restart_map:
if path_hash(path) != checksums[path]:
restarts += restart_map[path]
services_list = list(OrderedDict.fromkeys(restarts))
if not stopstart:
for service_name in services_list:
service('restart', service_name)
else:
for action in ['stop', 'start']:
for service_name in services_list:
service(action, service_name)
return restart_on_change_helper(
(lambda: f(*args, **kwargs)), restart_map, stopstart,
restart_functions)
return wrapped_f
return wrap
def restart_on_change_helper(lambda_f, restart_map, stopstart=False,
restart_functions=None):
"""Helper function to perform the restart_on_change function.
This is provided for decorators to restart services if files described
in the restart_map have changed after an invocation of lambda_f().
@param lambda_f: function to call.
@param restart_map: {file: [service, ...]}
@param stopstart: whether to stop, start or restart a service
@param restart_functions: nonstandard functions to use to restart services
{svc: func, ...}
@returns result of lambda_f()
"""
if restart_functions is None:
restart_functions = {}
checksums = {path: path_hash(path) for path in restart_map}
r = lambda_f()
# create a list of lists of the services to restart
restarts = [restart_map[path]
for path in restart_map
if path_hash(path) != checksums[path]]
# create a flat list of ordered services without duplicates from lists
services_list = list(OrderedDict.fromkeys(itertools.chain(*restarts)))
if services_list:
actions = ('stop', 'start') if stopstart else ('restart',)
for service_name in services_list:
if service_name in restart_functions:
restart_functions[service_name](service_name)
else:
for action in actions:
service(action, service_name)
return r
def lsb_release():
"""Return /etc/lsb-release in a dict"""
d = {}
@ -396,36 +519,92 @@ def pwgen(length=None):
return(''.join(random_chars))
def list_nics(nic_type):
'''Return a list of nics of given type(s)'''
def is_phy_iface(interface):
"""Returns True if interface is not virtual, otherwise False."""
if interface:
sys_net = '/sys/class/net'
if os.path.isdir(sys_net):
for iface in glob.glob(os.path.join(sys_net, '*')):
if '/virtual/' in os.path.realpath(iface):
continue
if interface == os.path.basename(iface):
return True
return False
def get_bond_master(interface):
"""Returns bond master if interface is bond slave otherwise None.
NOTE: the provided interface is expected to be physical
"""
if interface:
iface_path = '/sys/class/net/%s' % (interface)
if os.path.exists(iface_path):
if '/virtual/' in os.path.realpath(iface_path):
return None
master = os.path.join(iface_path, 'master')
if os.path.exists(master):
master = os.path.realpath(master)
# make sure it is a bond master
if os.path.exists(os.path.join(master, 'bonding')):
return os.path.basename(master)
return None
def list_nics(nic_type=None):
"""Return a list of nics of given type(s)"""
if isinstance(nic_type, six.string_types):
int_types = [nic_type]
else:
int_types = nic_type
interfaces = []
for int_type in int_types:
cmd = ['ip', 'addr', 'show', 'label', int_type + '*']
if nic_type:
for int_type in int_types:
cmd = ['ip', 'addr', 'show', 'label', int_type + '*']
ip_output = subprocess.check_output(cmd).decode('UTF-8')
ip_output = ip_output.split('\n')
ip_output = (line for line in ip_output if line)
for line in ip_output:
if line.split()[1].startswith(int_type):
matched = re.search('.*: (' + int_type +
r'[0-9]+\.[0-9]+)@.*', line)
if matched:
iface = matched.groups()[0]
else:
iface = line.split()[1].replace(":", "")
if iface not in interfaces:
interfaces.append(iface)
else:
cmd = ['ip', 'a']
ip_output = subprocess.check_output(cmd).decode('UTF-8').split('\n')
ip_output = (line for line in ip_output if line)
ip_output = (line.strip() for line in ip_output if line)
key = re.compile('^[0-9]+:\s+(.+):')
for line in ip_output:
if line.split()[1].startswith(int_type):
matched = re.search('.*: (' + int_type + r'[0-9]+\.[0-9]+)@.*', line)
if matched:
interface = matched.groups()[0]
else:
interface = line.split()[1].replace(":", "")
interfaces.append(interface)
matched = re.search(key, line)
if matched:
iface = matched.group(1)
iface = iface.partition("@")[0]
if iface not in interfaces:
interfaces.append(iface)
return interfaces
def set_nic_mtu(nic, mtu):
'''Set MTU on a network interface'''
"""Set the Maximum Transmission Unit (MTU) on a network interface."""
cmd = ['ip', 'link', 'set', nic, 'mtu', mtu]
subprocess.check_call(cmd)
def get_nic_mtu(nic):
"""Return the Maximum Transmission Unit (MTU) for a network interface."""
cmd = ['ip', 'addr', 'show', nic]
ip_output = subprocess.check_output(cmd).decode('UTF-8').split('\n')
mtu = ""
@ -437,6 +616,7 @@ def get_nic_mtu(nic):
def get_nic_hwaddr(nic):
"""Return the Media Access Control (MAC) for a network interface."""
cmd = ['ip', '-o', '-0', 'addr', 'show', nic]
ip_output = subprocess.check_output(cmd).decode('UTF-8')
hwaddr = ""
@ -447,7 +627,7 @@ def get_nic_hwaddr(nic):
def cmp_pkgrevno(package, revno, pkgcache=None):
'''Compare supplied revno with the revno of the installed package
"""Compare supplied revno with the revno of the installed package
* 1 => Installed revno is greater than supplied arg
* 0 => Installed revno is the same as supplied arg
@ -456,7 +636,7 @@ def cmp_pkgrevno(package, revno, pkgcache=None):
This function imports apt_cache function from charmhelpers.fetch if
the pkgcache argument is None. Be sure to add charmhelpers.fetch if
you call this function, or pass an apt_pkg.Cache() instance.
'''
"""
import apt_pkg
if not pkgcache:
from charmhelpers.fetch import apt_cache
@ -466,15 +646,30 @@ def cmp_pkgrevno(package, revno, pkgcache=None):
@contextmanager
def chdir(d):
def chdir(directory):
"""Change the current working directory to a different directory for a code
block and return the previous directory after the block exits. Useful to
run commands from a specificed directory.
:param str directory: The directory path to change to for this context.
"""
cur = os.getcwd()
try:
yield os.chdir(d)
yield os.chdir(directory)
finally:
os.chdir(cur)
def chownr(path, owner, group, follow_links=True):
def chownr(path, owner, group, follow_links=True, chowntopdir=False):
"""Recursively change user and group ownership of files and directories
in given path. Doesn't chown path itself by default, only its children.
:param str path: The string path to start changing ownership.
:param str owner: The owner string to use when looking up the uid.
:param str group: The group string to use when looking up the gid.
:param bool follow_links: Also Chown links if True
:param bool chowntopdir: Also chown path itself if True
"""
uid = pwd.getpwnam(owner).pw_uid
gid = grp.getgrnam(group).gr_gid
if follow_links:
@ -482,6 +677,10 @@ def chownr(path, owner, group, follow_links=True):
else:
chown = os.lchown
if chowntopdir:
broken_symlink = os.path.lexists(path) and not os.path.exists(path)
if not broken_symlink:
chown(path, uid, gid)
for root, dirs, files in os.walk(path):
for name in dirs + files:
full = os.path.join(root, name)
@ -491,4 +690,28 @@ def chownr(path, owner, group, follow_links=True):
def lchownr(path, owner, group):
"""Recursively change user and group ownership of files and directories
in a given path, not following symbolic links. See the documentation for
'os.lchown' for more information.
:param str path: The string path to start changing ownership.
:param str owner: The owner string to use when looking up the uid.
:param str group: The group string to use when looking up the gid.
"""
chownr(path, owner, group, follow_links=False)
def get_total_ram():
"""The total amount of system RAM in bytes.
This is what is reported by the OS, and may be overcommitted when
there are multiple containers hosted on the same machine.
"""
with open('/proc/meminfo', 'r') as f:
for line in f.readlines():
if line:
key, value, unit = line.split()
if key == 'MemTotal:':
assert unit == 'kB', 'Unknown unit'
return int(value) * 1024 # Classic, not KiB.
raise NotImplementedError()

View File

@ -0,0 +1,71 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2015 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import yaml
from charmhelpers.core import fstab
from charmhelpers.core import sysctl
from charmhelpers.core.host import (
add_group,
add_user_to_group,
fstab_mount,
mkdir,
)
from charmhelpers.core.strutils import bytes_from_string
from subprocess import check_output
def hugepage_support(user, group='hugetlb', nr_hugepages=256,
max_map_count=65536, mnt_point='/run/hugepages/kvm',
pagesize='2MB', mount=True, set_shmmax=False):
"""Enable hugepages on system.
Args:
user (str) -- Username to allow access to hugepages to
group (str) -- Group name to own hugepages
nr_hugepages (int) -- Number of pages to reserve
max_map_count (int) -- Number of Virtual Memory Areas a process can own
mnt_point (str) -- Directory to mount hugepages on
pagesize (str) -- Size of hugepages
mount (bool) -- Whether to Mount hugepages
"""
group_info = add_group(group)
gid = group_info.gr_gid
add_user_to_group(user, group)
if max_map_count < 2 * nr_hugepages:
max_map_count = 2 * nr_hugepages
sysctl_settings = {
'vm.nr_hugepages': nr_hugepages,
'vm.max_map_count': max_map_count,
'vm.hugetlb_shm_group': gid,
}
if set_shmmax:
shmmax_current = int(check_output(['sysctl', '-n', 'kernel.shmmax']))
shmmax_minsize = bytes_from_string(pagesize) * nr_hugepages
if shmmax_minsize > shmmax_current:
sysctl_settings['kernel.shmmax'] = shmmax_minsize
sysctl.create(yaml.dump(sysctl_settings), '/etc/sysctl.d/10-hugepage.conf')
mkdir(mnt_point, owner='root', group='root', perms=0o755, force=False)
lfstab = fstab.Fstab()
fstab_entry = lfstab.get_entry_by_attr('mountpoint', mnt_point)
if fstab_entry:
lfstab.remove_entry(fstab_entry)
entry = lfstab.Entry('nodev', mnt_point, 'hugetlbfs',
'mode=1770,gid={},pagesize={}'.format(gid, pagesize), 0, 0)
lfstab.add_entry(entry)
if mount:
fstab_mount(mnt_point)

View File

@ -0,0 +1,68 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2014-2015 Canonical Limited.
#
# This file is part of charm-helpers.
#
# charm-helpers is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 as
# published by the Free Software Foundation.
#
# charm-helpers is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
__author__ = "Jorge Niedbalski <jorge.niedbalski@canonical.com>"
from charmhelpers.core.hookenv import (
log,
INFO
)
from subprocess import check_call, check_output
import re
def modprobe(module, persist=True):
"""Load a kernel module and configure for auto-load on reboot."""
cmd = ['modprobe', module]
log('Loading kernel module %s' % module, level=INFO)
check_call(cmd)
if persist:
with open('/etc/modules', 'r+') as modules:
if module not in modules.read():
modules.write(module)
def rmmod(module, force=False):
"""Remove a module from the linux kernel"""
cmd = ['rmmod']
if force:
cmd.append('-f')
cmd.append(module)
log('Removing kernel module %s' % module, level=INFO)
return check_call(cmd)
def lsmod():
"""Shows what kernel modules are currently loaded"""
return check_output(['lsmod'],
universal_newlines=True)
def is_module_loaded(module):
"""Checks if a kernel module is already loaded"""
matches = re.findall('^%s[ ]+' % module, lsmod(), re.M)
return len(matches) > 0
def update_initramfs(version='all'):
"""Updates an initramfs image"""
return check_call(["update-initramfs", "-k", version, "-u"])

View File

@ -16,7 +16,9 @@
import os
import yaml
from charmhelpers.core import hookenv
from charmhelpers.core import host
from charmhelpers.core import templating
from charmhelpers.core.services.base import ManagerCallback
@ -240,27 +242,50 @@ class TemplateCallback(ManagerCallback):
:param str source: The template source file, relative to
`$CHARM_DIR/templates`
:param str target: The target to write the rendered template to
:param str target: The target to write the rendered template to (or None)
:param str owner: The owner of the rendered file
:param str group: The group of the rendered file
:param int perms: The permissions of the rendered file
:param partial on_change_action: functools partial to be executed when
rendered file changes
:param jinja2 loader template_loader: A jinja2 template loader
:return str: The rendered template
"""
def __init__(self, source, target,
owner='root', group='root', perms=0o444):
owner='root', group='root', perms=0o444,
on_change_action=None, template_loader=None):
self.source = source
self.target = target
self.owner = owner
self.group = group
self.perms = perms
self.on_change_action = on_change_action
self.template_loader = template_loader
def __call__(self, manager, service_name, event_name):
pre_checksum = ''
if self.on_change_action and os.path.isfile(self.target):
pre_checksum = host.file_hash(self.target)
service = manager.get_service(service_name)
context = {}
context = {'ctx': {}}
for ctx in service.get('required_data', []):
context.update(ctx)
templating.render(self.source, self.target, context,
self.owner, self.group, self.perms)
context['ctx'].update(ctx)
result = templating.render(self.source, self.target, context,
self.owner, self.group, self.perms,
template_loader=self.template_loader)
if self.on_change_action:
if pre_checksum == host.file_hash(self.target):
hookenv.log(
'No change detected: {}'.format(self.target),
hookenv.DEBUG)
else:
self.on_change_action()
return result
# Convenience aliases for templates

View File

@ -18,6 +18,7 @@
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import six
import re
def bool_from_string(value):
@ -40,3 +41,32 @@ def bool_from_string(value):
msg = "Unable to interpret string value '%s' as boolean" % (value)
raise ValueError(msg)
def bytes_from_string(value):
"""Interpret human readable string value as bytes.
Returns int
"""
BYTE_POWER = {
'K': 1,
'KB': 1,
'M': 2,
'MB': 2,
'G': 3,
'GB': 3,
'T': 4,
'TB': 4,
'P': 5,
'PB': 5,
}
if isinstance(value, six.string_types):
value = six.text_type(value)
else:
msg = "Unable to interpret non-string value '%s' as boolean" % (value)
raise ValueError(msg)
matches = re.match("([0-9]+)([a-zA-Z]+)", value)
if not matches:
msg = "Unable to interpret string value '%s' as bytes" % (value)
raise ValueError(msg)
return int(matches.group(1)) * (1024 ** BYTE_POWER[matches.group(2)])

View File

@ -21,13 +21,14 @@ from charmhelpers.core import hookenv
def render(source, target, context, owner='root', group='root',
perms=0o444, templates_dir=None, encoding='UTF-8'):
perms=0o444, templates_dir=None, encoding='UTF-8', template_loader=None):
"""
Render a template.
The `source` path, if not absolute, is relative to the `templates_dir`.
The `target` path should be absolute.
The `target` path should be absolute. It can also be `None`, in which
case no file will be written.
The context should be a dict containing the values to be replaced in the
template.
@ -36,6 +37,9 @@ def render(source, target, context, owner='root', group='root',
If omitted, `templates_dir` defaults to the `templates` folder in the charm.
The rendered template will be written to the file as well as being returned
as a string.
Note: Using this requires python-jinja2; if it is not installed, calling
this will attempt to use charmhelpers.fetch.apt_install to install it.
"""
@ -52,17 +56,26 @@ def render(source, target, context, owner='root', group='root',
apt_install('python-jinja2', fatal=True)
from jinja2 import FileSystemLoader, Environment, exceptions
if templates_dir is None:
templates_dir = os.path.join(hookenv.charm_dir(), 'templates')
loader = Environment(loader=FileSystemLoader(templates_dir))
if template_loader:
template_env = Environment(loader=template_loader)
else:
if templates_dir is None:
templates_dir = os.path.join(hookenv.charm_dir(), 'templates')
template_env = Environment(loader=FileSystemLoader(templates_dir))
try:
source = source
template = loader.get_template(source)
template = template_env.get_template(source)
except exceptions.TemplateNotFound as e:
hookenv.log('Could not load template %s from %s.' %
(source, templates_dir),
level=hookenv.ERROR)
raise e
content = template.render(context)
host.mkdir(os.path.dirname(target), owner, group, perms=0o755)
host.write_file(target, content.encode(encoding), owner, group, perms)
if target is not None:
target_dir = os.path.dirname(target)
if not os.path.exists(target_dir):
# This is a terrible default directory permission, as the file
# or its siblings will often contain secrets.
host.mkdir(os.path.dirname(target), owner, group, perms=0o755)
host.write_file(target, content.encode(encoding), owner, group, perms)
return content

View File

@ -152,6 +152,7 @@ associated to the hookname.
import collections
import contextlib
import datetime
import itertools
import json
import os
import pprint
@ -164,8 +165,7 @@ __author__ = 'Kapil Thangavelu <kapil.foss@gmail.com>'
class Storage(object):
"""Simple key value database for local unit state within charms.
Modifications are automatically committed at hook exit. That's
currently regardless of exit code.
Modifications are not persisted unless :meth:`flush` is called.
To support dicts, lists, integer, floats, and booleans values
are automatically json encoded/decoded.
@ -173,8 +173,11 @@ class Storage(object):
def __init__(self, path=None):
self.db_path = path
if path is None:
self.db_path = os.path.join(
os.environ.get('CHARM_DIR', ''), '.unit-state.db')
if 'UNIT_STATE_DB' in os.environ:
self.db_path = os.environ['UNIT_STATE_DB']
else:
self.db_path = os.path.join(
os.environ.get('CHARM_DIR', ''), '.unit-state.db')
self.conn = sqlite3.connect('%s' % self.db_path)
self.cursor = self.conn.cursor()
self.revision = None
@ -189,15 +192,8 @@ class Storage(object):
self.conn.close()
self._closed = True
def _scoped_query(self, stmt, params=None):
if params is None:
params = []
return stmt, params
def get(self, key, default=None, record=False):
self.cursor.execute(
*self._scoped_query(
'select data from kv where key=?', [key]))
self.cursor.execute('select data from kv where key=?', [key])
result = self.cursor.fetchone()
if not result:
return default
@ -206,33 +202,81 @@ class Storage(object):
return json.loads(result[0])
def getrange(self, key_prefix, strip=False):
stmt = "select key, data from kv where key like '%s%%'" % key_prefix
self.cursor.execute(*self._scoped_query(stmt))
"""
Get a range of keys starting with a common prefix as a mapping of
keys to values.
:param str key_prefix: Common prefix among all keys
:param bool strip: Optionally strip the common prefix from the key
names in the returned dict
:return dict: A (possibly empty) dict of key-value mappings
"""
self.cursor.execute("select key, data from kv where key like ?",
['%s%%' % key_prefix])
result = self.cursor.fetchall()
if not result:
return None
return {}
if not strip:
key_prefix = ''
return dict([
(k[len(key_prefix):], json.loads(v)) for k, v in result])
def update(self, mapping, prefix=""):
"""
Set the values of multiple keys at once.
:param dict mapping: Mapping of keys to values
:param str prefix: Optional prefix to apply to all keys in `mapping`
before setting
"""
for k, v in mapping.items():
self.set("%s%s" % (prefix, k), v)
def unset(self, key):
"""
Remove a key from the database entirely.
"""
self.cursor.execute('delete from kv where key=?', [key])
if self.revision and self.cursor.rowcount:
self.cursor.execute(
'insert into kv_revisions values (?, ?, ?)',
[key, self.revision, json.dumps('DELETED')])
def unsetrange(self, keys=None, prefix=""):
"""
Remove a range of keys starting with a common prefix, from the database
entirely.
:param list keys: List of keys to remove.
:param str prefix: Optional prefix to apply to all keys in ``keys``
before removing.
"""
if keys is not None:
keys = ['%s%s' % (prefix, key) for key in keys]
self.cursor.execute('delete from kv where key in (%s)' % ','.join(['?'] * len(keys)), keys)
if self.revision and self.cursor.rowcount:
self.cursor.execute(
'insert into kv_revisions values %s' % ','.join(['(?, ?, ?)'] * len(keys)),
list(itertools.chain.from_iterable((key, self.revision, json.dumps('DELETED')) for key in keys)))
else:
self.cursor.execute('delete from kv where key like ?',
['%s%%' % prefix])
if self.revision and self.cursor.rowcount:
self.cursor.execute(
'insert into kv_revisions values (?, ?, ?)',
['%s%%' % prefix, self.revision, json.dumps('DELETED')])
def set(self, key, value):
"""
Set a value in the database.
:param str key: Key to set the value for
:param value: Any JSON-serializable value to be set
"""
serialized = json.dumps(value)
self.cursor.execute(
'select data from kv where key=?', [key])
self.cursor.execute('select data from kv where key=?', [key])
exists = self.cursor.fetchone()
# Skip mutations to the same value

View File

@ -90,6 +90,22 @@ CLOUD_ARCHIVE_POCKETS = {
'kilo/proposed': 'trusty-proposed/kilo',
'trusty-kilo/proposed': 'trusty-proposed/kilo',
'trusty-proposed/kilo': 'trusty-proposed/kilo',
# Liberty
'liberty': 'trusty-updates/liberty',
'trusty-liberty': 'trusty-updates/liberty',
'trusty-liberty/updates': 'trusty-updates/liberty',
'trusty-updates/liberty': 'trusty-updates/liberty',
'liberty/proposed': 'trusty-proposed/liberty',
'trusty-liberty/proposed': 'trusty-proposed/liberty',
'trusty-proposed/liberty': 'trusty-proposed/liberty',
# Mitaka
'mitaka': 'trusty-updates/mitaka',
'trusty-mitaka': 'trusty-updates/mitaka',
'trusty-mitaka/updates': 'trusty-updates/mitaka',
'trusty-updates/mitaka': 'trusty-updates/mitaka',
'mitaka/proposed': 'trusty-proposed/mitaka',
'trusty-mitaka/proposed': 'trusty-proposed/mitaka',
'trusty-proposed/mitaka': 'trusty-proposed/mitaka',
}
# The order of this list is very important. Handlers should be listed in from
@ -217,12 +233,12 @@ def apt_purge(packages, fatal=False):
def apt_mark(packages, mark, fatal=False):
"""Flag one or more packages using apt-mark"""
log("Marking {} as {}".format(packages, mark))
cmd = ['apt-mark', mark]
if isinstance(packages, six.string_types):
cmd.append(packages)
else:
cmd.extend(packages)
log("Holding {}".format(packages))
if fatal:
subprocess.check_call(cmd, universal_newlines=True)
@ -403,7 +419,7 @@ def plugins(fetch_handlers=None):
importlib.import_module(package),
classname)
plugin_list.append(handler_class())
except (ImportError, AttributeError):
except NotImplementedError:
# Skip missing plugins so that they can be ommitted from
# installation if desired
log("FetchHandler {} not found, skipping plugin".format(

View File

@ -108,7 +108,7 @@ class ArchiveUrlFetchHandler(BaseFetchHandler):
install_opener(opener)
response = urlopen(source)
try:
with open(dest, 'w') as dest_file:
with open(dest, 'wb') as dest_file:
dest_file.write(response.read())
except Exception as e:
if os.path.isfile(dest):

View File

@ -15,60 +15,50 @@
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import os
from subprocess import check_call
from charmhelpers.fetch import (
BaseFetchHandler,
UnhandledSource
UnhandledSource,
filter_installed_packages,
apt_install,
)
from charmhelpers.core.host import mkdir
import six
if six.PY3:
raise ImportError('bzrlib does not support Python3')
try:
from bzrlib.branch import Branch
from bzrlib import bzrdir, workingtree, errors
except ImportError:
from charmhelpers.fetch import apt_install
apt_install("python-bzrlib")
from bzrlib.branch import Branch
from bzrlib import bzrdir, workingtree, errors
if filter_installed_packages(['bzr']) != []:
apt_install(['bzr'])
if filter_installed_packages(['bzr']) != []:
raise NotImplementedError('Unable to install bzr')
class BzrUrlFetchHandler(BaseFetchHandler):
"""Handler for bazaar branches via generic and lp URLs"""
def can_handle(self, source):
url_parts = self.parse_url(source)
if url_parts.scheme not in ('bzr+ssh', 'lp'):
if url_parts.scheme not in ('bzr+ssh', 'lp', ''):
return False
elif not url_parts.scheme:
return os.path.exists(os.path.join(source, '.bzr'))
else:
return True
def branch(self, source, dest):
url_parts = self.parse_url(source)
# If we use lp:branchname scheme we need to load plugins
if not self.can_handle(source):
raise UnhandledSource("Cannot handle {}".format(source))
if url_parts.scheme == "lp":
from bzrlib.plugin import load_plugins
load_plugins()
try:
local_branch = bzrdir.BzrDir.create_branch_convenience(dest)
except errors.AlreadyControlDirError:
local_branch = Branch.open(dest)
try:
remote_branch = Branch.open(source)
remote_branch.push(local_branch)
tree = workingtree.WorkingTree.open(dest)
tree.update()
except Exception as e:
raise e
if os.path.exists(dest):
check_call(['bzr', 'pull', '--overwrite', '-d', dest, source])
else:
check_call(['bzr', 'branch', source, dest])
def install(self, source):
def install(self, source, dest=None):
url_parts = self.parse_url(source)
branch_name = url_parts.path.strip("/").split("/")[-1]
dest_dir = os.path.join(os.environ.get('CHARM_DIR'), "fetched",
branch_name)
if dest:
dest_dir = os.path.join(dest, branch_name)
else:
dest_dir = os.path.join(os.environ.get('CHARM_DIR'), "fetched",
branch_name)
if not os.path.exists(dest_dir):
mkdir(dest_dir, perms=0o755)
try:

View File

@ -15,24 +15,18 @@
# along with charm-helpers. If not, see <http://www.gnu.org/licenses/>.
import os
from subprocess import check_call, CalledProcessError
from charmhelpers.fetch import (
BaseFetchHandler,
UnhandledSource
UnhandledSource,
filter_installed_packages,
apt_install,
)
from charmhelpers.core.host import mkdir
import six
if six.PY3:
raise ImportError('GitPython does not support Python 3')
try:
from git import Repo
except ImportError:
from charmhelpers.fetch import apt_install
apt_install("python-git")
from git import Repo
from git.exc import GitCommandError # noqa E402
if filter_installed_packages(['git']) != []:
apt_install(['git'])
if filter_installed_packages(['git']) != []:
raise NotImplementedError('Unable to install git')
class GitUrlFetchHandler(BaseFetchHandler):
@ -40,19 +34,24 @@ class GitUrlFetchHandler(BaseFetchHandler):
def can_handle(self, source):
url_parts = self.parse_url(source)
# TODO (mattyw) no support for ssh git@ yet
if url_parts.scheme not in ('http', 'https', 'git'):
if url_parts.scheme not in ('http', 'https', 'git', ''):
return False
elif not url_parts.scheme:
return os.path.exists(os.path.join(source, '.git'))
else:
return True
def clone(self, source, dest, branch, depth=None):
def clone(self, source, dest, branch="master", depth=None):
if not self.can_handle(source):
raise UnhandledSource("Cannot handle {}".format(source))
if depth:
Repo.clone_from(source, dest, branch=branch, depth=depth)
if os.path.exists(dest):
cmd = ['git', '-C', dest, 'pull', source, branch]
else:
Repo.clone_from(source, dest, branch=branch)
cmd = ['git', 'clone', source, dest, '--branch', branch]
if depth:
cmd.extend(['--depth', depth])
check_call(cmd)
def install(self, source, branch="master", dest=None, depth=None):
url_parts = self.parse_url(source)
@ -62,11 +61,9 @@ class GitUrlFetchHandler(BaseFetchHandler):
else:
dest_dir = os.path.join(os.environ.get('CHARM_DIR'), "fetched",
branch_name)
if not os.path.exists(dest_dir):
mkdir(dest_dir, perms=0o755)
try:
self.clone(source, dest_dir, branch, depth)
except GitCommandError as e:
except CalledProcessError as e:
raise UnhandledSource(e)
except OSError as e:
raise UnhandledSource(e.strerror)