BlenderGPT/lib/openai/wandb_logger.py

301 wiersze
10 KiB
Python

try:
import wandb
WANDB_AVAILABLE = True
except:
WANDB_AVAILABLE = False
if WANDB_AVAILABLE:
import datetime
import io
import json
import re
from pathlib import Path
from openai import File, FineTune
from openai.datalib import numpy as np
from openai.datalib import pandas as pd
class WandbLogger:
"""
Log fine-tunes to [Weights & Biases](https://wandb.me/openai-docs)
"""
if not WANDB_AVAILABLE:
print("Logging requires wandb to be installed. Run `pip install wandb`.")
else:
_wandb_api = None
_logged_in = False
@classmethod
def sync(
cls,
id=None,
n_fine_tunes=None,
project="GPT-3",
entity=None,
force=False,
**kwargs_wandb_init,
):
"""
Sync fine-tunes to Weights & Biases.
:param id: The id of the fine-tune (optional)
:param n_fine_tunes: Number of most recent fine-tunes to log when an id is not provided. By default, every fine-tune is synced.
:param project: Name of the project where you're sending runs. By default, it is "GPT-3".
:param entity: Username or team name where you're sending runs. By default, your default entity is used, which is usually your username.
:param force: Forces logging and overwrite existing wandb run of the same fine-tune.
"""
if not WANDB_AVAILABLE:
return
if id:
fine_tune = FineTune.retrieve(id=id)
fine_tune.pop("events", None)
fine_tunes = [fine_tune]
else:
# get list of fine_tune to log
fine_tunes = FineTune.list()
if not fine_tunes or fine_tunes.get("data") is None:
print("No fine-tune has been retrieved")
return
fine_tunes = fine_tunes["data"][
-n_fine_tunes if n_fine_tunes is not None else None :
]
# log starting from oldest fine_tune
show_individual_warnings = (
False if id is None and n_fine_tunes is None else True
)
fine_tune_logged = [
cls._log_fine_tune(
fine_tune,
project,
entity,
force,
show_individual_warnings,
**kwargs_wandb_init,
)
for fine_tune in fine_tunes
]
if not show_individual_warnings and not any(fine_tune_logged):
print("No new successful fine-tunes were found")
return "🎉 wandb sync completed successfully"
@classmethod
def _log_fine_tune(
cls,
fine_tune,
project,
entity,
force,
show_individual_warnings,
**kwargs_wandb_init,
):
fine_tune_id = fine_tune.get("id")
status = fine_tune.get("status")
# check run completed successfully
if status != "succeeded":
if show_individual_warnings:
print(
f'Fine-tune {fine_tune_id} has the status "{status}" and will not be logged'
)
return
# check results are present
try:
results_id = fine_tune["result_files"][0]["id"]
results = File.download(id=results_id).decode("utf-8")
except:
if show_individual_warnings:
print(f"Fine-tune {fine_tune_id} has no results and will not be logged")
return
# check run has not been logged already
run_path = f"{project}/{fine_tune_id}"
if entity is not None:
run_path = f"{entity}/{run_path}"
wandb_run = cls._get_wandb_run(run_path)
if wandb_run:
wandb_status = wandb_run.summary.get("status")
if show_individual_warnings:
if wandb_status == "succeeded":
print(
f"Fine-tune {fine_tune_id} has already been logged successfully at {wandb_run.url}"
)
if not force:
print(
'Use "--force" in the CLI or "force=True" in python if you want to overwrite previous run'
)
else:
print(
f"A run for fine-tune {fine_tune_id} was previously created but didn't end successfully"
)
if wandb_status != "succeeded" or force:
print(
f"A new wandb run will be created for fine-tune {fine_tune_id} and previous run will be overwritten"
)
if wandb_status == "succeeded" and not force:
return
# start a wandb run
wandb.init(
job_type="fine-tune",
config=cls._get_config(fine_tune),
project=project,
entity=entity,
name=fine_tune_id,
id=fine_tune_id,
**kwargs_wandb_init,
)
# log results
df_results = pd.read_csv(io.StringIO(results))
for _, row in df_results.iterrows():
metrics = {k: v for k, v in row.items() if not np.isnan(v)}
step = metrics.pop("step")
if step is not None:
step = int(step)
wandb.log(metrics, step=step)
fine_tuned_model = fine_tune.get("fine_tuned_model")
if fine_tuned_model is not None:
wandb.summary["fine_tuned_model"] = fine_tuned_model
# training/validation files and fine-tune details
cls._log_artifacts(fine_tune, project, entity)
# mark run as complete
wandb.summary["status"] = "succeeded"
wandb.finish()
return True
@classmethod
def _ensure_logged_in(cls):
if not cls._logged_in:
if wandb.login():
cls._logged_in = True
else:
raise Exception("You need to log in to wandb")
@classmethod
def _get_wandb_run(cls, run_path):
cls._ensure_logged_in()
try:
if cls._wandb_api is None:
cls._wandb_api = wandb.Api()
return cls._wandb_api.run(run_path)
except Exception:
return None
@classmethod
def _get_wandb_artifact(cls, artifact_path):
cls._ensure_logged_in()
try:
if cls._wandb_api is None:
cls._wandb_api = wandb.Api()
return cls._wandb_api.artifact(artifact_path)
except Exception:
return None
@classmethod
def _get_config(cls, fine_tune):
config = dict(fine_tune)
for key in ("training_files", "validation_files", "result_files"):
if config.get(key) and len(config[key]):
config[key] = config[key][0]
if config.get("created_at"):
config["created_at"] = datetime.datetime.fromtimestamp(config["created_at"])
return config
@classmethod
def _log_artifacts(cls, fine_tune, project, entity):
# training/validation files
training_file = (
fine_tune["training_files"][0]
if fine_tune.get("training_files") and len(fine_tune["training_files"])
else None
)
validation_file = (
fine_tune["validation_files"][0]
if fine_tune.get("validation_files") and len(fine_tune["validation_files"])
else None
)
for file, prefix, artifact_type in (
(training_file, "train", "training_files"),
(validation_file, "valid", "validation_files"),
):
if file is not None:
cls._log_artifact_inputs(file, prefix, artifact_type, project, entity)
# fine-tune details
fine_tune_id = fine_tune.get("id")
artifact = wandb.Artifact(
"fine_tune_details",
type="fine_tune_details",
metadata=fine_tune,
)
with artifact.new_file(
"fine_tune_details.json", mode="w", encoding="utf-8"
) as f:
json.dump(fine_tune, f, indent=2)
wandb.run.log_artifact(
artifact,
aliases=["latest", fine_tune_id],
)
@classmethod
def _log_artifact_inputs(cls, file, prefix, artifact_type, project, entity):
file_id = file["id"]
filename = Path(file["filename"]).name
stem = Path(file["filename"]).stem
# get input artifact
artifact_name = f"{prefix}-{filename}"
# sanitize name to valid wandb artifact name
artifact_name = re.sub(r"[^a-zA-Z0-9_\-.]", "_", artifact_name)
artifact_alias = file_id
artifact_path = f"{project}/{artifact_name}:{artifact_alias}"
if entity is not None:
artifact_path = f"{entity}/{artifact_path}"
artifact = cls._get_wandb_artifact(artifact_path)
# create artifact if file not already logged previously
if artifact is None:
# get file content
try:
file_content = File.download(id=file_id).decode("utf-8")
except:
print(
f"File {file_id} could not be retrieved. Make sure you are allowed to download training/validation files"
)
return
artifact = wandb.Artifact(artifact_name, type=artifact_type, metadata=file)
with artifact.new_file(filename, mode="w", encoding="utf-8") as f:
f.write(file_content)
# create a Table
try:
table, n_items = cls._make_table(file_content)
artifact.add(table, stem)
wandb.config.update({f"n_{prefix}": n_items})
artifact.metadata["items"] = n_items
except:
print(f"File {file_id} could not be read as a valid JSON file")
else:
# log number of items
wandb.config.update({f"n_{prefix}": artifact.metadata.get("items")})
wandb.run.use_artifact(artifact, aliases=["latest", artifact_alias])
@classmethod
def _make_table(cls, file_content):
df = pd.read_json(io.StringIO(file_content), orient="records", lines=True)
return wandb.Table(dataframe=df), len(df)