kopia lustrzana https://github.com/jaymody/picoGPT
Merge 20301c9d05
into 817292baea
commit
1fba64b3f4
|
@ -0,0 +1,50 @@
|
||||||
|
import time
|
||||||
|
from gpt2 import main as gpt2_main
|
||||||
|
from utils import load_encoder_hparams_and_params
|
||||||
|
|
||||||
|
def benchmark_generation(prompt, n_tokens_to_generate, model_size, use_speculative):
|
||||||
|
start_time = time.time()
|
||||||
|
gpt2_main(prompt, n_tokens_to_generate, model_size, use_generate_speculative=use_speculative)
|
||||||
|
end_time = time.time()
|
||||||
|
return end_time - start_time
|
||||||
|
|
||||||
|
def run_benchmark(prompt, n_tokens_to_generate, model_sizes):
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for model_size in model_sizes:
|
||||||
|
print(f"Benchmarking {model_size} model...")
|
||||||
|
|
||||||
|
# Warm-up run
|
||||||
|
benchmark_generation(prompt, n_tokens_to_generate, model_size, False)
|
||||||
|
benchmark_generation(prompt, n_tokens_to_generate, model_size, True)
|
||||||
|
|
||||||
|
# Actual benchmark
|
||||||
|
standard_time = benchmark_generation(prompt, n_tokens_to_generate, model_size, False)
|
||||||
|
speculative_time = benchmark_generation(prompt, n_tokens_to_generate, model_size, True)
|
||||||
|
|
||||||
|
improvement = (standard_time - speculative_time) / standard_time * 100
|
||||||
|
results[model_size] = {
|
||||||
|
"standard_time": standard_time,
|
||||||
|
"speculative_time": speculative_time,
|
||||||
|
"improvement": improvement
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def main():
|
||||||
|
prompt = "In a world where artificial intelligence has become ubiquitous"
|
||||||
|
n_tokens_to_generate = 50
|
||||||
|
model_sizes = ["124M", "355M"]
|
||||||
|
|
||||||
|
results = run_benchmark(prompt, n_tokens_to_generate, model_sizes)
|
||||||
|
|
||||||
|
print("\nBenchmark Results:")
|
||||||
|
print("==================")
|
||||||
|
for model_size, data in results.items():
|
||||||
|
print(f"\nModel Size: {model_size}")
|
||||||
|
print(f"Standard Generation Time: {data['standard_time']:.4f} seconds")
|
||||||
|
print(f"Speculative Generation Time: {data['speculative_time']:.4f} seconds")
|
||||||
|
print(f"Improvement: {data['improvement']:.2f}%")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
47
gpt2.py
47
gpt2.py
|
@ -94,12 +94,50 @@ def generate(inputs, params, n_head, n_tokens_to_generate):
|
||||||
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
|
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
|
||||||
|
|
||||||
|
|
||||||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
def generate_speculative(inputs, params, draft_params, n_head, n_tokens_to_generate, n_speculative=3):
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
for _ in tqdm(range(n_tokens_to_generate), "generating"):
|
||||||
|
# Generate speculative tokens using the draft model
|
||||||
|
draft_inputs = inputs.copy()
|
||||||
|
draft_tokens = []
|
||||||
|
for _ in range(n_speculative):
|
||||||
|
# Use the draft model to predict the next token
|
||||||
|
draft_logits = gpt2(draft_inputs, **draft_params, n_head=n_head)
|
||||||
|
next_id = np.argmax(draft_logits[-1])
|
||||||
|
draft_tokens.append(int(next_id))
|
||||||
|
draft_inputs.append(next_id)
|
||||||
|
|
||||||
|
# Verify speculative tokens using the main model
|
||||||
|
main_logits = gpt2(inputs + draft_tokens, **params, n_head=n_head)
|
||||||
|
main_probs = softmax(main_logits[-n_speculative-1:])
|
||||||
|
|
||||||
|
# Compare draft model predictions with main model predictions
|
||||||
|
accepted_tokens = 0
|
||||||
|
for i, token in enumerate(draft_tokens):
|
||||||
|
if np.argmax(main_probs[i]) == token:
|
||||||
|
accepted_tokens += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add accepted tokens to the input
|
||||||
|
inputs.extend(draft_tokens[:accepted_tokens])
|
||||||
|
|
||||||
|
# If no tokens were accepted, use the main model's prediction
|
||||||
|
if accepted_tokens == 0:
|
||||||
|
next_id = np.argmax(main_probs[0])
|
||||||
|
inputs.append(int(next_id))
|
||||||
|
|
||||||
|
# Return only the newly generated tokens
|
||||||
|
return inputs[len(inputs) - n_tokens_to_generate:]
|
||||||
|
|
||||||
|
|
||||||
|
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models", use_generate_speculative: bool = True):
|
||||||
from utils import load_encoder_hparams_and_params
|
from utils import load_encoder_hparams_and_params
|
||||||
|
|
||||||
# load encoder, hparams, and params from the released open-ai gpt-2 files
|
# load encoder, hparams, and params from the released open-ai gpt-2 files
|
||||||
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
|
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
|
||||||
|
_, _, draft_params = load_encoder_hparams_and_params("124M", models_dir)
|
||||||
# encode the input string using the BPE tokenizer
|
# encode the input string using the BPE tokenizer
|
||||||
input_ids = encoder.encode(prompt)
|
input_ids = encoder.encode(prompt)
|
||||||
|
|
||||||
|
@ -107,7 +145,10 @@ def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M",
|
||||||
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
|
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
|
||||||
|
|
||||||
# generate output ids
|
# generate output ids
|
||||||
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
|
if use_generate_speculative:
|
||||||
|
output_ids = generate_speculative(input_ids, params, draft_params, hparams["n_head"], n_tokens_to_generate)
|
||||||
|
else:
|
||||||
|
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
|
||||||
|
|
||||||
# decode the ids back into a string
|
# decode the ids back into a string
|
||||||
output_text = encoder.decode(output_ids)
|
output_text = encoder.decode(output_ids)
|
||||||
|
|
Ładowanie…
Reference in New Issue