recursive calls if not json parsable

- makes recursive calls to API (with a comment about it not being parsable) if response was not parsable
- pass prompt.txt as system prompt
- use env var for `DEFAULT_MODEL`
- use env var for OPENAI_API_KEY
pull/13/head
Felix Boehme 2023-04-13 11:29:06 -04:00
rodzic 7c072fba2a
commit 0656a83da7
1 zmienionych plików z 74 dodań i 29 usunięć

Wyświetl plik

@ -1,18 +1,27 @@
import difflib import difflib
import fire
import json import json
import os import os
import shutil import shutil
import subprocess import subprocess
import sys import sys
import openai
import fire
from dotenv import load_dotenv
from termcolor import cprint from termcolor import cprint
from dotenv import load_dotenv from dotenv import load_dotenv
# Set up the OpenAI API
load_dotenv() load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
import openai
openai.api_key = os.environ.get("OPENAI_API_KEY")
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-4")
with open("prompt.txt") as f:
SYSTEM_PROMPT = f.read()
def run_script(script_name, script_args): def run_script(script_name, script_args):
@ -26,7 +35,47 @@ def run_script(script_name, script_args):
return result.decode("utf-8"), 0 return result.decode("utf-8"), 0
def send_error_to_gpt(file_path, args, error_message, model): def send_error_to_gpt(file_path, args, error_message, model=DEFAULT_MODEL):
def json_validated_response(model, messages):
"""
This function is needed because the API can return a non-json response.
This will run recursively until a valid json response is returned.
"""
response = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=0.5,
)
messages.append(response.choices[0].message)
content = response.choices[0].message.content
# see if json can be parsed
try:
json_start_index = content.index(
"["
) # find the starting position of the JSON data
json_data = content[
json_start_index:
] # extract the JSON data from the response string
json_response = json.loads(json_data)
except (json.decoder.JSONDecodeError, ValueError) as e:
cprint(f"{e}. Re-running the query.", "red")
# debug
cprint(f"\n\GPT RESPONSE:\n\n{content}\n\n", "yellow")
# append a user message that says the json is invalid
messages.append(
{
"role": "user",
"content": "Your response could not be parsed by json.loads. Please restate your last message as pure JSON.",
}
)
# rerun the api call
return json_validated_response(model, messages)
except Exception as e:
cprint(f"Unknown error: {e}", "red")
cprint(f"\n\GPT RESPONSE:\n\n{content}\n\n", "yellow")
raise e
return json_response
with open(file_path, "r") as f: with open(file_path, "r") as f:
file_lines = f.readlines() file_lines = f.readlines()
@ -35,12 +84,7 @@ def send_error_to_gpt(file_path, args, error_message, model):
file_with_lines.append(str(i + 1) + ": " + line) file_with_lines.append(str(i + 1) + ": " + line)
file_with_lines = "".join(file_with_lines) file_with_lines = "".join(file_with_lines)
with open("prompt.txt") as f:
initial_prompt_text = f.read()
prompt = ( prompt = (
initial_prompt_text +
"\n\n"
"Here is the script that needs fixing:\n\n" "Here is the script that needs fixing:\n\n"
f"{file_with_lines}\n\n" f"{file_with_lines}\n\n"
"Here are the arguments it was provided:\n\n" "Here are the arguments it was provided:\n\n"
@ -52,26 +96,27 @@ def send_error_to_gpt(file_path, args, error_message, model):
) )
# print(prompt) # print(prompt)
response = openai.ChatCompletion.create( messages = [
model=model, {
messages=[ "role": "system",
{ "content": SYSTEM_PROMPT,
"role": "user", },
"content": prompt, {
} "role": "user",
], "content": prompt,
temperature=1.0, },
) ]
return response.choices[0].message.content.strip() return json_validated_response(model, messages)
def apply_changes(file_path, changes_json): def apply_changes(file_path, changes: list):
"""
Pass changes as loaded json (list of dicts)
"""
with open(file_path, "r") as f: with open(file_path, "r") as f:
original_file_lines = f.readlines() original_file_lines = f.readlines()
changes = json.loads(changes_json)
# Filter out explanation elements # Filter out explanation elements
operation_changes = [change for change in changes if "operation" in change] operation_changes = [change for change in changes if "operation" in change]
explanations = [ explanations = [
@ -114,7 +159,7 @@ def apply_changes(file_path, changes_json):
print(line, end="") print(line, end="")
def main(script_name, *script_args, revert=False, model="gpt-4"): def main(script_name, *script_args, revert=False, model=DEFAULT_MODEL):
if revert: if revert:
backup_file = script_name + ".bak" backup_file = script_name + ".bak"
if os.path.exists(backup_file): if os.path.exists(backup_file):
@ -140,11 +185,11 @@ def main(script_name, *script_args, revert=False, model="gpt-4"):
print("Output:", output) print("Output:", output)
json_response = send_error_to_gpt( json_response = send_error_to_gpt(
file_path=script_name, file_path=script_name,
args=script_args, args=script_args,
error_message=output, error_message=output,
model=model,
) )
apply_changes(script_name, json_response) apply_changes(script_name, json_response)
cprint("Changes applied. Rerunning...", "blue") cprint("Changes applied. Rerunning...", "blue")