Skip to content

Commit

Permalink
ssh: support jump boxes
Browse files Browse the repository at this point in the history
Example runbook example,

dev:
  mock_tcp_ping: True
  jump_boxes:
    - private_key_file: $(admin_private_key_file)
      address: 20.236.23.33
    - private_key_file: $(admin_private_key_file)
      address: 20.236.23.40
  • Loading branch information
squirrelsc committed Jul 1, 2022
1 parent 650bf3f commit 2778cff
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 4 deletions.
30 changes: 30 additions & 0 deletions lisa/development.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.


from typing import List, Optional

from lisa import schema

_development_settings: Optional[schema.Development] = None


def load_development_settings(runbook: Optional[schema.Development]) -> None:
global _development_settings
if runbook and runbook.enabled:
_development_settings = runbook


def is_mock_tcp_ping() -> bool:
return _development_settings is not None and _development_settings.mock_tcp_ping


def is_trace_enabled() -> bool:
return _development_settings is not None and _development_settings.enable_trace


def get_jump_boxes() -> List[schema.ProxyConnectionInfo]:
if _development_settings and _development_settings.jump_boxes:
return _development_settings.jump_boxes
else:
return []
5 changes: 4 additions & 1 deletion lisa/runners/lisa_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, cast

from lisa import SkippedException, notifier, schema, search_space
from lisa import SkippedException, development, notifier, schema, search_space
from lisa.action import ActionStatus
from lisa.environment import (
Environment,
Expand Down Expand Up @@ -51,6 +51,9 @@ def _initialize(self, *args: Any, **kwargs: Any) -> None:
platform_message = PlatformMessage(name=self.platform.type_name())
notifier.notify(platform_message)

# load development settings
development.load_development_settings(self._runbook.dev)

@property
def is_done(self) -> bool:
is_all_results_completed = all(
Expand Down
22 changes: 22 additions & 0 deletions lisa/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,27 @@ def __str__(self) -> str:
return f"{self.username}@{self.address}:{self.port}"


@dataclass_json()
@dataclass
class ProxyConnectionInfo(ConnectionInfo):
private_address: str = ""
private_port: int = field(
default=22,
metadata=field_metadata(
field_function=fields.Int, validate=validate.Range(min=1, max=65535)
),
)


@dataclass_json()
@dataclass
class Development:
enabled: bool = True
enable_trace: bool = False
mock_tcp_ping: bool = False
jump_boxes: List[ProxyConnectionInfo] = field(default_factory=list)


@dataclass_json()
@dataclass
class Runbook:
Expand All @@ -1372,6 +1393,7 @@ class Runbook:
testcase_raw: List[Any] = field(
default_factory=list, metadata=field_metadata(data_key=constants.TESTCASE)
)
dev: Optional[Development] = field(default=None)

def __post_init__(self, *args: Any, **kwargs: Any) -> None:
if not self.platform:
Expand Down
111 changes: 108 additions & 3 deletions lisa/util/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import socket
import sys
import time
from functools import partial
from pathlib import Path, PurePath
from time import sleep
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast
Expand All @@ -17,12 +18,14 @@
from func_timeout import FunctionTimedOut, func_set_timeout # type: ignore
from paramiko.ssh_exception import NoValidConnectionsError, SSHException

from lisa import schema
from lisa import development, schema
from lisa.util import InitializableMixin, LisaException, TcpConnectionException

from .logger import Logger
from .logger import Logger, get_logger
from .perf_timer import create_timer

_get_jump_box_logger = partial(get_logger, name="jump_box")


def wait_tcp_port_ready(
address: str, port: int, log: Optional[Logger] = None, timeout: int = 300
Expand All @@ -35,6 +38,11 @@ def wait_tcp_port_ready(
times: int = 0
result: int = 0

if development.is_mock_tcp_ping():
# If it's True, it means the direct connection doesn't work. Return a
# mock value for test purpose.
return True, 0

timeout_timer = create_timer()
while timeout_timer.elapsed(False) < timeout:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_socket:
Expand Down Expand Up @@ -105,6 +113,7 @@ def generate_run_command(
def try_connect(
connection_info: schema.ConnectionInfo,
ssh_timeout: int = 300,
sock: Optional[Any] = None,
) -> Any:
# spur always run a posix command and will fail on Windows.
# So try with paramiko firstly.
Expand All @@ -127,6 +136,7 @@ def try_connect(
password=connection_info.password,
key_filename=connection_info.private_key_file,
banner_timeout=10,
sock=sock,
)

stdin, stdout, _ = paramiko_client.exec_command("cmd\n")
Expand Down Expand Up @@ -180,6 +190,8 @@ def __init__(self, connection_info: schema.ConnectionInfo) -> None:
self.is_remote = True
self._connection_info = connection_info
self._inner_shell: Optional[spur.SshShell] = None
self._jump_boxes: List[Any] = []
self._jump_box_sock: Any = None

paramiko_logger = logging.getLogger("paramiko")
paramiko_logger.setLevel(logging.WARN)
Expand All @@ -194,15 +206,23 @@ def _initialize(self, *args: Any, **kwargs: Any) -> None:
self._connection_info.port,
tcp_error_code,
)

sock = self._establish_jump_boxes(
address=self._connection_info.address,
port=self._connection_info.port,
)

try:
stdout = try_connect(self._connection_info)
stdout = try_connect(self._connection_info, sock=sock)
except Exception as identifier:
raise LisaException(
f"failed to connect SSH "
f"[{self._connection_info.address}:{self._connection_info.port}], "
f"{identifier.__class__.__name__}: {identifier}"
)

self._close_jump_boxes()

# Some windows doesn't end the text stream, so read first line only.
# it's enough to detect os.
stdout_content = stdout.readline()
Expand All @@ -215,6 +235,11 @@ def _initialize(self, *args: Any, **kwargs: Any) -> None:
self.is_posix = True
shell_type = spur.ssh.ShellTypes.sh

sock = self._establish_jump_boxes(
address=self._connection_info.address,
port=self._connection_info.port,
)

spur_kwargs = {
"hostname": self._connection_info.address,
"port": self._connection_info.port,
Expand All @@ -226,6 +251,7 @@ def _initialize(self, *args: Any, **kwargs: Any) -> None:
# IP in different time. If so, there is host key conflict. So do not
# load host keys to avoid this kind of error.
"load_system_host_keys": False,
"sock": sock,
}

spur_ssh_shell = spur.SshShell(shell_type=shell_type, **spur_kwargs)
Expand All @@ -241,6 +267,8 @@ def close(self) -> None:
self._inner_shell = None
self._is_initialized = False

self._close_jump_boxes()

@property
def is_connected(self) -> bool:
is_inner_shell_ready = False
Expand Down Expand Up @@ -499,6 +527,83 @@ def _purepath_to_str(
path = str(path)
return path

def _establish_jump_boxes(self, address: str, port: int) -> Any:
jump_boxes_runbook = development.get_jump_boxes()
sock: Any = None
is_trace_enabled = development.is_trace_enabled()
if is_trace_enabled:
jb_logger = _get_jump_box_logger()
jb_logger.debug(f"proxy sock: {sock}")

for index, runbook in enumerate(jump_boxes_runbook):
if is_trace_enabled:
jb_logger.debug(f"creating connection from source: {runbook} ")
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy())
client.connect(
hostname=runbook.address,
port=runbook.port,
username=runbook.username,
password=runbook.password,
key_filename=runbook.private_key_file,
banner_timeout=10,
sock=sock,
)

if index < len(jump_boxes_runbook) - 1:
next_hop = jump_boxes_runbook[index + 1]
dest_address = (
next_hop.private_address
if next_hop.private_address
else next_hop.address
)
dest_port = (
next_hop.private_port if next_hop.private_port else next_hop.port
)
else:
dest_address = address
dest_port = port

if is_trace_enabled:
jb_logger.debug(f"next hop: {dest_address}:{dest_port}")
sock = self._open_jump_box_channel(
client,
src_address=runbook.address,
src_port=runbook.port,
dest_address=dest_address,
dest_port=dest_port,
)
self._jump_boxes.append(client)

return sock

def _open_jump_box_channel(
self,
client: paramiko.SSHClient,
src_address: str,
src_port: int,
dest_address: str,
dest_port: int,
) -> Any:
transport = client.get_transport()
assert transport

sock = transport.open_channel(
kind="direct-tcpip",
src_addr=(src_address, src_port),
dest_addr=(dest_address, dest_port),
)

return sock

def _close_jump_boxes(self) -> None:
for index in reversed(range(len(self._jump_boxes))):
self._jump_boxes[index].close()
self._jump_boxes[index] = None

self._jump_boxes.clear()
self._jump_box_sock = None


class LocalShell(InitializableMixin):
def __init__(self) -> None:
Expand Down

0 comments on commit 2778cff

Please sign in to comment.