Added temperature

pull/19/head
M-Daniyal-123 2023-02-24 19:19:50 +05:00 zatwierdzone przez GitHub
rodzic 89f19f4d4d
commit 6487b41e03
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
1 zmienionych plików z 4 dodań i 4 usunięć

Wyświetl plik

@ -40,20 +40,20 @@ def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
x = transformer_block(x, **block, n_head=n_head)
return layer_norm(x, **ln_f) @ wte.T
def generate(inputs, params, n_head, n_tokens_to_generate):
def generate(inputs, params, n_head, n_tokens_to_generate,temperature = 1.0):
from tqdm import tqdm
for _ in tqdm(range(n_tokens_to_generate), "generating"):
logits = gpt2(inputs, **params, n_head=n_head)
next_id = np.argmax(logits[-1])
next_id = np.argmax(softmax(logits[-1])/temperature)
inputs.append(int(next_id))
return inputs[len(inputs) - n_tokens_to_generate :]
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models",temperature = 1.0):
from utils import load_encoder_hparams_and_params
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
input_ids = encoder.encode(prompt)
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate,temperature)
output_text = encoder.decode(output_ids)
return output_text