diff --git a/gpt2.py b/gpt2.py index 62549bc..e227ec2 100644 --- a/gpt2.py +++ b/gpt2.py @@ -89,9 +89,9 @@ def generate(inputs, params, n_head, n_tokens_to_generate): for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop logits = gpt2(inputs, **params, n_head=n_head) # model forward pass next_id = np.argmax(logits[-1]) # greedy sampling - inputs = np.append(inputs, [next_id]) # append prediction to input + inputs.append(int(next_id)) # append prediction to input - return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids + return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"): diff --git a/gpt2_pico.py b/gpt2_pico.py index 904a97d..6c4bf80 100644 --- a/gpt2_pico.py +++ b/gpt2_pico.py @@ -45,8 +45,8 @@ def generate(inputs, params, n_head, n_tokens_to_generate): for _ in tqdm(range(n_tokens_to_generate), "generating"): logits = gpt2(inputs, **params, n_head=n_head) next_id = np.argmax(logits[-1]) - inputs = np.append(inputs, [next_id]) - return list(inputs[len(inputs) - n_tokens_to_generate :]) + inputs.append(int(next_id)) + return inputs[len(inputs) - n_tokens_to_generate :] def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"): from utils import load_encoder_hparams_and_params