Skip to content

Commit

Permalink
model search with model upload to best run
Browse files Browse the repository at this point in the history
  • Loading branch information
jrybicki-jsc committed Feb 6, 2024
1 parent 2c4f7db commit 66c3d30
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions model_search_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,7 @@ def uploat_to_mlflow(temp_dir, **context):
print(f"Experiment {experiment_name} was not found, creating new")
experiment_id = client.create_experiment(experiment_name)

run = client.create_run(experiment_id)
print(f"Uploading to experiment {experiment_name}/{experiment_id}/{run.info.run_id}")

print("Uploading model")
client.log_artifact(
run_id=run.info.run_id,
local_path=os.path.join(temp_dir, 'model.dat'),
artifact_path="model",
)
print(f"Uploading to experiment {experiment_name}/{experiment_id}")

print("Uploading model search results")
df = pd.read_csv(os.path.join(temp_dir, 'pd.csv'), index_col=0)
Expand All @@ -86,7 +78,7 @@ def uploat_to_mlflow(temp_dir, **context):
metrics=['mean_test_score', 'mean_fit_time']

for i, p in enumerate(dct['params'].values()):
with mlflow.start_run(experiment_id=experiment_id):
with mlflow.start_run(experiment_id=experiment_id) as run:
p = json.loads(p.replace('\'', '"'))
for parname, parvalue in p.items():
mlflow.log_param(key=parname, value=parvalue)
Expand All @@ -98,6 +90,15 @@ def uploat_to_mlflow(temp_dir, **context):
print(f"Logging metric {m} {dct[m][i]}")
mlflow.log_metric(key=m, value=dct[m][i])

if dct['rank_test_score'][i]==1:
print('This is the best model')
print("Uploading model to run: ", run.info.run_id)
mlflow.log_artifact(
local_path=os.path.join(temp_dir, 'model.dat'),
artifact_path="model",
)


#clean up
shutil.rmtree(temp_dir)

Expand Down

0 comments on commit 66c3d30

Please sign in to comment.