2026-01-14 19:25:22 +08:00
|
|
|
"""
|
|
|
|
|
LangGraph workflow definition for B2B Shopping AI Assistant
|
|
|
|
|
"""
|
|
|
|
|
from typing import Literal
|
|
|
|
|
import httpx
|
|
|
|
|
|
|
|
|
|
from langgraph.graph import StateGraph, END
|
|
|
|
|
|
|
|
|
|
from .state import AgentState, ConversationState, mark_finished, add_tool_result, set_response
|
2026-01-16 16:28:47 +08:00
|
|
|
# 延迟导入以避免循环依赖
|
2026-01-14 19:25:22 +08:00
|
|
|
from config import settings
|
|
|
|
|
from utils.logger import get_logger
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============ Node Functions ============
|
|
|
|
|
|
|
|
|
|
async def receive_message(state: AgentState) -> AgentState:
|
|
|
|
|
"""Receive and preprocess incoming message
|
|
|
|
|
|
|
|
|
|
This is the entry point of the workflow.
|
|
|
|
|
"""
|
|
|
|
|
logger.info(
|
|
|
|
|
"Receiving message",
|
|
|
|
|
conversation_id=state["conversation_id"],
|
|
|
|
|
message_length=len(state["current_message"])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add user message to history
|
|
|
|
|
state["messages"].append({
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": state["current_message"]
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
state["state"] = ConversationState.INITIAL.value
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def call_mcp_tools(state: AgentState) -> AgentState:
|
|
|
|
|
"""Execute pending MCP tool calls
|
|
|
|
|
|
|
|
|
|
Calls the appropriate MCP server based on the tool_calls in state.
|
|
|
|
|
"""
|
|
|
|
|
if not state["tool_calls"]:
|
|
|
|
|
logger.debug("No tool calls to execute")
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
"Executing MCP tools",
|
|
|
|
|
tool_count=len(state["tool_calls"])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# MCP server URL mapping
|
|
|
|
|
mcp_servers = {
|
|
|
|
|
"strapi": settings.strapi_mcp_url,
|
|
|
|
|
"order": settings.order_mcp_url,
|
|
|
|
|
"aftersale": settings.aftersale_mcp_url,
|
|
|
|
|
"product": settings.product_mcp_url
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
|
|
|
for tool_call in state["tool_calls"]:
|
|
|
|
|
server = tool_call["server"]
|
|
|
|
|
tool_name = tool_call["tool_name"]
|
|
|
|
|
arguments = tool_call["arguments"]
|
|
|
|
|
|
|
|
|
|
server_url = mcp_servers.get(server)
|
|
|
|
|
if not server_url:
|
|
|
|
|
state = add_tool_result(
|
|
|
|
|
state,
|
|
|
|
|
tool_name=tool_name,
|
|
|
|
|
success=False,
|
|
|
|
|
data=None,
|
|
|
|
|
error=f"Unknown MCP server: {server}"
|
|
|
|
|
)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Call MCP tool endpoint
|
|
|
|
|
response = await client.post(
|
|
|
|
|
f"{server_url}/tools/{tool_name}",
|
|
|
|
|
json=arguments
|
|
|
|
|
)
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
|
|
|
|
|
result = response.json()
|
|
|
|
|
state = add_tool_result(
|
|
|
|
|
state,
|
|
|
|
|
tool_name=tool_name,
|
|
|
|
|
success=True,
|
|
|
|
|
data=result
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.debug(
|
|
|
|
|
"Tool executed successfully",
|
|
|
|
|
tool=tool_name,
|
|
|
|
|
server=server
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
|
|
|
logger.error(
|
|
|
|
|
"Tool HTTP error",
|
|
|
|
|
tool=tool_name,
|
|
|
|
|
status=e.response.status_code
|
|
|
|
|
)
|
|
|
|
|
state = add_tool_result(
|
|
|
|
|
state,
|
|
|
|
|
tool_name=tool_name,
|
|
|
|
|
success=False,
|
|
|
|
|
data=None,
|
|
|
|
|
error=f"HTTP {e.response.status_code}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Tool execution failed", tool=tool_name, error=str(e))
|
|
|
|
|
state = add_tool_result(
|
|
|
|
|
state,
|
|
|
|
|
tool_name=tool_name,
|
|
|
|
|
success=False,
|
|
|
|
|
data=None,
|
|
|
|
|
error=str(e)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Clear pending tool calls
|
|
|
|
|
state["tool_calls"] = []
|
|
|
|
|
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def human_handoff(state: AgentState) -> AgentState:
|
|
|
|
|
"""Handle transfer to human agent
|
|
|
|
|
|
|
|
|
|
Sets up the state for human intervention.
|
|
|
|
|
"""
|
|
|
|
|
logger.info(
|
|
|
|
|
"Human handoff requested",
|
|
|
|
|
conversation_id=state["conversation_id"],
|
|
|
|
|
reason=state.get("handoff_reason")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
state["state"] = ConversationState.HUMAN_REVIEW.value
|
|
|
|
|
|
|
|
|
|
# Generate handoff message
|
|
|
|
|
reason = state.get("handoff_reason", "您的问题需要人工客服协助")
|
|
|
|
|
state = set_response(
|
|
|
|
|
state,
|
|
|
|
|
f"正在为您转接人工客服,请稍候。\n转接原因:{reason}\n\n人工客服将尽快为您服务。"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def send_response(state: AgentState) -> AgentState:
|
|
|
|
|
"""Finalize and send response
|
|
|
|
|
|
|
|
|
|
This is the final node that marks processing as complete.
|
|
|
|
|
"""
|
|
|
|
|
logger.info(
|
|
|
|
|
"Sending response",
|
|
|
|
|
conversation_id=state["conversation_id"],
|
|
|
|
|
response_length=len(state.get("response", ""))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add assistant response to history
|
|
|
|
|
if state.get("response"):
|
|
|
|
|
state["messages"].append({
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"content": state["response"]
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
state = mark_finished(state)
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def handle_error(state: AgentState) -> AgentState:
|
|
|
|
|
"""Handle errors in the workflow"""
|
|
|
|
|
logger.error(
|
|
|
|
|
"Workflow error",
|
|
|
|
|
conversation_id=state["conversation_id"],
|
|
|
|
|
error=state.get("error")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
state = set_response(
|
|
|
|
|
state,
|
|
|
|
|
"抱歉,处理您的请求时遇到了问题。请稍后重试,或联系人工客服获取帮助。"
|
|
|
|
|
)
|
|
|
|
|
state = mark_finished(state)
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============ Routing Functions ============
|
|
|
|
|
|
|
|
|
|
def should_call_tools(state: AgentState) -> Literal["call_tools", "send_response", "back_to_agent"]:
|
|
|
|
|
"""Determine if tools need to be called"""
|
2026-01-16 16:28:47 +08:00
|
|
|
|
|
|
|
|
logger.debug(
|
|
|
|
|
"Checking if tools should be called",
|
|
|
|
|
conversation_id=state.get("conversation_id"),
|
|
|
|
|
has_tool_calls=bool(state.get("tool_calls")),
|
|
|
|
|
tool_calls_count=len(state.get("tool_calls", [])),
|
|
|
|
|
has_response=bool(state.get("response")),
|
|
|
|
|
state_value=state.get("state")
|
|
|
|
|
)
|
|
|
|
|
|
2026-01-14 19:25:22 +08:00
|
|
|
# If there are pending tool calls, execute them
|
|
|
|
|
if state.get("tool_calls"):
|
2026-01-16 16:28:47 +08:00
|
|
|
logger.info(
|
|
|
|
|
"Routing to tool execution",
|
|
|
|
|
tool_count=len(state["tool_calls"])
|
|
|
|
|
)
|
2026-01-14 19:25:22 +08:00
|
|
|
return "call_tools"
|
2026-01-16 16:28:47 +08:00
|
|
|
|
2026-01-14 19:25:22 +08:00
|
|
|
# If we have a response ready, send it
|
|
|
|
|
if state.get("response"):
|
2026-01-16 16:28:47 +08:00
|
|
|
logger.debug("Routing to send_response (has response)")
|
2026-01-14 19:25:22 +08:00
|
|
|
return "send_response"
|
2026-01-16 16:28:47 +08:00
|
|
|
|
2026-01-14 19:25:22 +08:00
|
|
|
# If we're waiting for info, send the question
|
|
|
|
|
if state.get("state") == ConversationState.AWAITING_INFO.value:
|
2026-01-16 16:28:47 +08:00
|
|
|
logger.debug("Routing to send_response (awaiting info)")
|
2026-01-14 19:25:22 +08:00
|
|
|
return "send_response"
|
2026-01-16 16:28:47 +08:00
|
|
|
|
2026-01-14 19:25:22 +08:00
|
|
|
# Otherwise, something went wrong
|
2026-01-16 16:28:47 +08:00
|
|
|
logger.warning("Unexpected state, routing to send_response", state=state.get("state"))
|
2026-01-14 19:25:22 +08:00
|
|
|
return "send_response"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def after_tools(state: AgentState) -> str:
|
|
|
|
|
"""Route after tool execution
|
|
|
|
|
|
|
|
|
|
Returns the agent that should process the tool results.
|
|
|
|
|
"""
|
|
|
|
|
current_agent = state.get("current_agent")
|
|
|
|
|
|
|
|
|
|
# Route back to the agent that made the tool call
|
|
|
|
|
agent_mapping = {
|
|
|
|
|
"customer_service": "customer_service_agent",
|
|
|
|
|
"order": "order_agent",
|
|
|
|
|
"aftersale": "aftersale_agent",
|
|
|
|
|
"product": "product_agent"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return agent_mapping.get(current_agent, "customer_service_agent")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_completion(state: AgentState) -> Literal["continue", "end", "error"]:
|
|
|
|
|
"""Check if workflow should continue or end"""
|
|
|
|
|
|
|
|
|
|
# Check for errors
|
|
|
|
|
if state.get("error"):
|
|
|
|
|
return "error"
|
|
|
|
|
|
|
|
|
|
# Check if finished
|
|
|
|
|
if state.get("finished"):
|
|
|
|
|
return "end"
|
|
|
|
|
|
|
|
|
|
# Check step limit
|
|
|
|
|
if state.get("step_count", 0) >= state.get("max_steps", 10):
|
|
|
|
|
logger.warning("Max steps reached", conversation_id=state["conversation_id"])
|
|
|
|
|
return "end"
|
|
|
|
|
|
|
|
|
|
return "continue"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============ Graph Construction ============
|
|
|
|
|
|
|
|
|
|
def create_agent_graph() -> StateGraph:
|
|
|
|
|
"""Create the main agent workflow graph
|
2026-01-16 16:28:47 +08:00
|
|
|
|
2026-01-14 19:25:22 +08:00
|
|
|
Returns:
|
|
|
|
|
Compiled LangGraph workflow
|
|
|
|
|
"""
|
2026-01-16 16:28:47 +08:00
|
|
|
# 延迟导入以避免循环依赖
|
|
|
|
|
from agents.router import classify_intent, route_by_intent
|
|
|
|
|
from agents.customer_service import customer_service_agent
|
|
|
|
|
from agents.order import order_agent
|
|
|
|
|
from agents.aftersale import aftersale_agent
|
|
|
|
|
from agents.product import product_agent
|
|
|
|
|
|
2026-01-14 19:25:22 +08:00
|
|
|
# Create graph with AgentState
|
|
|
|
|
graph = StateGraph(AgentState)
|
2026-01-16 16:28:47 +08:00
|
|
|
|
2026-01-14 19:25:22 +08:00
|
|
|
# Add nodes
|
|
|
|
|
graph.add_node("receive", receive_message)
|
|
|
|
|
graph.add_node("classify", classify_intent)
|
|
|
|
|
graph.add_node("customer_service_agent", customer_service_agent)
|
|
|
|
|
graph.add_node("order_agent", order_agent)
|
|
|
|
|
graph.add_node("aftersale_agent", aftersale_agent)
|
|
|
|
|
graph.add_node("product_agent", product_agent)
|
|
|
|
|
graph.add_node("call_tools", call_mcp_tools)
|
|
|
|
|
graph.add_node("human_handoff", human_handoff)
|
|
|
|
|
graph.add_node("send_response", send_response)
|
|
|
|
|
graph.add_node("handle_error", handle_error)
|
|
|
|
|
|
|
|
|
|
# Set entry point
|
|
|
|
|
graph.set_entry_point("receive")
|
|
|
|
|
|
|
|
|
|
# Add edges
|
|
|
|
|
graph.add_edge("receive", "classify")
|
|
|
|
|
|
|
|
|
|
# Conditional routing based on intent
|
|
|
|
|
graph.add_conditional_edges(
|
|
|
|
|
"classify",
|
|
|
|
|
route_by_intent,
|
|
|
|
|
{
|
|
|
|
|
"customer_service_agent": "customer_service_agent",
|
|
|
|
|
"order_agent": "order_agent",
|
|
|
|
|
"aftersale_agent": "aftersale_agent",
|
|
|
|
|
"product_agent": "product_agent",
|
|
|
|
|
"human_handoff": "human_handoff"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# After each agent, check if tools need to be called
|
|
|
|
|
for agent_node in ["customer_service_agent", "order_agent", "aftersale_agent", "product_agent"]:
|
|
|
|
|
graph.add_conditional_edges(
|
|
|
|
|
agent_node,
|
|
|
|
|
should_call_tools,
|
|
|
|
|
{
|
|
|
|
|
"call_tools": "call_tools",
|
|
|
|
|
"send_response": "send_response",
|
|
|
|
|
"back_to_agent": agent_node
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# After tool execution, route back to appropriate agent
|
|
|
|
|
graph.add_conditional_edges(
|
|
|
|
|
"call_tools",
|
|
|
|
|
after_tools,
|
|
|
|
|
{
|
|
|
|
|
"customer_service_agent": "customer_service_agent",
|
|
|
|
|
"order_agent": "order_agent",
|
|
|
|
|
"aftersale_agent": "aftersale_agent",
|
|
|
|
|
"product_agent": "product_agent"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Human handoff leads to send response
|
|
|
|
|
graph.add_edge("human_handoff", "send_response")
|
|
|
|
|
|
|
|
|
|
# Error handling
|
|
|
|
|
graph.add_edge("handle_error", END)
|
|
|
|
|
|
|
|
|
|
# Final node
|
|
|
|
|
graph.add_edge("send_response", END)
|
|
|
|
|
|
|
|
|
|
return graph.compile()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Global compiled graph
|
|
|
|
|
_compiled_graph = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_agent_graph():
|
|
|
|
|
"""Get or create the compiled agent graph"""
|
|
|
|
|
global _compiled_graph
|
|
|
|
|
if _compiled_graph is None:
|
|
|
|
|
_compiled_graph = create_agent_graph()
|
|
|
|
|
return _compiled_graph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def process_message(
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
user_id: str,
|
|
|
|
|
account_id: str,
|
|
|
|
|
message: str,
|
|
|
|
|
history: list[dict] = None,
|
2026-01-16 16:28:47 +08:00
|
|
|
context: dict = None,
|
|
|
|
|
user_token: str = None
|
2026-01-14 19:25:22 +08:00
|
|
|
) -> AgentState:
|
|
|
|
|
"""Process a user message through the agent workflow
|
2026-01-16 16:28:47 +08:00
|
|
|
|
2026-01-14 19:25:22 +08:00
|
|
|
Args:
|
|
|
|
|
conversation_id: Chatwoot conversation ID
|
|
|
|
|
user_id: User identifier
|
|
|
|
|
account_id: B2B account identifier
|
|
|
|
|
message: User's message
|
|
|
|
|
history: Previous conversation history
|
|
|
|
|
context: Existing conversation context
|
2026-01-16 16:28:47 +08:00
|
|
|
user_token: User JWT token for API calls
|
|
|
|
|
|
2026-01-14 19:25:22 +08:00
|
|
|
Returns:
|
|
|
|
|
Final agent state with response
|
|
|
|
|
"""
|
|
|
|
|
from .state import create_initial_state
|
2026-01-16 16:28:47 +08:00
|
|
|
|
2026-01-14 19:25:22 +08:00
|
|
|
# Create initial state
|
|
|
|
|
initial_state = create_initial_state(
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
account_id=account_id,
|
|
|
|
|
current_message=message,
|
|
|
|
|
messages=history,
|
2026-01-16 16:28:47 +08:00
|
|
|
context=context,
|
|
|
|
|
user_token=user_token
|
2026-01-14 19:25:22 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Get compiled graph
|
|
|
|
|
graph = get_agent_graph()
|
|
|
|
|
|
|
|
|
|
# Run the workflow
|
|
|
|
|
logger.info(
|
|
|
|
|
"Starting workflow",
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
message=message[:100]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
final_state = await graph.ainvoke(initial_state)
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
"Workflow completed",
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
intent=final_state.get("intent"),
|
|
|
|
|
steps=final_state.get("step_count")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return final_state
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Workflow failed", error=str(e))
|
|
|
|
|
initial_state["error"] = str(e)
|
|
|
|
|
initial_state["response"] = "抱歉,处理您的请求时遇到了问题。请稍后重试。"
|
|
|
|
|
initial_state["finished"] = True
|
|
|
|
|
return initial_state
|