Files
assistant/agent/core/llm.py

170 lines
4.9 KiB
Python
Raw Normal View History

"""
ZhipuAI LLM Client for B2B Shopping AI Assistant
"""
import concurrent.futures
from typing import Any, Optional
from dataclasses import dataclass
from zhipuai import ZhipuAI
from config import settings
from utils.logger import get_logger
logger = get_logger(__name__)
@dataclass
class Message:
"""Chat message structure"""
role: str # "system", "user", "assistant"
content: str
@dataclass
class LLMResponse:
"""LLM response structure"""
content: str
finish_reason: str
usage: dict[str, int]
class ZhipuLLMClient:
"""ZhipuAI LLM Client wrapper"""
DEFAULT_TIMEOUT = 30 # seconds
def __init__(
self,
api_key: Optional[str] = None,
model: Optional[str] = None,
timeout: Optional[int] = None
):
self.api_key = api_key or settings.zhipu_api_key
self.model = model or settings.zhipu_model
self.timeout = timeout or self.DEFAULT_TIMEOUT
self._client = ZhipuAI(api_key=self.api_key)
logger.info("ZhipuAI client initialized", model=self.model, timeout=self.timeout)
async def chat(
self,
messages: list[Message],
temperature: float = 0.7,
max_tokens: int = 2048,
top_p: float = 0.9,
**kwargs: Any
) -> LLMResponse:
"""Send chat completion request"""
formatted_messages = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
logger.info(
"Sending chat request",
model=self.model,
message_count=len(messages),
temperature=temperature
)
def _make_request():
return self._client.chat.completions.create(
model=self.model,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
**kwargs
)
try:
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(_make_request)
response = future.result(timeout=self.timeout)
choice = response.choices[0]
content = choice.message.content
logger.info(
"Chat response received",
finish_reason=choice.finish_reason,
content_length=len(content) if content else 0,
usage=response.usage.__dict__ if hasattr(response, 'usage') else {}
)
if not content:
logger.warning("LLM returned empty content")
return LLMResponse(
content=content or "",
finish_reason=choice.finish_reason,
usage={
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens
}
)
except concurrent.futures.TimeoutError:
logger.error("Chat request timed out", timeout=self.timeout)
raise TimeoutError(f"Request timed out after {self.timeout} seconds")
except Exception as e:
logger.error("Chat request failed", error=str(e))
raise
async def chat_with_tools(
self,
messages: list[Message],
tools: list[dict[str, Any]],
temperature: float = 0.7,
**kwargs: Any
) -> tuple[LLMResponse, None]:
"""Send chat completion request with tool calling"""
formatted_messages = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
logger.info(
"Sending chat request with tools",
model=self.model,
tool_count=len(tools)
)
try:
response = self._client.chat.completions.create(
model=self.model,
messages=formatted_messages,
tools=tools,
temperature=temperature,
**kwargs
)
choice = response.choices[0]
content = choice.message.content or ""
return LLMResponse(
content=content,
finish_reason=choice.finish_reason,
usage={
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens
}
), None
except Exception as e:
logger.error("Chat with tools request failed", error=str(e))
raise
llm_client: Optional[ZhipuLLMClient] = None
def get_llm_client() -> ZhipuLLMClient:
"""Get or create global LLM client instance"""
global llm_client
if llm_client is None:
llm_client = ZhipuLLMClient()
return llm_client