opencafe/cafe/plugins/sshv2/cafe/engine/sshv2/client.py

317 lines
11 KiB
Python

# Copyright 2015 Rackspace
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from socks import socket, create_connection
from StringIO import StringIO
from uuid import uuid4
import io
import time
from paramiko import AutoAddPolicy, RSAKey
from paramiko import py3compat
from paramiko.client import SSHClient as ParamikoSSHClient
from cafe.engine.sshv2.common import (
BaseSSHClass, _SSHLogger, DEFAULT_TIMEOUT, POLLING_RATE, CHANNEL_KEEPALIVE)
from cafe.engine.sshv2.models import ExecResponse
# this is a hack to preimport dependencies imported in a thread during connect
# which causes a deadlock. https://github.com/paramiko/paramiko/issues/104
py3compat.u("")
# dirty hack 2.0 also issue 104
from cryptography.hazmat.backends import default_backend
default_backend()
class ProxyTypes(object):
SOCKS5 = 2
SOCKS4 = 1
class ExtendedParamikoSSHClient(ParamikoSSHClient):
def execute_command(
self, command, bufsize=-1, timeout=None, stdin_str="", stdin_file=None,
raise_exceptions=False):
timeout = timeout or DEFAULT_TIMEOUT
chan = self._transport.open_session()
chan.settimeout(POLLING_RATE)
chan.exec_command(command)
stdin_str = stdin_str if stdin_file is None else stdin_file.read()
stdin = chan.makefile("wb", bufsize)
stdout = chan.makefile("rb", bufsize)
stderr = chan.makefile_stderr("rb", bufsize)
stdin.write(stdin_str)
stdin.close()
stdout_str = stderr_str = ""
exit_status = None
max_time = time.time() + timeout
while not chan.exit_status_ready():
stderr_str += self._read_channel(stderr)
stdout_str += self._read_channel(stdout)
if max_time < time.time():
raise socket.timeout(
"Command timed out\nSTDOUT:{0}\nSTDERR:{1}\n".format(
stdout_str, stderr_str))
exit_status = chan.recv_exit_status()
stdout_str += self._read_channel(stdout)
stderr_str += self._read_channel(stderr)
chan.close()
return stdin_str, stdout_str, stderr_str, exit_status
def _read_channel(self, chan):
read = ""
try:
read += chan.read()
except socket.timeout:
pass
return read
class SSHClient(BaseSSHClass):
def __init__(
self, hostname=None, port=22, username=None, password=None,
accept_missing_host_key=True, timeout=None, compress=True, pkey=None,
look_for_keys=False, allow_agent=False, key_filename=None,
proxy_type=None, proxy_ip=None, proxy_port=None, sock=None):
super(SSHClient, self).__init__()
self.connect_kwargs = {}
self.accept_missing_host_key = accept_missing_host_key
self.proxy_port = proxy_port
self.proxy_ip = proxy_ip
self.proxy_type = proxy_type
self.connect_kwargs["timeout"] = timeout or DEFAULT_TIMEOUT
self.connect_kwargs["hostname"] = hostname
self.connect_kwargs["port"] = int(port)
self.connect_kwargs["username"] = username
self.connect_kwargs["password"] = password
self.connect_kwargs["compress"] = compress
self.connect_kwargs["pkey"] = pkey
self.connect_kwargs["look_for_keys"] = look_for_keys
self.connect_kwargs["allow_agent"] = allow_agent
self.connect_kwargs["key_filename"] = key_filename
self.connect_kwargs["sock"] = sock
@property
def timeout(self):
return self.connect_kwargs.get("timeout")
@timeout.setter
def timeout(self, value):
self.connect_kwargs["timeout"] = value
@timeout.deleter
def timeout(self):
if "timeout" in self.connect_kwargs:
del self.connect_kwargs["timeout"]
def _connect(
self, hostname=None, port=None, username=None, password=None,
accept_missing_host_key=None, timeout=None, compress=None, pkey=None,
look_for_keys=None, allow_agent=None, key_filename=None,
proxy_type=None, proxy_ip=None, proxy_port=None, sock=None):
connect_kwargs = dict(self.connect_kwargs)
connect_kwargs.update({
k: locals().get(k) for k in self.connect_kwargs
if locals().get(k) is not None})
connect_kwargs["port"] = int(connect_kwargs.get("port"))
ssh = ExtendedParamikoSSHClient()
if bool(self.accept_missing_host_key or accept_missing_host_key):
ssh.set_missing_host_key_policy(AutoAddPolicy())
if connect_kwargs.get("pkey") is not None:
connect_kwargs["pkey"] = RSAKey.from_private_key(
io.StringIO(unicode(connect_kwargs["pkey"])))
proxy_type = proxy_type or self.proxy_type
proxy_ip = proxy_ip or self.proxy_ip
proxy_port = proxy_port or self.proxy_port
if connect_kwargs.get("sock") is not None:
pass
elif all([proxy_type, proxy_ip, proxy_port]):
connect_kwargs["sock"] = create_connection(
(connect_kwargs.get("hostname"), connect_kwargs.get("port")),
proxy_type, proxy_ip, int(proxy_port))
ssh.connect(**connect_kwargs)
return ssh
@_SSHLogger
def execute_command(
self, command, bufsize=-1, stdin_str="", stdin_file=None,
**connect_kwargs):
ssh_client = self._connect(**connect_kwargs)
stdin, stdout, stderr, exit_status = ssh_client.execute_command(
timeout=self._get_timeout(connect_kwargs.get("timeout")),
command=command, bufsize=bufsize, stdin_str=stdin_str,
stdin_file=stdin_file)
ssh_client.close()
del ssh_client
return ExecResponse(
stdin=stdin, stdout=stdout, stderr=stderr, exit_status=exit_status)
def _get_timeout(self, timeout=None):
return timeout if timeout is not None else self.timeout
@_SSHLogger
def create_shell(self, keepalive=None, **connect_kwargs):
connection = self._connect(**connect_kwargs)
return SSHShell(
connection, self._get_timeout(connect_kwargs.get("timeout")),
keepalive)
@_SSHLogger
def create_sftp(self, keepalive=None, **connect_kwargs):
connection = self._connect(**connect_kwargs)
return SFTPShell(connection, keepalive)
class SFTPShell(BaseSSHClass):
def __init__(self, connection=None, keepalive=None):
super(SFTPShell, self).__init__()
self.connection = connection
self.sftp = connection.open_sftp()
self.sftp.get_channel().get_transport().set_keepalive(
keepalive or CHANNEL_KEEPALIVE)
self.chdir(".")
def __getattribute__(self, name):
if name in [
"chdir", "chmod", "chown", "file", "get", "getcwd", "getfo",
"listdir", "listdir_attr", "listdir_iter", "lstat", "mkdir",
"normalize", "open", "put", "putfo", "readlink", "remove",
"rename", "rmdir", "stat", "symlink", "truncate", "unlink",
"utime"]:
return self.sftp.__getattribute__(name)
else:
return super(SFTPShell, self).__getattribute__(name)
def exists(self, path):
ret_val = False
try:
self.sftp.stat(path)
ret_val = True
except IOError, e:
if e[0] != 2:
raise
ret_val = False
return ret_val
def get_file(self, remote_path):
ret_val = StringIO()
self.getfo(remote_path, ret_val)
return ret_val.getvalue()
def write_file(self, data, remote_path):
return self.putfo(StringIO(data), remote_path)
def close(self):
if hasattr(self, "sftp"):
self.sftp.close()
del self.sftp
if hasattr(self, "connection"):
self.connection.close()
del self.connection
class SSHShell(BaseSSHClass):
RAISE = "RAISE"
RAISE_DISCONNECT = "RAISE_DISCONNECT"
def __init__(self, connection, timeout, keepalive=None):
super(SSHShell, self).__init__()
self.connection = connection
self.timeout = timeout
self.channel = self._create_channel()
self.channel.settimeout(POLLING_RATE)
self._clear_channel()
self.channel.get_transport().set_keepalive(
keepalive or CHANNEL_KEEPALIVE)
def close(self):
if hasattr(self, "channel"):
self.channel.close()
del self.channel
if hasattr(self, "connection"):
self.connection.close()
del self.connection
@_SSHLogger
def execute_shell_command(
self, cmd, timeout=None, timeout_action=RAISE_DISCONNECT,
exception_on_timeout=True):
max_time = time.time() + self._get_timeout(timeout)
uuid = str(uuid4()).replace("-", "")
cmd = "echo {1}\n{0}\necho {1} $?\n".format(cmd.strip(), uuid)
try:
self._clear_channel()
self._wait_for_active_shell(max_time)
self.channel.send(cmd)
response = self._read_shell_response(uuid, max_time)
except socket.timeout as e:
if timeout_action == self.RAISE_DISCONNECT:
self.close()
raise e
return response
def _create_channel(self):
chan = self.connection._transport.open_session()
chan.invoke_shell()
return chan
def _wait_for_active_shell(self, max_time):
while not self.channel.send_ready():
time.sleep(POLLING_RATE)
if max_time < time.time():
raise socket.timeout("Timed out waiting for active shell")
def _read_shell_response(self, uuid, max_time):
stdout = stderr = ""
exit_status = None
while max_time > time.time():
stdout += self._read_channel(self.channel.recv)
stderr += self._read_channel(self.channel.recv_stderr)
if stdout.count(uuid) == 2:
list_ = stdout.split(uuid)
stdout = list_[1]
try:
exit_status = int(list_[2])
except (ValueError, TypeError):
exit_status = None
break
else:
raise socket.timeout(
"Command timed out\nSTDOUT:{0}\nSTDERR:{1}\n".format(
stdout, stderr))
response = ExecResponse(
stdin=None, stdout=stdout.strip(), stderr=stderr,
exit_status=exit_status)
return response
def _read_channel(self, read_func, buffsize=1024):
read = ""
try:
read += read_func(buffsize)
except socket.timeout:
pass
return read
def _clear_channel(self):
self._read_channel(self.channel.recv)
self._read_channel(self.channel.recv_stderr)
def _get_timeout(self, timeout=None):
return timeout if timeout is not None else self.timeout