kopia lustrzana https://github.com/kartoza/docker-osm
working on databaseagent, still some problems
rodzic
ea5ed6e9e0
commit
fa5d2f5f82
|
@ -34,6 +34,8 @@ services:
|
||||||
- settings-data:/home/settings
|
- settings-data:/home/settings
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: "pg_isready -d ${POSTGRES_DBNAME}"
|
test: "pg_isready -d ${POSTGRES_DBNAME}"
|
||||||
|
ports:
|
||||||
|
- "${POSTGRES_PORT}:5432"
|
||||||
|
|
||||||
imposm:
|
imposm:
|
||||||
image: kartoza/docker-osm:imposm-latest
|
image: kartoza/docker-osm:imposm-latest
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import openai
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -16,10 +15,11 @@ class DatabaseAgent:
|
||||||
def get_table_from_database(self, query):
|
def get_table_from_database(self, query):
|
||||||
return {"name": "get_table_from_database", "query": query}
|
return {"name": "get_table_from_database", "query": query}
|
||||||
|
|
||||||
def __init__(self, model_version="gpt-3.5-turbo-0613", schema=None):
|
def __init__(self, client, schema=None):
|
||||||
self.model_version = model_version
|
self.model_version = model_version
|
||||||
|
self.client = client
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.function_descriptions = self.get_function_descriptions()
|
self.tools = self.get_function_descriptions()
|
||||||
self.messages = [
|
self.messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -27,12 +27,7 @@ class DatabaseAgent:
|
||||||
|
|
||||||
When responding with sql queries, you must use the 'osm' schema desgignation.
|
When responding with sql queries, you must use the 'osm' schema desgignation.
|
||||||
|
|
||||||
Favor SQL queries that return polygons first, then lines, then points.
|
An example prompt from the user is: "Add buildings to the map. Ensure the query is restricted to the following bbox: 30, 50, 31, 51"
|
||||||
|
|
||||||
An example prompt from the user is:
|
|
||||||
|
|
||||||
'add buildings to the map
|
|
||||||
Ensure the query is restricted to the following bbox: 30, 50, 31, 51'
|
|
||||||
|
|
||||||
You would respond with:
|
You would respond with:
|
||||||
'SELECT ST_AsGeoJSON(geometry) FROM osm.osm_buildings WHERE ST_Intersects(geometry, ST_MakeEnvelope(30, 50, 31, 51, 4326))'.
|
'SELECT ST_AsGeoJSON(geometry) FROM osm.osm_buildings WHERE ST_Intersects(geometry, ST_MakeEnvelope(30, 50, 31, 51, 4326))'.
|
||||||
|
@ -50,92 +45,77 @@ class DatabaseAgent:
|
||||||
}
|
}
|
||||||
|
|
||||||
def listen(self, message, bbox):
|
def listen(self, message, bbox):
|
||||||
logger.info(f"In DatabaseAgent.listen()...message is: {message}")
|
|
||||||
logger.info(f"In DatabaseAgent.listen()...bbox is: {bbox}")
|
|
||||||
data = {'message': message}
|
|
||||||
response = requests.get("http://localhost:5000/get_table_name", json=data)
|
|
||||||
"""Listen to a message from the user."""
|
"""Listen to a message from the user."""
|
||||||
map_context = f"Ensure the query is restricted to the following bbox: {bbox}"
|
#data = {'message': message}
|
||||||
table_name_context = ""
|
table_names = [table['table_name'] for table in self.schema]
|
||||||
|
prefixed_message = f"Choose the most likely table the following text is referring to from this list:\m {table_names}.\n"
|
||||||
|
final_message = prefixed_message + message
|
||||||
|
|
||||||
|
# use openai to choose the most likely table name from the schema
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_version,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant that chooses a table name from a list. Only respond with the table name."},
|
||||||
|
{"role": "user", "content": final_message},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
frequency_penalty=0,
|
||||||
|
presence_penalty=0,
|
||||||
|
)
|
||||||
|
table_name = response.choices[0].message.content
|
||||||
|
logger.info(f"table_name in DatabaseAgent is: {table_name}")
|
||||||
|
map_context = f"Ensure the query is restricted to the following bbox: {bbox}"
|
||||||
|
db = Database(
|
||||||
|
database=os.getenv("POSTGRES_DBNAME"),
|
||||||
|
user=os.getenv("POSTGRES_USER"),
|
||||||
|
password=os.getenv("POSTGRES_PASS"),
|
||||||
|
host=os.getenv("POSTGRES_HOST"),
|
||||||
|
port=os.getenv("POSTGRES_PORT")
|
||||||
|
)
|
||||||
|
column_names = db.get_column_names(table_name)
|
||||||
|
logger.info(f"column_names in DatabaseAgent is: {column_names}")
|
||||||
|
db.close()
|
||||||
|
table_name_context = f"Generate your query using the following table name: {table_name} and the appropriate column names: {column_names}"
|
||||||
|
|
||||||
# attempt to get the tablename from the user message
|
|
||||||
# we do this to avoid sending the entire schema to openai
|
|
||||||
if response.status_code == 200:
|
|
||||||
logger.info(f"Response from /get_table_name route: {response}")
|
|
||||||
response_data = response.json()
|
|
||||||
logger.info(f"Response message from /get_table_name route: {response_data}")
|
|
||||||
table_name = response_data.get('choices', [{}])[0].get('message', {}).get('content', '')
|
|
||||||
if table_name:
|
|
||||||
logger.info(f"Table name: {table_name}")
|
|
||||||
db = Database(
|
|
||||||
database=os.getenv("POSTGRES_DBNAME"),
|
|
||||||
user=os.getenv("POSTGRES_USER"),
|
|
||||||
password=os.getenv("POSTGRES_PASS"),
|
|
||||||
host=os.getenv("POSTGRES_HOST"),
|
|
||||||
port=os.getenv("POSTGRES_PORT")
|
|
||||||
)
|
|
||||||
column_names = db.get_column_names(table_name)
|
|
||||||
db.close()
|
|
||||||
table_name_context = f"Generate your query using the following table name: {table_name} and the appropriate column names: {column_names}"
|
|
||||||
|
|
||||||
#remove the last item in self.messages
|
|
||||||
if len(self.messages) > 1:
|
|
||||||
self.messages.pop()
|
|
||||||
self.messages.append({
|
self.messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": message + "\n" + map_context + "\n" + table_name_context,
|
"content": message + "\n" + map_context + "\n" + table_name_context,
|
||||||
})
|
})
|
||||||
# self.messages.append({
|
|
||||||
# "role": "user",
|
|
||||||
# "content": map_context,
|
|
||||||
# })
|
|
||||||
logger.info(f"DatabaseAgent self.messages: {self.messages}")
|
|
||||||
logger.info(f"DatabaseAgent self.function_descriptions: {self.function_descriptions}")
|
|
||||||
# this will be the function gpt will call if it
|
# this will be the function gpt will call if it
|
||||||
# determines that the user wants to call a function
|
# determines that the user wants to call a function
|
||||||
function_response = None
|
function_response = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = openai.ChatCompletion.create(
|
logger.info("Calling OpenAI API in DatabaseAgent...")
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_version,
|
model=self.model_version,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
functions=self.function_descriptions,
|
tools=self.tools,
|
||||||
function_call="auto",
|
tool_choice="auto",
|
||||||
temperature=0,
|
|
||||||
max_tokens=256,
|
|
||||||
top_p=1,
|
|
||||||
frequency_penalty=0,
|
|
||||||
presence_penalty=0
|
|
||||||
)
|
)
|
||||||
|
response_message = response.choices[0].message
|
||||||
|
logger.info(f"response_message in DatabaseAgent is: {response_message}")
|
||||||
|
tool_calls = response_message.tool_calls
|
||||||
|
|
||||||
response_message = response["choices"][0]["message"]
|
if tool_calls:
|
||||||
|
self.messages.append(response_message)
|
||||||
logger.info(f"Response from OpenAI in DatabaseAgent: {response_message}")
|
for tool_call in tool_calls:
|
||||||
|
function_name = tool_call.function.name
|
||||||
if response_message.get("function_call"):
|
function_to_call = self.available_functions[function_name]
|
||||||
function_name = response_message["function_call"]["name"]
|
function_args = json.loads(tool_call.function.arguments)
|
||||||
function_name = function_name.strip()
|
function_response = function_to_call(**function_args)
|
||||||
logger.info(f"Function name: {function_name}")
|
self.messages.append(
|
||||||
function_to_call = self.available_functions[function_name]
|
{
|
||||||
logger.info(f"Function to call: {function_to_call}")
|
"tool_call_id": tool_call.id,
|
||||||
function_args = json.loads(response_message["function_call"]["arguments"])
|
"role": "tool",
|
||||||
logger.info(f"Function args: {function_args}")
|
"name": function_name,
|
||||||
# determine the function to call
|
"content": json.dumps(function_response),
|
||||||
function_response = function_to_call(**function_args)
|
}
|
||||||
|
)
|
||||||
### This is blowing up my context window
|
logger.info("Sucessful StyleAgent task completion.")
|
||||||
# self.messages.append(response_message)
|
|
||||||
# self.messages.append({
|
|
||||||
# "role": "function",
|
|
||||||
# "name": function_name,
|
|
||||||
# "content": function_response,
|
|
||||||
# })
|
|
||||||
|
|
||||||
return {"response": function_response}
|
return {"response": function_response}
|
||||||
elif response_message.get("content"):
|
|
||||||
return {"response": response_message["content"]}
|
|
||||||
else:
|
|
||||||
return {"response": "I'm sorry, I don't understand."}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error": "Failed to get response from OpenAI in NavigationAgent: " + str(e)}, 500
|
return {"error": "Failed to get response from OpenAI in NavigationAgent: " + str(e)}, 500
|
||||||
|
@ -145,13 +125,13 @@ class DatabaseAgent:
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"name": "get_geojson_from_database",
|
"name": "get_geojson_from_database",
|
||||||
"description": """Retrieve geojson sptatial data using PostGIS SQL.""",
|
"description": """Retrieve geojson spatial data using PostGIS SQL.""",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"query": {
|
"query": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": f"""SQL query to get geojson from the database.
|
"description": """SQL query to get geojson from the database.
|
||||||
The query shall be returned in string format as a single command."""
|
The query shall be returned in string format as a single command."""
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
@ -10,8 +10,8 @@ class MapInfoAgent:
|
||||||
def select_layer_name(self, layer_name):
|
def select_layer_name(self, layer_name):
|
||||||
return {"name": "select_layer_name", "layer_name": layer_name}
|
return {"name": "select_layer_name", "layer_name": layer_name}
|
||||||
|
|
||||||
def __init__(self, openai, model_version="gpt-3.5-turbo-0613"):
|
def __init__(self, client, model_version):
|
||||||
self.openai = openai
|
self.client = client
|
||||||
self.model_version = model_version
|
self.model_version = model_version
|
||||||
self.tools = map_info_function_descriptions
|
self.tools = map_info_function_descriptions
|
||||||
self.messages = [
|
self.messages = [
|
||||||
|
@ -49,7 +49,7 @@ class MapInfoAgent:
|
||||||
function_response = None
|
function_response = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.openai.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_version,
|
model=self.model_version,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
|
@ -80,25 +80,23 @@ class MapInfoAgent:
|
||||||
# what they meant. In general, it's pretty good at this unless there are multiple layers with similar names, in which case
|
# what they meant. In general, it's pretty good at this unless there are multiple layers with similar names, in which case
|
||||||
# it just chooses one.
|
# it just chooses one.
|
||||||
if function_name == "select_layer_name":
|
if function_name == "select_layer_name":
|
||||||
logger.info(f"Sending layer name retrieval request to OpenAI...")
|
|
||||||
prompt = f"Please select a layer name from the following list that is closest to the text '{function_response['layer_name']}': {str(layer_names)}\n Only state the layer name in your response."
|
prompt = f"""Please select a layer name from the following list that is closest to the
|
||||||
logger.info(f"Prompt to OpenAI: {prompt}")
|
text '{function_response['layer_name']}': {str(layer_names)}\n
|
||||||
|
Only state the layer name in your response."""
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": prompt,
|
"content": prompt,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
second_response = self.openai.chat.completions.create(
|
second_response = self.client.chat.completions.create(
|
||||||
model=self.model_version,
|
model=self.model_version,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
logger.info(f"Second response from OpenAI in MapInfoAgent: {second_response}")
|
|
||||||
second_response_message = second_response.choices[0].message.content
|
second_response_message = second_response.choices[0].message.content
|
||||||
logger.info(f"Second response message from OpenAI in MapInfoAgent: {second_response_message}")
|
|
||||||
logger.info(f"Function Response bofore setting the layer name: {function_response}")
|
|
||||||
function_response['layer_name'] = second_response_message
|
function_response['layer_name'] = second_response_message
|
||||||
logger.info(f"Function response after call to select a layername: {function_response}")
|
|
||||||
return {"response": function_response}
|
return {"response": function_response}
|
||||||
elif response_message.get("content"):
|
elif response_message.get("content"):
|
||||||
return {"response": response_message["content"]}
|
return {"response": response_message["content"]}
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
import json
|
import json
|
||||||
#import openai
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class MarshallAgent:
|
class MarshallAgent:
|
||||||
"""A Marshall agent that has function descriptions for choosing the appropriate agent for a specified task."""
|
"""A Marshall agent that has function descriptions for choosing the appropriate agent for a specified task."""
|
||||||
def __init__(self, openai, model_version="gpt-3.5-turbo-0613"):
|
def __init__(self, client, model_version):
|
||||||
self.model_version = model_version
|
self.model_version = model_version
|
||||||
self.openai = openai
|
self.client = client
|
||||||
self.tools = [
|
self.tools = [
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
@ -31,7 +30,7 @@ class MarshallAgent:
|
||||||
self.system_message = """You are a helpful assistant that decides which agent to use for a specified task.
|
self.system_message = """You are a helpful assistant that decides which agent to use for a specified task.
|
||||||
|
|
||||||
For tasks related to adding layers and other geospatial data to the map, use the DatabaseAgent.
|
For tasks related to adding layers and other geospatial data to the map, use the DatabaseAgent.
|
||||||
Examples include 'add buildings to the map' and 'get landuse polygons within this extent'.
|
Examples include 'add buildings to the map', 'show industrial buildings', and 'get landuse polygons within this extent'.
|
||||||
|
|
||||||
For tasks that ask to change the style of a map, such as opacity, color, or line width, you will
|
For tasks that ask to change the style of a map, such as opacity, color, or line width, you will
|
||||||
use the StyleAgent. Examples StyleAgent prompts include 'change color to green', 'opacity 45%'
|
use the StyleAgent. Examples StyleAgent prompts include 'change color to green', 'opacity 45%'
|
||||||
|
@ -52,18 +51,14 @@ class MarshallAgent:
|
||||||
self.available_functions = {
|
self.available_functions = {
|
||||||
"choose_agent": self.choose_agent,
|
"choose_agent": self.choose_agent,
|
||||||
}
|
}
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def choose_agent(self, agent_name):
|
def choose_agent(self, agent_name):
|
||||||
return {"name": "choose_agent", "agent_name": agent_name}
|
return {"name": "choose_agent", "agent_name": agent_name}
|
||||||
|
|
||||||
def listen(self, message):
|
def listen(self, message):
|
||||||
self.logger.info(f"In MarshallAgent.listen()...message is: {message}")
|
logger.info(f"In MarshallAgent.listen()...message is: {message}")
|
||||||
"""Listen to a message from the user."""
|
"""Listen to a message from the user."""
|
||||||
|
|
||||||
# # Remove the last item in self.messages. Our agent has no memory
|
|
||||||
# if len(self.messages) > 1:
|
|
||||||
# self.messages.pop()
|
|
||||||
|
|
||||||
self.messages.append({
|
self.messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -75,7 +70,7 @@ class MarshallAgent:
|
||||||
function_response = None
|
function_response = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.openai.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_version,
|
model=self.model_version,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
|
@ -100,11 +95,7 @@ class MarshallAgent:
|
||||||
"content": json.dumps(function_response),
|
"content": json.dumps(function_response),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# second_response = self.openai.chat.completions.create(
|
logger.info(f"Sucessful MarshallAgent task completion: {function_response}")
|
||||||
# model=self.model_version,
|
|
||||||
# messages=self.messages,
|
|
||||||
# )
|
|
||||||
logger.info(f"Sucessful MarallAgent task completion: {function_response}")
|
|
||||||
return {"response": function_response}
|
return {"response": function_response}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
from .function_descriptions.navigation_function_descriptions import navigation_function_descriptions
|
from .function_descriptions.navigation_function_descriptions import navigation_function_descriptions
|
||||||
#import openai
|
#import client
|
||||||
import json
|
import json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -20,8 +20,8 @@ class NavigationAgent:
|
||||||
def zoom_out(self, zoom_levels=1):
|
def zoom_out(self, zoom_levels=1):
|
||||||
return {"name": "zoom_out", "zoom_levels": zoom_levels}
|
return {"name": "zoom_out", "zoom_levels": zoom_levels}
|
||||||
|
|
||||||
def __init__(self, openai, model_version="gpt-3.5-turbo-0613"):
|
def __init__(self, client, model_version):
|
||||||
self.openai = openai
|
self.client = client
|
||||||
self.model_version = model_version
|
self.model_version = model_version
|
||||||
self.messages = [
|
self.messages = [
|
||||||
{
|
{
|
||||||
|
@ -42,7 +42,6 @@ class NavigationAgent:
|
||||||
"zoom_out": self.zoom_out,
|
"zoom_out": self.zoom_out,
|
||||||
}
|
}
|
||||||
self.tools = navigation_function_descriptions
|
self.tools = navigation_function_descriptions
|
||||||
logger.info(f"self.tools in NavigationAgent is: {self.tools}")
|
|
||||||
|
|
||||||
def listen(self, message):
|
def listen(self, message):
|
||||||
logging.info(f"In NavigationAgent...message is: {message}")
|
logging.info(f"In NavigationAgent...message is: {message}")
|
||||||
|
@ -61,15 +60,13 @@ class NavigationAgent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("Calling OpenAI API in NavigationAgent...")
|
logger.info("Calling OpenAI API in NavigationAgent...")
|
||||||
response = self.openai.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_version,
|
model=self.model_version,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
logger.info(f"response in NavigationAgent is: {response}")
|
|
||||||
response_message = response.choices[0].message
|
response_message = response.choices[0].message
|
||||||
logger.info(f"response_message in NavigationAgent is: {response_message}")
|
|
||||||
tool_calls = response_message.tool_calls
|
tool_calls = response_message.tool_calls
|
||||||
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
|
|
|
@ -19,8 +19,8 @@ class StyleAgent:
|
||||||
def set_visibility(self, layer_name, visibility):
|
def set_visibility(self, layer_name, visibility):
|
||||||
return {"name": "set_visibility", "layer_name": layer_name, "visibility": visibility}
|
return {"name": "set_visibility", "layer_name": layer_name, "visibility": visibility}
|
||||||
|
|
||||||
def __init__(self, openai, model_version="gpt-3.5-turbo-0613"):
|
def __init__(self, client, model_version):
|
||||||
self.openai = openai
|
self.client = client
|
||||||
self.model_version = model_version
|
self.model_version = model_version
|
||||||
|
|
||||||
self.tools = style_function_descriptions
|
self.tools = style_function_descriptions
|
||||||
|
@ -57,15 +57,13 @@ class StyleAgent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("Calling OpenAI API in StyleAgent...")
|
logger.info("Calling OpenAI API in StyleAgent...")
|
||||||
response = self.openai.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_version,
|
model=self.model_version,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
logger.info(f"response in StyleAgent is: {response}")
|
|
||||||
response_message = response.choices[0].message
|
response_message = response.choices[0].message
|
||||||
logger.info(f"response_message in StyleAgent is: {response_message}")
|
|
||||||
tool_calls = response_message.tool_calls
|
tool_calls = response_message.tool_calls
|
||||||
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
|
|
|
@ -22,14 +22,14 @@ load_dotenv()
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
CORS(app)
|
CORS(app)
|
||||||
|
|
||||||
openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||||
model_version = os.getenv("OPENAI_MODEL_VERSION")
|
model_version = os.getenv("OPENAI_MODEL_VERSION")
|
||||||
UPLOAD_FOLDER = 'uploads/audio'
|
UPLOAD_FOLDER = 'uploads/audio'
|
||||||
|
|
||||||
navigation_agent = NavigationAgent(openai, model_version=model_version)
|
navigation_agent = NavigationAgent(client, model_version=model_version)
|
||||||
marshall_agent = MarshallAgent(openai, model_version=model_version)
|
marshall_agent = MarshallAgent(client, model_version=model_version)
|
||||||
style_agent = StyleAgent(openai, model_version=model_version)
|
style_agent = StyleAgent(client, model_version=model_version)
|
||||||
map_info_agent = MapInfoAgent(openai, model_version=model_version)
|
map_info_agent = MapInfoAgent(client, model_version=model_version)
|
||||||
|
|
||||||
def get_database_schema():
|
def get_database_schema():
|
||||||
db = Database(
|
db = Database(
|
||||||
|
@ -44,15 +44,12 @@ def get_database_schema():
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
schema = get_database_schema()
|
schema = get_database_schema()
|
||||||
database_agent = DatabaseAgent(model_version=model_version, schema=schema)
|
database_agent = DatabaseAgent(client, model_version=model_version, schema=schema)
|
||||||
|
|
||||||
@app.route('/get_query', methods=['POST'])
|
@app.route('/get_query', methods=['POST'])
|
||||||
def get_query():
|
def get_query():
|
||||||
logging.info(f"Received request in /get_query route...: {request}")
|
|
||||||
message = request.json.get('message', '')
|
message = request.json.get('message', '')
|
||||||
bbox = request.json.get('bbox', '')
|
bbox = request.json.get('bbox', '')
|
||||||
logging.info(f"Received message in /get_query route...: {message}")
|
|
||||||
logging.info(f"Received bbox in /get_query route...: {bbox}")
|
|
||||||
return jsonify(database_agent.listen(message, bbox))
|
return jsonify(database_agent.listen(message, bbox))
|
||||||
|
|
||||||
@app.route('/get_table_name', methods=['GET'])
|
@app.route('/get_table_name', methods=['GET'])
|
||||||
|
@ -62,17 +59,16 @@ def get_table_name():
|
||||||
prefixed_message = f"Choose the most likely table the following text is referring to from this list:\m {table_names}.\n"
|
prefixed_message = f"Choose the most likely table the following text is referring to from this list:\m {table_names}.\n"
|
||||||
final_message = prefixed_message + message
|
final_message = prefixed_message + message
|
||||||
logging.info(f"Received message in /get_table_name route...: {final_message}")
|
logging.info(f"Received message in /get_table_name route...: {final_message}")
|
||||||
response = openai.ChatCompletion.create(
|
response = client.chat.completions.create(
|
||||||
model=model_version,
|
model=model_version,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a helpful assistant that chooses a table name from a list. Only respond with the table name."},
|
{"role": "system", "content": "You are a helpful assistant that chooses a table name from a list. Only respond with the table name."},
|
||||||
{"role": "user", "content": final_message},
|
{"role": "user", "content": final_message},
|
||||||
],
|
],
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=256,
|
max_tokens=32,
|
||||||
top_p=1,
|
|
||||||
frequency_penalty=0,
|
frequency_penalty=0,
|
||||||
presence_penalty=0
|
presence_penalty=0,
|
||||||
)
|
)
|
||||||
logging.info(f"Response from OpenAI in /get_table_name route: {response}")
|
logging.info(f"Response from OpenAI in /get_table_name route: {response}")
|
||||||
#response_message = response["choices"][0]["message"]
|
#response_message = response["choices"][0]["message"]
|
||||||
|
@ -163,7 +159,10 @@ def upload_audio():
|
||||||
audio_file = request.files['audio']
|
audio_file = request.files['audio']
|
||||||
audio_file.save(os.path.join(UPLOAD_FOLDER, "user_audio.webm"))
|
audio_file.save(os.path.join(UPLOAD_FOLDER, "user_audio.webm"))
|
||||||
audio_file=open(os.path.join(UPLOAD_FOLDER, "user_audio.webm"), 'rb')
|
audio_file=open(os.path.join(UPLOAD_FOLDER, "user_audio.webm"), 'rb')
|
||||||
transcript = openai.audio.transcribe("whisper-1", audio_file)
|
transcript = client.audio.transcriptions.create(
|
||||||
|
model="whisper-1",
|
||||||
|
file = audio_file
|
||||||
|
)
|
||||||
logging.info(f"Received transcript: {transcript}")
|
logging.info(f"Received transcript: {transcript}")
|
||||||
message = transcript['text']
|
message = transcript['text']
|
||||||
#delete the audio
|
#delete the audio
|
||||||
|
|
|
@ -270,7 +270,6 @@
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
response_data = data.response;
|
response_data = data.response;
|
||||||
console.log(response_data);
|
|
||||||
chatbox.value += databaseResponseAgent.handleResponse(userMessage, response_data);
|
chatbox.value += databaseResponseAgent.handleResponse(userMessage, response_data);
|
||||||
return;
|
return;
|
||||||
})
|
})
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,4 @@
|
||||||
|
[ZoneTransfer]
|
||||||
|
ZoneId=3
|
||||||
|
ReferrerUrl=https://download.geofabrik.de/europe.html
|
||||||
|
HostUrl=https://download.geofabrik.de/europe/liechtenstein-latest.osm.pbf
|
Ładowanie…
Reference in New Issue