Skip to content

Commit

Permalink
fix(tabular-benchmark): Always just use "id" for tabular ids
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Oct 31, 2023
1 parent 41a0124 commit aca8ca4
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 83 deletions.
27 changes: 14 additions & 13 deletions src/mfpbench/lcbench_tabular/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,27 +115,28 @@ def _get_raw_lcbench_space(
@dataclass(frozen=True, eq=False, unsafe_hash=True) # type: ignore[misc]
class LCBenchTabularConfig(TabularConfig):
batch_size: int
imputation_strategy: str
learning_rate_scheduler: str
network: str
max_dropout: float
normalization_strategy: str
optimizer: str
cosine_annealing_T_max: int
cosine_annealing_eta_min: float
activation: str
max_units: int
mlp_shape: str
num_layers: int
learning_rate: float
momentum: float
weight_decay: float
# All of these are constant and hence optional
loss: str | None = None # This is the name of the loss function used, not a float
imputation_strategy: str | None = None
learning_rate_scheduler: str | None = None
network: str | None = None
normalization_strategy: str | None = None
optimizer: str | None = None
cosine_annealing_T_max: int | None = None
cosine_annealing_eta_min: float | None = None
activation: str | None = None
mlp_shape: str | None = None


@dataclass(frozen=True) # type: ignore[misc]
class LCBenchTabularResult(Result[LCBenchTabularConfig, int]):
time: float
loss: float
val_accuracy: float
val_cross_entropy: float
val_balanced_accuracy: float
Expand Down Expand Up @@ -233,7 +234,7 @@ def __init__(
task_id: str,
datadir: str | Path | None = None,
*,
remove_constants: bool = True,
remove_constants: bool = False,
seed: int | None = None,
prior: str | Path | LCBenchTabularConfig | Mapping[str, Any] | None = None,
perturb_prior: float | None = None,
Expand Down Expand Up @@ -295,8 +296,8 @@ def __init__(
super().__init__(
table=table, # type: ignore
name=benchmark_task_name,
config_name="config_id",
fidelity_name=cls.fidelity_name,
id_key="id",
fidelity_key=cls.fidelity_name,
result_keys=LCBenchTabularResult.names(),
config_keys=LCBenchTabularConfig.names(),
remove_constants=remove_constants,
Expand Down
6 changes: 3 additions & 3 deletions src/mfpbench/setup_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _process(cls, path: Path) -> None:
config: dict = config_data["config"]

log_data: dict = config_data["log"]
loss: list[float] = log_data["Train/loss"]
loss: list[str] = log_data["Train/loss"] # Name of the loss
val_ce: list[float] = log_data["Train/val_cross_entropy"]
val_acc: list[float] = log_data["Train/val_accuracy"]
val_bal_acc: list[float] = log_data["Train/val_balanced_accuracy"]
Expand Down Expand Up @@ -240,7 +240,7 @@ def _process(cls, path: Path) -> None:
)
# These are single valued but this will make them as a list into
# the dataframe
df = df.assign(**{"config_id": config_id, **config})
df = df.assign(**{"id": config_id, **config})

config_frames_for_dataset.append(df)

Expand All @@ -249,7 +249,7 @@ def _process(cls, path: Path) -> None:
df_for_dataset = (
pd.concat(config_frames_for_dataset, ignore_index=True)
.convert_dtypes()
.set_index(["config_id", "epoch"])
.set_index(["id", "epoch"])
.sort_index()
)
table_path = path / f"{dataset_name}.parquet"
Expand Down
154 changes: 87 additions & 67 deletions src/mfpbench/tabular.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from datetime import datetime
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Mapping, Sequence, TypeVar, overload
from typing_extensions import override
Expand All @@ -26,10 +25,10 @@


class TabularBenchmark(Benchmark[CTabular, R, F]):
config_name: str
id_key: str
"""The column in the table that contains the config id. Will be set to the index"""

fidelity_name: str
fidelity_key: str
"""The name of the fidelity used in this benchmark"""

config_keys: Sequence[str]
Expand All @@ -41,20 +40,23 @@ class TabularBenchmark(Benchmark[CTabular, R, F]):
table: pd.DataFrame
"""The table of results used for this benchmark"""

configs: Mapping[str, CTabular]
"""The configs used in this benchmark"""

# The config and result type of this benchmark
Config: type[CTabular]
Result: type[R]

# Whether this benchmark has conditonals in it or not
has_conditionals: bool = False

def __init__( # noqa: PLR0913, C901
def __init__( # noqa: PLR0913
self,
name: str,
table: pd.DataFrame,
*,
config_name: str,
fidelity_name: str,
id_key: str,
fidelity_key: str,
result_keys: Sequence[str],
config_keys: Sequence[str],
remove_constants: bool = False,
Expand All @@ -68,8 +70,8 @@ def __init__( # noqa: PLR0913, C901
Args:
name: The name of this benchmark.
table: The table to use for the benchmark.
config_name: The column in the table that contains the config id
fidelity_name: The column in the table that contains the fidelity
id_key: The column in the table that contains the config id
fidelity_key: The column in the table that contains the fidelity
result_keys: The columns in the table that contain the results
config_keys: The columns in the table that contain the config values
remove_constants: Remove constant config columns from the data or not.
Expand All @@ -86,6 +88,31 @@ def __init__( # noqa: PLR0913, C901
seed: The seed to use for the benchmark.
"""
cls = self.__class__

# Make sure we work with a clean slate, no issue with index.
table = table.reset_index()

# Make sure all the keys they specified exist
if id_key not in table.columns:
raise ValueError(f"'{id_key=}' not in columns {table.columns}")

if fidelity_key not in table.columns:
raise ValueError(f"'{fidelity_key=}' not in columns {table.columns}")

if not all(key in table.columns for key in result_keys):
raise ValueError(f"{result_keys=} not in columns {table.columns}")

if not all(key in table.columns for key in config_keys):
raise ValueError(f"{config_keys=} not in columns {table.columns}")

# Make sure that the column `id` only exist if it's the `id_key`
if "id" in table.columns and id_key != "id":
raise ValueError(
f"Can't have `id` in the columns if it's not the {id_key=}."
" Please drop it or rename it.",
)

# Remove constants from the table
if remove_constants:

def is_constant(_s: pd.Series) -> bool:
Expand All @@ -98,45 +125,24 @@ def is_constant(_s: pd.Series) -> bool:
table = table.drop(columns=constant_cols) # type: ignore
config_keys = [k for k in config_keys if k not in constant_cols]

# If the table isn't indexed, index it
index_cols = [config_name, fidelity_name]
if table.index.names != index_cols:
# Only drop the index if it's not relevant.
relevant_cols: list[str] = [ # type: ignore
*list(index_cols), # type: ignore
*list(result_keys),
*list(config_keys),
]
relevant = any(name in relevant_cols for name in table.index.names)
table = table.reset_index(drop=not relevant)

if config_name not in table.columns:
raise ValueError(f"{config_name=} not in columns {table.columns}")
if fidelity_name not in table.columns:
raise ValueError(f"{fidelity_name=} not in columns {table.columns}")
# Remap their id column to `id`
table = table.rename(columns={id_key: "id"})

table = table.set_index(index_cols)
table = table.sort_index()
# Index the table
index_cols: list[str] = ["id", fidelity_key]

# Make sure all keys are in the table
for key in chain(result_keys, config_keys):
if key not in table.columns:
raise ValueError(f"{key=} not in columns {table.columns}")

# Make sure the keyword "id" is not in the columns as we use it to
# identify configs
if "id" in table.columns:
raise ValueError(f"{table.columns=} contains 'id'. Please rename it")

# Make sure we have equidistance fidelities for all configs
fidelity_values = table.index.get_level_values(fidelity_name)
fidelity_counts = fidelity_values.value_counts()
if not (fidelity_counts == fidelity_counts.iloc[0]).all():
raise ValueError(f"{fidelity_name=} not uniform. \n{fidelity_counts}")
# Drop all the columns that are not relevant
relevant_cols: list[str] = [
*index_cols,
*result_keys,
*config_keys,
]
table = table[relevant_cols] # type: ignore
table = table.set_index(index_cols).sort_index()

# We now have the following table
#
# config_id fidelity | **metric, **config_values
# id fidelity | **metric, **config_values
# 0 0 |
# 1 |
# 2 |
Expand All @@ -145,38 +151,41 @@ def is_constant(_s: pd.Series) -> bool:
# 2 |
# ...

# Make sure we have equidistance fidelities for all configs
fidelity_values = table.index.get_level_values(fidelity_key)
fidelity_counts = fidelity_values.value_counts()
if not (fidelity_counts == fidelity_counts.iloc[0]).all():
raise ValueError(f"{fidelity_key=} not uniform. \n{fidelity_counts}")

sorted_fids = sorted(fidelity_values.unique())
start = sorted_fids[0]
end = sorted_fids[-1]
step = sorted_fids[1] - sorted_fids[0]

# Here we get all the unique configs
# config_id fidelity | **metric, **config_values
# id fidelity | **metric, **config_values
# 0 0 |
# 1 0 |
# ...
config_id_table = table.groupby(level=config_name).agg("first")
id_table = table.groupby(level=id_key).agg("first")
configs = {
str(config_id): cls.Config.from_dict(
{
**row[config_keys].to_dict(), # type: ignore
"id": str(config_id),
},
)
for config_id, row in config_id_table.iterrows()
for config_id, row in id_table.iterrows()
}

fidelity_values = table.index.get_level_values(fidelity_name).unique()

# We just assume equidistant fidelities
sorted_fids = sorted(fidelity_values)
start = sorted_fids[0]
end = sorted_fids[-1]
step = sorted_fids[1] - sorted_fids[0]

# Create the configuration space
if space is None:
space = ConfigurationSpace(name, seed=seed)

self.table = table
self.configs = configs
self.fidelity_name = fidelity_name
self.config_name = config_name
self.fidelity_key = fidelity_key
self.id_key = id_key
self.config_keys = sorted(config_keys)
self.result_keys = sorted(result_keys)
self.fidelity_range = (start, end, step) # type: ignore
Expand Down Expand Up @@ -279,23 +288,35 @@ def _find_config(
self,
config: CTabular | Mapping[str, Any] | str | int,
) -> CTabular:
# It's an interger but likely meant to be string
# We don't do any numeric based lookups
if isinstance(config, int):
config = str(config)

# It's a key into the self.configs dict
if isinstance(config, str):
return self.configs[config]

# If's a Config, that's fine
if isinstance(config, self.Config):
return config

if self.config_name in config:
_id = config[self.config_name]
return self.configs[_id]
# At this point, we assume we're basically dealing with a dictionary
assert isinstance(config, Mapping)

# Not sure how that ended up there, but we can at least handle that
if self.id_key in config:
_real_config_id = str(config[self.id_key])
return self.configs[_real_config_id]

# Also ... not sure but anywho
if "id" in config:
_id = config["id"]
return self.configs[_id]

# Alright, nothing worked, here we try to match the actual hyperparameter
# values to what we have in our known configs and attempt to get the
# id that way
match = first_true(
self.configs.values(),
pred=lambda c: c == config, # type: ignore
Expand Down Expand Up @@ -407,8 +428,8 @@ def __init__( # noqa: PLR0913
table: pd.DataFrame,
*,
name: str | None = None,
fidelity_name: str,
config_name: str,
id_key: str,
fidelity_key: str,
result_keys: Sequence[str],
config_keys: Sequence[str],
result_mapping: (dict[str, str | Callable[[pd.DataFrame], Any]] | None) = None,
Expand All @@ -424,9 +445,8 @@ def __init__( # noqa: PLR0913
table: The table to use for the benchmark
name: The name of the benchmark. If None, will be set to
`unknown-{datetime.now().isoformat()}`
fidelity_name: The column in the table that contains the fidelity
config_name: The column in the table that contains the config id
id_key: The column in the table that contains the config id
fidelity_key: The column in the table that contains the fidelity
result_keys: The columns in the table that contain the results
config_keys: The columns in the table that contain the config values
result_mapping: A mapping from the result keys to the table keys.
Expand Down Expand Up @@ -467,8 +487,8 @@ def __init__( # noqa: PLR0913
super().__init__(
name=name,
table=table,
config_name=config_name,
fidelity_name=fidelity_name,
id_key=id_key,
fidelity_key=fidelity_key,
result_keys=[*result_keys, *_result_mapping.keys()],
config_keys=config_keys,
remove_constants=remove_constants,
Expand All @@ -485,8 +505,8 @@ def __init__( # noqa: PLR0913
table = pd.read_parquet(path)
benchmark = GenericTabularBenchmark(
table=table,
fidelity_name="epoch",
config_name="config_id",
id_key="id",
fidelity_key="epoch",
result_keys=[
"time",
"val_accuracy",
Expand Down

0 comments on commit aca8ca4

Please sign in to comment.