Add env variable VALIDATE_JSON_RETREY

Add env variable VALIDATE_JSON_RETREY to configure number of retries when validating json object.
Add typing and comments.
pull/31/head^2^2
nervousapps 2023-04-18 11:42:09 +02:00 zatwierdzone przez GitHub
rodzic 2f5a026ff9
commit 9f75fc9123
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
1 zmienionych plików z 49 dodań i 39 usunięć

Wyświetl plik

@ -6,6 +6,7 @@ import shutil
import subprocess import subprocess
import sys import sys
import openai import openai
from typing import List, Dict
from termcolor import cprint from termcolor import cprint
from dotenv import load_dotenv from dotenv import load_dotenv
@ -14,14 +15,18 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY") openai.api_key = os.getenv("OPENAI_API_KEY")
# Default model is GPT-4
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-4") DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-4")
# Nb retries for json_validated_response, default to -1, infinite
VALIDATE_JSON_RETRY = int(os.getenv("VALIDATE_JSON_RETRY", -1))
with open("prompt.txt") as f: # Read the system prompt
with open(os.path.join(os.path.dirname(__file__), "prompt.txt"), 'r') as f:
SYSTEM_PROMPT = f.read() SYSTEM_PROMPT = f.read()
def run_script(script_name, script_args): def run_script(script_name: str, script_args: List) -> str:
script_args = [str(arg) for arg in script_args] script_args = [str(arg) for arg in script_args]
""" """
If script_name.endswith(".py") then run with python If script_name.endswith(".py") then run with python
@ -40,12 +45,14 @@ def run_script(script_name, script_args):
return result.decode("utf-8"), 0 return result.decode("utf-8"), 0
def json_validated_response(model, messages): def json_validated_response(model: str, messages: List[Dict], nb_retry: int = 0) -> Dict:
""" """
This function is needed because the API can return a non-json response. This function is needed because the API can return a non-json response.
This will run recursively until a valid json response is returned. This will run recursively VALIDATE_JSON_RETRY times.
todo: might want to stop after a certain number of retries If VALIDATE_JSON_RETRY is -1, it will run recursively until a valid json response is returned.
""" """
json_response = {}
if VALIDATE_JSON_RETRY == -1 or nb_retry < VALIDATE_JSON_RETRY:
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=model, model=model,
messages=messages, messages=messages,
@ -73,16 +80,19 @@ def json_validated_response(model, messages):
"content": "Your response could not be parsed by json.loads. Please restate your last message as pure JSON.", "content": "Your response could not be parsed by json.loads. Please restate your last message as pure JSON.",
} }
) )
# inc nb_retry
nb_retry+=1
# rerun the api call # rerun the api call
return json_validated_response(model, messages) return json_validated_response(model, messages, nb_retry)
except Exception as e: except Exception as e:
cprint(f"Unknown error: {e}", "red") cprint(f"Unknown error: {e}", "red")
cprint(f"\nGPT RESPONSE:\n\n{content}\n\n", "yellow") cprint(f"\nGPT RESPONSE:\n\n{content}\n\n", "yellow")
raise e raise e
# If not valid after VALIDATE_JSON_RETRY retries, return an empty object / or raise an exception and exit
return json_response return json_response
def send_error_to_gpt(file_path, args, error_message, model=DEFAULT_MODEL): def send_error_to_gpt(file_path: str, args: List, error_message: str, model: str = DEFAULT_MODEL) -> Dict:
with open(file_path, "r") as f: with open(file_path, "r") as f:
file_lines = f.readlines() file_lines = f.readlines()
@ -117,7 +127,7 @@ def send_error_to_gpt(file_path, args, error_message, model=DEFAULT_MODEL):
return json_validated_response(model, messages) return json_validated_response(model, messages)
def apply_changes(file_path, changes: list, confirm=False): def apply_changes(file_path: str, changes: List, confirm: bool = False):
""" """
Pass changes as loaded json (list of dicts) Pass changes as loaded json (list of dicts)
""" """