""" LangGraph workflow definition for B2B Shopping AI Assistant """ from typing import Literal import httpx from langgraph.graph import StateGraph, END from .state import AgentState, ConversationState, mark_finished, add_tool_result, set_response from agents.router import classify_intent, route_by_intent from agents.customer_service import customer_service_agent from agents.order import order_agent from agents.aftersale import aftersale_agent from agents.product import product_agent from config import settings from utils.logger import get_logger logger = get_logger(__name__) # ============ Node Functions ============ async def receive_message(state: AgentState) -> AgentState: """Receive and preprocess incoming message This is the entry point of the workflow. """ logger.info( "Receiving message", conversation_id=state["conversation_id"], message_length=len(state["current_message"]) ) # Add user message to history state["messages"].append({ "role": "user", "content": state["current_message"] }) state["state"] = ConversationState.INITIAL.value return state async def call_mcp_tools(state: AgentState) -> AgentState: """Execute pending MCP tool calls Calls the appropriate MCP server based on the tool_calls in state. """ if not state["tool_calls"]: logger.debug("No tool calls to execute") return state logger.info( "Executing MCP tools", tool_count=len(state["tool_calls"]) ) # MCP server URL mapping mcp_servers = { "strapi": settings.strapi_mcp_url, "order": settings.order_mcp_url, "aftersale": settings.aftersale_mcp_url, "product": settings.product_mcp_url } async with httpx.AsyncClient(timeout=30.0) as client: for tool_call in state["tool_calls"]: server = tool_call["server"] tool_name = tool_call["tool_name"] arguments = tool_call["arguments"] server_url = mcp_servers.get(server) if not server_url: state = add_tool_result( state, tool_name=tool_name, success=False, data=None, error=f"Unknown MCP server: {server}" ) continue try: # Call MCP tool endpoint response = await client.post( f"{server_url}/tools/{tool_name}", json=arguments ) response.raise_for_status() result = response.json() state = add_tool_result( state, tool_name=tool_name, success=True, data=result ) logger.debug( "Tool executed successfully", tool=tool_name, server=server ) except httpx.HTTPStatusError as e: logger.error( "Tool HTTP error", tool=tool_name, status=e.response.status_code ) state = add_tool_result( state, tool_name=tool_name, success=False, data=None, error=f"HTTP {e.response.status_code}" ) except Exception as e: logger.error("Tool execution failed", tool=tool_name, error=str(e)) state = add_tool_result( state, tool_name=tool_name, success=False, data=None, error=str(e) ) # Clear pending tool calls state["tool_calls"] = [] return state async def human_handoff(state: AgentState) -> AgentState: """Handle transfer to human agent Sets up the state for human intervention. """ logger.info( "Human handoff requested", conversation_id=state["conversation_id"], reason=state.get("handoff_reason") ) state["state"] = ConversationState.HUMAN_REVIEW.value # Generate handoff message reason = state.get("handoff_reason", "您的问题需要人工客服协助") state = set_response( state, f"正在为您转接人工客服,请稍候。\n转接原因:{reason}\n\n人工客服将尽快为您服务。" ) return state async def send_response(state: AgentState) -> AgentState: """Finalize and send response This is the final node that marks processing as complete. """ logger.info( "Sending response", conversation_id=state["conversation_id"], response_length=len(state.get("response", "")) ) # Add assistant response to history if state.get("response"): state["messages"].append({ "role": "assistant", "content": state["response"] }) state = mark_finished(state) return state async def handle_error(state: AgentState) -> AgentState: """Handle errors in the workflow""" logger.error( "Workflow error", conversation_id=state["conversation_id"], error=state.get("error") ) state = set_response( state, "抱歉,处理您的请求时遇到了问题。请稍后重试,或联系人工客服获取帮助。" ) state = mark_finished(state) return state # ============ Routing Functions ============ def should_call_tools(state: AgentState) -> Literal["call_tools", "send_response", "back_to_agent"]: """Determine if tools need to be called""" # If there are pending tool calls, execute them if state.get("tool_calls"): return "call_tools" # If we have a response ready, send it if state.get("response"): return "send_response" # If we're waiting for info, send the question if state.get("state") == ConversationState.AWAITING_INFO.value: return "send_response" # Otherwise, something went wrong return "send_response" def after_tools(state: AgentState) -> str: """Route after tool execution Returns the agent that should process the tool results. """ current_agent = state.get("current_agent") # Route back to the agent that made the tool call agent_mapping = { "customer_service": "customer_service_agent", "order": "order_agent", "aftersale": "aftersale_agent", "product": "product_agent" } return agent_mapping.get(current_agent, "customer_service_agent") def check_completion(state: AgentState) -> Literal["continue", "end", "error"]: """Check if workflow should continue or end""" # Check for errors if state.get("error"): return "error" # Check if finished if state.get("finished"): return "end" # Check step limit if state.get("step_count", 0) >= state.get("max_steps", 10): logger.warning("Max steps reached", conversation_id=state["conversation_id"]) return "end" return "continue" # ============ Graph Construction ============ def create_agent_graph() -> StateGraph: """Create the main agent workflow graph Returns: Compiled LangGraph workflow """ # Create graph with AgentState graph = StateGraph(AgentState) # Add nodes graph.add_node("receive", receive_message) graph.add_node("classify", classify_intent) graph.add_node("customer_service_agent", customer_service_agent) graph.add_node("order_agent", order_agent) graph.add_node("aftersale_agent", aftersale_agent) graph.add_node("product_agent", product_agent) graph.add_node("call_tools", call_mcp_tools) graph.add_node("human_handoff", human_handoff) graph.add_node("send_response", send_response) graph.add_node("handle_error", handle_error) # Set entry point graph.set_entry_point("receive") # Add edges graph.add_edge("receive", "classify") # Conditional routing based on intent graph.add_conditional_edges( "classify", route_by_intent, { "customer_service_agent": "customer_service_agent", "order_agent": "order_agent", "aftersale_agent": "aftersale_agent", "product_agent": "product_agent", "human_handoff": "human_handoff" } ) # After each agent, check if tools need to be called for agent_node in ["customer_service_agent", "order_agent", "aftersale_agent", "product_agent"]: graph.add_conditional_edges( agent_node, should_call_tools, { "call_tools": "call_tools", "send_response": "send_response", "back_to_agent": agent_node } ) # After tool execution, route back to appropriate agent graph.add_conditional_edges( "call_tools", after_tools, { "customer_service_agent": "customer_service_agent", "order_agent": "order_agent", "aftersale_agent": "aftersale_agent", "product_agent": "product_agent" } ) # Human handoff leads to send response graph.add_edge("human_handoff", "send_response") # Error handling graph.add_edge("handle_error", END) # Final node graph.add_edge("send_response", END) return graph.compile() # Global compiled graph _compiled_graph = None def get_agent_graph(): """Get or create the compiled agent graph""" global _compiled_graph if _compiled_graph is None: _compiled_graph = create_agent_graph() return _compiled_graph async def process_message( conversation_id: str, user_id: str, account_id: str, message: str, history: list[dict] = None, context: dict = None ) -> AgentState: """Process a user message through the agent workflow Args: conversation_id: Chatwoot conversation ID user_id: User identifier account_id: B2B account identifier message: User's message history: Previous conversation history context: Existing conversation context Returns: Final agent state with response """ from .state import create_initial_state # Create initial state initial_state = create_initial_state( conversation_id=conversation_id, user_id=user_id, account_id=account_id, current_message=message, messages=history, context=context ) # Get compiled graph graph = get_agent_graph() # Run the workflow logger.info( "Starting workflow", conversation_id=conversation_id, message=message[:100] ) try: final_state = await graph.ainvoke(initial_state) logger.info( "Workflow completed", conversation_id=conversation_id, intent=final_state.get("intent"), steps=final_state.get("step_count") ) return final_state except Exception as e: logger.error("Workflow failed", error=str(e)) initial_state["error"] = str(e) initial_state["response"] = "抱歉,处理您的请求时遇到了问题。请稍后重试。" initial_state["finished"] = True return initial_state