working on databaseagent, still some problems

pull/137/head
= 2023-11-08 23:12:27 -05:00
rodzic ea5ed6e9e0
commit fa5d2f5f82
10 zmienionych plików z 103 dodań i 143 usunięć

Wyświetl plik

@ -34,6 +34,8 @@ services:
- settings-data:/home/settings - settings-data:/home/settings
healthcheck: healthcheck:
test: "pg_isready -d ${POSTGRES_DBNAME}" test: "pg_isready -d ${POSTGRES_DBNAME}"
ports:
- "${POSTGRES_PORT}:5432"
imposm: imposm:
image: kartoza/docker-osm:imposm-latest image: kartoza/docker-osm:imposm-latest

Wyświetl plik

@ -1,4 +1,3 @@
import openai
import json import json
import logging import logging
import os import os
@ -16,10 +15,11 @@ class DatabaseAgent:
def get_table_from_database(self, query): def get_table_from_database(self, query):
return {"name": "get_table_from_database", "query": query} return {"name": "get_table_from_database", "query": query}
def __init__(self, model_version="gpt-3.5-turbo-0613", schema=None): def __init__(self, client, schema=None):
self.model_version = model_version self.model_version = model_version
self.client = client
self.schema = schema self.schema = schema
self.function_descriptions = self.get_function_descriptions() self.tools = self.get_function_descriptions()
self.messages = [ self.messages = [
{ {
"role": "system", "role": "system",
@ -27,12 +27,7 @@ class DatabaseAgent:
When responding with sql queries, you must use the 'osm' schema desgignation. When responding with sql queries, you must use the 'osm' schema desgignation.
Favor SQL queries that return polygons first, then lines, then points. An example prompt from the user is: "Add buildings to the map. Ensure the query is restricted to the following bbox: 30, 50, 31, 51"
An example prompt from the user is:
'add buildings to the map
Ensure the query is restricted to the following bbox: 30, 50, 31, 51'
You would respond with: You would respond with:
'SELECT ST_AsGeoJSON(geometry) FROM osm.osm_buildings WHERE ST_Intersects(geometry, ST_MakeEnvelope(30, 50, 31, 51, 4326))'. 'SELECT ST_AsGeoJSON(geometry) FROM osm.osm_buildings WHERE ST_Intersects(geometry, ST_MakeEnvelope(30, 50, 31, 51, 4326))'.
@ -50,92 +45,77 @@ class DatabaseAgent:
} }
def listen(self, message, bbox): def listen(self, message, bbox):
logger.info(f"In DatabaseAgent.listen()...message is: {message}")
logger.info(f"In DatabaseAgent.listen()...bbox is: {bbox}")
data = {'message': message}
response = requests.get("http://localhost:5000/get_table_name", json=data)
"""Listen to a message from the user.""" """Listen to a message from the user."""
map_context = f"Ensure the query is restricted to the following bbox: {bbox}" #data = {'message': message}
table_name_context = "" table_names = [table['table_name'] for table in self.schema]
prefixed_message = f"Choose the most likely table the following text is referring to from this list:\m {table_names}.\n"
final_message = prefixed_message + message
# use openai to choose the most likely table name from the schema
response = self.client.chat.completions.create(
model=self.model_version,
messages=[
{"role": "system", "content": "You are a helpful assistant that chooses a table name from a list. Only respond with the table name."},
{"role": "user", "content": final_message},
],
temperature=0,
max_tokens=32,
frequency_penalty=0,
presence_penalty=0,
)
table_name = response.choices[0].message.content
logger.info(f"table_name in DatabaseAgent is: {table_name}")
map_context = f"Ensure the query is restricted to the following bbox: {bbox}"
db = Database(
database=os.getenv("POSTGRES_DBNAME"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASS"),
host=os.getenv("POSTGRES_HOST"),
port=os.getenv("POSTGRES_PORT")
)
column_names = db.get_column_names(table_name)
logger.info(f"column_names in DatabaseAgent is: {column_names}")
db.close()
table_name_context = f"Generate your query using the following table name: {table_name} and the appropriate column names: {column_names}"
# attempt to get the tablename from the user message
# we do this to avoid sending the entire schema to openai
if response.status_code == 200:
logger.info(f"Response from /get_table_name route: {response}")
response_data = response.json()
logger.info(f"Response message from /get_table_name route: {response_data}")
table_name = response_data.get('choices', [{}])[0].get('message', {}).get('content', '')
if table_name:
logger.info(f"Table name: {table_name}")
db = Database(
database=os.getenv("POSTGRES_DBNAME"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASS"),
host=os.getenv("POSTGRES_HOST"),
port=os.getenv("POSTGRES_PORT")
)
column_names = db.get_column_names(table_name)
db.close()
table_name_context = f"Generate your query using the following table name: {table_name} and the appropriate column names: {column_names}"
#remove the last item in self.messages
if len(self.messages) > 1:
self.messages.pop()
self.messages.append({ self.messages.append({
"role": "user", "role": "user",
"content": message + "\n" + map_context + "\n" + table_name_context, "content": message + "\n" + map_context + "\n" + table_name_context,
}) })
# self.messages.append({
# "role": "user",
# "content": map_context,
# })
logger.info(f"DatabaseAgent self.messages: {self.messages}")
logger.info(f"DatabaseAgent self.function_descriptions: {self.function_descriptions}")
# this will be the function gpt will call if it # this will be the function gpt will call if it
# determines that the user wants to call a function # determines that the user wants to call a function
function_response = None function_response = None
try: try:
response = openai.ChatCompletion.create( logger.info("Calling OpenAI API in DatabaseAgent...")
response = self.client.chat.completions.create(
model=self.model_version, model=self.model_version,
messages=self.messages, messages=self.messages,
functions=self.function_descriptions, tools=self.tools,
function_call="auto", tool_choice="auto",
temperature=0,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0
) )
response_message = response.choices[0].message
logger.info(f"response_message in DatabaseAgent is: {response_message}")
tool_calls = response_message.tool_calls
response_message = response["choices"][0]["message"] if tool_calls:
self.messages.append(response_message)
logger.info(f"Response from OpenAI in DatabaseAgent: {response_message}") for tool_call in tool_calls:
function_name = tool_call.function.name
if response_message.get("function_call"): function_to_call = self.available_functions[function_name]
function_name = response_message["function_call"]["name"] function_args = json.loads(tool_call.function.arguments)
function_name = function_name.strip() function_response = function_to_call(**function_args)
logger.info(f"Function name: {function_name}") self.messages.append(
function_to_call = self.available_functions[function_name] {
logger.info(f"Function to call: {function_to_call}") "tool_call_id": tool_call.id,
function_args = json.loads(response_message["function_call"]["arguments"]) "role": "tool",
logger.info(f"Function args: {function_args}") "name": function_name,
# determine the function to call "content": json.dumps(function_response),
function_response = function_to_call(**function_args) }
)
### This is blowing up my context window logger.info("Sucessful StyleAgent task completion.")
# self.messages.append(response_message)
# self.messages.append({
# "role": "function",
# "name": function_name,
# "content": function_response,
# })
return {"response": function_response} return {"response": function_response}
elif response_message.get("content"):
return {"response": response_message["content"]}
else:
return {"response": "I'm sorry, I don't understand."}
except Exception as e: except Exception as e:
return {"error": "Failed to get response from OpenAI in NavigationAgent: " + str(e)}, 500 return {"error": "Failed to get response from OpenAI in NavigationAgent: " + str(e)}, 500
@ -145,13 +125,13 @@ class DatabaseAgent:
return [ return [
{ {
"name": "get_geojson_from_database", "name": "get_geojson_from_database",
"description": """Retrieve geojson sptatial data using PostGIS SQL.""", "description": """Retrieve geojson spatial data using PostGIS SQL.""",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"query": { "query": {
"type": "string", "type": "string",
"description": f"""SQL query to get geojson from the database. "description": """SQL query to get geojson from the database.
The query shall be returned in string format as a single command.""" The query shall be returned in string format as a single command."""
}, },
}, },

Wyświetl plik

@ -10,8 +10,8 @@ class MapInfoAgent:
def select_layer_name(self, layer_name): def select_layer_name(self, layer_name):
return {"name": "select_layer_name", "layer_name": layer_name} return {"name": "select_layer_name", "layer_name": layer_name}
def __init__(self, openai, model_version="gpt-3.5-turbo-0613"): def __init__(self, client, model_version):
self.openai = openai self.client = client
self.model_version = model_version self.model_version = model_version
self.tools = map_info_function_descriptions self.tools = map_info_function_descriptions
self.messages = [ self.messages = [
@ -49,7 +49,7 @@ class MapInfoAgent:
function_response = None function_response = None
try: try:
response = self.openai.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_version, model=self.model_version,
messages=self.messages, messages=self.messages,
tools=self.tools, tools=self.tools,
@ -80,25 +80,23 @@ class MapInfoAgent:
# what they meant. In general, it's pretty good at this unless there are multiple layers with similar names, in which case # what they meant. In general, it's pretty good at this unless there are multiple layers with similar names, in which case
# it just chooses one. # it just chooses one.
if function_name == "select_layer_name": if function_name == "select_layer_name":
logger.info(f"Sending layer name retrieval request to OpenAI...")
prompt = f"Please select a layer name from the following list that is closest to the text '{function_response['layer_name']}': {str(layer_names)}\n Only state the layer name in your response." prompt = f"""Please select a layer name from the following list that is closest to the
logger.info(f"Prompt to OpenAI: {prompt}") text '{function_response['layer_name']}': {str(layer_names)}\n
Only state the layer name in your response."""
messages = [ messages = [
{ {
"role": "user", "role": "user",
"content": prompt, "content": prompt,
}, },
] ]
second_response = self.openai.chat.completions.create( second_response = self.client.chat.completions.create(
model=self.model_version, model=self.model_version,
messages=messages, messages=messages,
) )
logger.info(f"Second response from OpenAI in MapInfoAgent: {second_response}")
second_response_message = second_response.choices[0].message.content second_response_message = second_response.choices[0].message.content
logger.info(f"Second response message from OpenAI in MapInfoAgent: {second_response_message}")
logger.info(f"Function Response bofore setting the layer name: {function_response}")
function_response['layer_name'] = second_response_message function_response['layer_name'] = second_response_message
logger.info(f"Function response after call to select a layername: {function_response}")
return {"response": function_response} return {"response": function_response}
elif response_message.get("content"): elif response_message.get("content"):
return {"response": response_message["content"]} return {"response": response_message["content"]}

Wyświetl plik

@ -1,14 +1,13 @@
import json import json
#import openai
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MarshallAgent: class MarshallAgent:
"""A Marshall agent that has function descriptions for choosing the appropriate agent for a specified task.""" """A Marshall agent that has function descriptions for choosing the appropriate agent for a specified task."""
def __init__(self, openai, model_version="gpt-3.5-turbo-0613"): def __init__(self, client, model_version):
self.model_version = model_version self.model_version = model_version
self.openai = openai self.client = client
self.tools = [ self.tools = [
{ {
"type": "function", "type": "function",
@ -31,7 +30,7 @@ class MarshallAgent:
self.system_message = """You are a helpful assistant that decides which agent to use for a specified task. self.system_message = """You are a helpful assistant that decides which agent to use for a specified task.
For tasks related to adding layers and other geospatial data to the map, use the DatabaseAgent. For tasks related to adding layers and other geospatial data to the map, use the DatabaseAgent.
Examples include 'add buildings to the map' and 'get landuse polygons within this extent'. Examples include 'add buildings to the map', 'show industrial buildings', and 'get landuse polygons within this extent'.
For tasks that ask to change the style of a map, such as opacity, color, or line width, you will For tasks that ask to change the style of a map, such as opacity, color, or line width, you will
use the StyleAgent. Examples StyleAgent prompts include 'change color to green', 'opacity 45%' use the StyleAgent. Examples StyleAgent prompts include 'change color to green', 'opacity 45%'
@ -52,18 +51,14 @@ class MarshallAgent:
self.available_functions = { self.available_functions = {
"choose_agent": self.choose_agent, "choose_agent": self.choose_agent,
} }
self.logger = logging.getLogger(__name__)
def choose_agent(self, agent_name): def choose_agent(self, agent_name):
return {"name": "choose_agent", "agent_name": agent_name} return {"name": "choose_agent", "agent_name": agent_name}
def listen(self, message): def listen(self, message):
self.logger.info(f"In MarshallAgent.listen()...message is: {message}") logger.info(f"In MarshallAgent.listen()...message is: {message}")
"""Listen to a message from the user.""" """Listen to a message from the user."""
# # Remove the last item in self.messages. Our agent has no memory
# if len(self.messages) > 1:
# self.messages.pop()
self.messages.append({ self.messages.append({
"role": "user", "role": "user",
@ -75,7 +70,7 @@ class MarshallAgent:
function_response = None function_response = None
try: try:
response = self.openai.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_version, model=self.model_version,
messages=self.messages, messages=self.messages,
tools=self.tools, tools=self.tools,
@ -100,11 +95,7 @@ class MarshallAgent:
"content": json.dumps(function_response), "content": json.dumps(function_response),
} }
) )
# second_response = self.openai.chat.completions.create( logger.info(f"Sucessful MarshallAgent task completion: {function_response}")
# model=self.model_version,
# messages=self.messages,
# )
logger.info(f"Sucessful MarallAgent task completion: {function_response}")
return {"response": function_response} return {"response": function_response}
except Exception as e: except Exception as e:

Wyświetl plik

@ -1,6 +1,6 @@
import logging import logging
from .function_descriptions.navigation_function_descriptions import navigation_function_descriptions from .function_descriptions.navigation_function_descriptions import navigation_function_descriptions
#import openai #import client
import json import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,8 +20,8 @@ class NavigationAgent:
def zoom_out(self, zoom_levels=1): def zoom_out(self, zoom_levels=1):
return {"name": "zoom_out", "zoom_levels": zoom_levels} return {"name": "zoom_out", "zoom_levels": zoom_levels}
def __init__(self, openai, model_version="gpt-3.5-turbo-0613"): def __init__(self, client, model_version):
self.openai = openai self.client = client
self.model_version = model_version self.model_version = model_version
self.messages = [ self.messages = [
{ {
@ -42,7 +42,6 @@ class NavigationAgent:
"zoom_out": self.zoom_out, "zoom_out": self.zoom_out,
} }
self.tools = navigation_function_descriptions self.tools = navigation_function_descriptions
logger.info(f"self.tools in NavigationAgent is: {self.tools}")
def listen(self, message): def listen(self, message):
logging.info(f"In NavigationAgent...message is: {message}") logging.info(f"In NavigationAgent...message is: {message}")
@ -61,15 +60,13 @@ class NavigationAgent:
try: try:
logger.info("Calling OpenAI API in NavigationAgent...") logger.info("Calling OpenAI API in NavigationAgent...")
response = self.openai.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_version, model=self.model_version,
messages=self.messages, messages=self.messages,
tools=self.tools, tools=self.tools,
tool_choice="auto", tool_choice="auto",
) )
logger.info(f"response in NavigationAgent is: {response}")
response_message = response.choices[0].message response_message = response.choices[0].message
logger.info(f"response_message in NavigationAgent is: {response_message}")
tool_calls = response_message.tool_calls tool_calls = response_message.tool_calls
if tool_calls: if tool_calls:

Wyświetl plik

@ -19,8 +19,8 @@ class StyleAgent:
def set_visibility(self, layer_name, visibility): def set_visibility(self, layer_name, visibility):
return {"name": "set_visibility", "layer_name": layer_name, "visibility": visibility} return {"name": "set_visibility", "layer_name": layer_name, "visibility": visibility}
def __init__(self, openai, model_version="gpt-3.5-turbo-0613"): def __init__(self, client, model_version):
self.openai = openai self.client = client
self.model_version = model_version self.model_version = model_version
self.tools = style_function_descriptions self.tools = style_function_descriptions
@ -57,15 +57,13 @@ class StyleAgent:
try: try:
logger.info("Calling OpenAI API in StyleAgent...") logger.info("Calling OpenAI API in StyleAgent...")
response = self.openai.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_version, model=self.model_version,
messages=self.messages, messages=self.messages,
tools=self.tools, tools=self.tools,
tool_choice="auto", tool_choice="auto",
) )
logger.info(f"response in StyleAgent is: {response}")
response_message = response.choices[0].message response_message = response.choices[0].message
logger.info(f"response_message in StyleAgent is: {response_message}")
tool_calls = response_message.tool_calls tool_calls = response_message.tool_calls
if tool_calls: if tool_calls:

Wyświetl plik

@ -22,14 +22,14 @@ load_dotenv()
app = Flask(__name__) app = Flask(__name__)
CORS(app) CORS(app)
openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
model_version = os.getenv("OPENAI_MODEL_VERSION") model_version = os.getenv("OPENAI_MODEL_VERSION")
UPLOAD_FOLDER = 'uploads/audio' UPLOAD_FOLDER = 'uploads/audio'
navigation_agent = NavigationAgent(openai, model_version=model_version) navigation_agent = NavigationAgent(client, model_version=model_version)
marshall_agent = MarshallAgent(openai, model_version=model_version) marshall_agent = MarshallAgent(client, model_version=model_version)
style_agent = StyleAgent(openai, model_version=model_version) style_agent = StyleAgent(client, model_version=model_version)
map_info_agent = MapInfoAgent(openai, model_version=model_version) map_info_agent = MapInfoAgent(client, model_version=model_version)
def get_database_schema(): def get_database_schema():
db = Database( db = Database(
@ -44,15 +44,12 @@ def get_database_schema():
return schema return schema
schema = get_database_schema() schema = get_database_schema()
database_agent = DatabaseAgent(model_version=model_version, schema=schema) database_agent = DatabaseAgent(client, model_version=model_version, schema=schema)
@app.route('/get_query', methods=['POST']) @app.route('/get_query', methods=['POST'])
def get_query(): def get_query():
logging.info(f"Received request in /get_query route...: {request}")
message = request.json.get('message', '') message = request.json.get('message', '')
bbox = request.json.get('bbox', '') bbox = request.json.get('bbox', '')
logging.info(f"Received message in /get_query route...: {message}")
logging.info(f"Received bbox in /get_query route...: {bbox}")
return jsonify(database_agent.listen(message, bbox)) return jsonify(database_agent.listen(message, bbox))
@app.route('/get_table_name', methods=['GET']) @app.route('/get_table_name', methods=['GET'])
@ -62,17 +59,16 @@ def get_table_name():
prefixed_message = f"Choose the most likely table the following text is referring to from this list:\m {table_names}.\n" prefixed_message = f"Choose the most likely table the following text is referring to from this list:\m {table_names}.\n"
final_message = prefixed_message + message final_message = prefixed_message + message
logging.info(f"Received message in /get_table_name route...: {final_message}") logging.info(f"Received message in /get_table_name route...: {final_message}")
response = openai.ChatCompletion.create( response = client.chat.completions.create(
model=model_version, model=model_version,
messages=[ messages=[
{"role": "system", "content": "You are a helpful assistant that chooses a table name from a list. Only respond with the table name."}, {"role": "system", "content": "You are a helpful assistant that chooses a table name from a list. Only respond with the table name."},
{"role": "user", "content": final_message}, {"role": "user", "content": final_message},
], ],
temperature=0, temperature=0,
max_tokens=256, max_tokens=32,
top_p=1,
frequency_penalty=0, frequency_penalty=0,
presence_penalty=0 presence_penalty=0,
) )
logging.info(f"Response from OpenAI in /get_table_name route: {response}") logging.info(f"Response from OpenAI in /get_table_name route: {response}")
#response_message = response["choices"][0]["message"] #response_message = response["choices"][0]["message"]
@ -163,7 +159,10 @@ def upload_audio():
audio_file = request.files['audio'] audio_file = request.files['audio']
audio_file.save(os.path.join(UPLOAD_FOLDER, "user_audio.webm")) audio_file.save(os.path.join(UPLOAD_FOLDER, "user_audio.webm"))
audio_file=open(os.path.join(UPLOAD_FOLDER, "user_audio.webm"), 'rb') audio_file=open(os.path.join(UPLOAD_FOLDER, "user_audio.webm"), 'rb')
transcript = openai.audio.transcribe("whisper-1", audio_file) transcript = client.audio.transcriptions.create(
model="whisper-1",
file = audio_file
)
logging.info(f"Received transcript: {transcript}") logging.info(f"Received transcript: {transcript}")
message = transcript['text'] message = transcript['text']
#delete the audio #delete the audio

Wyświetl plik

@ -270,7 +270,6 @@
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
response_data = data.response; response_data = data.response;
console.log(response_data);
chatbox.value += databaseResponseAgent.handleResponse(userMessage, response_data); chatbox.value += databaseResponseAgent.handleResponse(userMessage, response_data);
return; return;
}) })

File diff suppressed because one or more lines are too long

Wyświetl plik

@ -0,0 +1,4 @@
[ZoneTransfer]
ZoneId=3
ReferrerUrl=https://download.geofabrik.de/europe.html
HostUrl=https://download.geofabrik.de/europe/liechtenstein-latest.osm.pbf