From bf9d37ddef01f36520e2f1453ca8da4fc8941eb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20=C4=8Cert=C3=ADk?= Date: Thu, 16 Feb 2023 13:27:29 -0700 Subject: [PATCH] Convert the output to a list of ints In Jax, if you do list(np.array([1, 2, 3])) it becomes a list of 0-d arrays for each integer. We avoid this by explicitly casting each element to int. This patch works with both NumPy and Jax. Now gpt2.py successfully runs with both Jax and NumPy. --- gpt2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt2.py b/gpt2.py index eb5da99..52683df 100644 --- a/gpt2.py +++ b/gpt2.py @@ -91,7 +91,7 @@ def generate(inputs, params, n_head, n_tokens_to_generate): next_id = np.argmax(logits[-1]) # greedy sampling inputs = np.append(inputs, next_id) # append prediction to input - return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids + return [int(x) for x in inputs[-n_tokens_to_generate:]] # only return generated ids def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):