Added Temperature

pull/20/head
M-Daniyal-123 2023-02-24 18:41:54 +05:00 zatwierdzone przez GitHub
rodzic 79a40c0638
commit ccd68b4d02
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
1 zmienionych plików z 3 dodań i 3 usunięć

Wyświetl plik

@ -85,7 +85,7 @@ def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]
return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab]
def generate(inputs, params, n_head, n_tokens_to_generate):
def generate(inputs, params, n_head, n_tokens_to_generate,temperature):
from tqdm import tqdm
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
@ -96,7 +96,7 @@ def generate(inputs, params, n_head, n_tokens_to_generate):
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models",temperature = 1.0):
from utils import load_encoder_hparams_and_params
# load encoder, hparams, and params from the released open-ai gpt-2 files
@ -109,7 +109,7 @@ def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M",
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
# generate output ids
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate,temperature)
# decode the ids back into a string
output_text = encoder.decode(output_ids)