From 6af232c907078cdb0bad1bfe793c4483db285861 Mon Sep 17 00:00:00 2001 From: Avram Lubkin Date: Thu, 17 Nov 2022 22:06:14 -0500 Subject: [PATCH] Typing fixes for new version of mypy --- lisa/environment.py | 6 +- lisa/node.py | 2 +- lisa/sut_orchestrator/aws/common.py | 13 +- lisa/sut_orchestrator/azure/common.py | 216 ++++++++++------------- lisa/sut_orchestrator/azure/platform_.py | 2 +- lisa/tools/bzip2.py | 1 + lisa/util/__init__.py | 16 +- 7 files changed, 108 insertions(+), 148 deletions(-) diff --git a/lisa/environment.py b/lisa/environment.py index a3dcf4c797..b5a01a8757 100644 --- a/lisa/environment.py +++ b/lisa/environment.py @@ -487,15 +487,13 @@ def load_environments( class EnvironmentHookSpec: @hookspec def get_environment_information(self, environment: Environment) -> Dict[str, str]: - ... + raise NotImplementedError class EnvironmentHookImpl: @hookimpl def get_environment_information(self, environment: Environment) -> Dict[str, str]: - information: Dict[str, str] = {} - information["name"] = environment.name - + information: Dict[str, str] = {"name": environment.name} if environment.nodes: node = environment.default_node try: diff --git a/lisa/node.py b/lisa/node.py index bd585b03ac..da8ad184aa 100644 --- a/lisa/node.py +++ b/lisa/node.py @@ -733,7 +733,7 @@ def quick_connect( class NodeHookSpec: @hookspec def get_node_information(self, node: Node) -> Dict[str, str]: - ... + raise NotImplementedError class NodeHookImpl: diff --git a/lisa/sut_orchestrator/aws/common.py b/lisa/sut_orchestrator/aws/common.py index b014494864..169bacdfd5 100644 --- a/lisa/sut_orchestrator/aws/common.py +++ b/lisa/sut_orchestrator/aws/common.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import InitVar, dataclass, field +from dataclasses import dataclass, field from typing import Dict, List, Optional from dataclasses_json import dataclass_json @@ -67,16 +67,13 @@ class AwsNodeSchema: data_disk_size: int = 32 disk_type: str = "" - # for marketplace image, which need to accept terms - _marketplace: InitVar[Optional[AwsVmMarketplaceSchema]] = None + def __post_init__(self) -> None: + # Caching for marketplace image + self._marketplace: Optional[AwsVmMarketplaceSchema] = None @property def marketplace(self) -> AwsVmMarketplaceSchema: - # this is a safe guard and prevent mypy error on typing - if not hasattr(self, "_marketplace"): - self._marketplace: Optional[AwsVmMarketplaceSchema] = None - - if not self._marketplace: + if self._marketplace is None: assert isinstance( self.marketplace_raw, str ), f"actual: {type(self.marketplace_raw)}" diff --git a/lisa/sut_orchestrator/azure/common.py b/lisa/sut_orchestrator/azure/common.py index 95029dadec..7566e9fdb5 100644 --- a/lisa/sut_orchestrator/azure/common.py +++ b/lisa/sut_orchestrator/azure/common.py @@ -3,7 +3,7 @@ import re import sys -from dataclasses import InitVar, dataclass, field +from dataclasses import dataclass, field from datetime import datetime, timedelta from pathlib import Path from threading import Lock @@ -187,13 +187,12 @@ class AzureNodeSchema: # image. is_linux: Optional[bool] = None - _marketplace: InitVar[Optional[AzureVmMarketplaceSchema]] = None + def __post_init__(self) -> None: + # Caching + self._marketplace: Optional[AzureVmMarketplaceSchema] = None + self._shared_gallery: Optional[SharedImageGallerySchema] = None + self._vhd: Optional[VhdSchema] = None - _shared_gallery: InitVar[Optional[SharedImageGallerySchema]] = None - - _vhd: InitVar[Optional[VhdSchema]] = None - - def __post_init__(self, *args: Any, **kwargs: Any) -> None: # trim whitespace of values. strip_strs( self, @@ -216,109 +215,96 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None: @property def marketplace(self) -> Optional[AzureVmMarketplaceSchema]: - # this is a safe guard and prevent mypy error on typing - if not hasattr(self, "_marketplace"): - self._marketplace: Optional[AzureVmMarketplaceSchema] = None - marketplace: Optional[AzureVmMarketplaceSchema] = self._marketplace - if not marketplace: - if isinstance(self.marketplace_raw, dict): + if self._marketplace is not None: + return self._marketplace + + if isinstance(self.marketplace_raw, dict): + # Users decide the cases of image names, + # inconsistent cases cause a mismatch error in notifiers. + # lower() normalizes the image names, it has no impact on deployment + self.marketplace_raw = { + k: v.lower() for k, v in self.marketplace_raw.items() + } + self._marketplace = schema.load_by_type( + AzureVmMarketplaceSchema, self.marketplace_raw + ) + # Validated marketplace_raw and filter out any unwanted content + self.marketplace_raw = self._marketplace.to_dict() # type: ignore + + elif self.marketplace_raw: + assert isinstance( + self.marketplace_raw, str + ), f"actual: {type(self.marketplace_raw)}" + + self.marketplace_raw = self.marketplace_raw.strip() + + if self.marketplace_raw: # Users decide the cases of image names, - # the inconsistent cases cause the mismatched error in notifiers. - # The lower() normalizes the image names, - # it has no impact on deployment. - self.marketplace_raw = dict( - (k, v.lower()) for k, v in self.marketplace_raw.items() - ) - marketplace = schema.load_by_type( - AzureVmMarketplaceSchema, self.marketplace_raw - ) - # this step makes marketplace_raw is validated, and - # filter out any unwanted content. - self.marketplace_raw = marketplace.to_dict() # type: ignore - elif self.marketplace_raw: - assert isinstance( - self.marketplace_raw, str - ), f"actual: {type(self.marketplace_raw)}" - - self.marketplace_raw = self.marketplace_raw.strip() - - if self.marketplace_raw: - # Users decide the cases of image names, - # the inconsistent cases cause the mismatched error in notifiers. - # The lower() normalizes the image names, - # it has no impact on deployment. - marketplace_strings = re.split( - r"[:\s]+", self.marketplace_raw.lower() + # inconsistent cases cause a mismatch error in notifiers. + # lower() normalizes the image names, it has no impact on deployment + marketplace_strings = re.split(r"[:\s]+", self.marketplace_raw.lower()) + + if len(marketplace_strings) != 4: + raise LisaException( + "Invalid value for the provided marketplace " + f"parameter: '{self.marketplace_raw}'." + "The marketplace parameter should be in the format: " + "' ' " + "or ':::'" ) + self._marketplace = AzureVmMarketplaceSchema(*marketplace_strings) + # marketplace_raw is used + self.marketplace_raw = ( + self._marketplace.to_dict() # type: ignore [attr-defined] + ) - if len(marketplace_strings) == 4: - marketplace = AzureVmMarketplaceSchema(*marketplace_strings) - # marketplace_raw is used - self.marketplace_raw = marketplace.to_dict() # type: ignore - else: - raise LisaException( - f"Invalid value for the provided marketplace " - f"parameter: '{self.marketplace_raw}'." - f"The marketplace parameter should be in the format: " - f"' ' " - f"or ':::'" - ) - self._marketplace = marketplace - return marketplace + return self._marketplace @marketplace.setter def marketplace(self, value: Optional[AzureVmMarketplaceSchema]) -> None: self._marketplace = value - if value is None: - self.marketplace_raw = None - else: - self.marketplace_raw = value.to_dict() # type: ignore + # dataclass_json doesn't use a protocol return type, so to_dict() is unknown + self.marketplace_raw = ( + None if value is None else value.to_dict() # type: ignore [attr-defined] + ) @property def shared_gallery(self) -> Optional[SharedImageGallerySchema]: - # this is a safe guard and prevent mypy error on typing - if not hasattr(self, "_shared_gallery"): - self._shared_gallery: Optional[SharedImageGallerySchema] = None - shared_gallery: Optional[SharedImageGallerySchema] = self._shared_gallery - if shared_gallery: - return shared_gallery + if self._shared_gallery is not None: + return self._shared_gallery + if isinstance(self.shared_gallery_raw, dict): # Users decide the cases of image names, - # the inconsistent cases cause the mismatched error in notifiers. - # The lower() normalizes the image names, - # it has no impact on deployment. - self.shared_gallery_raw = dict( - (k, v.lower()) for k, v in self.shared_gallery_raw.items() - ) - shared_gallery = schema.load_by_type( + # inconsistent cases cause a mismatch error in notifiers. + # lower() normalizes the image names, it has no impact on deployment + self.shared_gallery_raw = { + k: v.lower() for k, v in self.shared_gallery_raw.items() + } + + self._shared_gallery = schema.load_by_type( SharedImageGallerySchema, self.shared_gallery_raw ) - if not shared_gallery.subscription_id: - shared_gallery.subscription_id = self.subscription_id - # this step makes shared_gallery_raw is validated, and - # filter out any unwanted content. - self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore + if not self._shared_gallery.subscription_id: + self._shared_gallery.subscription_id = self.subscription_id + # Validated shared_gallery_raw and filter out any unwanted content + self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore + elif self.shared_gallery_raw: assert isinstance( self.shared_gallery_raw, str ), f"actual: {type(self.shared_gallery_raw)}" # Users decide the cases of image names, - # the inconsistent cases cause the mismatched error in notifiers. - # The lower() normalizes the image names, - # it has no impact on deployment. + # inconsistent cases cause a mismatch error in notifiers. + # lower() normalizes the image names, it has no impact on deployment shared_gallery_strings = re.split( r"[/]+", self.shared_gallery_raw.strip().lower() ) if len(shared_gallery_strings) == 5: - shared_gallery = SharedImageGallerySchema(*shared_gallery_strings) - # shared_gallery_raw is used - self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore + self._shared_gallery = SharedImageGallerySchema(*shared_gallery_strings) elif len(shared_gallery_strings) == 3: - shared_gallery = SharedImageGallerySchema( + self._shared_gallery = SharedImageGallerySchema( self.subscription_id, None, *shared_gallery_strings ) - # shared_gallery_raw is used - self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore else: raise LisaException( f"Invalid value for the provided shared gallery " @@ -328,51 +314,43 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]: f"/' or '/" f"/'" ) - self._shared_gallery = shared_gallery - return shared_gallery + self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore + + return self._shared_gallery @shared_gallery.setter def shared_gallery(self, value: Optional[SharedImageGallerySchema]) -> None: self._shared_gallery = value - if value is None: - self.shared_gallery_raw = None - else: - self.shared_gallery_raw = value.to_dict() # type: ignore + # dataclass_json doesn't use a protocol return type, so to_dict() is unknown + self.shared_gallery_raw = ( + None if value is None else value.to_dict() # type: ignore [attr-defined] + ) @property def vhd(self) -> Optional[VhdSchema]: - # this is a safe guard and prevent mypy error on typing - if not hasattr(self, "_vhd"): - self._vhd: Optional[VhdSchema] = None - vhd: Optional[VhdSchema] = self._vhd - if vhd: - return vhd + if self._vhd is not None: + return self._vhd + if isinstance(self.vhd_raw, dict): - vhd = schema.load_by_type(VhdSchema, self.vhd_raw) - add_secret(vhd.vhd_path, PATTERN_URL) - if vhd.vmgs_path: - add_secret(vhd.vmgs_path, PATTERN_URL) - # this step makes vhd_raw is validated, and - # filter out any unwanted content. - self.vhd_raw = vhd.to_dict() # type: ignore + self._vhd = schema.load_by_type(VhdSchema, self.vhd_raw) + add_secret(self._vhd.vhd_path, PATTERN_URL) + if self._vhd.vmgs_path: + add_secret(self._vhd.vmgs_path, PATTERN_URL) + # Validated vhd_raw and filter out any unwanted content + self.vhd_raw = self._vhd.to_dict() # type: ignore + elif self.vhd_raw is not None: assert isinstance(self.vhd_raw, str), f"actual: {type(self.vhd_raw)}" - vhd = VhdSchema(self.vhd_raw) - add_secret(vhd.vhd_path, PATTERN_URL) - self.vhd_raw = vhd.to_dict() # type: ignore - self._vhd = vhd - if vhd: - return vhd - else: - return None + self._vhd = VhdSchema(self.vhd_raw) + add_secret(self._vhd.vhd_path, PATTERN_URL) + self.vhd_raw = self._vhd.to_dict() # type: ignore + + return self._vhd @vhd.setter def vhd(self, value: Optional[VhdSchema]) -> None: self._vhd = value - if value is None: - self.vhd_raw = None - else: - self.vhd_raw = self._vhd.to_dict() # type: ignore + self.vhd_raw = None if value is None else self._vhd.to_dict() # type: ignore def get_image_name(self) -> str: result = "" @@ -383,7 +361,7 @@ def get_image_name(self) -> str: self.shared_gallery_raw, dict ), f"actual type: {type(self.shared_gallery_raw)}" if self.shared_gallery.resource_group_name: - result = "/".join([x for x in self.shared_gallery_raw.values()]) + result = "/".join(self.shared_gallery_raw.values()) else: result = ( f"{self.shared_gallery.image_gallery}/" @@ -394,7 +372,7 @@ def get_image_name(self) -> str: assert isinstance( self.marketplace_raw, dict ), f"actual type: {type(self.marketplace_raw)}" - result = " ".join([x for x in self.marketplace_raw.values()]) + result = " ".join(self.marketplace_raw.values()) return result @@ -418,9 +396,7 @@ def from_node_runbook(cls, runbook: AzureNodeSchema) -> "AzureNodeArmParameter": parameters["vhd_raw"] = parameters["vhd"] del parameters["vhd"] - arm_parameters = AzureNodeArmParameter(**parameters) - - return arm_parameters + return AzureNodeArmParameter(**parameters) class DataDiskCreateOption: diff --git a/lisa/sut_orchestrator/azure/platform_.py b/lisa/sut_orchestrator/azure/platform_.py index cd292c2eab..3a1566c275 100644 --- a/lisa/sut_orchestrator/azure/platform_.py +++ b/lisa/sut_orchestrator/azure/platform_.py @@ -2104,7 +2104,7 @@ def _get_vhd_os_disk_size(self, blob_url: str) -> int: assert properties.size, f"fail to get blob size of {blob_url}" # Azure requires only megabyte alignment of vhds, round size up # for cases where the size is megabyte aligned - return math.ceil(properties.size / 1024 / 1024 / 1024) + return int(math.ceil(properties.size / 1024 / 1024 / 1024)) def _get_sig_info( self, shared_image: SharedImageGallerySchema diff --git a/lisa/tools/bzip2.py b/lisa/tools/bzip2.py index 525a50ac8d..341431c79f 100644 --- a/lisa/tools/bzip2.py +++ b/lisa/tools/bzip2.py @@ -9,6 +9,7 @@ class Bzip2(Tool): def command(self) -> str: return "bzip2" + @property def can_install(self) -> bool: return True diff --git a/lisa/util/__init__.py b/lisa/util/__init__.py index 1dcebe3910..a80cdf49cb 100644 --- a/lisa/util/__init__.py +++ b/lisa/util/__init__.py @@ -161,9 +161,7 @@ def __init__(self, os: "OperatingSystem", message: str = "") -> None: self.version = os.information.full_version self.kernel_version = "" if hasattr(os, "get_kernel_information"): - self.kernel_version = ( - os.get_kernel_information().raw_version # type: ignore - ) + self.kernel_version = os.get_kernel_information().raw_version self._extended_message = message def __str__(self) -> str: @@ -506,18 +504,8 @@ def find_group_in_lines( def deep_update_dict(src: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]: - if ( - dest is None - or isinstance(dest, int) - or isinstance(dest, bool) - or isinstance(dest, float) - or isinstance(dest, str) - ): - result = dest - else: + if isinstance(dest, dict): result = dest.copy() - - if isinstance(result, dict): for key, value in src.items(): if isinstance(value, dict) and key in dest: value = deep_update_dict(value, dest[key])