Files
assistant/agent/core/graph.py

427 lines
12 KiB
Python
Raw Normal View History

"""
LangGraph workflow definition for B2B Shopping AI Assistant
"""
from typing import Literal
import httpx
from langgraph.graph import StateGraph, END
from .state import AgentState, ConversationState, mark_finished, add_tool_result, set_response
# 延迟导入以避免循环依赖
from config import settings
from utils.logger import get_logger
logger = get_logger(__name__)
# ============ Node Functions ============
async def receive_message(state: AgentState) -> AgentState:
"""Receive and preprocess incoming message
This is the entry point of the workflow.
"""
logger.info(
"Receiving message",
conversation_id=state["conversation_id"],
message_length=len(state["current_message"])
)
# Add user message to history
state["messages"].append({
"role": "user",
"content": state["current_message"]
})
state["state"] = ConversationState.INITIAL.value
return state
async def call_mcp_tools(state: AgentState) -> AgentState:
"""Execute pending MCP tool calls
Calls the appropriate MCP server based on the tool_calls in state.
"""
if not state["tool_calls"]:
logger.debug("No tool calls to execute")
return state
logger.info(
"Executing MCP tools",
tool_count=len(state["tool_calls"])
)
# MCP server URL mapping
mcp_servers = {
"strapi": settings.strapi_mcp_url,
"order": settings.order_mcp_url,
"aftersale": settings.aftersale_mcp_url,
"product": settings.product_mcp_url
}
async with httpx.AsyncClient(timeout=30.0) as client:
for tool_call in state["tool_calls"]:
server = tool_call["server"]
tool_name = tool_call["tool_name"]
arguments = tool_call["arguments"]
server_url = mcp_servers.get(server)
if not server_url:
state = add_tool_result(
state,
tool_name=tool_name,
success=False,
data=None,
error=f"Unknown MCP server: {server}"
)
continue
try:
# Call MCP tool endpoint
response = await client.post(
f"{server_url}/tools/{tool_name}",
json=arguments
)
response.raise_for_status()
result = response.json()
state = add_tool_result(
state,
tool_name=tool_name,
success=True,
data=result
)
logger.debug(
"Tool executed successfully",
tool=tool_name,
server=server
)
except httpx.HTTPStatusError as e:
logger.error(
"Tool HTTP error",
tool=tool_name,
status=e.response.status_code
)
state = add_tool_result(
state,
tool_name=tool_name,
success=False,
data=None,
error=f"HTTP {e.response.status_code}"
)
except Exception as e:
logger.error("Tool execution failed", tool=tool_name, error=str(e))
state = add_tool_result(
state,
tool_name=tool_name,
success=False,
data=None,
error=str(e)
)
# Clear pending tool calls
state["tool_calls"] = []
return state
async def human_handoff(state: AgentState) -> AgentState:
"""Handle transfer to human agent
Sets up the state for human intervention.
"""
logger.info(
"Human handoff requested",
conversation_id=state["conversation_id"],
reason=state.get("handoff_reason")
)
state["state"] = ConversationState.HUMAN_REVIEW.value
# Generate handoff message
reason = state.get("handoff_reason", "您的问题需要人工客服协助")
state = set_response(
state,
f"正在为您转接人工客服,请稍候。\n转接原因:{reason}\n\n人工客服将尽快为您服务。"
)
return state
async def send_response(state: AgentState) -> AgentState:
"""Finalize and send response
This is the final node that marks processing as complete.
"""
logger.info(
"Sending response",
conversation_id=state["conversation_id"],
response_length=len(state.get("response", ""))
)
# Add assistant response to history
if state.get("response"):
state["messages"].append({
"role": "assistant",
"content": state["response"]
})
state = mark_finished(state)
return state
async def handle_error(state: AgentState) -> AgentState:
"""Handle errors in the workflow"""
logger.error(
"Workflow error",
conversation_id=state["conversation_id"],
error=state.get("error")
)
state = set_response(
state,
"抱歉,处理您的请求时遇到了问题。请稍后重试,或联系人工客服获取帮助。"
)
state = mark_finished(state)
return state
# ============ Routing Functions ============
def should_call_tools(state: AgentState) -> Literal["call_tools", "send_response", "back_to_agent"]:
"""Determine if tools need to be called"""
logger.debug(
"Checking if tools should be called",
conversation_id=state.get("conversation_id"),
has_tool_calls=bool(state.get("tool_calls")),
tool_calls_count=len(state.get("tool_calls", [])),
has_response=bool(state.get("response")),
state_value=state.get("state")
)
# If there are pending tool calls, execute them
if state.get("tool_calls"):
logger.info(
"Routing to tool execution",
tool_count=len(state["tool_calls"])
)
return "call_tools"
# If we have a response ready, send it
if state.get("response"):
logger.debug("Routing to send_response (has response)")
return "send_response"
# If we're waiting for info, send the question
if state.get("state") == ConversationState.AWAITING_INFO.value:
logger.debug("Routing to send_response (awaiting info)")
return "send_response"
# Otherwise, something went wrong
logger.warning("Unexpected state, routing to send_response", state=state.get("state"))
return "send_response"
def after_tools(state: AgentState) -> str:
"""Route after tool execution
Returns the agent that should process the tool results.
"""
current_agent = state.get("current_agent")
# Route back to the agent that made the tool call
agent_mapping = {
"customer_service": "customer_service_agent",
"order": "order_agent",
"aftersale": "aftersale_agent",
"product": "product_agent"
}
return agent_mapping.get(current_agent, "customer_service_agent")
def check_completion(state: AgentState) -> Literal["continue", "end", "error"]:
"""Check if workflow should continue or end"""
# Check for errors
if state.get("error"):
return "error"
# Check if finished
if state.get("finished"):
return "end"
# Check step limit
if state.get("step_count", 0) >= state.get("max_steps", 10):
logger.warning("Max steps reached", conversation_id=state["conversation_id"])
return "end"
return "continue"
# ============ Graph Construction ============
def create_agent_graph() -> StateGraph:
"""Create the main agent workflow graph
Returns:
Compiled LangGraph workflow
"""
# 延迟导入以避免循环依赖
from agents.router import classify_intent, route_by_intent
from agents.customer_service import customer_service_agent
from agents.order import order_agent
from agents.aftersale import aftersale_agent
from agents.product import product_agent
# Create graph with AgentState
graph = StateGraph(AgentState)
# Add nodes
graph.add_node("receive", receive_message)
graph.add_node("classify", classify_intent)
graph.add_node("customer_service_agent", customer_service_agent)
graph.add_node("order_agent", order_agent)
graph.add_node("aftersale_agent", aftersale_agent)
graph.add_node("product_agent", product_agent)
graph.add_node("call_tools", call_mcp_tools)
graph.add_node("human_handoff", human_handoff)
graph.add_node("send_response", send_response)
graph.add_node("handle_error", handle_error)
# Set entry point
graph.set_entry_point("receive")
# Add edges
graph.add_edge("receive", "classify")
# Conditional routing based on intent
graph.add_conditional_edges(
"classify",
route_by_intent,
{
"customer_service_agent": "customer_service_agent",
"order_agent": "order_agent",
"aftersale_agent": "aftersale_agent",
"product_agent": "product_agent",
"human_handoff": "human_handoff"
}
)
# After each agent, check if tools need to be called
for agent_node in ["customer_service_agent", "order_agent", "aftersale_agent", "product_agent"]:
graph.add_conditional_edges(
agent_node,
should_call_tools,
{
"call_tools": "call_tools",
"send_response": "send_response",
"back_to_agent": agent_node
}
)
# After tool execution, route back to appropriate agent
graph.add_conditional_edges(
"call_tools",
after_tools,
{
"customer_service_agent": "customer_service_agent",
"order_agent": "order_agent",
"aftersale_agent": "aftersale_agent",
"product_agent": "product_agent"
}
)
# Human handoff leads to send response
graph.add_edge("human_handoff", "send_response")
# Error handling
graph.add_edge("handle_error", END)
# Final node
graph.add_edge("send_response", END)
return graph.compile()
# Global compiled graph
_compiled_graph = None
def get_agent_graph():
"""Get or create the compiled agent graph"""
global _compiled_graph
if _compiled_graph is None:
_compiled_graph = create_agent_graph()
return _compiled_graph
async def process_message(
conversation_id: str,
user_id: str,
account_id: str,
message: str,
history: list[dict] = None,
context: dict = None,
user_token: str = None
) -> AgentState:
"""Process a user message through the agent workflow
Args:
conversation_id: Chatwoot conversation ID
user_id: User identifier
account_id: B2B account identifier
message: User's message
history: Previous conversation history
context: Existing conversation context
user_token: User JWT token for API calls
Returns:
Final agent state with response
"""
from .state import create_initial_state
# Create initial state
initial_state = create_initial_state(
conversation_id=conversation_id,
user_id=user_id,
account_id=account_id,
current_message=message,
messages=history,
context=context,
user_token=user_token
)
# Get compiled graph
graph = get_agent_graph()
# Run the workflow
logger.info(
"Starting workflow",
conversation_id=conversation_id,
message=message[:100]
)
try:
final_state = await graph.ainvoke(initial_state)
logger.info(
"Workflow completed",
conversation_id=conversation_id,
intent=final_state.get("intent"),
steps=final_state.get("step_count")
)
return final_state
except Exception as e:
logger.error("Workflow failed", error=str(e))
initial_state["error"] = str(e)
initial_state["response"] = "抱歉,处理您的请求时遇到了问题。请稍后重试。"
initial_state["finished"] = True
return initial_state