kopia lustrzana https://github.com/gd3kr/BlenderGPT
301 wiersze
10 KiB
Python
301 wiersze
10 KiB
Python
try:
|
|
import wandb
|
|
|
|
WANDB_AVAILABLE = True
|
|
except:
|
|
WANDB_AVAILABLE = False
|
|
|
|
|
|
if WANDB_AVAILABLE:
|
|
import datetime
|
|
import io
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
|
|
from openai import File, FineTune
|
|
from openai.datalib import numpy as np
|
|
from openai.datalib import pandas as pd
|
|
|
|
|
|
class WandbLogger:
|
|
"""
|
|
Log fine-tunes to [Weights & Biases](https://wandb.me/openai-docs)
|
|
"""
|
|
|
|
if not WANDB_AVAILABLE:
|
|
print("Logging requires wandb to be installed. Run `pip install wandb`.")
|
|
else:
|
|
_wandb_api = None
|
|
_logged_in = False
|
|
|
|
@classmethod
|
|
def sync(
|
|
cls,
|
|
id=None,
|
|
n_fine_tunes=None,
|
|
project="GPT-3",
|
|
entity=None,
|
|
force=False,
|
|
**kwargs_wandb_init,
|
|
):
|
|
"""
|
|
Sync fine-tunes to Weights & Biases.
|
|
:param id: The id of the fine-tune (optional)
|
|
:param n_fine_tunes: Number of most recent fine-tunes to log when an id is not provided. By default, every fine-tune is synced.
|
|
:param project: Name of the project where you're sending runs. By default, it is "GPT-3".
|
|
:param entity: Username or team name where you're sending runs. By default, your default entity is used, which is usually your username.
|
|
:param force: Forces logging and overwrite existing wandb run of the same fine-tune.
|
|
"""
|
|
|
|
if not WANDB_AVAILABLE:
|
|
return
|
|
|
|
if id:
|
|
fine_tune = FineTune.retrieve(id=id)
|
|
fine_tune.pop("events", None)
|
|
fine_tunes = [fine_tune]
|
|
|
|
else:
|
|
# get list of fine_tune to log
|
|
fine_tunes = FineTune.list()
|
|
if not fine_tunes or fine_tunes.get("data") is None:
|
|
print("No fine-tune has been retrieved")
|
|
return
|
|
fine_tunes = fine_tunes["data"][
|
|
-n_fine_tunes if n_fine_tunes is not None else None :
|
|
]
|
|
|
|
# log starting from oldest fine_tune
|
|
show_individual_warnings = (
|
|
False if id is None and n_fine_tunes is None else True
|
|
)
|
|
fine_tune_logged = [
|
|
cls._log_fine_tune(
|
|
fine_tune,
|
|
project,
|
|
entity,
|
|
force,
|
|
show_individual_warnings,
|
|
**kwargs_wandb_init,
|
|
)
|
|
for fine_tune in fine_tunes
|
|
]
|
|
|
|
if not show_individual_warnings and not any(fine_tune_logged):
|
|
print("No new successful fine-tunes were found")
|
|
|
|
return "🎉 wandb sync completed successfully"
|
|
|
|
@classmethod
|
|
def _log_fine_tune(
|
|
cls,
|
|
fine_tune,
|
|
project,
|
|
entity,
|
|
force,
|
|
show_individual_warnings,
|
|
**kwargs_wandb_init,
|
|
):
|
|
fine_tune_id = fine_tune.get("id")
|
|
status = fine_tune.get("status")
|
|
|
|
# check run completed successfully
|
|
if status != "succeeded":
|
|
if show_individual_warnings:
|
|
print(
|
|
f'Fine-tune {fine_tune_id} has the status "{status}" and will not be logged'
|
|
)
|
|
return
|
|
|
|
# check results are present
|
|
try:
|
|
results_id = fine_tune["result_files"][0]["id"]
|
|
results = File.download(id=results_id).decode("utf-8")
|
|
except:
|
|
if show_individual_warnings:
|
|
print(f"Fine-tune {fine_tune_id} has no results and will not be logged")
|
|
return
|
|
|
|
# check run has not been logged already
|
|
run_path = f"{project}/{fine_tune_id}"
|
|
if entity is not None:
|
|
run_path = f"{entity}/{run_path}"
|
|
wandb_run = cls._get_wandb_run(run_path)
|
|
if wandb_run:
|
|
wandb_status = wandb_run.summary.get("status")
|
|
if show_individual_warnings:
|
|
if wandb_status == "succeeded":
|
|
print(
|
|
f"Fine-tune {fine_tune_id} has already been logged successfully at {wandb_run.url}"
|
|
)
|
|
if not force:
|
|
print(
|
|
'Use "--force" in the CLI or "force=True" in python if you want to overwrite previous run'
|
|
)
|
|
else:
|
|
print(
|
|
f"A run for fine-tune {fine_tune_id} was previously created but didn't end successfully"
|
|
)
|
|
if wandb_status != "succeeded" or force:
|
|
print(
|
|
f"A new wandb run will be created for fine-tune {fine_tune_id} and previous run will be overwritten"
|
|
)
|
|
if wandb_status == "succeeded" and not force:
|
|
return
|
|
|
|
# start a wandb run
|
|
wandb.init(
|
|
job_type="fine-tune",
|
|
config=cls._get_config(fine_tune),
|
|
project=project,
|
|
entity=entity,
|
|
name=fine_tune_id,
|
|
id=fine_tune_id,
|
|
**kwargs_wandb_init,
|
|
)
|
|
|
|
# log results
|
|
df_results = pd.read_csv(io.StringIO(results))
|
|
for _, row in df_results.iterrows():
|
|
metrics = {k: v for k, v in row.items() if not np.isnan(v)}
|
|
step = metrics.pop("step")
|
|
if step is not None:
|
|
step = int(step)
|
|
wandb.log(metrics, step=step)
|
|
fine_tuned_model = fine_tune.get("fine_tuned_model")
|
|
if fine_tuned_model is not None:
|
|
wandb.summary["fine_tuned_model"] = fine_tuned_model
|
|
|
|
# training/validation files and fine-tune details
|
|
cls._log_artifacts(fine_tune, project, entity)
|
|
|
|
# mark run as complete
|
|
wandb.summary["status"] = "succeeded"
|
|
|
|
wandb.finish()
|
|
return True
|
|
|
|
@classmethod
|
|
def _ensure_logged_in(cls):
|
|
if not cls._logged_in:
|
|
if wandb.login():
|
|
cls._logged_in = True
|
|
else:
|
|
raise Exception("You need to log in to wandb")
|
|
|
|
@classmethod
|
|
def _get_wandb_run(cls, run_path):
|
|
cls._ensure_logged_in()
|
|
try:
|
|
if cls._wandb_api is None:
|
|
cls._wandb_api = wandb.Api()
|
|
return cls._wandb_api.run(run_path)
|
|
except Exception:
|
|
return None
|
|
|
|
@classmethod
|
|
def _get_wandb_artifact(cls, artifact_path):
|
|
cls._ensure_logged_in()
|
|
try:
|
|
if cls._wandb_api is None:
|
|
cls._wandb_api = wandb.Api()
|
|
return cls._wandb_api.artifact(artifact_path)
|
|
except Exception:
|
|
return None
|
|
|
|
@classmethod
|
|
def _get_config(cls, fine_tune):
|
|
config = dict(fine_tune)
|
|
for key in ("training_files", "validation_files", "result_files"):
|
|
if config.get(key) and len(config[key]):
|
|
config[key] = config[key][0]
|
|
if config.get("created_at"):
|
|
config["created_at"] = datetime.datetime.fromtimestamp(config["created_at"])
|
|
return config
|
|
|
|
@classmethod
|
|
def _log_artifacts(cls, fine_tune, project, entity):
|
|
# training/validation files
|
|
training_file = (
|
|
fine_tune["training_files"][0]
|
|
if fine_tune.get("training_files") and len(fine_tune["training_files"])
|
|
else None
|
|
)
|
|
validation_file = (
|
|
fine_tune["validation_files"][0]
|
|
if fine_tune.get("validation_files") and len(fine_tune["validation_files"])
|
|
else None
|
|
)
|
|
for file, prefix, artifact_type in (
|
|
(training_file, "train", "training_files"),
|
|
(validation_file, "valid", "validation_files"),
|
|
):
|
|
if file is not None:
|
|
cls._log_artifact_inputs(file, prefix, artifact_type, project, entity)
|
|
|
|
# fine-tune details
|
|
fine_tune_id = fine_tune.get("id")
|
|
artifact = wandb.Artifact(
|
|
"fine_tune_details",
|
|
type="fine_tune_details",
|
|
metadata=fine_tune,
|
|
)
|
|
with artifact.new_file(
|
|
"fine_tune_details.json", mode="w", encoding="utf-8"
|
|
) as f:
|
|
json.dump(fine_tune, f, indent=2)
|
|
wandb.run.log_artifact(
|
|
artifact,
|
|
aliases=["latest", fine_tune_id],
|
|
)
|
|
|
|
@classmethod
|
|
def _log_artifact_inputs(cls, file, prefix, artifact_type, project, entity):
|
|
file_id = file["id"]
|
|
filename = Path(file["filename"]).name
|
|
stem = Path(file["filename"]).stem
|
|
|
|
# get input artifact
|
|
artifact_name = f"{prefix}-{filename}"
|
|
# sanitize name to valid wandb artifact name
|
|
artifact_name = re.sub(r"[^a-zA-Z0-9_\-.]", "_", artifact_name)
|
|
artifact_alias = file_id
|
|
artifact_path = f"{project}/{artifact_name}:{artifact_alias}"
|
|
if entity is not None:
|
|
artifact_path = f"{entity}/{artifact_path}"
|
|
artifact = cls._get_wandb_artifact(artifact_path)
|
|
|
|
# create artifact if file not already logged previously
|
|
if artifact is None:
|
|
# get file content
|
|
try:
|
|
file_content = File.download(id=file_id).decode("utf-8")
|
|
except:
|
|
print(
|
|
f"File {file_id} could not be retrieved. Make sure you are allowed to download training/validation files"
|
|
)
|
|
return
|
|
artifact = wandb.Artifact(artifact_name, type=artifact_type, metadata=file)
|
|
with artifact.new_file(filename, mode="w", encoding="utf-8") as f:
|
|
f.write(file_content)
|
|
|
|
# create a Table
|
|
try:
|
|
table, n_items = cls._make_table(file_content)
|
|
artifact.add(table, stem)
|
|
wandb.config.update({f"n_{prefix}": n_items})
|
|
artifact.metadata["items"] = n_items
|
|
except:
|
|
print(f"File {file_id} could not be read as a valid JSON file")
|
|
else:
|
|
# log number of items
|
|
wandb.config.update({f"n_{prefix}": artifact.metadata.get("items")})
|
|
|
|
wandb.run.use_artifact(artifact, aliases=["latest", artifact_alias])
|
|
|
|
@classmethod
|
|
def _make_table(cls, file_content):
|
|
df = pd.read_json(io.StringIO(file_content), orient="records", lines=True)
|
|
return wandb.Table(dataframe=df), len(df)
|