Add env variable VALIDATE_JSON_RETREY

Add env variable VALIDATE_JSON_RETREY to configure number of retries when validating json object.
Add typing and comments.
pull/31/head^2^2
nervousapps 2023-04-18 11:42:09 +02:00 zatwierdzone przez GitHub
rodzic 2f5a026ff9
commit 9f75fc9123
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
1 zmienionych plików z 49 dodań i 39 usunięć

Wyświetl plik

@ -6,6 +6,7 @@ import shutil
import subprocess import subprocess
import sys import sys
import openai import openai
from typing import List, Dict
from termcolor import cprint from termcolor import cprint
from dotenv import load_dotenv from dotenv import load_dotenv
@ -14,14 +15,18 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY") openai.api_key = os.getenv("OPENAI_API_KEY")
# Default model is GPT-4
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-4") DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-4")
# Nb retries for json_validated_response, default to -1, infinite
VALIDATE_JSON_RETRY = int(os.getenv("VALIDATE_JSON_RETRY", -1))
with open("prompt.txt") as f: # Read the system prompt
with open(os.path.join(os.path.dirname(__file__), "prompt.txt"), 'r') as f:
SYSTEM_PROMPT = f.read() SYSTEM_PROMPT = f.read()
def run_script(script_name, script_args): def run_script(script_name: str, script_args: List) -> str:
script_args = [str(arg) for arg in script_args] script_args = [str(arg) for arg in script_args]
""" """
If script_name.endswith(".py") then run with python If script_name.endswith(".py") then run with python
@ -40,49 +45,54 @@ def run_script(script_name, script_args):
return result.decode("utf-8"), 0 return result.decode("utf-8"), 0
def json_validated_response(model, messages): def json_validated_response(model: str, messages: List[Dict], nb_retry: int = 0) -> Dict:
""" """
This function is needed because the API can return a non-json response. This function is needed because the API can return a non-json response.
This will run recursively until a valid json response is returned. This will run recursively VALIDATE_JSON_RETRY times.
todo: might want to stop after a certain number of retries If VALIDATE_JSON_RETRY is -1, it will run recursively until a valid json response is returned.
""" """
response = openai.ChatCompletion.create( json_response = {}
model=model, if VALIDATE_JSON_RETRY == -1 or nb_retry < VALIDATE_JSON_RETRY:
messages=messages, response = openai.ChatCompletion.create(
temperature=0.5, model=model,
) messages=messages,
messages.append(response.choices[0].message) temperature=0.5,
content = response.choices[0].message.content
# see if json can be parsed
try:
json_start_index = content.index(
"["
) # find the starting position of the JSON data
json_data = content[
json_start_index:
] # extract the JSON data from the response string
json_response = json.loads(json_data)
except (json.decoder.JSONDecodeError, ValueError) as e:
cprint(f"{e}. Re-running the query.", "red")
# debug
cprint(f"\nGPT RESPONSE:\n\n{content}\n\n", "yellow")
# append a user message that says the json is invalid
messages.append(
{
"role": "user",
"content": "Your response could not be parsed by json.loads. Please restate your last message as pure JSON.",
}
) )
# rerun the api call messages.append(response.choices[0].message)
return json_validated_response(model, messages) content = response.choices[0].message.content
except Exception as e: # see if json can be parsed
cprint(f"Unknown error: {e}", "red") try:
cprint(f"\nGPT RESPONSE:\n\n{content}\n\n", "yellow") json_start_index = content.index(
raise e "["
) # find the starting position of the JSON data
json_data = content[
json_start_index:
] # extract the JSON data from the response string
json_response = json.loads(json_data)
except (json.decoder.JSONDecodeError, ValueError) as e:
cprint(f"{e}. Re-running the query.", "red")
# debug
cprint(f"\nGPT RESPONSE:\n\n{content}\n\n", "yellow")
# append a user message that says the json is invalid
messages.append(
{
"role": "user",
"content": "Your response could not be parsed by json.loads. Please restate your last message as pure JSON.",
}
)
# inc nb_retry
nb_retry+=1
# rerun the api call
return json_validated_response(model, messages, nb_retry)
except Exception as e:
cprint(f"Unknown error: {e}", "red")
cprint(f"\nGPT RESPONSE:\n\n{content}\n\n", "yellow")
raise e
# If not valid after VALIDATE_JSON_RETRY retries, return an empty object / or raise an exception and exit
return json_response return json_response
def send_error_to_gpt(file_path, args, error_message, model=DEFAULT_MODEL): def send_error_to_gpt(file_path: str, args: List, error_message: str, model: str = DEFAULT_MODEL) -> Dict:
with open(file_path, "r") as f: with open(file_path, "r") as f:
file_lines = f.readlines() file_lines = f.readlines()
@ -117,7 +127,7 @@ def send_error_to_gpt(file_path, args, error_message, model=DEFAULT_MODEL):
return json_validated_response(model, messages) return json_validated_response(model, messages)
def apply_changes(file_path, changes: list, confirm=False): def apply_changes(file_path: str, changes: List, confirm: bool = False):
""" """
Pass changes as loaded json (list of dicts) Pass changes as loaded json (list of dicts)
""" """