kopia lustrzana https://github.com/jaymody/picoGPT
Convert the output to a list of ints
In Jax, if you do list(np.array([1, 2, 3])) it becomes a list of 0-d arrays for each integer. We avoid this by explicitly casting each element to int. This patch works with both NumPy and Jax. Now gpt2.py successfully runs with both Jax and NumPy.pull/10/head
rodzic
c456f120ae
commit
bf9d37ddef
2
gpt2.py
2
gpt2.py
|
@ -91,7 +91,7 @@ def generate(inputs, params, n_head, n_tokens_to_generate):
|
|||
next_id = np.argmax(logits[-1]) # greedy sampling
|
||||
inputs = np.append(inputs, next_id) # append prediction to input
|
||||
|
||||
return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids
|
||||
return [int(x) for x in inputs[-n_tokens_to_generate:]] # only return generated ids
|
||||
|
||||
|
||||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
||||
|
|
Ładowanie…
Reference in New Issue