Yash Jain 2024-11-17 16:46:59 +05:30 zatwierdzone przez GitHub
commit 1fba64b3f4
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
2 zmienionych plików z 94 dodań i 3 usunięć

Wyświetl plik

@ -0,0 +1,50 @@
import time
from gpt2 import main as gpt2_main
from utils import load_encoder_hparams_and_params
def benchmark_generation(prompt, n_tokens_to_generate, model_size, use_speculative):
start_time = time.time()
gpt2_main(prompt, n_tokens_to_generate, model_size, use_generate_speculative=use_speculative)
end_time = time.time()
return end_time - start_time
def run_benchmark(prompt, n_tokens_to_generate, model_sizes):
results = {}
for model_size in model_sizes:
print(f"Benchmarking {model_size} model...")
# Warm-up run
benchmark_generation(prompt, n_tokens_to_generate, model_size, False)
benchmark_generation(prompt, n_tokens_to_generate, model_size, True)
# Actual benchmark
standard_time = benchmark_generation(prompt, n_tokens_to_generate, model_size, False)
speculative_time = benchmark_generation(prompt, n_tokens_to_generate, model_size, True)
improvement = (standard_time - speculative_time) / standard_time * 100
results[model_size] = {
"standard_time": standard_time,
"speculative_time": speculative_time,
"improvement": improvement
}
return results
def main():
prompt = "In a world where artificial intelligence has become ubiquitous"
n_tokens_to_generate = 50
model_sizes = ["124M", "355M"]
results = run_benchmark(prompt, n_tokens_to_generate, model_sizes)
print("\nBenchmark Results:")
print("==================")
for model_size, data in results.items():
print(f"\nModel Size: {model_size}")
print(f"Standard Generation Time: {data['standard_time']:.4f} seconds")
print(f"Speculative Generation Time: {data['speculative_time']:.4f} seconds")
print(f"Improvement: {data['improvement']:.2f}%")
if __name__ == "__main__":
main()

47
gpt2.py
Wyświetl plik

@ -94,12 +94,50 @@ def generate(inputs, params, n_head, n_tokens_to_generate):
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
def generate_speculative(inputs, params, draft_params, n_head, n_tokens_to_generate, n_speculative=3):
from tqdm import tqdm
for _ in tqdm(range(n_tokens_to_generate), "generating"):
# Generate speculative tokens using the draft model
draft_inputs = inputs.copy()
draft_tokens = []
for _ in range(n_speculative):
# Use the draft model to predict the next token
draft_logits = gpt2(draft_inputs, **draft_params, n_head=n_head)
next_id = np.argmax(draft_logits[-1])
draft_tokens.append(int(next_id))
draft_inputs.append(next_id)
# Verify speculative tokens using the main model
main_logits = gpt2(inputs + draft_tokens, **params, n_head=n_head)
main_probs = softmax(main_logits[-n_speculative-1:])
# Compare draft model predictions with main model predictions
accepted_tokens = 0
for i, token in enumerate(draft_tokens):
if np.argmax(main_probs[i]) == token:
accepted_tokens += 1
else:
break
# Add accepted tokens to the input
inputs.extend(draft_tokens[:accepted_tokens])
# If no tokens were accepted, use the main model's prediction
if accepted_tokens == 0:
next_id = np.argmax(main_probs[0])
inputs.append(int(next_id))
# Return only the newly generated tokens
return inputs[len(inputs) - n_tokens_to_generate:]
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models", use_generate_speculative: bool = True):
from utils import load_encoder_hparams_and_params
# load encoder, hparams, and params from the released open-ai gpt-2 files
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
_, _, draft_params = load_encoder_hparams_and_params("124M", models_dir)
# encode the input string using the BPE tokenizer
input_ids = encoder.encode(prompt)
@ -107,7 +145,10 @@ def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M",
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
# generate output ids
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
if use_generate_speculative:
output_ids = generate_speculative(input_ids, params, draft_params, hparams["n_head"], n_tokens_to_generate)
else:
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
# decode the ids back into a string
output_text = encoder.decode(output_ids)