Skip to content

Commit

Permalink
fix example
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Aug 29, 2023
1 parent 5556027 commit 1613119
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
32 changes: 17 additions & 15 deletions example1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,19 @@

import mfpbench

# Optional imports here just to give typing and so you can look deeper if needed
from mfpbench.jahs import JAHSBenchmark, JAHSConfig, JAHSResult

if __name__ == "__main__":

seed = 724

# Just adding a type here if you want to explore it
benchmark = mfpbench.get("jahs_cifar10", seed=seed) # datadir = ...
benchmark = cast(JAHSBenchmark, benchmark)
# Just adding a type here if you want to explore it # get(..., datadir = ...)
benchmark: mfpbench.Benchmark = mfpbench.get("lcbench", task_id="3945", seed=seed)

min_fidelity, max_fidelity, fidelity_step = benchmark.fidelity_range

# benchmark = mfpbench.get("lcbench", seed=seed, datadir=datadir, task_id="3945")
# benchmark = cast(LCBench, benchmark)

# Get a random config just to see it
config: JAHSConfig = benchmark.sample()
config: mfpbench.Config = benchmark.sample()

# And the search space

Expand All @@ -30,20 +27,21 @@
assert exact_copy == config

# You can only mutate it by a copy to keep things consistent
new_copy = config.copy(TrivialAugment=True)
new_copy = config.copy(momentum=0.2)

# You can always validate a config to make sure you aren't doing anything wrong
new_copy.validate()

# Like in this case where we used a bad optmizer
bad_copy = config.copy(Optimizer="Adam")
bad_copy = config.copy(momentum=-10)
try:
bad_copy.validate()
except AssertionError:
pass

# Anyways, here's the results for the config
result: JAHSResult = benchmark.query(config, at=42)
result: mfpbench.Result
result = benchmark.query(config, at=42)
result = benchmark.query(config.dict(), at=42) # You can also use a dict
result = benchmark.query(benchmark.space.sample_configuration()) # Or configspace

Expand All @@ -53,9 +51,13 @@
# The full result object

# And if you need the full trajectory
results: list[JAHSResult] = benchmark.trajectory(config)
sliced_result = benchmark.trajectory(config, to=100)
sliced_result_2 = benchmark.trajectory(config, frm=50, to=100)
results: list[mfpbench.Result] = benchmark.trajectory(config)
sliced_result = benchmark.trajectory(config, to=max_fidelity)
sliced_result_2 = benchmark.trajectory(
config,
frm=max_fidelity // 2,
to=max_fidelity,
)

first = results[0]
last = results[-1]
Expand All @@ -65,7 +67,7 @@
best = sorted_result[0]

# Now here's 100 configs and we'll get the best configuration
configs: list[JAHSConfig] = benchmark.sample(100)
configs: list[mfpbench.Config] = benchmark.sample(100)

# Get all trajectories for each run
# [
Expand Down
1 change: 1 addition & 0 deletions mfpbench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PD1translatewmt_xformer_64,
PD1uniref50_transformer_128,
)
from mfpbench.result import Result # noqa: F401
from mfpbench.synthetic.hartmann import (
MFHartmann3Benchmark,
MFHartmann3BenchmarkBad,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ tqdm = "*"
numpy = "1.*"
yahpo-gym = "1.0.1"
xgboost = "^1"
ConfigSpace = "<0.7"

[tool.poetry.group.jahs.dependencies]
jahs-bench = { git = "https://github.com/automl/jahs_bench_201.git", rev = "880fbcb35a83df7b6c02440a6c13adb921f54657" }
Expand Down

0 comments on commit 1613119

Please sign in to comment.