Added temperature in generation

pull/19/head
M-Daniyal-123 2023-02-24 19:17:44 +05:00 zatwierdzone przez GitHub
rodzic dfb5df895a
commit 89f19f4d4d
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
1 zmienionych plików z 4 dodań i 4 usunięć

Wyświetl plik

@ -83,18 +83,18 @@ def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]
return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab] return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab]
def generate(inputs, params, n_head, n_tokens_to_generate): def generate(inputs, params, n_head, n_tokens_to_generate,temperature):
from tqdm import tqdm from tqdm import tqdm
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
next_id = np.argmax(logits[-1]) # greedy sampling next_id = np.argmax(softmax(logits[-1])/temperature) # greedy sampling ## Added Temperature
inputs.append(int(next_id)) # append prediction to input inputs.append(int(next_id)) # append prediction to input
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"): def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models",temperature = 1.0):
from utils import load_encoder_hparams_and_params from utils import load_encoder_hparams_and_params
# load encoder, hparams, and params from the released open-ai gpt-2 files # load encoder, hparams, and params from the released open-ai gpt-2 files
@ -107,7 +107,7 @@ def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M",
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"] assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
# generate output ids # generate output ids
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate) output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate,temperature)
# decode the ids back into a string # decode the ids back into a string
output_text = encoder.decode(output_ids) output_text = encoder.decode(output_ids)