diff --git a/wolverine.py b/wolverine.py index f1d910d..d9eff4a 100644 --- a/wolverine.py +++ b/wolverine.py @@ -1,18 +1,27 @@ import difflib -import fire import json import os import shutil import subprocess import sys -import openai + +import fire +from dotenv import load_dotenv from termcolor import cprint from dotenv import load_dotenv -# Set up the OpenAI API load_dotenv() -openai.api_key = os.getenv("OPENAI_API_KEY") + +import openai + +openai.api_key = os.environ.get("OPENAI_API_KEY") + +DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-4") + + +with open("prompt.txt") as f: + SYSTEM_PROMPT = f.read() def run_script(script_name, script_args): @@ -26,7 +35,47 @@ def run_script(script_name, script_args): return result.decode("utf-8"), 0 -def send_error_to_gpt(file_path, args, error_message, model): +def send_error_to_gpt(file_path, args, error_message, model=DEFAULT_MODEL): + def json_validated_response(model, messages): + """ + This function is needed because the API can return a non-json response. + This will run recursively until a valid json response is returned. + """ + response = openai.ChatCompletion.create( + model=model, + messages=messages, + temperature=0.5, + ) + messages.append(response.choices[0].message) + content = response.choices[0].message.content + # see if json can be parsed + try: + json_start_index = content.index( + "[" + ) # find the starting position of the JSON data + json_data = content[ + json_start_index: + ] # extract the JSON data from the response string + json_response = json.loads(json_data) + except (json.decoder.JSONDecodeError, ValueError) as e: + cprint(f"{e}. Re-running the query.", "red") + # debug + cprint(f"\n\GPT RESPONSE:\n\n{content}\n\n", "yellow") + # append a user message that says the json is invalid + messages.append( + { + "role": "user", + "content": "Your response could not be parsed by json.loads. Please restate your last message as pure JSON.", + } + ) + # rerun the api call + return json_validated_response(model, messages) + except Exception as e: + cprint(f"Unknown error: {e}", "red") + cprint(f"\n\GPT RESPONSE:\n\n{content}\n\n", "yellow") + raise e + return json_response + with open(file_path, "r") as f: file_lines = f.readlines() @@ -35,12 +84,7 @@ def send_error_to_gpt(file_path, args, error_message, model): file_with_lines.append(str(i + 1) + ": " + line) file_with_lines = "".join(file_with_lines) - with open("prompt.txt") as f: - initial_prompt_text = f.read() - prompt = ( - initial_prompt_text + - "\n\n" "Here is the script that needs fixing:\n\n" f"{file_with_lines}\n\n" "Here are the arguments it was provided:\n\n" @@ -52,26 +96,27 @@ def send_error_to_gpt(file_path, args, error_message, model): ) # print(prompt) - response = openai.ChatCompletion.create( - model=model, - messages=[ - { - "role": "user", - "content": prompt, - } - ], - temperature=1.0, - ) + messages = [ + { + "role": "system", + "content": SYSTEM_PROMPT, + }, + { + "role": "user", + "content": prompt, + }, + ] - return response.choices[0].message.content.strip() + return json_validated_response(model, messages) -def apply_changes(file_path, changes_json): +def apply_changes(file_path, changes: list): + """ + Pass changes as loaded json (list of dicts) + """ with open(file_path, "r") as f: original_file_lines = f.readlines() - changes = json.loads(changes_json) - # Filter out explanation elements operation_changes = [change for change in changes if "operation" in change] explanations = [ @@ -114,7 +159,7 @@ def apply_changes(file_path, changes_json): print(line, end="") -def main(script_name, *script_args, revert=False, model="gpt-4"): +def main(script_name, *script_args, revert=False, model=DEFAULT_MODEL): if revert: backup_file = script_name + ".bak" if os.path.exists(backup_file): @@ -140,11 +185,11 @@ def main(script_name, *script_args, revert=False, model="gpt-4"): print("Output:", output) json_response = send_error_to_gpt( - file_path=script_name, - args=script_args, - error_message=output, - model=model, + file_path=script_name, + args=script_args, + error_message=output, ) + apply_changes(script_name, json_response) cprint("Changes applied. Rerunning...", "blue")