Skip to content

Commit

Permalink
[LoginNodes] Allow multiple login node pools (#6389)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjmakin committed Aug 13, 2024
1 parent 9882f24 commit d1af051
Show file tree
Hide file tree
Showing 19 changed files with 462 additions and 245 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ CHANGELOG
- Add support for ap-southeast-3 region.
- Add security groups to login node network load balancer.
- Add `AllowedIps` configuration for login nodes.
- Allow multiple login node pools.

**BUG FIXES**
- Fix validator `EfaPlacementGroupValidator` so that it does not suggest to configure a Placement Group when Capacity Blocks are used.
Expand Down
31 changes: 20 additions & 11 deletions cli/src/pcluster/api/controllers/cluster_operations_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,18 +265,27 @@ def describe_cluster(cluster_name, region=None):

def _get_login_nodes(cluster):
login_nodes_status = cluster.login_nodes_status

# TODO Fix once the API models are updated to support multiple pools in describe-cluster response
if login_nodes_status.get_login_nodes_pool_available():
status = LoginNodesState.FAILED
if login_nodes_status.get_status() == LoginNodesPoolState.ACTIVE:
status = LoginNodesState.ACTIVE
elif login_nodes_status.get_status() == LoginNodesPoolState.PENDING:
status = LoginNodesState.PENDING
login_nodes = LoginNodesPool(status=status)
login_nodes.address = login_nodes_status.get_address()
login_nodes.scheme = login_nodes_status.get_scheme()
login_nodes.healthy_nodes = login_nodes_status.get_healthy_nodes()
login_nodes.unhealthy_nodes = login_nodes_status.get_unhealthy_nodes()
return login_nodes
login_nodes = []

for _pool_name, pool_status in login_nodes_status.get_pool_status_dict().items():
status = LoginNodesState.FAILED
if pool_status.get_status() == LoginNodesPoolState.ACTIVE:
status = LoginNodesState.ACTIVE
elif pool_status.get_status() == LoginNodesPoolState.PENDING:
status = LoginNodesState.PENDING
pool = LoginNodesPool(status=status)
# pool.name = pool_name
pool.address = pool_status.get_address()
pool.scheme = pool_status.get_scheme()
pool.healthy_nodes = pool_status.get_healthy_nodes()
pool.unhealthy_nodes = pool_status.get_unhealthy_nodes()
login_nodes.append(pool)
break

return login_nodes[0]
return None


Expand Down
5 changes: 5 additions & 0 deletions cli/src/pcluster/config/cluster_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,11 @@ def __init__(
super().__init__(**kwargs)
self.pools = pools

def _register_validators(self, context: ValidatorContext = None): # noqa: D102 #pylint: disable=unused-argument
self._register_validator(
DuplicateNameValidator, name_list=[pool.name for pool in self.pools], resource_name="Pool"
)


class HeadNode(Resource):
"""Represent the Head Node resource."""
Expand Down
35 changes: 32 additions & 3 deletions cli/src/pcluster/config/update_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,25 @@ def condition_checker_login_nodes_stop_policy(_, patch):
return not patch.cluster.has_running_login_nodes()


def condition_checker_login_nodes_pool_stop_policy(change, patch):
"""Check if login nodes are running in the pool in which update was requested."""
pool_name = get_pool_name_from_change_paths(change)
return not patch.cluster.has_running_login_nodes(pool_name=pool_name)


def get_pool_name_from_change_paths(change):
"""
Return the name of the pool in which update was requested.
Example path=['LoginNodes', 'Pools[pool2]', 'Ssh'].
"""
for path in change.path:
if re.search("Pools\\[", path):
_, pool_name = extract_type_and_name_from_path(path)
return pool_name
return ""


# Common fail_reason messages
UpdatePolicy.FAIL_REASONS = {
"ebs_volume_resize": "Updating the file system after a resize operation requires commands specific to your "
Expand All @@ -501,6 +520,9 @@ def condition_checker_login_nodes_stop_policy(_, patch):
"pcluster update-compute-fleet command and then run an update with the --force-update flag"
),
"login_nodes_running": lambda change, patch: "The update is not supported when login nodes are running",
"login_nodes_pool_running": lambda change, patch: (
f"The update is not supported when login nodes in {get_pool_name_from_change_paths(change)} are running"
),
"compute_or_login_nodes_running": lambda change, patch: (
"The update is not supported when compute or login nodes are running"
),
Expand Down Expand Up @@ -669,8 +691,7 @@ def condition_checker_login_nodes_stop_policy(_, patch):
condition_checker=condition_checker_managed_fsx,
)


# Update policy for updating LoginNodes / Pools
# Update policy for adding or removing a login node pool
UpdatePolicy.LOGIN_NODES_POOLS = UpdatePolicy(
name="LOGIN_NODES_POOLS_UPDATE_POLICY",
level=6,
Expand All @@ -679,7 +700,7 @@ def condition_checker_login_nodes_stop_policy(_, patch):
action_needed=UpdatePolicy.ACTIONS_NEEDED["login_nodes_stop"],
)

# Update supported only with all login nodes down
# Update supported only with all login nodes in the cluster down
UpdatePolicy.LOGIN_NODES_STOP = UpdatePolicy(
name="LOGIN_NODES_STOP",
level=10,
Expand All @@ -688,6 +709,14 @@ def condition_checker_login_nodes_stop_policy(_, patch):
action_needed=UpdatePolicy.ACTIONS_NEEDED["login_nodes_stop"],
)

# Update supported only with all login nodes in the pool down
UpdatePolicy.LOGIN_NODES_POOL_STOP = UpdatePolicy(
name="LOGIN_NODES_POOL_STOP",
level=10,
condition_checker=condition_checker_login_nodes_pool_stop_policy,
fail_reason=UpdatePolicy.FAIL_REASONS["login_nodes_pool_running"],
action_needed=UpdatePolicy.ACTIONS_NEEDED["login_nodes_stop"],
)

# Update supported only with all computre and login nodes down
UpdatePolicy.COMPUTE_AND_LOGIN_NODES_STOP = UpdatePolicy(
Expand Down
22 changes: 13 additions & 9 deletions cli/src/pcluster/models/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,11 @@ def compute_fleet_status(self) -> ComputeFleetStatus:

@property
def login_nodes_status(self):
"""Status of the login nodes pool."""
"""Status of the login nodes."""
login_nodes_status = LoginNodesStatus(self.stack_name)
if self.stack.scheduler == "slurm" and self.config.login_nodes:
# This approach works since by design we have now only one pool.
# We should fix this if we want to add more than a login nodes pool per cluster.
login_nodes_status.retrieve_data(self.config.login_nodes.pools[0].name)
login_node_pool_names = [pool.name for pool in self.config.login_nodes.pools]
login_nodes_status.retrieve_data(login_node_pool_names)
return login_nodes_status

@property
Expand Down Expand Up @@ -755,13 +754,18 @@ def has_running_capacity(self, updated_value: bool = False) -> bool:
)
return self.__has_running_capacity

def has_running_login_nodes(self, updated_value: bool = False) -> bool:
"""Return True if the cluster has running login nodes. Note: the value will be cached."""
def has_running_login_nodes(self, updated_value: bool = False, pool_name: str = None) -> bool:
"""
Return True if the cluster has running login nodes, or a specific pool if a pool name is provided.
Note: the value will be cached.
"""
healthy_nodes = self.login_nodes_status.get_healthy_nodes(pool_name=pool_name)
unhealthy_nodes = self.login_nodes_status.get_unhealthy_nodes(pool_name=pool_name)

if self.__has_running_login_nodes is None or updated_value:
self.__has_running_login_nodes = (
self.login_nodes_status.get_healthy_nodes() is not None
and self.login_nodes_status.get_unhealthy_nodes() is not None
and self.login_nodes_status.get_healthy_nodes() + self.login_nodes_status.get_unhealthy_nodes() != 0
healthy_nodes is not None and unhealthy_nodes is not None and healthy_nodes + unhealthy_nodes != 0
)
return self.__has_running_login_nodes

Expand Down
117 changes: 89 additions & 28 deletions cli/src/pcluster/models/login_nodes_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,57 +28,57 @@ def __str__(self):
return str(self.value)


class LoginNodesStatus:
"""Represents the status of the cluster login nodes pools."""
class PoolStatus:
"""Represents the status of a pool of login nodes."""

def __init__(self, stack_name):
self._stack_name = stack_name
self._login_nodes_pool_name = None
self._login_nodes_pool_available = False
self._load_balancer_arn = None
self._target_group_arn = None
self._status = None
def __init__(self, stack_name, pool_name):
self._dns_name = None
self._status = None
self._scheme = None
self._pool_name = pool_name
self._pool_available = False
self._stack_name = stack_name
self._healthy_nodes = None
self._unhealthy_nodes = None
self._load_balancer_arn = None
self._target_group_arn = None
self._retrieve_data()

def __str__(self):
return (
f'("status": "{self._status}", "address": "{self._dns_name}", "scheme": "{self._scheme}", '
f'"healthyNodes": "{self._healthy_nodes}", "unhealthy_nodes": "{self._unhealthy_nodes}")'
f'"healthy_nodes": "{self._healthy_nodes}", "unhealthy_nodes": "{self._unhealthy_nodes}"),'
)

def get_login_nodes_pool_available(self):
"""Return the status of a login nodes fleet."""
return self._login_nodes_pool_available
def get_healthy_nodes(self):
"""Return the number of healthy nodes of the login node pool."""
return self._healthy_nodes

def get_unhealthy_nodes(self):
"""Return the number of unhealthy nodes of the login node pool."""
return self._unhealthy_nodes

def get_pool_available(self):
"""Return true if the pool is available."""
return self._pool_available

def get_status(self):
"""Return the status of a login nodes fleet."""
"""Return the status of the login node pool."""
return self._status

def get_address(self):
"""Return the single connection address of a login nodes fleet."""
"""Return the connection addresses of the login node pool."""
return self._dns_name

def get_scheme(self):
"""Return the schema of a login nodes fleet."""
"""Return the schema of the login node pool."""
return self._scheme

def get_healthy_nodes(self):
"""Return the number of healthy nodes of a login nodes fleet."""
return self._healthy_nodes

def get_unhealthy_nodes(self):
"""Return the number of unhealthy nodes of a login nodes fleet."""
return self._unhealthy_nodes

def retrieve_data(self, login_nodes_pool_name):
def _retrieve_data(self):
"""Initialize the class with the information related to the login nodes pool."""
self._login_nodes_pool_name = login_nodes_pool_name
self._retrieve_assigned_load_balancer()
if self._load_balancer_arn:
self._login_nodes_pool_available = True
self._pool_available = True
self._populate_target_groups()
self._populate_target_group_health()

Expand All @@ -98,7 +98,7 @@ def _load_balancer_arn_from_tags(self, tags_list):
for tags in tags_list:
if self._key_value_tag_found(
tags, "parallelcluster:cluster-name", self._stack_name
) and self._key_value_tag_found(tags, "parallelcluster:login-nodes-pool", self._login_nodes_pool_name):
) and self._key_value_tag_found(tags, "parallelcluster:login-nodes-pool", self._pool_name):
self._load_balancer_arn = tags.get("ResourceArn")
break

Expand Down Expand Up @@ -153,3 +153,64 @@ def _populate_target_group_health(self):
"This is expected if login nodes pool creation/deletion is in progress",
e,
)


class LoginNodesStatus:
"""Represents the status of the cluster login nodes pools."""

def __init__(self, stack_name):
self._stack_name = stack_name
self._pool_status_dict = dict()
self._login_nodes_pool_available = False
self._total_healthy_nodes = None
self._total_unhealthy_nodes = None

def __str__(self):
out = ""
for pool_status in self._pool_status_dict.values():
out += str(pool_status)
return out

def get_login_nodes_pool_available(self):
"""Return true if a pool is available in the login nodes fleet."""
return self._login_nodes_pool_available

def get_pool_status_dict(self):
"""Return a dictionary mapping each login node pool name to respective pool status."""
return self._pool_status_dict

def get_healthy_nodes(self, pool_name=None):
"""Return the total number of healthy login nodes in the cluster or a specific pool."""
healthy_nodes = (
self._pool_status_dict.get(pool_name).get_healthy_nodes() if pool_name else self._total_healthy_nodes
)
return healthy_nodes

def get_unhealthy_nodes(self, pool_name=None):
"""Return the total number of unhealthy login nodes in the cluster or a specific pool."""
unhealthy_nodes = (
self._pool_status_dict.get(pool_name).get_unhealthy_nodes() if pool_name else self._total_unhealthy_nodes
)
return unhealthy_nodes

def retrieve_data(self, login_node_pool_names):
"""Initialize the class with the information related to the login node fleet."""
for pool_name in login_node_pool_names:
self._pool_status_dict[pool_name] = PoolStatus(self._stack_name, pool_name)
self._total_healthy_nodes = sum(
(
pool_status.get_healthy_nodes()
for pool_status in self._pool_status_dict.values()
if pool_status.get_healthy_nodes()
)
)
self._total_unhealthy_nodes = sum(
(
pool_status.get_unhealthy_nodes()
for pool_status in self._pool_status_dict.values()
if pool_status.get_unhealthy_nodes()
)
)
self._login_nodes_pool_available = any(
(pool_status.get_pool_available() for pool_status in self._pool_status_dict.values())
)
Loading

0 comments on commit d1af051

Please sign in to comment.