kopia lustrzana https://github.com/kartoza/docker-osm
fixed database querying
rodzic
fa5d2f5f82
commit
d62ff8558b
|
@ -1,8 +1,8 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import requests
|
|
||||||
from utils.database import Database
|
from utils.database import Database
|
||||||
|
from .function_descriptions.database_function_descriptions import database_function_descriptions
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -15,17 +15,18 @@ class DatabaseAgent:
|
||||||
def get_table_from_database(self, query):
|
def get_table_from_database(self, query):
|
||||||
return {"name": "get_table_from_database", "query": query}
|
return {"name": "get_table_from_database", "query": query}
|
||||||
|
|
||||||
def __init__(self, client, schema=None):
|
def __init__(self, client, model_version, schema=None):
|
||||||
self.model_version = model_version
|
self.model_version = model_version
|
||||||
self.client = client
|
self.client = client
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.tools = self.get_function_descriptions()
|
self.tools = database_function_descriptions
|
||||||
self.messages = [
|
self.messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": f"""You are a helpful assistant that answers questions about data in the 'osm' schema of a PostGIS database.
|
"content": f"""You are a helpful assistant that answers questions about data in the 'osm' schema of a PostGIS database.
|
||||||
|
|
||||||
When responding with sql queries, you must use the 'osm' schema desgignation.
|
When responding with sql queries, you must use the 'osm' schema desgignation. If the user is requesting map features such as buildings, roads, or airports,
|
||||||
|
you will construct a query that returns geojson.
|
||||||
|
|
||||||
An example prompt from the user is: "Add buildings to the map. Ensure the query is restricted to the following bbox: 30, 50, 31, 51"
|
An example prompt from the user is: "Add buildings to the map. Ensure the query is restricted to the following bbox: 30, 50, 31, 51"
|
||||||
|
|
||||||
|
@ -65,6 +66,7 @@ class DatabaseAgent:
|
||||||
)
|
)
|
||||||
table_name = response.choices[0].message.content
|
table_name = response.choices[0].message.content
|
||||||
logger.info(f"table_name in DatabaseAgent is: {table_name}")
|
logger.info(f"table_name in DatabaseAgent is: {table_name}")
|
||||||
|
logger.info(f"bbox in DatabaseAgent is: {bbox}")
|
||||||
map_context = f"Ensure the query is restricted to the following bbox: {bbox}"
|
map_context = f"Ensure the query is restricted to the following bbox: {bbox}"
|
||||||
db = Database(
|
db = Database(
|
||||||
database=os.getenv("POSTGRES_DBNAME"),
|
database=os.getenv("POSTGRES_DBNAME"),
|
||||||
|
@ -74,13 +76,13 @@ class DatabaseAgent:
|
||||||
port=os.getenv("POSTGRES_PORT")
|
port=os.getenv("POSTGRES_PORT")
|
||||||
)
|
)
|
||||||
column_names = db.get_column_names(table_name)
|
column_names = db.get_column_names(table_name)
|
||||||
logger.info(f"column_names in DatabaseAgent is: {column_names}")
|
|
||||||
db.close()
|
db.close()
|
||||||
table_name_context = f"Generate your query using the following table name: {table_name} and the appropriate column names: {column_names}"
|
table_name_context = f"Generate your query using the following table name: {table_name} and the appropriate column names: {column_names}"
|
||||||
|
payload = message + "\n" + map_context + "\n" + table_name_context
|
||||||
|
logger.info(f"payload in DatabaseAgent is: {payload}")
|
||||||
self.messages.append({
|
self.messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": message + "\n" + map_context + "\n" + table_name_context,
|
"content": payload,
|
||||||
})
|
})
|
||||||
|
|
||||||
# this will be the function gpt will call if it
|
# this will be the function gpt will call if it
|
||||||
|
@ -89,14 +91,16 @@ class DatabaseAgent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("Calling OpenAI API in DatabaseAgent...")
|
logger.info("Calling OpenAI API in DatabaseAgent...")
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_version,
|
model=self.model_version,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
|
logger.info(f"Response from OpenAI in DatabaseAgent: {response}")
|
||||||
response_message = response.choices[0].message
|
response_message = response.choices[0].message
|
||||||
logger.info(f"response_message in DatabaseAgent is: {response_message}")
|
logger.info(f"Response_message in DatabaseAgent is: {response_message}")
|
||||||
tool_calls = response_message.tool_calls
|
tool_calls = response_message.tool_calls
|
||||||
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
|
@ -121,36 +125,4 @@ class DatabaseAgent:
|
||||||
return {"error": "Failed to get response from OpenAI in NavigationAgent: " + str(e)}, 500
|
return {"error": "Failed to get response from OpenAI in NavigationAgent: " + str(e)}, 500
|
||||||
return {"response": function_response}
|
return {"response": function_response}
|
||||||
|
|
||||||
def get_function_descriptions(self):
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"name": "get_geojson_from_database",
|
|
||||||
"description": """Retrieve geojson spatial data using PostGIS SQL.""",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {
|
|
||||||
"type": "string",
|
|
||||||
"description": """SQL query to get geojson from the database.
|
|
||||||
The query shall be returned in string format as a single command."""
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["query"],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "get_table_from_database",
|
|
||||||
"description": """Retrieve a non-spatial table using PostgreSQL SQL.""",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {
|
|
||||||
"type": "string",
|
|
||||||
"description": """SQL query that gets the answer to the user's question or task.
|
|
||||||
The query shall be returned in string format as a single command.""",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
get_geojson_from_database = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_geojson_from_database",
|
||||||
|
"description": """Retrieve geojson spatial data using PostGIS SQL.""",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": """SQL query to get geojson from the database.
|
||||||
|
The query shall be returned in string format as a single command."""
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
get_table_from_database = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_table_from_database",
|
||||||
|
"description": """Retrieve a non-spatial table using PostgreSQL SQL.""",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": """SQL query that gets the answer to the user's question or task.
|
||||||
|
The query shall be returned in string format as a single command.""",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
database_function_descriptions = [
|
||||||
|
get_geojson_from_database,
|
||||||
|
get_table_from_database,
|
||||||
|
]
|
|
@ -26,13 +26,12 @@ class NavigationAgent:
|
||||||
self.messages = [
|
self.messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": """You are a helpful assistant in navigating a maplibre map. When tasked to do something, you will call the appropriate function to navigate the map.
|
"content": """You are a helpful assistant in navigating a maplibre map. When tasked to do something, you will call the
|
||||||
|
appropriate function to navigate the map. Examples tasks to go to a location include: 'Go to Tokyo', 'Where is the Eiffel Tower?',
|
||||||
Examples tasks to go to a location include: 'Go to Tokyo', 'Where is the Eiffel Tower?', 'Navigate to New York City'
|
'Navigate to New York City'. Examples tasks to pan in a direction include: 'Pan north 3 km', 'Pan ne 1 kilometer', 'Pan southwest',
|
||||||
|
'Pan east 2 kilometers', 'Pan south 5 kilometers', 'Pan west 1 kilometer', 'Pan northwest 2 kilometers', 'Pan southeast 3 kilometers'.
|
||||||
Examples tasks to pan in a direction include: 'Pan north 3 km', 'Pan ne 1 kilometer', 'Pan southwest', 'Pan east 2 kilometers', 'Pan south 5 kilometers', 'Pan west 1 kilometer', 'Pan northwest 2 kilometers', 'Pan southeast 3 kilometers'
|
Examples tasks to zoom in or zoom out include: 'Zoom in 2 zoom levels.', 'Zoom out 3', 'zoom out', 'zoom in',
|
||||||
|
'move closer', 'move in', 'get closer', 'move further away', 'get further away', 'back out', 'closer' """,
|
||||||
Examples tasks to zoom in or zoom out include: 'Zoom in 2 zoom levels.', 'Zoom out 3', 'zoom out', 'zoom in', 'move closer', 'move in', 'get closer', 'move further away', 'get further away', 'back out', 'closer' """,
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
self.available_functions = {
|
self.available_functions = {
|
||||||
|
@ -67,6 +66,7 @@ class NavigationAgent:
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
response_message = response.choices[0].message
|
response_message = response.choices[0].message
|
||||||
|
logger.info(f"Response from OpenAI in NavigationAgent: {response_message}")
|
||||||
tool_calls = response_message.tool_calls
|
tool_calls = response_message.tool_calls
|
||||||
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
|
|
|
@ -14,10 +14,12 @@ from agents.database_agent import DatabaseAgent
|
||||||
from utils.database import Database
|
from utils.database import Database
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG,
|
logging.basicConfig(level=logging.INFO,
|
||||||
format='%(asctime)s %(levelname)s %(name)s %(threadName)s : %(message)s',
|
format='%(levelname)s %(name)s : %(message)s',
|
||||||
handlers=[logging.StreamHandler()])
|
handlers=[logging.StreamHandler()])
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
CORS(app)
|
CORS(app)
|
||||||
|
@ -44,12 +46,15 @@ def get_database_schema():
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
schema = get_database_schema()
|
schema = get_database_schema()
|
||||||
database_agent = DatabaseAgent(client, model_version=model_version, schema=schema)
|
database_agent = DatabaseAgent(client, model_version, schema=schema)
|
||||||
|
|
||||||
@app.route('/get_query', methods=['POST'])
|
@app.route('/get_query', methods=['POST'])
|
||||||
def get_query():
|
def get_query():
|
||||||
|
logger.info("In get_query route...")
|
||||||
message = request.json.get('message', '')
|
message = request.json.get('message', '')
|
||||||
|
logger.info(f"Received message in /get_query route...: {message}")
|
||||||
bbox = request.json.get('bbox', '')
|
bbox = request.json.get('bbox', '')
|
||||||
|
logger.info(f"Received bbox in /get_query route...: {bbox}")
|
||||||
return jsonify(database_agent.listen(message, bbox))
|
return jsonify(database_agent.listen(message, bbox))
|
||||||
|
|
||||||
@app.route('/get_table_name', methods=['GET'])
|
@app.route('/get_table_name', methods=['GET'])
|
||||||
|
@ -164,7 +169,7 @@ def upload_audio():
|
||||||
file = audio_file
|
file = audio_file
|
||||||
)
|
)
|
||||||
logging.info(f"Received transcript: {transcript}")
|
logging.info(f"Received transcript: {transcript}")
|
||||||
message = transcript['text']
|
message = transcript.text
|
||||||
#delete the audio
|
#delete the audio
|
||||||
os.remove(os.path.join(UPLOAD_FOLDER, "user_audio.webm"))
|
os.remove(os.path.join(UPLOAD_FOLDER, "user_audio.webm"))
|
||||||
return message
|
return message
|
||||||
|
|
|
@ -1,4 +0,0 @@
|
||||||
[ZoneTransfer]
|
|
||||||
ZoneId=3
|
|
||||||
ReferrerUrl=https://download.geofabrik.de/europe.html
|
|
||||||
HostUrl=https://download.geofabrik.de/europe/liechtenstein-latest.osm.pbf
|
|
Ładowanie…
Reference in New Issue