From f7dfc78ffaaf6f2f2523a8c04f3893e42d7ed995 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20=C4=8Cert=C3=ADk?= Date: Thu, 16 Feb 2023 15:26:49 -0700 Subject: [PATCH] Fix gpt2.py to work with Jax (#10) --- gpt2.py | 4 ++-- gpt2_pico.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gpt2.py b/gpt2.py index 62549bc..e227ec2 100644 --- a/gpt2.py +++ b/gpt2.py @@ -89,9 +89,9 @@ def generate(inputs, params, n_head, n_tokens_to_generate): for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop logits = gpt2(inputs, **params, n_head=n_head) # model forward pass next_id = np.argmax(logits[-1]) # greedy sampling - inputs = np.append(inputs, [next_id]) # append prediction to input + inputs.append(int(next_id)) # append prediction to input - return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids + return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"): diff --git a/gpt2_pico.py b/gpt2_pico.py index 904a97d..6c4bf80 100644 --- a/gpt2_pico.py +++ b/gpt2_pico.py @@ -45,8 +45,8 @@ def generate(inputs, params, n_head, n_tokens_to_generate): for _ in tqdm(range(n_tokens_to_generate), "generating"): logits = gpt2(inputs, **params, n_head=n_head) next_id = np.argmax(logits[-1]) - inputs = np.append(inputs, [next_id]) - return list(inputs[len(inputs) - n_tokens_to_generate :]) + inputs.append(int(next_id)) + return inputs[len(inputs) - n_tokens_to_generate :] def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"): from utils import load_encoder_hparams_and_params