kopia lustrzana https://github.com/jaymody/picoGPT
Bug fix.
rodzic
bf118a3660
commit
4cd64933bb
2
main.py
2
main.py
|
@ -83,10 +83,10 @@ def main(prompt, models_dir, model_size, n_tokens_to_generate):
|
||||||
assert model_size in ["124M", "355M", "774M", "1558M"]
|
assert model_size in ["124M", "355M", "774M", "1558M"]
|
||||||
|
|
||||||
model_dir = os.path.join(models_dir, model_size)
|
model_dir = os.path.join(models_dir, model_size)
|
||||||
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
|
|
||||||
if not os.path.isdir(model_dir):
|
if not os.path.isdir(model_dir):
|
||||||
os.makedirs(model_dir)
|
os.makedirs(model_dir)
|
||||||
download_gpt2_files(model_size, model_dir)
|
download_gpt2_files(model_size, model_dir)
|
||||||
|
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
|
||||||
|
|
||||||
with open(os.path.join(model_dir, "hparams.json")) as file:
|
with open(os.path.join(model_dir, "hparams.json")) as file:
|
||||||
hparams = json.load(file)
|
hparams = json.load(file)
|
||||||
|
|
Ładowanie…
Reference in New Issue