fuel-octane/octane/util/ssh.py

308 lines
8.9 KiB
Python

# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import contextlib
import io
import logging
import os
import pipes
import random
import shutil
import threading
import paramiko
from paramiko import channel
from octane import magic_consts
from octane.util import subprocess
LOG = logging.getLogger(__name__)
PIPE = subprocess.PIPE
class _cache(object):
def __init__(self, new):
self.new = new
self.cache = {}
self.lock = threading.Lock()
self.invalidate = []
self.check_fn = None
def __call__(self, node):
node_id = node.data['id']
try:
obj = self.cache[node_id]
except KeyError:
obj = None
else:
if not self.check_fn or self.check_fn(node, obj):
return obj
# Now obj is either bad old obj or None
with self.lock:
try:
new_obj = self.cache[node_id]
except KeyError:
pass # Need to just create a new one
else:
if new_obj is not obj:
return new_obj # Someone already created a new one
# We're going to replace this obj, invalidate other caches
for cache in self.invalidate:
with cache.lock:
cache.cache.pop(node_id, None)
new_obj = self.new(node)
self.cache[node_id] = new_obj
return new_obj
def check(self, fn):
self.check_fn = fn
return fn
@_cache
def get_client(node):
LOG.info("Creating new SSH connection to node %s", node.data['id'])
creds = get_env_credentials(node.env)
params = {
'username': creds['user'] if creds else 'root',
'key_filename': magic_consts.SSH_KEYS,
}
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(node.data['ip'], **params)
return client
@get_client.check
def _check_client(node, client):
t = client.get_transport()
if t:
# Send normal keepalive packet, but wait for result to let socket die
t.global_request('keepalive@lag.net', wait=True)
if t.is_active():
return True
LOG.info("SSH connection to node %s died, reconnecting", node.data['id'])
return False
class ChannelFile(io.IOBase, channel.ChannelFile):
pass
class ChannelStderrFile(io.IOBase, channel.ChannelStderrFile):
pass
class _LogPipe(subprocess._BaseLogPipe):
def __init__(self, level, pipe, parse_levels=False):
super(_LogPipe, self).__init__(level, parse_levels=parse_levels)
self._pipe = pipe
def pipe(self):
return self._pipe
class SSHPopen(subprocess.BasePopen):
def __init__(self, name, cmd, popen_kwargs):
self.node = popen_kwargs.pop('node')
for key in ['stdin', 'stdout', 'stderr']:
assert popen_kwargs.get(key) in [None, PIPE]
super(SSHPopen, self).__init__(name, cmd, popen_kwargs)
as_root = popen_kwargs.get('as_root', True)
transport = get_client(self.node).get_transport()
username = transport.get_username()
if username != 'root' and as_root:
cmd = ['sudo', '--'] + cmd
self._channel = transport.open_session()
self._channel.exec_command(" ".join(map(pipes.quote, cmd)))
self.name = "%s[at node-%d]" % (self.name, self.node.data['id'])
if 'stdin' not in self.popen_kwargs:
self.close_stdin()
else:
self.stdin = ChannelFile(self._channel, 'wb')
stdout = ChannelFile(self._channel, 'rb')
if 'stdout' not in self.popen_kwargs:
self._pipe_stdout = _LogPipe(logging.INFO, stdout)
self._pipe_stdout.start(self.name + " stdout")
else:
self._pipe_stdout = None
self.stdout = stdout
stderr = ChannelStderrFile(self._channel, 'rb')
stderr_level = self.popen_kwargs.pop('stderr_log_level', logging.ERROR)
if 'stderr' not in self.popen_kwargs:
self._pipe_stderr = _LogPipe(
stderr_level, stderr,
parse_levels=popen_kwargs.get('parse_levels', False),
)
self._pipe_stderr.start(self.name + " stderr")
else:
self._pipe_stderr = None
self.stderr = stderr
def poll(self):
if self._channel.exit_status_ready():
return self._channel.recv_exit_status()
else:
return None
def wait(self):
return self._channel.recv_exit_status()
def terminate(self):
self._channel.close()
def close_stdin(self):
self._channel.shutdown_write()
def communicate(self):
if self.stdin:
self.close_stdin()
if self.stdout:
stdout = self.stdout.read()
else:
stdout = None
if self.stderr:
stderr = self.stderr.read()
else:
stderr = None
return stdout, stderr
def popen(cmd, **kwargs):
return subprocess.popen(cmd, popen_class=SSHPopen, **kwargs)
def call(cmd, **kwargs):
return subprocess.call(cmd, popen_class=SSHPopen, **kwargs)
def call_output(cmd, **kwargs):
return subprocess.call_output(cmd, popen_class=SSHPopen, **kwargs)
@_cache
def _get_sftp(node):
transport = get_client(node).get_transport()
username = transport.get_username()
if username != 'root':
LOG.info('Run sftp server as root on node %s', node.data['hostname'])
channel = transport.open_channel('session')
channel.exec_command('sudo ' + magic_consts.SFTP_SERVER_BIN)
return paramiko.SFTPClient(channel)
return paramiko.SFTPClient.from_transport(transport)
get_client.invalidate.append(_get_sftp)
def sftp(node):
get_client(node) # ensure we're still connected
return _get_sftp(node)
@contextlib.contextmanager
def update_file(sftp, filename):
old = sftp.open(filename, 'r')
try:
temp_filename = '%s.octane.%08x' % (filename,
random.randrange(1 << 8 * 4))
new = sftp.open(temp_filename, 'wx')
except IOError: # we're unlucky, try other name (or fail)
temp_filename = '%s.octane.%08x' % (filename,
random.randrange(1 << 8 * 4))
new = sftp.open(temp_filename, 'wx')
with contextlib.nested(old, new):
try:
yield old, new
except subprocess.DontUpdateException:
sftp.unlink(temp_filename)
return
except Exception:
sftp.unlink(temp_filename)
raise
stat = old.stat()
new.chmod(stat.st_mode)
new.chown(stat.st_uid, stat.st_gid)
bak_filename = filename + '.octane.bak'
sftp.rename(filename, bak_filename)
sftp.rename(temp_filename, filename)
sftp.unlink(bak_filename)
@contextlib.contextmanager
def tempdir(node):
out = call_output(['mktemp', '-d'], node=node)
dirname = out[:-1]
try:
yield dirname
finally:
call(['rm', '-rf', dirname], node=node)
@contextlib.contextmanager
def applied_patches(cwd, node, *patches):
patched_files = []
try:
for path in patches:
with open(path, "rb") as patch:
with popen(
["patch", "-N", "-p1", "-d", cwd],
node=node, stdin=PIPE) as proc:
shutil.copyfileobj(patch, proc.stdin)
patched_files.append(path)
yield
finally:
patched_files.reverse()
for path in patched_files:
with open(path, "rb") as patch:
with popen(
["patch", "-R", "-p1", "-d", cwd],
node=node, stdin=PIPE) as proc:
shutil.copyfileobj(patch, proc.stdin)
def get_env_credentials(env):
attrs = env.get_attributes()
editable = attrs['editable'].get('service_user')
if not editable:
return None
return {
'user': editable['name']['value'],
'password': editable['password']['value'],
}
def remove_all_files_from_dirs(dir_names, node):
ftp = sftp(node)
for dir_name in dir_names:
for filename in ftp.listdir(dir_name):
path = os.path.join(dir_name, filename)
ftp.unlink(path)
def write_content_to_file(sftp, filename, content):
with sftp.open(filename, 'w') as f:
f.write(content)