kopia lustrzana https://github.com/jaymody/picoGPT
Added temperature
rodzic
89f19f4d4d
commit
6487b41e03
|
@ -40,20 +40,20 @@ def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
|
|||
x = transformer_block(x, **block, n_head=n_head)
|
||||
return layer_norm(x, **ln_f) @ wte.T
|
||||
|
||||
def generate(inputs, params, n_head, n_tokens_to_generate):
|
||||
def generate(inputs, params, n_head, n_tokens_to_generate,temperature = 1.0):
|
||||
from tqdm import tqdm
|
||||
for _ in tqdm(range(n_tokens_to_generate), "generating"):
|
||||
logits = gpt2(inputs, **params, n_head=n_head)
|
||||
next_id = np.argmax(logits[-1])
|
||||
next_id = np.argmax(softmax(logits[-1])/temperature)
|
||||
inputs.append(int(next_id))
|
||||
return inputs[len(inputs) - n_tokens_to_generate :]
|
||||
|
||||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
||||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models",temperature = 1.0):
|
||||
from utils import load_encoder_hparams_and_params
|
||||
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
|
||||
input_ids = encoder.encode(prompt)
|
||||
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
|
||||
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
|
||||
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate,temperature)
|
||||
output_text = encoder.decode(output_ids)
|
||||
return output_text
|
||||
|
||||
|
|
Ładowanie…
Reference in New Issue