Skip to content

Commit

Permalink
Add type annotations for data module (#634)
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Aug 30, 2024
1 parent e0b89e3 commit 6d14f98
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 32 deletions.
54 changes: 27 additions & 27 deletions amlb/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
which can also be encoded (``y_enc``, ``X_enc``)
- **Feature** provides metadata for a given feature/column as well as encoding functions.
"""
from __future__ import annotations

from abc import ABC, abstractmethod
from enum import Enum, auto
from enum import Enum
import logging
from typing import List, Union
from typing import List, Union, Iterable

import numpy as np
import pandas as pd
Expand All @@ -31,7 +33,7 @@

class Feature:

def __init__(self, index, name, data_type, values=None, has_missing_values=False, is_target=False):
def __init__(self, index: int, name: str, data_type: str | None, values: Iterable[str] | None = None, has_missing_values: bool = False, is_target: bool = False):
"""
:param index: index of the feature in the full data frame.
:param name: name of the feature.
Expand All @@ -43,64 +45,63 @@ def __init__(self, index, name, data_type, values=None, has_missing_values=False
self.index = index
self.name = name
self.data_type = data_type.lower() if data_type is not None else None
self.values = values
self.values = values # type: ignore # https://github.com/python/mypy/issues/3004
self.has_missing_values = has_missing_values
self.is_target = is_target
# print(self)

def is_categorical(self, strict=True):
def is_categorical(self, strict: bool = True) -> bool:
if strict:
return self.data_type == 'category'
else:
return self.data_type is not None and not self.is_numerical()
return self.data_type is not None and not self.is_numerical()

def is_numerical(self):
def is_numerical(self) -> bool:
return self.data_type in ['int', 'float', 'number']

@lazy_property
def label_encoder(self):
def label_encoder(self) -> Encoder:
return Encoder('label' if self.values is not None else 'no-op',
target=self.is_target,
encoded_type=int if self.is_target and not self.is_numerical() else float,
missing_values=[None, np.nan, pd.NA],
missing_policy='mask' if self.has_missing_values else 'ignore',
normalize_fn=self.normalize
normalize_fn=Feature.normalize
).fit(self.values)

@lazy_property
def one_hot_encoder(self):
def one_hot_encoder(self) -> Encoder:
return Encoder('one-hot' if self.values is not None else 'no-op',
target=self.is_target,
encoded_type=int if self.is_target and not self.is_numerical() else float,
missing_values=[None, np.nan, pd.NA],
missing_policy='mask' if self.has_missing_values else 'ignore',
normalize_fn=self.normalize
normalize_fn=Feature.normalize
).fit(self.values)

def normalize(self, arr):
@staticmethod
def normalize(arr: Iterable[str]) -> np.ndarray:
return np.char.lower(np.char.strip(np.asarray(arr).astype(str)))

@property
def values(self):
def values(self) -> list[str] | None:
return self._values

@values.setter
def values(self, values):
self._values = self.normalize(values).tolist() if values is not None else None
def values(self, values: Iterable[str]) -> None:
self._values = Feature.normalize(values).tolist() if values is not None else None

def __repr__(self):
def __repr__(self) -> str:
return repr_def(self, 'all')


class Datasplit(ABC):

def __init__(self, dataset, format):
def __init__(self, dataset: Dataset, file_format: str):
"""
:param format: the default format of the data file, obtained through the 'path' property.
:param file_format: the default format of the data file, obtained through the 'path' property.
"""
super().__init__()
self.dataset = dataset
self.format = format
self.format = file_format

@property
def path(self) -> str:
Expand Down Expand Up @@ -137,7 +138,7 @@ def y(self) -> DF:
"""
:return:the target column as a pandas DataFrame: if you need a Series, just call `y.squeeze()`.
"""
return self.data.iloc[:, [self.dataset.target.index]]
return self.data.iloc[:, [self.dataset.target.index]] # type: ignore

@lazy_property
@profile(logger=log)
Expand All @@ -164,7 +165,7 @@ def y_enc(self) -> AM:
return self.data_enc[:, self.dataset.target.index]

@profile(logger=log)
def release(self, properties=None):
def release(self, properties: Iterable[str] | None = None) -> None:
clear_cache(self, properties)


Expand All @@ -177,7 +178,7 @@ class DatasetType(Enum):

class Dataset(ABC):

def __init__(self):
def __init__(self) -> None:
super().__init__()

@property
Expand Down Expand Up @@ -228,11 +229,10 @@ def target(self) -> Feature:
pass

@profile(logger=log)
def release(self, properties=None):
def release(self) -> None:
"""
Call this to release cached properties and optimize memory once in-memory data are not needed anymore.
:param properties:
"""
self.train.release()
self.test.release()
clear_cache(self, properties)
clear_cache(self)
10 changes: 5 additions & 5 deletions amlb/datasets/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ def __repr__(self):

class FileDatasplit(Datasplit):

def __init__(self, dataset: FileDataset, format: str, path: str):
super().__init__(dataset, format)
def __init__(self, dataset: FileDataset, file_format: str, path: str):
super().__init__(dataset, file_format)
self._path = path
self._data = {format: path}
self._data = {file_format: path}

def data_path(self, format):
supported_formats = [cls.format for cls in __file_converters__]
Expand Down Expand Up @@ -267,7 +267,7 @@ def __init__(self, train_path, test_path,
class ArffDatasplit(FileDatasplit):

def __init__(self, dataset, path):
super().__init__(dataset, format='arff', path=path)
super().__init__(dataset, file_format='arff', path=path)
self._ds = None

def _ensure_loaded(self):
Expand Down Expand Up @@ -419,7 +419,7 @@ def compute_seasonal_error(self):
class CsvDatasplit(FileDatasplit):

def __init__(self, dataset, path, timestamp_column=None):
super().__init__(dataset, format='csv', path=path)
super().__init__(dataset, file_format='csv', path=path)
self._ds = None
self.timestamp_column = timestamp_column

Expand Down

0 comments on commit 6d14f98

Please sign in to comment.