kopia lustrzana https://github.com/kartoza/docker-osm
fixed database querying
rodzic
fa5d2f5f82
commit
d62ff8558b
|
@ -1,8 +1,8 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
from utils.database import Database
|
||||
from .function_descriptions.database_function_descriptions import database_function_descriptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -15,17 +15,18 @@ class DatabaseAgent:
|
|||
def get_table_from_database(self, query):
|
||||
return {"name": "get_table_from_database", "query": query}
|
||||
|
||||
def __init__(self, client, schema=None):
|
||||
def __init__(self, client, model_version, schema=None):
|
||||
self.model_version = model_version
|
||||
self.client = client
|
||||
self.schema = schema
|
||||
self.tools = self.get_function_descriptions()
|
||||
self.tools = database_function_descriptions
|
||||
self.messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""You are a helpful assistant that answers questions about data in the 'osm' schema of a PostGIS database.
|
||||
|
||||
When responding with sql queries, you must use the 'osm' schema desgignation.
|
||||
When responding with sql queries, you must use the 'osm' schema desgignation. If the user is requesting map features such as buildings, roads, or airports,
|
||||
you will construct a query that returns geojson.
|
||||
|
||||
An example prompt from the user is: "Add buildings to the map. Ensure the query is restricted to the following bbox: 30, 50, 31, 51"
|
||||
|
||||
|
@ -65,6 +66,7 @@ class DatabaseAgent:
|
|||
)
|
||||
table_name = response.choices[0].message.content
|
||||
logger.info(f"table_name in DatabaseAgent is: {table_name}")
|
||||
logger.info(f"bbox in DatabaseAgent is: {bbox}")
|
||||
map_context = f"Ensure the query is restricted to the following bbox: {bbox}"
|
||||
db = Database(
|
||||
database=os.getenv("POSTGRES_DBNAME"),
|
||||
|
@ -74,13 +76,13 @@ class DatabaseAgent:
|
|||
port=os.getenv("POSTGRES_PORT")
|
||||
)
|
||||
column_names = db.get_column_names(table_name)
|
||||
logger.info(f"column_names in DatabaseAgent is: {column_names}")
|
||||
db.close()
|
||||
table_name_context = f"Generate your query using the following table name: {table_name} and the appropriate column names: {column_names}"
|
||||
|
||||
payload = message + "\n" + map_context + "\n" + table_name_context
|
||||
logger.info(f"payload in DatabaseAgent is: {payload}")
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": message + "\n" + map_context + "\n" + table_name_context,
|
||||
"content": payload,
|
||||
})
|
||||
|
||||
# this will be the function gpt will call if it
|
||||
|
@ -89,14 +91,16 @@ class DatabaseAgent:
|
|||
|
||||
try:
|
||||
logger.info("Calling OpenAI API in DatabaseAgent...")
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_version,
|
||||
messages=self.messages,
|
||||
tools=self.tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
logger.info(f"Response from OpenAI in DatabaseAgent: {response}")
|
||||
response_message = response.choices[0].message
|
||||
logger.info(f"response_message in DatabaseAgent is: {response_message}")
|
||||
logger.info(f"Response_message in DatabaseAgent is: {response_message}")
|
||||
tool_calls = response_message.tool_calls
|
||||
|
||||
if tool_calls:
|
||||
|
@ -121,36 +125,4 @@ class DatabaseAgent:
|
|||
return {"error": "Failed to get response from OpenAI in NavigationAgent: " + str(e)}, 500
|
||||
return {"response": function_response}
|
||||
|
||||
def get_function_descriptions(self):
|
||||
return [
|
||||
{
|
||||
"name": "get_geojson_from_database",
|
||||
"description": """Retrieve geojson spatial data using PostGIS SQL.""",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": """SQL query to get geojson from the database.
|
||||
The query shall be returned in string format as a single command."""
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "get_table_from_database",
|
||||
"description": """Retrieve a non-spatial table using PostgreSQL SQL.""",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": """SQL query that gets the answer to the user's question or task.
|
||||
The query shall be returned in string format as a single command.""",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
get_geojson_from_database = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_geojson_from_database",
|
||||
"description": """Retrieve geojson spatial data using PostGIS SQL.""",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": """SQL query to get geojson from the database.
|
||||
The query shall be returned in string format as a single command."""
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
get_table_from_database = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_table_from_database",
|
||||
"description": """Retrieve a non-spatial table using PostgreSQL SQL.""",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": """SQL query that gets the answer to the user's question or task.
|
||||
The query shall be returned in string format as a single command.""",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
database_function_descriptions = [
|
||||
get_geojson_from_database,
|
||||
get_table_from_database,
|
||||
]
|
|
@ -26,13 +26,12 @@ class NavigationAgent:
|
|||
self.messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are a helpful assistant in navigating a maplibre map. When tasked to do something, you will call the appropriate function to navigate the map.
|
||||
|
||||
Examples tasks to go to a location include: 'Go to Tokyo', 'Where is the Eiffel Tower?', 'Navigate to New York City'
|
||||
|
||||
Examples tasks to pan in a direction include: 'Pan north 3 km', 'Pan ne 1 kilometer', 'Pan southwest', 'Pan east 2 kilometers', 'Pan south 5 kilometers', 'Pan west 1 kilometer', 'Pan northwest 2 kilometers', 'Pan southeast 3 kilometers'
|
||||
|
||||
Examples tasks to zoom in or zoom out include: 'Zoom in 2 zoom levels.', 'Zoom out 3', 'zoom out', 'zoom in', 'move closer', 'move in', 'get closer', 'move further away', 'get further away', 'back out', 'closer' """,
|
||||
"content": """You are a helpful assistant in navigating a maplibre map. When tasked to do something, you will call the
|
||||
appropriate function to navigate the map. Examples tasks to go to a location include: 'Go to Tokyo', 'Where is the Eiffel Tower?',
|
||||
'Navigate to New York City'. Examples tasks to pan in a direction include: 'Pan north 3 km', 'Pan ne 1 kilometer', 'Pan southwest',
|
||||
'Pan east 2 kilometers', 'Pan south 5 kilometers', 'Pan west 1 kilometer', 'Pan northwest 2 kilometers', 'Pan southeast 3 kilometers'.
|
||||
Examples tasks to zoom in or zoom out include: 'Zoom in 2 zoom levels.', 'Zoom out 3', 'zoom out', 'zoom in',
|
||||
'move closer', 'move in', 'get closer', 'move further away', 'get further away', 'back out', 'closer' """,
|
||||
},
|
||||
]
|
||||
self.available_functions = {
|
||||
|
@ -67,6 +66,7 @@ class NavigationAgent:
|
|||
tool_choice="auto",
|
||||
)
|
||||
response_message = response.choices[0].message
|
||||
logger.info(f"Response from OpenAI in NavigationAgent: {response_message}")
|
||||
tool_calls = response_message.tool_calls
|
||||
|
||||
if tool_calls:
|
||||
|
|
|
@ -14,10 +14,12 @@ from agents.database_agent import DatabaseAgent
|
|||
from utils.database import Database
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(name)s %(threadName)s : %(message)s',
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(levelname)s %(name)s : %(message)s',
|
||||
handlers=[logging.StreamHandler()])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
load_dotenv()
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
@ -44,12 +46,15 @@ def get_database_schema():
|
|||
return schema
|
||||
|
||||
schema = get_database_schema()
|
||||
database_agent = DatabaseAgent(client, model_version=model_version, schema=schema)
|
||||
database_agent = DatabaseAgent(client, model_version, schema=schema)
|
||||
|
||||
@app.route('/get_query', methods=['POST'])
|
||||
def get_query():
|
||||
logger.info("In get_query route...")
|
||||
message = request.json.get('message', '')
|
||||
logger.info(f"Received message in /get_query route...: {message}")
|
||||
bbox = request.json.get('bbox', '')
|
||||
logger.info(f"Received bbox in /get_query route...: {bbox}")
|
||||
return jsonify(database_agent.listen(message, bbox))
|
||||
|
||||
@app.route('/get_table_name', methods=['GET'])
|
||||
|
@ -164,7 +169,7 @@ def upload_audio():
|
|||
file = audio_file
|
||||
)
|
||||
logging.info(f"Received transcript: {transcript}")
|
||||
message = transcript['text']
|
||||
message = transcript.text
|
||||
#delete the audio
|
||||
os.remove(os.path.join(UPLOAD_FOLDER, "user_audio.webm"))
|
||||
return message
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
[ZoneTransfer]
|
||||
ZoneId=3
|
||||
ReferrerUrl=https://download.geofabrik.de/europe.html
|
||||
HostUrl=https://download.geofabrik.de/europe/liechtenstein-latest.osm.pbf
|
Ładowanie…
Reference in New Issue