Skip to content

Commit

Permalink
Typing fixes for new version of mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
avylove committed Apr 26, 2023
1 parent c057eb5 commit d2befb6
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 202 deletions.
6 changes: 2 additions & 4 deletions lisa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,15 +493,13 @@ def load_environments(
class EnvironmentHookSpec:
@hookspec
def get_environment_information(self, environment: Environment) -> Dict[str, str]:
...
raise NotImplementedError


class EnvironmentHookImpl:
@hookimpl
def get_environment_information(self, environment: Environment) -> Dict[str, str]:
information: Dict[str, str] = {}
information["name"] = environment.name

information: Dict[str, str] = {"name": environment.name}
if environment.nodes:
node = environment.default_node
try:
Expand Down
2 changes: 1 addition & 1 deletion lisa/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def quick_connect(
class NodeHookSpec:
@hookspec
def get_node_information(self, node: Node) -> Dict[str, str]:
...
raise NotImplementedError


class NodeHookImpl:
Expand Down
13 changes: 5 additions & 8 deletions lisa/sut_orchestrator/aws/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import InitVar, dataclass, field
from dataclasses import dataclass, field
from typing import Dict, List, Optional

from dataclasses_json import dataclass_json
Expand Down Expand Up @@ -67,16 +67,13 @@ class AwsNodeSchema:
data_disk_size: int = 32
disk_type: str = ""

# for marketplace image, which need to accept terms
_marketplace: InitVar[Optional[AwsVmMarketplaceSchema]] = None
def __post_init__(self) -> None:
# Caching for marketplace image
self._marketplace: Optional[AwsVmMarketplaceSchema] = None

@property
def marketplace(self) -> AwsVmMarketplaceSchema:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_marketplace"):
self._marketplace: Optional[AwsVmMarketplaceSchema] = None

if not self._marketplace:
if self._marketplace is None:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"
Expand Down
216 changes: 96 additions & 120 deletions lisa/sut_orchestrator/azure/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import re
import sys
from dataclasses import InitVar, dataclass, field
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from functools import lru_cache
from pathlib import Path
Expand Down Expand Up @@ -225,13 +225,12 @@ class AzureNodeSchema:
# image.
is_linux: Optional[bool] = None

_marketplace: InitVar[Optional[AzureVmMarketplaceSchema]] = None
def __post_init__(self) -> None:
# Caching
self._marketplace: Optional[AzureVmMarketplaceSchema] = None
self._shared_gallery: Optional[SharedImageGallerySchema] = None
self._vhd: Optional[VhdSchema] = None

_shared_gallery: InitVar[Optional[SharedImageGallerySchema]] = None

_vhd: InitVar[Optional[VhdSchema]] = None

def __post_init__(self, *args: Any, **kwargs: Any) -> None:
# trim whitespace of values.
strip_strs(
self,
Expand All @@ -254,109 +253,96 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:

@property
def marketplace(self) -> Optional[AzureVmMarketplaceSchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_marketplace"):
self._marketplace: Optional[AzureVmMarketplaceSchema] = None
marketplace: Optional[AzureVmMarketplaceSchema] = self._marketplace
if not marketplace:
if isinstance(self.marketplace_raw, dict):
if self._marketplace is not None:
return self._marketplace

if isinstance(self.marketplace_raw, dict):
# Users decide the cases of image names,
# inconsistent cases cause a mismatch error in notifiers.
# lower() normalizes the image names, it has no impact on deployment
self.marketplace_raw = {
k: v.lower() for k, v in self.marketplace_raw.items()
}
self._marketplace = schema.load_by_type(
AzureVmMarketplaceSchema, self.marketplace_raw
)
# Validated marketplace_raw and filter out any unwanted content
self.marketplace_raw = self._marketplace.to_dict() # type: ignore

elif self.marketplace_raw:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"

self.marketplace_raw = self.marketplace_raw.strip()

if self.marketplace_raw:
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
self.marketplace_raw = dict(
(k, v.lower()) for k, v in self.marketplace_raw.items()
)
marketplace = schema.load_by_type(
AzureVmMarketplaceSchema, self.marketplace_raw
)
# this step makes marketplace_raw is validated, and
# filter out any unwanted content.
self.marketplace_raw = marketplace.to_dict() # type: ignore
elif self.marketplace_raw:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"

self.marketplace_raw = self.marketplace_raw.strip()

if self.marketplace_raw:
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
marketplace_strings = re.split(
r"[:\s]+", self.marketplace_raw.lower()
# inconsistent cases cause a mismatch error in notifiers.
# lower() normalizes the image names, it has no impact on deployment
marketplace_strings = re.split(r"[:\s]+", self.marketplace_raw.lower())

if len(marketplace_strings) != 4:
raise LisaException(
"Invalid value for the provided marketplace "
f"parameter: '{self.marketplace_raw}'."
"The marketplace parameter should be in the format: "
"'<Publisher> <Offer> <Sku> <Version>' "
"or '<Publisher>:<Offer>:<Sku>:<Version>'"
)
self._marketplace = AzureVmMarketplaceSchema(*marketplace_strings)
# marketplace_raw is used
self.marketplace_raw = (
self._marketplace.to_dict() # type: ignore [attr-defined]
)

if len(marketplace_strings) == 4:
marketplace = AzureVmMarketplaceSchema(*marketplace_strings)
# marketplace_raw is used
self.marketplace_raw = marketplace.to_dict() # type: ignore
else:
raise LisaException(
f"Invalid value for the provided marketplace "
f"parameter: '{self.marketplace_raw}'."
f"The marketplace parameter should be in the format: "
f"'<Publisher> <Offer> <Sku> <Version>' "
f"or '<Publisher>:<Offer>:<Sku>:<Version>'"
)
self._marketplace = marketplace
return marketplace
return self._marketplace

@marketplace.setter
def marketplace(self, value: Optional[AzureVmMarketplaceSchema]) -> None:
self._marketplace = value
if value is None:
self.marketplace_raw = None
else:
self.marketplace_raw = value.to_dict() # type: ignore
# dataclass_json doesn't use a protocol return type, so to_dict() is unknown
self.marketplace_raw = (
None if value is None else value.to_dict() # type: ignore [attr-defined]
)

@property
def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_shared_gallery"):
self._shared_gallery: Optional[SharedImageGallerySchema] = None
shared_gallery: Optional[SharedImageGallerySchema] = self._shared_gallery
if shared_gallery:
return shared_gallery
if self._shared_gallery is not None:
return self._shared_gallery

if isinstance(self.shared_gallery_raw, dict):
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
self.shared_gallery_raw = dict(
(k, v.lower()) for k, v in self.shared_gallery_raw.items()
)
shared_gallery = schema.load_by_type(
# inconsistent cases cause a mismatch error in notifiers.
# lower() normalizes the image names, it has no impact on deployment
self.shared_gallery_raw = {
k: v.lower() for k, v in self.shared_gallery_raw.items()
}

self._shared_gallery = schema.load_by_type(
SharedImageGallerySchema, self.shared_gallery_raw
)
if not shared_gallery.subscription_id:
shared_gallery.subscription_id = self.subscription_id
# this step makes shared_gallery_raw is validated, and
# filter out any unwanted content.
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
if not self._shared_gallery.subscription_id:
self._shared_gallery.subscription_id = self.subscription_id
# Validated shared_gallery_raw and filter out any unwanted content
self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore

elif self.shared_gallery_raw:
assert isinstance(
self.shared_gallery_raw, str
), f"actual: {type(self.shared_gallery_raw)}"
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
# inconsistent cases cause a mismatch error in notifiers.
# lower() normalizes the image names, it has no impact on deployment
shared_gallery_strings = re.split(
r"[/]+", self.shared_gallery_raw.strip().lower()
)
if len(shared_gallery_strings) == 5:
shared_gallery = SharedImageGallerySchema(*shared_gallery_strings)
# shared_gallery_raw is used
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
self._shared_gallery = SharedImageGallerySchema(*shared_gallery_strings)
elif len(shared_gallery_strings) == 3:
shared_gallery = SharedImageGallerySchema(
self._shared_gallery = SharedImageGallerySchema(
self.subscription_id, None, *shared_gallery_strings
)
# shared_gallery_raw is used
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
else:
raise LisaException(
f"Invalid value for the provided shared gallery "
Expand All @@ -366,51 +352,43 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
f"<image_definition>/<image_version>' or '<image_gallery>/"
f"<image_definition>/<image_version>'"
)
self._shared_gallery = shared_gallery
return shared_gallery
self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore

return self._shared_gallery

@shared_gallery.setter
def shared_gallery(self, value: Optional[SharedImageGallerySchema]) -> None:
self._shared_gallery = value
if value is None:
self.shared_gallery_raw = None
else:
self.shared_gallery_raw = value.to_dict() # type: ignore
# dataclass_json doesn't use a protocol return type, so to_dict() is unknown
self.shared_gallery_raw = (
None if value is None else value.to_dict() # type: ignore [attr-defined]
)

@property
def vhd(self) -> Optional[VhdSchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_vhd"):
self._vhd: Optional[VhdSchema] = None
vhd: Optional[VhdSchema] = self._vhd
if vhd:
return vhd
if self._vhd is not None:
return self._vhd

if isinstance(self.vhd_raw, dict):
vhd = schema.load_by_type(VhdSchema, self.vhd_raw)
add_secret(vhd.vhd_path, PATTERN_URL)
if vhd.vmgs_path:
add_secret(vhd.vmgs_path, PATTERN_URL)
# this step makes vhd_raw is validated, and
# filter out any unwanted content.
self.vhd_raw = vhd.to_dict() # type: ignore
self._vhd = schema.load_by_type(VhdSchema, self.vhd_raw)
add_secret(self._vhd.vhd_path, PATTERN_URL)
if self._vhd.vmgs_path:
add_secret(self._vhd.vmgs_path, PATTERN_URL)
# Validated vhd_raw and filter out any unwanted content
self.vhd_raw = self._vhd.to_dict() # type: ignore

elif self.vhd_raw is not None:
assert isinstance(self.vhd_raw, str), f"actual: {type(self.vhd_raw)}"
vhd = VhdSchema(self.vhd_raw)
add_secret(vhd.vhd_path, PATTERN_URL)
self.vhd_raw = vhd.to_dict() # type: ignore
self._vhd = vhd
if vhd:
return vhd
else:
return None
self._vhd = VhdSchema(self.vhd_raw)
add_secret(self._vhd.vhd_path, PATTERN_URL)
self.vhd_raw = self._vhd.to_dict() # type: ignore

return self._vhd

@vhd.setter
def vhd(self, value: Optional[VhdSchema]) -> None:
self._vhd = value
if value is None:
self.vhd_raw = None
else:
self.vhd_raw = self._vhd.to_dict() # type: ignore
self.vhd_raw = None if value is None else self._vhd.to_dict() # type: ignore

def get_image_name(self) -> str:
result = ""
Expand All @@ -421,7 +399,7 @@ def get_image_name(self) -> str:
self.shared_gallery_raw, dict
), f"actual type: {type(self.shared_gallery_raw)}"
if self.shared_gallery.resource_group_name:
result = "/".join([x for x in self.shared_gallery_raw.values()])
result = "/".join(self.shared_gallery_raw.values())
else:
result = (
f"{self.shared_gallery.image_gallery}/"
Expand All @@ -432,7 +410,7 @@ def get_image_name(self) -> str:
assert isinstance(
self.marketplace_raw, dict
), f"actual type: {type(self.marketplace_raw)}"
result = " ".join([x for x in self.marketplace_raw.values()])
result = " ".join(self.marketplace_raw.values())
return result


Expand All @@ -457,9 +435,7 @@ def from_node_runbook(cls, runbook: AzureNodeSchema) -> "AzureNodeArmParameter":
parameters["vhd_raw"] = parameters["vhd"]
del parameters["vhd"]

arm_parameters = AzureNodeArmParameter(**parameters)

return arm_parameters
return AzureNodeArmParameter(**parameters)


class DataDiskCreateOption:
Expand Down
Loading

0 comments on commit d2befb6

Please sign in to comment.