Skip to content

Commit

Permalink
Typing fixes for new version of mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
avylove committed Jan 23, 2023
1 parent 93f4efd commit fa09b81
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 96 deletions.
6 changes: 2 additions & 4 deletions lisa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,13 @@ def load_environments(
class EnvironmentHookSpec:
@hookspec
def get_environment_information(self, environment: Environment) -> Dict[str, str]:
...
raise NotImplementedError


class EnvironmentHookImpl:
@hookimpl
def get_environment_information(self, environment: Environment) -> Dict[str, str]:
information: Dict[str, str] = {}
information["name"] = environment.name

information: Dict[str, str] = {"name": environment.name}
if environment.nodes:
node = environment.default_node
try:
Expand Down
13 changes: 6 additions & 7 deletions lisa/sut_orchestrator/aws/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import InitVar, dataclass, field
from dataclasses import dataclass, field
from typing import Dict, List, Optional

from dataclasses_json import dataclass_json
Expand Down Expand Up @@ -67,16 +67,15 @@ class AwsNodeSchema:
data_disk_size: int = 32
disk_type: str = ""

# for marketplace image, which need to accept terms
_marketplace: InitVar[Optional[AwsVmMarketplaceSchema]] = None
def __post_init__(self) -> None:

# Caching for marketplace image
self._marketplace: Optional[AwsVmMarketplaceSchema] = None

@property
def marketplace(self) -> AwsVmMarketplaceSchema:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_marketplace"):
self._marketplace: Optional[AwsVmMarketplaceSchema] = None

if not self._marketplace:
if self._marketplace is None:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"
Expand Down
142 changes: 71 additions & 71 deletions lisa/sut_orchestrator/azure/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import re
import sys
from dataclasses import InitVar, dataclass, field
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from threading import Lock
Expand Down Expand Up @@ -175,11 +175,12 @@ class AzureNodeSchema:
# image.
is_linux: Optional[bool] = None

_marketplace: InitVar[Optional[AzureVmMarketplaceSchema]] = None
def __post_init__(self) -> None:

_shared_gallery: InitVar[Optional[SharedImageGallerySchema]] = None
# Caching
self._marketplace: Optional[AzureVmMarketplaceSchema] = None
self._shared_gallery: Optional[SharedImageGallerySchema] = None

def __post_init__(self, *args: Any, **kwargs: Any) -> None:
# trim whitespace of values.
strip_strs(
self,
Expand All @@ -201,80 +202,78 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:

@property
def marketplace(self) -> Optional[AzureVmMarketplaceSchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_marketplace"):
self._marketplace: Optional[AzureVmMarketplaceSchema] = None
marketplace: Optional[AzureVmMarketplaceSchema] = self._marketplace
if not marketplace:
if isinstance(self.marketplace_raw, dict):

if self._marketplace is not None:
return self._marketplace

if isinstance(self.marketplace_raw, dict):
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
self.marketplace_raw = {
k: v.lower() for k, v in self.marketplace_raw.items()
}
self._marketplace = schema.load_by_type(
AzureVmMarketplaceSchema, self.marketplace_raw
)
# This step makes sure marketplace_raw is validated, and
# filters out any unwanted content.
self.marketplace_raw = self._marketplace.to_dict() # type: ignore

elif self.marketplace_raw:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"

self.marketplace_raw = self.marketplace_raw.strip()

if self.marketplace_raw:
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
self.marketplace_raw = dict(
(k, v.lower()) for k, v in self.marketplace_raw.items()
)
marketplace = schema.load_by_type(
AzureVmMarketplaceSchema, self.marketplace_raw
)
# this step makes marketplace_raw is validated, and
# filter out any unwanted content.
self.marketplace_raw = marketplace.to_dict() # type: ignore
elif self.marketplace_raw:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"

self.marketplace_raw = self.marketplace_raw.strip()

if self.marketplace_raw:
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
marketplace_strings = re.split(
r"[:\s]+", self.marketplace_raw.lower()
marketplace_strings = re.split(r"[:\s]+", self.marketplace_raw.lower())

if len(marketplace_strings) != 4:
raise LisaException(
"Invalid value for the provided marketplace "
f"parameter: '{self.marketplace_raw}'."
"The marketplace parameter should be in the format: "
"'<Publisher> <Offer> <Sku> <Version>' "
"or '<Publisher>:<Offer>:<Sku>:<Version>'"
)
self._marketplace = AzureVmMarketplaceSchema(*marketplace_strings)
# marketplace_raw is used
self.marketplace_raw = (
self._marketplace.to_dict() # type: ignore [attr-defined]
)

if len(marketplace_strings) == 4:
marketplace = AzureVmMarketplaceSchema(*marketplace_strings)
# marketplace_raw is used
self.marketplace_raw = marketplace.to_dict() # type: ignore
else:
raise LisaException(
f"Invalid value for the provided marketplace "
f"parameter: '{self.marketplace_raw}'."
f"The marketplace parameter should be in the format: "
f"'<Publisher> <Offer> <Sku> <Version>' "
f"or '<Publisher>:<Offer>:<Sku>:<Version>'"
)
self._marketplace = marketplace
return marketplace
return self._marketplace

@marketplace.setter
def marketplace(self, value: Optional[AzureVmMarketplaceSchema]) -> None:
self._marketplace = value
if value is None:
self.marketplace_raw = None
else:
self.marketplace_raw = value.to_dict() # type: ignore
# dataclass_json doesn't use a protocol return type, so to_dict() is unknown
self.marketplace_raw = (
None if value is None else value.to_dict() # type: ignore [attr-defined]
)

@property
def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_shared_gallery"):
self._shared_gallery: Optional[SharedImageGallerySchema] = None
shared_gallery: Optional[SharedImageGallerySchema] = self._shared_gallery
if shared_gallery:
return shared_gallery

if self._shared_gallery is not None:
return self._shared_gallery

if isinstance(self.shared_gallery_raw, dict):
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
self.shared_gallery_raw = dict(
(k, v.lower()) for k, v in self.shared_gallery_raw.items()
)
self.shared_gallery_raw = {
k: v.lower() for k, v in self.shared_gallery_raw.items()
}

shared_gallery = schema.load_by_type(
SharedImageGallerySchema, self.shared_gallery_raw
)
Expand All @@ -283,6 +282,8 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
# this step makes shared_gallery_raw is validated, and
# filter out any unwanted content.
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
self._shared_gallery = shared_gallery

elif self.shared_gallery_raw:
assert isinstance(
self.shared_gallery_raw, str
Expand All @@ -299,11 +300,12 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
# shared_gallery_raw is used
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
elif len(shared_gallery_strings) == 3:
shared_gallery = SharedImageGallerySchema(
self._shared_gallery = SharedImageGallerySchema(
self.subscription_id, None, *shared_gallery_strings
)
# shared_gallery_raw is used
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore

else:
raise LisaException(
f"Invalid value for the provided shared gallery "
Expand All @@ -313,16 +315,16 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
f"<image_definition>/<image_version>' or '<image_gallery>/"
f"<image_definition>/<image_version>'"
)
self._shared_gallery = shared_gallery
return shared_gallery

return self._shared_gallery

@shared_gallery.setter
def shared_gallery(self, value: Optional[SharedImageGallerySchema]) -> None:
self._shared_gallery = value
if value is None:
self.shared_gallery_raw = None
else:
self.shared_gallery_raw = value.to_dict() # type: ignore
# dataclass_json doesn't use a protocol return type, so to_dict() is unknown
self.shared_gallery_raw = (
None if value is None else value.to_dict() # type: ignore [attr-defined]
)

def get_image_name(self) -> str:
result = ""
Expand Down Expand Up @@ -365,9 +367,7 @@ def from_node_runbook(cls, runbook: AzureNodeSchema) -> "AzureNodeArmParameter":
parameters["shared_gallery_raw"] = parameters["shared_gallery"]
del parameters["shared_gallery"]

arm_parameters = AzureNodeArmParameter(**parameters)

return arm_parameters
return AzureNodeArmParameter(**parameters)


class DataDiskCreateOption:
Expand Down
1 change: 1 addition & 0 deletions lisa/tools/bzip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Bzip2(Tool):
def command(self) -> str:
return "bzip2"

@property
def can_install(self) -> bool:
return True

Expand Down
17 changes: 3 additions & 14 deletions lisa/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ def __init__(self, os: "OperatingSystem", message: str = "") -> None:
self.version = os.information.full_version
self.kernel_version = ""
if hasattr(os, "get_kernel_information"):
self.kernel_version = (
os.get_kernel_information().raw_version # type: ignore
)
self.kernel_version = os.get_kernel_information().raw_version
self._extended_message = message

def __str__(self) -> str:
Expand Down Expand Up @@ -505,18 +503,9 @@ def find_group_in_lines(


def deep_update_dict(src: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]:
if (
dest is None
or isinstance(dest, int)
or isinstance(dest, bool)
or isinstance(dest, float)
or isinstance(dest, str)
):
result = dest
else:
result = dest.copy()

if isinstance(result, dict):
if isinstance(dest, dict):
result = dest.copy()
for key, value in src.items():
if isinstance(value, dict) and key in dest:
value = deep_update_dict(value, dest[key])
Expand Down

0 comments on commit fa09b81

Please sign in to comment.