196 lines
5.7 KiB
Python
196 lines
5.7 KiB
Python
|
|
"""
|
||
|
|
ZhipuAI LLM Client for B2B Shopping AI Assistant
|
||
|
|
"""
|
||
|
|
from typing import Any, AsyncGenerator, Optional
|
||
|
|
from dataclasses import dataclass
|
||
|
|
|
||
|
|
from zhipuai import ZhipuAI
|
||
|
|
|
||
|
|
from config import settings
|
||
|
|
from utils.logger import get_logger
|
||
|
|
|
||
|
|
logger = get_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class Message:
|
||
|
|
"""Chat message structure"""
|
||
|
|
role: str # "system", "user", "assistant"
|
||
|
|
content: str
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class LLMResponse:
|
||
|
|
"""LLM response structure"""
|
||
|
|
content: str
|
||
|
|
finish_reason: str
|
||
|
|
usage: dict[str, int]
|
||
|
|
|
||
|
|
|
||
|
|
class ZhipuLLMClient:
|
||
|
|
"""ZhipuAI LLM Client wrapper"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
api_key: Optional[str] = None,
|
||
|
|
model: Optional[str] = None
|
||
|
|
):
|
||
|
|
"""Initialize ZhipuAI client
|
||
|
|
|
||
|
|
Args:
|
||
|
|
api_key: ZhipuAI API key, defaults to settings
|
||
|
|
model: Model name, defaults to settings
|
||
|
|
"""
|
||
|
|
self.api_key = api_key or settings.zhipu_api_key
|
||
|
|
self.model = model or settings.zhipu_model
|
||
|
|
self._client = ZhipuAI(api_key=self.api_key)
|
||
|
|
logger.info("ZhipuAI client initialized", model=self.model)
|
||
|
|
|
||
|
|
async def chat(
|
||
|
|
self,
|
||
|
|
messages: list[Message],
|
||
|
|
temperature: float = 0.7,
|
||
|
|
max_tokens: int = 2048,
|
||
|
|
top_p: float = 0.9,
|
||
|
|
**kwargs: Any
|
||
|
|
) -> LLMResponse:
|
||
|
|
"""Send chat completion request
|
||
|
|
|
||
|
|
Args:
|
||
|
|
messages: List of chat messages
|
||
|
|
temperature: Sampling temperature
|
||
|
|
max_tokens: Maximum tokens to generate
|
||
|
|
top_p: Top-p sampling parameter
|
||
|
|
**kwargs: Additional parameters
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
LLM response with content and metadata
|
||
|
|
"""
|
||
|
|
formatted_messages = [
|
||
|
|
{"role": msg.role, "content": msg.content}
|
||
|
|
for msg in messages
|
||
|
|
]
|
||
|
|
|
||
|
|
logger.debug(
|
||
|
|
"Sending chat request",
|
||
|
|
model=self.model,
|
||
|
|
message_count=len(messages),
|
||
|
|
temperature=temperature
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
response = self._client.chat.completions.create(
|
||
|
|
model=self.model,
|
||
|
|
messages=formatted_messages,
|
||
|
|
temperature=temperature,
|
||
|
|
max_tokens=max_tokens,
|
||
|
|
top_p=top_p,
|
||
|
|
**kwargs
|
||
|
|
)
|
||
|
|
|
||
|
|
choice = response.choices[0]
|
||
|
|
result = LLMResponse(
|
||
|
|
content=choice.message.content,
|
||
|
|
finish_reason=choice.finish_reason,
|
||
|
|
usage={
|
||
|
|
"prompt_tokens": response.usage.prompt_tokens,
|
||
|
|
"completion_tokens": response.usage.completion_tokens,
|
||
|
|
"total_tokens": response.usage.total_tokens
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.debug(
|
||
|
|
"Chat response received",
|
||
|
|
finish_reason=result.finish_reason,
|
||
|
|
total_tokens=result.usage["total_tokens"]
|
||
|
|
)
|
||
|
|
|
||
|
|
return result
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Chat request failed", error=str(e))
|
||
|
|
raise
|
||
|
|
|
||
|
|
async def chat_with_tools(
|
||
|
|
self,
|
||
|
|
messages: list[Message],
|
||
|
|
tools: list[dict[str, Any]],
|
||
|
|
temperature: float = 0.7,
|
||
|
|
**kwargs: Any
|
||
|
|
) -> tuple[LLMResponse, Optional[list[dict[str, Any]]]]:
|
||
|
|
"""Send chat completion request with tool calling
|
||
|
|
|
||
|
|
Args:
|
||
|
|
messages: List of chat messages
|
||
|
|
tools: List of tool definitions
|
||
|
|
temperature: Sampling temperature
|
||
|
|
**kwargs: Additional parameters
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple of (LLM response, tool calls if any)
|
||
|
|
"""
|
||
|
|
formatted_messages = [
|
||
|
|
{"role": msg.role, "content": msg.content}
|
||
|
|
for msg in messages
|
||
|
|
]
|
||
|
|
|
||
|
|
logger.debug(
|
||
|
|
"Sending chat request with tools",
|
||
|
|
model=self.model,
|
||
|
|
tool_count=len(tools)
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
response = self._client.chat.completions.create(
|
||
|
|
model=self.model,
|
||
|
|
messages=formatted_messages,
|
||
|
|
tools=tools,
|
||
|
|
temperature=temperature,
|
||
|
|
**kwargs
|
||
|
|
)
|
||
|
|
|
||
|
|
choice = response.choices[0]
|
||
|
|
result = LLMResponse(
|
||
|
|
content=choice.message.content or "",
|
||
|
|
finish_reason=choice.finish_reason,
|
||
|
|
usage={
|
||
|
|
"prompt_tokens": response.usage.prompt_tokens,
|
||
|
|
"completion_tokens": response.usage.completion_tokens,
|
||
|
|
"total_tokens": response.usage.total_tokens
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
# Extract tool calls if present
|
||
|
|
tool_calls = None
|
||
|
|
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||
|
|
tool_calls = [
|
||
|
|
{
|
||
|
|
"id": tc.id,
|
||
|
|
"type": tc.type,
|
||
|
|
"function": {
|
||
|
|
"name": tc.function.name,
|
||
|
|
"arguments": tc.function.arguments
|
||
|
|
}
|
||
|
|
}
|
||
|
|
for tc in choice.message.tool_calls
|
||
|
|
]
|
||
|
|
logger.debug("Tool calls received", tool_count=len(tool_calls))
|
||
|
|
|
||
|
|
return result, tool_calls
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Chat with tools request failed", error=str(e))
|
||
|
|
raise
|
||
|
|
|
||
|
|
|
||
|
|
# Global LLM client instance
|
||
|
|
llm_client: Optional[ZhipuLLMClient] = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_llm_client() -> ZhipuLLMClient:
|
||
|
|
"""Get or create global LLM client instance"""
|
||
|
|
global llm_client
|
||
|
|
if llm_client is None:
|
||
|
|
llm_client = ZhipuLLMClient()
|
||
|
|
return llm_client
|