fixed database querying

pull/137/head
shanergi 2023-11-09 13:33:14 -05:00
rodzic fa5d2f5f82
commit d62ff8558b
5 zmienionych plików z 71 dodań i 56 usunięć

Wyświetl plik

@ -1,8 +1,8 @@
import json
import logging
import os
import requests
from utils.database import Database
from .function_descriptions.database_function_descriptions import database_function_descriptions
logger = logging.getLogger(__name__)
@ -15,17 +15,18 @@ class DatabaseAgent:
def get_table_from_database(self, query):
return {"name": "get_table_from_database", "query": query}
def __init__(self, client, schema=None):
def __init__(self, client, model_version, schema=None):
self.model_version = model_version
self.client = client
self.schema = schema
self.tools = self.get_function_descriptions()
self.tools = database_function_descriptions
self.messages = [
{
"role": "system",
"content": f"""You are a helpful assistant that answers questions about data in the 'osm' schema of a PostGIS database.
When responding with sql queries, you must use the 'osm' schema desgignation.
When responding with sql queries, you must use the 'osm' schema desgignation. If the user is requesting map features such as buildings, roads, or airports,
you will construct a query that returns geojson.
An example prompt from the user is: "Add buildings to the map. Ensure the query is restricted to the following bbox: 30, 50, 31, 51"
@ -65,6 +66,7 @@ class DatabaseAgent:
)
table_name = response.choices[0].message.content
logger.info(f"table_name in DatabaseAgent is: {table_name}")
logger.info(f"bbox in DatabaseAgent is: {bbox}")
map_context = f"Ensure the query is restricted to the following bbox: {bbox}"
db = Database(
database=os.getenv("POSTGRES_DBNAME"),
@ -74,13 +76,13 @@ class DatabaseAgent:
port=os.getenv("POSTGRES_PORT")
)
column_names = db.get_column_names(table_name)
logger.info(f"column_names in DatabaseAgent is: {column_names}")
db.close()
table_name_context = f"Generate your query using the following table name: {table_name} and the appropriate column names: {column_names}"
payload = message + "\n" + map_context + "\n" + table_name_context
logger.info(f"payload in DatabaseAgent is: {payload}")
self.messages.append({
"role": "user",
"content": message + "\n" + map_context + "\n" + table_name_context,
"content": payload,
})
# this will be the function gpt will call if it
@ -89,14 +91,16 @@ class DatabaseAgent:
try:
logger.info("Calling OpenAI API in DatabaseAgent...")
response = self.client.chat.completions.create(
model=self.model_version,
messages=self.messages,
tools=self.tools,
tool_choice="auto",
)
logger.info(f"Response from OpenAI in DatabaseAgent: {response}")
response_message = response.choices[0].message
logger.info(f"response_message in DatabaseAgent is: {response_message}")
logger.info(f"Response_message in DatabaseAgent is: {response_message}")
tool_calls = response_message.tool_calls
if tool_calls:
@ -121,36 +125,4 @@ class DatabaseAgent:
return {"error": "Failed to get response from OpenAI in NavigationAgent: " + str(e)}, 500
return {"response": function_response}
def get_function_descriptions(self):
return [
{
"name": "get_geojson_from_database",
"description": """Retrieve geojson spatial data using PostGIS SQL.""",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": """SQL query to get geojson from the database.
The query shall be returned in string format as a single command."""
},
},
"required": ["query"],
}
},
{
"name": "get_table_from_database",
"description": """Retrieve a non-spatial table using PostgreSQL SQL.""",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": """SQL query that gets the answer to the user's question or task.
The query shall be returned in string format as a single command.""",
},
},
"required": ["query"],
},
},
]

Wyświetl plik

@ -0,0 +1,42 @@
get_geojson_from_database = {
"type": "function",
"function": {
"name": "get_geojson_from_database",
"description": """Retrieve geojson spatial data using PostGIS SQL.""",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": """SQL query to get geojson from the database.
The query shall be returned in string format as a single command."""
},
},
"required": ["query"],
},
},
}
get_table_from_database = {
"type": "function",
"function": {
"name": "get_table_from_database",
"description": """Retrieve a non-spatial table using PostgreSQL SQL.""",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": """SQL query that gets the answer to the user's question or task.
The query shall be returned in string format as a single command.""",
},
},
"required": ["query"],
},
},
}
database_function_descriptions = [
get_geojson_from_database,
get_table_from_database,
]

Wyświetl plik

@ -26,13 +26,12 @@ class NavigationAgent:
self.messages = [
{
"role": "system",
"content": """You are a helpful assistant in navigating a maplibre map. When tasked to do something, you will call the appropriate function to navigate the map.
Examples tasks to go to a location include: 'Go to Tokyo', 'Where is the Eiffel Tower?', 'Navigate to New York City'
Examples tasks to pan in a direction include: 'Pan north 3 km', 'Pan ne 1 kilometer', 'Pan southwest', 'Pan east 2 kilometers', 'Pan south 5 kilometers', 'Pan west 1 kilometer', 'Pan northwest 2 kilometers', 'Pan southeast 3 kilometers'
Examples tasks to zoom in or zoom out include: 'Zoom in 2 zoom levels.', 'Zoom out 3', 'zoom out', 'zoom in', 'move closer', 'move in', 'get closer', 'move further away', 'get further away', 'back out', 'closer' """,
"content": """You are a helpful assistant in navigating a maplibre map. When tasked to do something, you will call the
appropriate function to navigate the map. Examples tasks to go to a location include: 'Go to Tokyo', 'Where is the Eiffel Tower?',
'Navigate to New York City'. Examples tasks to pan in a direction include: 'Pan north 3 km', 'Pan ne 1 kilometer', 'Pan southwest',
'Pan east 2 kilometers', 'Pan south 5 kilometers', 'Pan west 1 kilometer', 'Pan northwest 2 kilometers', 'Pan southeast 3 kilometers'.
Examples tasks to zoom in or zoom out include: 'Zoom in 2 zoom levels.', 'Zoom out 3', 'zoom out', 'zoom in',
'move closer', 'move in', 'get closer', 'move further away', 'get further away', 'back out', 'closer' """,
},
]
self.available_functions = {
@ -67,6 +66,7 @@ class NavigationAgent:
tool_choice="auto",
)
response_message = response.choices[0].message
logger.info(f"Response from OpenAI in NavigationAgent: {response_message}")
tool_calls = response_message.tool_calls
if tool_calls:

Wyświetl plik

@ -14,10 +14,12 @@ from agents.database_agent import DatabaseAgent
from utils.database import Database
import logging
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(name)s %(threadName)s : %(message)s',
logging.basicConfig(level=logging.INFO,
format='%(levelname)s %(name)s : %(message)s',
handlers=[logging.StreamHandler()])
logger = logging.getLogger(__name__)
load_dotenv()
app = Flask(__name__)
CORS(app)
@ -44,12 +46,15 @@ def get_database_schema():
return schema
schema = get_database_schema()
database_agent = DatabaseAgent(client, model_version=model_version, schema=schema)
database_agent = DatabaseAgent(client, model_version, schema=schema)
@app.route('/get_query', methods=['POST'])
def get_query():
logger.info("In get_query route...")
message = request.json.get('message', '')
logger.info(f"Received message in /get_query route...: {message}")
bbox = request.json.get('bbox', '')
logger.info(f"Received bbox in /get_query route...: {bbox}")
return jsonify(database_agent.listen(message, bbox))
@app.route('/get_table_name', methods=['GET'])
@ -164,7 +169,7 @@ def upload_audio():
file = audio_file
)
logging.info(f"Received transcript: {transcript}")
message = transcript['text']
message = transcript.text
#delete the audio
os.remove(os.path.join(UPLOAD_FOLDER, "user_audio.webm"))
return message

Wyświetl plik

@ -1,4 +0,0 @@
[ZoneTransfer]
ZoneId=3
ReferrerUrl=https://download.geofabrik.de/europe.html
HostUrl=https://download.geofabrik.de/europe/liechtenstein-latest.osm.pbf