Convert the output to a list of ints

In Jax, if you do list(np.array([1, 2, 3])) it becomes a list of 0-d
arrays for each integer. We avoid this by explicitly casting each
element to int. This patch works with both NumPy and Jax.

Now gpt2.py successfully runs with both Jax and NumPy.
pull/10/head
Ondřej Čertík 2023-02-16 13:27:29 -07:00
rodzic c456f120ae
commit bf9d37ddef
1 zmienionych plików z 1 dodań i 1 usunięć

Wyświetl plik

@ -91,7 +91,7 @@ def generate(inputs, params, n_head, n_tokens_to_generate):
next_id = np.argmax(logits[-1]) # greedy sampling
inputs = np.append(inputs, next_id) # append prediction to input
return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids
return [int(x) for x in inputs[-n_tokens_to_generate:]] # only return generated ids
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):