From d62ff8558b04b34afbe09469dd806c32d8b9e749 Mon Sep 17 00:00:00 2001 From: shanergi Date: Thu, 9 Nov 2023 13:33:14 -0500 Subject: [PATCH] fixed database querying --- flask_app/agents/database_agent.py | 54 +++++-------------- .../database_function_descriptions.py | 42 +++++++++++++++ flask_app/agents/navigation_agent.py | 14 ++--- flask_app/app.py | 13 +++-- ...echtenstein-latest.osm.pbf:Zone.Identifier | 4 -- 5 files changed, 71 insertions(+), 56 deletions(-) create mode 100644 flask_app/agents/function_descriptions/database_function_descriptions.py delete mode 100644 settings/liechtenstein-latest.osm.pbf:Zone.Identifier diff --git a/flask_app/agents/database_agent.py b/flask_app/agents/database_agent.py index 1deba18..3d4c184 100644 --- a/flask_app/agents/database_agent.py +++ b/flask_app/agents/database_agent.py @@ -1,8 +1,8 @@ import json import logging import os -import requests from utils.database import Database +from .function_descriptions.database_function_descriptions import database_function_descriptions logger = logging.getLogger(__name__) @@ -15,17 +15,18 @@ class DatabaseAgent: def get_table_from_database(self, query): return {"name": "get_table_from_database", "query": query} - def __init__(self, client, schema=None): + def __init__(self, client, model_version, schema=None): self.model_version = model_version self.client = client self.schema = schema - self.tools = self.get_function_descriptions() + self.tools = database_function_descriptions self.messages = [ { "role": "system", "content": f"""You are a helpful assistant that answers questions about data in the 'osm' schema of a PostGIS database. - When responding with sql queries, you must use the 'osm' schema desgignation. + When responding with sql queries, you must use the 'osm' schema desgignation. If the user is requesting map features such as buildings, roads, or airports, + you will construct a query that returns geojson. An example prompt from the user is: "Add buildings to the map. Ensure the query is restricted to the following bbox: 30, 50, 31, 51" @@ -65,6 +66,7 @@ class DatabaseAgent: ) table_name = response.choices[0].message.content logger.info(f"table_name in DatabaseAgent is: {table_name}") + logger.info(f"bbox in DatabaseAgent is: {bbox}") map_context = f"Ensure the query is restricted to the following bbox: {bbox}" db = Database( database=os.getenv("POSTGRES_DBNAME"), @@ -74,13 +76,13 @@ class DatabaseAgent: port=os.getenv("POSTGRES_PORT") ) column_names = db.get_column_names(table_name) - logger.info(f"column_names in DatabaseAgent is: {column_names}") db.close() table_name_context = f"Generate your query using the following table name: {table_name} and the appropriate column names: {column_names}" - + payload = message + "\n" + map_context + "\n" + table_name_context + logger.info(f"payload in DatabaseAgent is: {payload}") self.messages.append({ "role": "user", - "content": message + "\n" + map_context + "\n" + table_name_context, + "content": payload, }) # this will be the function gpt will call if it @@ -89,14 +91,16 @@ class DatabaseAgent: try: logger.info("Calling OpenAI API in DatabaseAgent...") + response = self.client.chat.completions.create( model=self.model_version, messages=self.messages, tools=self.tools, tool_choice="auto", ) + logger.info(f"Response from OpenAI in DatabaseAgent: {response}") response_message = response.choices[0].message - logger.info(f"response_message in DatabaseAgent is: {response_message}") + logger.info(f"Response_message in DatabaseAgent is: {response_message}") tool_calls = response_message.tool_calls if tool_calls: @@ -121,36 +125,4 @@ class DatabaseAgent: return {"error": "Failed to get response from OpenAI in NavigationAgent: " + str(e)}, 500 return {"response": function_response} - def get_function_descriptions(self): - return [ - { - "name": "get_geojson_from_database", - "description": """Retrieve geojson spatial data using PostGIS SQL.""", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": """SQL query to get geojson from the database. - The query shall be returned in string format as a single command.""" - }, - }, - "required": ["query"], - } - }, - { - "name": "get_table_from_database", - "description": """Retrieve a non-spatial table using PostgreSQL SQL.""", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": """SQL query that gets the answer to the user's question or task. - The query shall be returned in string format as a single command.""", - }, - }, - "required": ["query"], - }, - }, - ] \ No newline at end of file + \ No newline at end of file diff --git a/flask_app/agents/function_descriptions/database_function_descriptions.py b/flask_app/agents/function_descriptions/database_function_descriptions.py new file mode 100644 index 0000000..381bf07 --- /dev/null +++ b/flask_app/agents/function_descriptions/database_function_descriptions.py @@ -0,0 +1,42 @@ +get_geojson_from_database = { + "type": "function", + "function": { + "name": "get_geojson_from_database", + "description": """Retrieve geojson spatial data using PostGIS SQL.""", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": """SQL query to get geojson from the database. + The query shall be returned in string format as a single command.""" + }, + }, + "required": ["query"], + }, + }, +} + +get_table_from_database = { + "type": "function", + "function": { + "name": "get_table_from_database", + "description": """Retrieve a non-spatial table using PostgreSQL SQL.""", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": """SQL query that gets the answer to the user's question or task. + The query shall be returned in string format as a single command.""", + }, + }, + "required": ["query"], + }, + }, +} + +database_function_descriptions = [ + get_geojson_from_database, + get_table_from_database, +] \ No newline at end of file diff --git a/flask_app/agents/navigation_agent.py b/flask_app/agents/navigation_agent.py index 8ded914..c3c4796 100644 --- a/flask_app/agents/navigation_agent.py +++ b/flask_app/agents/navigation_agent.py @@ -26,13 +26,12 @@ class NavigationAgent: self.messages = [ { "role": "system", - "content": """You are a helpful assistant in navigating a maplibre map. When tasked to do something, you will call the appropriate function to navigate the map. - - Examples tasks to go to a location include: 'Go to Tokyo', 'Where is the Eiffel Tower?', 'Navigate to New York City' - - Examples tasks to pan in a direction include: 'Pan north 3 km', 'Pan ne 1 kilometer', 'Pan southwest', 'Pan east 2 kilometers', 'Pan south 5 kilometers', 'Pan west 1 kilometer', 'Pan northwest 2 kilometers', 'Pan southeast 3 kilometers' - - Examples tasks to zoom in or zoom out include: 'Zoom in 2 zoom levels.', 'Zoom out 3', 'zoom out', 'zoom in', 'move closer', 'move in', 'get closer', 'move further away', 'get further away', 'back out', 'closer' """, + "content": """You are a helpful assistant in navigating a maplibre map. When tasked to do something, you will call the + appropriate function to navigate the map. Examples tasks to go to a location include: 'Go to Tokyo', 'Where is the Eiffel Tower?', + 'Navigate to New York City'. Examples tasks to pan in a direction include: 'Pan north 3 km', 'Pan ne 1 kilometer', 'Pan southwest', + 'Pan east 2 kilometers', 'Pan south 5 kilometers', 'Pan west 1 kilometer', 'Pan northwest 2 kilometers', 'Pan southeast 3 kilometers'. + Examples tasks to zoom in or zoom out include: 'Zoom in 2 zoom levels.', 'Zoom out 3', 'zoom out', 'zoom in', + 'move closer', 'move in', 'get closer', 'move further away', 'get further away', 'back out', 'closer' """, }, ] self.available_functions = { @@ -67,6 +66,7 @@ class NavigationAgent: tool_choice="auto", ) response_message = response.choices[0].message + logger.info(f"Response from OpenAI in NavigationAgent: {response_message}") tool_calls = response_message.tool_calls if tool_calls: diff --git a/flask_app/app.py b/flask_app/app.py index 185767e..4aeeee0 100644 --- a/flask_app/app.py +++ b/flask_app/app.py @@ -14,10 +14,12 @@ from agents.database_agent import DatabaseAgent from utils.database import Database import logging -logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s %(levelname)s %(name)s %(threadName)s : %(message)s', +logging.basicConfig(level=logging.INFO, + format='%(levelname)s %(name)s : %(message)s', handlers=[logging.StreamHandler()]) +logger = logging.getLogger(__name__) + load_dotenv() app = Flask(__name__) CORS(app) @@ -44,12 +46,15 @@ def get_database_schema(): return schema schema = get_database_schema() -database_agent = DatabaseAgent(client, model_version=model_version, schema=schema) +database_agent = DatabaseAgent(client, model_version, schema=schema) @app.route('/get_query', methods=['POST']) def get_query(): + logger.info("In get_query route...") message = request.json.get('message', '') + logger.info(f"Received message in /get_query route...: {message}") bbox = request.json.get('bbox', '') + logger.info(f"Received bbox in /get_query route...: {bbox}") return jsonify(database_agent.listen(message, bbox)) @app.route('/get_table_name', methods=['GET']) @@ -164,7 +169,7 @@ def upload_audio(): file = audio_file ) logging.info(f"Received transcript: {transcript}") - message = transcript['text'] + message = transcript.text #delete the audio os.remove(os.path.join(UPLOAD_FOLDER, "user_audio.webm")) return message diff --git a/settings/liechtenstein-latest.osm.pbf:Zone.Identifier b/settings/liechtenstein-latest.osm.pbf:Zone.Identifier deleted file mode 100644 index 4f52817..0000000 --- a/settings/liechtenstein-latest.osm.pbf:Zone.Identifier +++ /dev/null @@ -1,4 +0,0 @@ -[ZoneTransfer] -ZoneId=3 -ReferrerUrl=https://download.geofabrik.de/europe.html -HostUrl=https://download.geofabrik.de/europe/liechtenstein-latest.osm.pbf