kopia lustrzana https://github.com/jaymody/picoGPT
Added Temperature
rodzic
79a40c0638
commit
ccd68b4d02
6
gpt2.py
6
gpt2.py
|
@ -85,7 +85,7 @@ def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]
|
||||||
return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab]
|
return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab]
|
||||||
|
|
||||||
|
|
||||||
def generate(inputs, params, n_head, n_tokens_to_generate):
|
def generate(inputs, params, n_head, n_tokens_to_generate,temperature):
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
|
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
|
||||||
|
@ -96,7 +96,7 @@ def generate(inputs, params, n_head, n_tokens_to_generate):
|
||||||
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
|
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
|
||||||
|
|
||||||
|
|
||||||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models",temperature = 1.0):
|
||||||
from utils import load_encoder_hparams_and_params
|
from utils import load_encoder_hparams_and_params
|
||||||
|
|
||||||
# load encoder, hparams, and params from the released open-ai gpt-2 files
|
# load encoder, hparams, and params from the released open-ai gpt-2 files
|
||||||
|
@ -109,7 +109,7 @@ def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M",
|
||||||
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
|
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
|
||||||
|
|
||||||
# generate output ids
|
# generate output ids
|
||||||
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
|
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate,temperature)
|
||||||
|
|
||||||
# decode the ids back into a string
|
# decode the ids back into a string
|
||||||
output_text = encoder.decode(output_ids)
|
output_text = encoder.decode(output_ids)
|
||||||
|
|
Ładowanie…
Reference in New Issue