Skip to content

Commit

Permalink
Specify User Agent for Console CLI requests (#952)
Browse files Browse the repository at this point in the history
Allows setting a global user agent on CDK deployments that will be configured in a client_options section in the services.yaml. This user agent will currently be provided to the Traffic Replayer for its replayed requests (existing logic) and to the Console CLI for all of its boto3 requests and python requests library requests. This can also be extended to other migrations as support is added.

---------

Signed-off-by: Tanner Lewis <[email protected]>
  • Loading branch information
lewijacn committed Sep 20, 2024
1 parent b1ce98b commit 62d38b5
Show file tree
Hide file tree
Showing 25 changed files with 646 additions and 294 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- [Metadata Migration](#metadata-migration)
- [Replay](#replay)
- [Kafka](#kafka)
- [Client Options](#client-options)
- [Usage](#usage)
- [Library](#library)
- [CLI](#cli)
Expand Down Expand Up @@ -82,6 +83,8 @@ metadata_migration:
kafka:
broker_endpoints: "kafka:9092"
standard:
client_options:
user_agent_extra: "test-user-agent-v1.0"
```
## Services.yaml spec
Expand Down Expand Up @@ -225,13 +228,19 @@ Exactly one of the following blocks must be present:

A Kafka cluster is used in the capture and replay stage of the migration to store recorded requests and responses before they're replayed. While it's not necessary for a user to directly interact with the Kafka cluster in most cases, there are a handful of commands that can be helpful for checking on the status or resetting state that are exposed by the Console CLI.

- `broker_endpoints`: required, comma-separated list of kafaka broker endpoints
- `broker_endpoints`: required, comma-separated list of kafka broker endpoints

Exactly one of the following keys must be present, but both are nullable (they don't have or need any additional parameters).

- `msk`: the Kafka instance is deployed as AWS Managed Service Kafka
- `standard`: the Kafka instance is deployed as a standard Kafka cluster (e.g. on Docker)

### Client Options

Client options are global settings that are applied to different clients used throughout this library

- `user_agent_extra`: optional, a user agent string that will be appended to the `User-Agent` header of all requests from this library

## Usage

### Library
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from console_link.models.snapshot import Snapshot
from console_link.models.replayer_base import Replayer
from console_link.models.kafka import Kafka
from console_link.models.client_options import ClientOptions

import yaml
from cerberus import Validator
Expand All @@ -25,7 +26,8 @@
"snapshot": {"type": "dict", "required": False},
"metadata_migration": {"type": "dict", "required": False},
"replay": {"type": "dict", "required": False},
"kafka": {"type": "dict", "required": False}
"kafka": {"type": "dict", "required": False},
"client_options": {"type": "dict", "required": False},
}


Expand All @@ -38,6 +40,7 @@ class Environment:
metadata: Optional[Metadata] = None
replay: Optional[Replayer] = None
kafka: Optional[Kafka] = None
client_options: Optional[ClientOptions] = None

def __init__(self, config_file: str):
logger.info(f"Loading config file: {config_file}")
Expand All @@ -50,23 +53,29 @@ def __init__(self, config_file: str):
logger.error(f"Config file validation errors: {v.errors}")
raise ValueError("Invalid config file", v.errors)

if 'client_options' in self.config:
self.client_options: ClientOptions = ClientOptions(self.config["client_options"])

if 'source_cluster' in self.config:
self.source_cluster = Cluster(self.config["source_cluster"])
self.source_cluster = Cluster(config=self.config["source_cluster"],
client_options=self.client_options)
logger.info(f"Source cluster initialized: {self.source_cluster.endpoint}")
else:
logger.info("No source cluster provided")

# At some point, target and replayers should be stored as pairs, but for the time being
# we can probably assume one target cluster.
if 'target_cluster' in self.config:
self.target_cluster: Cluster = Cluster(self.config["target_cluster"])
self.target_cluster: Cluster = Cluster(config=self.config["target_cluster"],
client_options=self.client_options)
logger.info(f"Target cluster initialized: {self.target_cluster.endpoint}")
else:
logger.warning("No target cluster provided. This may prevent other actions from proceeding.")

if 'metrics_source' in self.config:
self.metrics_source: MetricsSource = get_metrics_source(
self.config["metrics_source"]
config=self.config["metrics_source"],
client_options=self.client_options
)
logger.info(f"Metrics source initialized: {self.metrics_source}")
else:
Expand All @@ -75,13 +84,14 @@ def __init__(self, config_file: str):
if 'backfill' in self.config:
self.backfill: Backfill = get_backfill(self.config["backfill"],
source_cluster=self.source_cluster,
target_cluster=self.target_cluster)
target_cluster=self.target_cluster,
client_options=self.client_options)
logger.info(f"Backfill migration initialized: {self.backfill}")
else:
logger.info("No backfill provided")

if 'replay' in self.config:
self.replay: Replayer = get_replayer(self.config["replay"])
self.replay: Replayer = get_replayer(self.config["replay"], client_options=self.client_options)
logger.info(f"Replay initialized: {self.replay}")

if 'snapshot' in self.config:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from console_link.models.client_options import ClientOptions
from console_link.models.osi_utils import (create_pipeline_from_env, start_pipeline, stop_pipeline,
OpenSearchIngestionMigrationProps)
from console_link.models.cluster import Cluster
from console_link.models.backfill_base import Backfill
from console_link.models.command_result import CommandResult
from typing import Dict
from typing import Dict, Optional
from cerberus import Validator
import boto3

from console_link.models.utils import create_boto3_client

OSI_SCHEMA = {
'pipeline_role_arn': {
Expand Down Expand Up @@ -61,15 +62,17 @@ class OpenSearchIngestionBackfill(Backfill):
A migration manager for an OpenSearch Ingestion pipeline.
"""

def __init__(self, config: Dict, source_cluster: Cluster, target_cluster: Cluster) -> None:
def __init__(self, config: Dict, source_cluster: Cluster, target_cluster: Cluster,
client_options: Optional[ClientOptions] = None) -> None:
super().__init__(config)
self.client_options = client_options
config = config["opensearch_ingestion"]

v = Validator(OSI_SCHEMA)
if not v.validate(config):
raise ValueError("Invalid config file for OpenSearchIngestion migration", v.errors)
self.osi_props = OpenSearchIngestionMigrationProps(config=config)
self.osi_client = boto3.client('osis')
self.osi_client = create_boto3_client(aws_service_name='osis', client_options=self.client_options)
self.source_cluster = source_cluster
self.target_cluster = target_cluster

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import requests

from console_link.models.backfill_base import Backfill, BackfillStatus
from console_link.models.client_options import ClientOptions
from console_link.models.cluster import Cluster
from console_link.models.schema_tools import contains_one_of
from console_link.models.command_result import CommandResult
Expand Down Expand Up @@ -87,14 +88,17 @@ def scale(self, units: int, *args, **kwargs) -> CommandResult:


class ECSRFSBackfill(RFSBackfill):
def __init__(self, config: Dict, target_cluster: Cluster) -> None:
def __init__(self, config: Dict, target_cluster: Cluster, client_options: Optional[ClientOptions] = None) -> None:
super().__init__(config)
self.client_options = client_options
self.target_cluster = target_cluster
self.default_scale = self.config["reindex_from_snapshot"].get("scale", 1)

self.ecs_config = self.config["reindex_from_snapshot"]["ecs"]
self.ecs_client = ECSService(self.ecs_config["cluster_name"], self.ecs_config["service_name"],
self.ecs_config.get("aws_region", None))
self.ecs_client = ECSService(cluster_name=self.ecs_config["cluster_name"],
service_name=self.ecs_config["service_name"],
aws_region=self.ecs_config.get("aws_region", None),
client_options=self.client_options)

def start(self, *args, **kwargs) -> CommandResult:
logger.info(f"Starting RFS backfill by setting desired count to {self.default_scale} instances")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Dict, Optional
import logging
from cerberus import Validator

logger = logging.getLogger(__name__)

SCHEMA = {
"client_options": {
"type": "dict",
"schema": {
"user_agent_extra": {"type": "string", "required": False},
},
}
}


class ClientOptions:
"""
Options that can be configured for boto3 and request library clients.
"""

user_agent_extra: Optional[str] = None

def __init__(self, config: Dict) -> None:
logger.info(f"Initializing client options with config: {config}")
v = Validator(SCHEMA)
if not v.validate({'client_options': config}):
raise ValueError("Invalid config file for client options", v.errors)

self.user_agent_extra = config.get("user_agent_extra", None)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from requests.auth import HTTPBasicAuth
from requests_auth_aws_sigv4 import AWSSigV4

from console_link.models.client_options import ClientOptions
from console_link.models.schema_tools import contains_one_of
from console_link.models.utils import create_boto3_client, append_user_agent_header_for_requests

requests.packages.urllib3.disable_warnings() # ignore: type

Expand Down Expand Up @@ -79,8 +81,9 @@ class Cluster:
auth_type: Optional[AuthMethod] = None
auth_details: Optional[Dict[str, Any]] = None
allow_insecure: bool = False
client_options: Optional[ClientOptions] = None

def __init__(self, config: Dict) -> None:
def __init__(self, config: Dict, client_options: Optional[ClientOptions] = None) -> None:
logger.info(f"Initializing cluster with config: {config}")
v = Validator(SCHEMA)
if not v.validate({'cluster': config}):
Expand All @@ -97,6 +100,7 @@ def __init__(self, config: Dict) -> None:
elif 'sigv4' in config:
self.auth_type = AuthMethod.SIGV4
self.auth_details = config["sigv4"] if config["sigv4"] is not None else {}
self.client_options = client_options

def get_basic_auth_password(self) -> str:
"""This method will return the basic auth password, if basic auth is enabled.
Expand All @@ -108,11 +112,11 @@ def get_basic_auth_password(self) -> str:
return self.auth_details["password"]
# Pull password from AWS Secrets Manager
assert "password_from_secret_arn" in self.auth_details # for mypy's sake
client = boto3.client('secretsmanager')
client = create_boto3_client(aws_service_name="secretsmanager", client_options=self.client_options)
password = client.get_secret_value(SecretId=self.auth_details["password_from_secret_arn"])
return password["SecretString"]

def _get_sigv4_details(self, force_region=False) -> tuple[str, str]:
def _get_sigv4_details(self, force_region=False) -> tuple[str, Optional[str]]:
"""Return the service signing name and region name. If force_region is true,
it will instantiate a boto3 session to guarantee that the region is not None.
This will fail if AWS credentials are not available.
Expand Down Expand Up @@ -145,9 +149,14 @@ def call_api(self, path, method: HttpMethod = HttpMethod.GET, data=None, headers
"""
if session is None:
session = requests.Session()

auth = self._generate_auth_object()

request_headers = headers
if self.client_options and self.client_options.user_agent_extra:
user_agent_extra = self.client_options.user_agent_extra
request_headers = append_user_agent_header_for_requests(headers=headers, user_agent_extra=user_agent_extra)

# Extract query parameters from kwargs
params = kwargs.get('params', {})

Expand All @@ -159,7 +168,7 @@ def call_api(self, path, method: HttpMethod = HttpMethod.GET, data=None, headers
params=params,
auth=auth,
data=data,
headers=headers,
headers=request_headers,
timeout=timeout
)
logger.info(f"Received response: {r.status_code} {method.name} {self.endpoint}{path} - {r.text[:1000]}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import logging
from typing import NamedTuple, Optional

import boto3

from console_link.models.command_result import CommandResult
from console_link.models.utils import AWSAPIError, raise_for_aws_api_error

from console_link.models.utils import AWSAPIError, raise_for_aws_api_error, create_boto3_client

logger = logging.getLogger(__name__)

Expand All @@ -20,13 +17,15 @@ def __str__(self):


class ECSService:
def __init__(self, cluster_name, service_name, aws_region=None):
def __init__(self, cluster_name, service_name, aws_region=None, client_options=None):
self.cluster_name = cluster_name
self.service_name = service_name
self.aws_region = aws_region
self.client_options = client_options

logger.info(f"Creating ECS client for region {aws_region}, if specified")
self.client = boto3.client("ecs", region_name=self.aws_region)
self.client = create_boto3_client(aws_service_name="ecs", region=self.aws_region,
client_options=self.client_options)

def set_desired_count(self, desired_count: int) -> CommandResult:
logger.info(f"Setting desired count for service {self.service_name} to {desired_count}")
Expand All @@ -47,7 +46,7 @@ def set_desired_count(self, desired_count: int) -> CommandResult:
desired_count = response["service"]["desiredCount"]
return CommandResult(True, f"Service {self.service_name} set to {desired_count} desired count."
f" Currently {running_count} running and {pending_count} pending.")

def get_instance_statuses(self) -> Optional[InstanceStatuses]:
logger.info(f"Getting instance statuses for service {self.service_name}")
response = self.client.describe_services(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum
from typing import Dict, Optional

from console_link.models.client_options import ClientOptions
from console_link.models.replayer_docker import DockerReplayer
from console_link.models.metrics_source import CloudwatchMetricsSource, PrometheusMetricsSource
from console_link.models.backfill_base import Backfill
Expand Down Expand Up @@ -55,9 +56,9 @@ def get_snapshot(config: Dict, source_cluster: Cluster):
raise UnsupportedSnapshotError(next(iter(config.keys())))


def get_replayer(config: Dict):
def get_replayer(config: Dict, client_options: Optional[ClientOptions] = None):
if 'ecs' in config:
return ECSReplayer(config)
return ECSReplayer(config=config, client_options=client_options)
if 'docker' in config:
return DockerReplayer(config)
logger.error(f"An unsupported replayer type was provided: {config.keys()}")
Expand All @@ -74,7 +75,8 @@ def get_kafka(config: Dict):
raise UnsupportedKafkaError(', '.join(config.keys()))


def get_backfill(config: Dict, source_cluster: Optional[Cluster], target_cluster: Optional[Cluster]) -> Backfill:
def get_backfill(config: Dict, source_cluster: Optional[Cluster], target_cluster: Optional[Cluster],
client_options: Optional[ClientOptions] = None) -> Backfill:
if BackfillType.opensearch_ingestion.name in config:
if source_cluster is None:
raise ValueError("source_cluster must be provided for OpenSearch Ingestion backfill")
Expand All @@ -83,7 +85,8 @@ def get_backfill(config: Dict, source_cluster: Optional[Cluster], target_cluster
logger.debug("Creating OpenSearch Ingestion backfill instance")
return OpenSearchIngestionBackfill(config=config,
source_cluster=source_cluster,
target_cluster=target_cluster)
target_cluster=target_cluster,
client_options=client_options)
elif BackfillType.reindex_from_snapshot.name in config:
if target_cluster is None:
raise ValueError("target_cluster must be provided for RFS backfill")
Expand All @@ -95,17 +98,18 @@ def get_backfill(config: Dict, source_cluster: Optional[Cluster], target_cluster
elif 'ecs' in config[BackfillType.reindex_from_snapshot.name]:
logger.debug("Creating ECS RFS backfill instance")
return ECSRFSBackfill(config=config,
target_cluster=target_cluster)
target_cluster=target_cluster,
client_options=client_options)

logger.error(f"An unsupported backfill source type was provided: {config.keys()}")
raise UnsupportedBackfillTypeError(', '.join(config.keys()))


def get_metrics_source(config):
def get_metrics_source(config, client_options: Optional[ClientOptions] = None):
if 'prometheus' in config:
return PrometheusMetricsSource(config)
return PrometheusMetricsSource(config=config, client_options=client_options)
elif 'cloudwatch' in config:
return CloudwatchMetricsSource(config)
return CloudwatchMetricsSource(config=config, client_options=client_options)
else:
logger.error(f"An unsupported metrics source type was provided: {config.keys()}")
raise UnsupportedMetricsSourceError(', '.join(config.keys()))
Loading

0 comments on commit 62d38b5

Please sign in to comment.