from langchain_google_genai import ChatGoogleGenerativeAI import json with open('../api_keys.json', mode = 'r') as f: api_keys = json.load(f) # Initialize the client with your API key llm = ChatGoogleGenerativeAI( model="gemini-2.0-flash", google_api_key=api_keys['ashishjain1547'] ) from langgraph.graph import StateGraph, END from langchain_community.document_loaders import WebBaseLoader from langchain.chat_models import ChatOpenAI import os from typing import TypedDict, Annotated class AgentState(TypedDict): nlq: str metadata: dict refined_query: str further_refined_query: str feedback: str itr_count: int import re def extract_triple_quoted_json(response_text): # This pattern looks for a JSON object (starting with '{' and ending with '}') # enclosed in either triple double-quotes or triple single-quotes. pattern = r'(?:\'\'\'|""")\s*(\{.*?\})\s*(?:\'\'\'|""")' match = re.search(pattern, response_text, re.DOTALL) if match: return match.group(1) return None class ReEngineerQuery: def __init__(self, model): self.model = model graph = StateGraph(AgentState) graph.add_node("refine_query", self.refine_query) graph.add_node("evaluate_reengineered_query", self.evaluate_reengineered_query) graph.add_node("further_refine_query", self.further_refine_query) # --- Edges --- graph.set_entry_point("refine_query") graph.add_edge("refine_query", "evaluate_reengineered_query") graph.add_edge("evaluate_reengineered_query", "further_refine_query") graph.add_conditional_edges( "further_refine_query", self.should_continue, {"end": END, "refine_query": "refine_query"} ) # Compile the graph and store references self.graph = graph.compile() def refine_query(self, state): META_PROMPT_TO_REENGINEER_NLQ = """ You are an expert SQLite query generator. Based on the natural language query and the provided table metadata, please reengineer the query to clearly specify: - The specific table(s) that should be referenced, - The calculations or aggregations to perform, - The structure of the final SQL query. NLQ: {nlq} Table Metadata: {metadata} Reengineered Query: """ # Build the meta-prompt by substituting the NLQ and metadata. prompt = META_PROMPT_TO_REENGINEER_NLQ.format(nlq=state.get("nlq", ""), metadata=state.get("metadata", "")) # Invoke the LLM with the prompt. response = self.model(prompt) # Return the refined query. return {"refined_query": response.strip(), "itr_count": state.get("itr_count") + 1} def evaluate_reengineered_query(self, state): EVALUATE_REENGINEERED_QUERY = """ You are an expert SQLite engineer grading an NLQ for correctness, completeness and clarity. \ Generate critique and recommendations for the NLQ so it can be easily converted to an SQL. \ Please evaluate the reengineered query below: {refined_query} """ prompt = EVALUATE_REENGINEERED_QUERY.format(refined_query=state.get("refined_query", "")) response = self.model(prompt) return {"feedback": response.strip()} def further_refine_query(self, state): REENGINEER_QUERY_POST_FEEDBACK = """ You are an expert SQLite query generator. Based on the natural language query, the provided table metadata and feedback, please reengineer the query based on the feedback given to clearly specify: - The specific table(s) that should be referenced, - The calculations or aggregations to perform, - The structure of the final SQL query. NLQ: {nlq} Table Metadata: {metadata} Feedback: {feedback} Reengineered Query: """ # Build the meta-prompt by substituting the NLQ and metadata. prompt = REENGINEER_QUERY_POST_FEEDBACK.format(nlq=state.get("nlq", ""), metadata=state.get("metadata", ""), feedback=state.get("feedback", "")) # Invoke the LLM with the prompt. response = self.model(prompt) print(response) return {"further_refined_query": response.strip(), "itr_count": state.get("itr_count") + 1} def should_continue(self, state): CHECK_CONSISTENCY = """ You are an expert database query evaluator. Your task is to evaluate two queries \ if they are consistent and mean the same thing. One query is the last query and the other is the reengineered query. \ Last Query: {refined_query} Reengineered Query: {further_refined_query} Is the reengineered query consistent with the last query? Return a JSON response with the key 'answer': 'yes' or 'no'.""" prompt = CHECK_CONSISTENCY.format(refined_query=state.get("refined_query", ""), further_refined_query=state.get("further_refined_query", "")) response = self.model(prompt) #extract json from the response print(response) json_response = "{" + response.split("{")[1].split("}")[0].strip().replace("'", '"') + "}" json_response = json.loads(json_response) print(json_response) if json_response['answer'] == "yes": return "end" return "refine_query" class MyGeminiChatModel: """ Minimal wrapper that expects a prompt and returns GPT-3.5 Turbo response text. """ def __init__(self, api_key): self.client = ChatGoogleGenerativeAI( model="gemini-2.0-flash", google_api_key = api_key ) def bind_tools(self, tools): return self # For compatibility with how TranslatorCriticApp uses .bind_tools def __call__(self, prompt: str) -> str: response = llm.invoke(prompt) return response.content model = MyGeminiChatModel(api_keys["ashishjain1547"]) app = ReEngineerQuery(model) from IPython.display import Image Image(app.graph.get_graph().draw_png()) <IMG> with open('tables.json', mode = 'r') as f: metadata = json.load(f) nlq = "Show me the orders from last year." refined_query = "" further_refined_query = "" feedback = "" itr_count = 0 result = app.graph.invoke({"nlq": nlq, "metadata": metadata, "refined_query": refined_query, "further_refined_query": further_refined_query, "feedback": feedback, "itr_count": itr_count}) SELECT * FROM AB_ORDERS WHERE ORDERYEAR = CAST(STRFTIME('%Y', DATE('now', '-1 year')) AS INTEGER) nlq = "Show all tables" refined_query = "" further_refined_query = "" feedback = "" itr_count = 0 result = app.graph.invoke({"nlq": nlq, "metadata": metadata, "refined_query": refined_query, "further_refined_query": further_refined_query, "feedback": feedback, "itr_count": itr_count}) To show all tables in the SQLite database, you can use the following SQL query: ```sql SELECT name FROM sqlite_master WHERE type='table'; ``` This query retrieves the names of all tables present in the SQLite database by querying the `sqlite_master` table where the `type` column is equal to 'table'. Knowing all the tables in the database is essential for database management, schema understanding, and querying purposes. It helps in identifying the available data structures, relationships between tables, and overall database organization. This information is crucial for developers, analysts, and administrators to effectively work with the database and perform various operations such as data retrieval, manipulation, and optimization. nlq = "What's in orders table?" refined_query = "" further_refined_query = "" feedback = "" itr_count = 0 result = app.graph.invoke({"nlq": nlq, "metadata": metadata, "refined_query": refined_query, "further_refined_query": further_refined_query, "feedback": feedback, "itr_count": itr_count}) SELECT * FROM AB_ORDERS nlq = "Show me top 5 categories with respect to orders." refined_query = "" further_refined_query = "" feedback = "" itr_count = 0 result = app.graph.invoke({"nlq": nlq, "metadata": metadata, "refined_query": refined_query, "further_refined_query": further_refined_query, "feedback": feedback, "itr_count": itr_count}) SELECT C.CATEGORYNAME, COUNT(O.ORDERID) AS ORDER_COUNT FROM AB_CATEGORIES C JOIN AB_PRODUCTS P ON C.CATEGORYID = P.CATEGORYID JOIN AB_ORDERDETAILS OD ON P.PRODUCTID = OD.PRODUCTID JOIN AB_ORDERS O ON OD.ORDERID = O.ORDERID GROUP BY C.CATEGORYNAME ORDER BY ORDER_COUNT DESC LIMIT 5; nlq = "Which areas are dairy products sold?" refined_query = "" further_refined_query = "" feedback = "" itr_count = 0 result = app.graph.invoke({"nlq": nlq, "metadata": metadata, "refined_query": refined_query, "further_refined_query": further_refined_query, "feedback": feedback, "itr_count": itr_count}) SELECT DISTINCT C.CITY AS AREA FROM AB_CUSTOMERS C JOIN AB_ORDERS O ON C.CUSTOMERID = O.CUSTOMERID JOIN AB_ORDERDETAILS OD ON O.ORDERID = OD.ORDERID JOIN AB_PRODUCTS P ON OD.PRODUCTID = P.PRODUCTID JOIN AB_CATEGORIES CAT ON P.CATEGORYID = CAT.CATEGORYID WHERE CAT.CATEGORYNAME = 'Dairy Products'; nlq = "Compare orders from top two cities with respect to total sales." refined_query = "" further_refined_query = "" feedback = "" itr_count = 0 result = app.graph.invoke({"nlq": nlq, "metadata": metadata, "refined_query": refined_query, "further_refined_query": further_refined_query, "feedback": feedback, "itr_count": itr_count}) -- Query to compare total sales from the top two cities SELECT c.CITY AS City, SUM(p.PRICE * od.QUANTITY) AS TotalSales FROM AB_CUSTOMERS c JOIN AB_ORDERS o ON c.CUSTOMERID = o.CUSTOMERID JOIN AB_ORDERDETAILS od ON o.ORDERID = od.ORDERID JOIN AB_PRODUCTS p ON od.PRODUCTID = p.PRODUCTID GROUP BY c.CITY ORDER BY TotalSales DESC LIMIT 2;
Wednesday, March 5, 2025
Text-to-SQL Agent
To See All Articles About Technology: Index of Lessons in Technology
Subscribe to:
Post Comments (Atom)
No comments:
Post a Comment