feat: 增强 Agent 系统和完善项目结构
主要改进: - Agent 增强: 订单查询、售后支持、客服路由等功能优化 - 新增语言检测和 Token 管理模块 - 改进 Chatwoot webhook 处理和用户标识 - MCP 服务器增强: 订单 MCP 和 Strapi MCP 功能扩展 - 新增商城客户端、知识库、缓存和同步模块 - 添加多语言提示词系统 (YAML) - 完善项目结构: 整理文档、脚本和测试文件 - 新增调试和测试工具脚本 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -13,3 +13,7 @@ python-dotenv>=1.0.0
|
||||
|
||||
# Logging
|
||||
structlog>=24.1.0
|
||||
|
||||
# Web Framework
|
||||
fastapi>=0.100.0
|
||||
uvicorn>=0.23.0
|
||||
|
||||
@@ -19,8 +19,17 @@ class Settings(BaseSettings):
|
||||
"""Server configuration"""
|
||||
hyperf_api_url: str
|
||||
hyperf_api_token: str
|
||||
|
||||
# Mall API 配置
|
||||
mall_api_url: str = "https://apicn.qa1.gaia888.com"
|
||||
mall_api_token: str = ""
|
||||
mall_tenant_id: str = "2"
|
||||
mall_currency_code: str = "EUR"
|
||||
mall_language_id: str = "1"
|
||||
mall_source: str = "us.qa1.gaia888.com"
|
||||
|
||||
log_level: str = "INFO"
|
||||
|
||||
|
||||
model_config = ConfigDict(env_file=".env")
|
||||
|
||||
|
||||
@@ -31,12 +40,35 @@ mcp = FastMCP(
|
||||
"Order Management"
|
||||
)
|
||||
|
||||
# Tool registry for HTTP access
|
||||
_tools = {}
|
||||
|
||||
|
||||
# Hyperf client for this server
|
||||
from shared.hyperf_client import HyperfClient
|
||||
hyperf = HyperfClient(settings.hyperf_api_url, settings.hyperf_api_token)
|
||||
|
||||
# Mall API client
|
||||
from shared.mall_client import MallClient
|
||||
mall = MallClient(
|
||||
api_url=getattr(settings, 'mall_api_url', 'https://apicn.qa1.gaia888.com'),
|
||||
api_token=getattr(settings, 'mall_api_token', ''),
|
||||
tenant_id=getattr(settings, 'mall_tenant_id', '2'),
|
||||
currency_code=getattr(settings, 'mall_currency_code', 'EUR'),
|
||||
language_id=getattr(settings, 'mall_language_id', '1'),
|
||||
source=getattr(settings, 'mall_source', 'us.qa1.gaia888.com')
|
||||
)
|
||||
|
||||
|
||||
def register_tool(name: str):
|
||||
"""Register a tool for HTTP access"""
|
||||
def decorator(func):
|
||||
_tools[name] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
@register_tool("query_order")
|
||||
@mcp.tool()
|
||||
async def query_order(
|
||||
user_id: str,
|
||||
@@ -96,6 +128,7 @@ async def query_order(
|
||||
}
|
||||
|
||||
|
||||
@register_tool("track_logistics")
|
||||
@mcp.tool()
|
||||
async def track_logistics(
|
||||
order_id: str,
|
||||
@@ -134,6 +167,7 @@ async def track_logistics(
|
||||
}
|
||||
|
||||
|
||||
@register_tool("modify_order")
|
||||
@mcp.tool()
|
||||
async def modify_order(
|
||||
order_id: str,
|
||||
@@ -177,6 +211,7 @@ async def modify_order(
|
||||
}
|
||||
|
||||
|
||||
@register_tool("cancel_order")
|
||||
@mcp.tool()
|
||||
async def cancel_order(
|
||||
order_id: str,
|
||||
@@ -217,17 +252,18 @@ async def cancel_order(
|
||||
}
|
||||
|
||||
|
||||
@register_tool("get_invoice")
|
||||
@mcp.tool()
|
||||
async def get_invoice(
|
||||
order_id: str,
|
||||
invoice_type: str = "normal"
|
||||
) -> dict:
|
||||
"""Get invoice for an order
|
||||
|
||||
|
||||
Args:
|
||||
order_id: Order ID
|
||||
invoice_type: Invoice type ('normal' for regular invoice, 'vat' for VAT invoice)
|
||||
|
||||
|
||||
Returns:
|
||||
Invoice information and download URL
|
||||
"""
|
||||
@@ -236,7 +272,7 @@ async def get_invoice(
|
||||
f"/orders/{order_id}/invoice",
|
||||
params={"type": invoice_type}
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"order_id": order_id,
|
||||
@@ -255,7 +291,64 @@ async def get_invoice(
|
||||
}
|
||||
|
||||
|
||||
@register_tool("get_mall_order")
|
||||
@mcp.tool()
|
||||
async def get_mall_order(
|
||||
order_id: str,
|
||||
user_token: str = None,
|
||||
user_id: str = None,
|
||||
account_id: str = None
|
||||
) -> dict:
|
||||
"""Query order from Mall API by order ID
|
||||
|
||||
从商城 API 查询订单详情
|
||||
|
||||
Args:
|
||||
order_id: 订单号 (e.g., "202071324")
|
||||
user_token: 用户 JWT token(可选,如果提供则使用该 token 进行查询)
|
||||
user_id: 用户 ID(自动注入,此工具不使用)
|
||||
account_id: 账户 ID(自动注入,此工具不使用)
|
||||
|
||||
Returns:
|
||||
订单详情,包含订单号、状态、商品信息、金额、物流信息等
|
||||
Order details including order ID, status, items, amount, logistics info, etc.
|
||||
"""
|
||||
try:
|
||||
# 如果提供了 user_token,使用用户自己的 token
|
||||
if user_token:
|
||||
client = MallClient(
|
||||
api_url=settings.mall_api_url,
|
||||
api_token=user_token,
|
||||
tenant_id=settings.mall_tenant_id,
|
||||
currency_code=settings.mall_currency_code,
|
||||
language_id=settings.mall_language_id,
|
||||
source=settings.mall_source
|
||||
)
|
||||
else:
|
||||
# 否则使用默认的 mall 实例
|
||||
client = mall
|
||||
|
||||
result = await client.get_order_by_id(order_id)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"order": result,
|
||||
"order_id": order_id
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"order_id": order_id
|
||||
}
|
||||
finally:
|
||||
# 如果创建了临时客户端,关闭它
|
||||
if user_token:
|
||||
await client.close()
|
||||
|
||||
|
||||
# Health check endpoint
|
||||
@register_tool("health_check")
|
||||
@mcp.tool()
|
||||
async def health_check() -> dict:
|
||||
"""Check server health status"""
|
||||
@@ -268,17 +361,75 @@ async def health_check() -> dict:
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Create FastAPI app from MCP
|
||||
app = mcp.http_app()
|
||||
|
||||
# Add health endpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
# Health check endpoint
|
||||
async def health_check(request):
|
||||
return JSONResponse({"status": "healthy"})
|
||||
|
||||
# Add the route to the app
|
||||
from starlette.routing import Route
|
||||
app.router.routes.append(Route('/health', health_check, methods=['GET']))
|
||||
|
||||
|
||||
# Tool execution endpoint
|
||||
async def execute_tool(request: Request):
|
||||
"""Execute an MCP tool via HTTP"""
|
||||
tool_name = request.path_params["tool_name"]
|
||||
|
||||
try:
|
||||
# Get arguments from request body
|
||||
arguments = await request.json()
|
||||
|
||||
# Get tool function from registry
|
||||
if tool_name not in _tools:
|
||||
return JSONResponse({
|
||||
"success": False,
|
||||
"error": f"Tool '{tool_name}' not found"
|
||||
}, status_code=404)
|
||||
|
||||
tool_obj = _tools[tool_name]
|
||||
|
||||
# Call the tool with arguments
|
||||
# FastMCP FunctionTool.run() takes a dict of arguments
|
||||
tool_result = await tool_obj.run(arguments)
|
||||
|
||||
# Extract content from ToolResult
|
||||
# ToolResult.content is a list of TextContent objects with a 'text' attribute
|
||||
if tool_result.content and len(tool_result.content) > 0:
|
||||
content = tool_result.content[0].text
|
||||
# Try to parse as JSON if possible
|
||||
try:
|
||||
import json
|
||||
result = json.loads(content)
|
||||
except:
|
||||
result = content
|
||||
else:
|
||||
result = None
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"result": result
|
||||
})
|
||||
except TypeError as e:
|
||||
return JSONResponse({
|
||||
"success": False,
|
||||
"error": f"Invalid arguments: {str(e)}"
|
||||
}, status_code=400)
|
||||
except Exception as e:
|
||||
return JSONResponse({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status_code=500)
|
||||
|
||||
# Create routes list
|
||||
routes = [
|
||||
Route('/health', health_check, methods=['GET']),
|
||||
Route('/tools/{tool_name}', execute_tool, methods=['POST'])
|
||||
]
|
||||
|
||||
# Create app from MCP with custom routes
|
||||
app = mcp.http_app()
|
||||
|
||||
# Add our custom routes to the existing app
|
||||
for route in routes:
|
||||
app.router.routes.append(route)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8002)
|
||||
|
||||
180
mcp_servers/shared/mall_client.py
Normal file
180
mcp_servers/shared/mall_client.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Mall API Client for MCP Servers
|
||||
用于调用商城 API,包括订单查询等接口
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
import httpx
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class MallSettings(BaseSettings):
|
||||
"""Mall API configuration"""
|
||||
mall_api_url: Optional[str] = None
|
||||
mall_api_token: Optional[str] = None
|
||||
mall_tenant_id: str = "2"
|
||||
mall_currency_code: str = "EUR"
|
||||
mall_language_id: str = "1"
|
||||
mall_source: str = "us.qa1.gaia888.com"
|
||||
|
||||
model_config = ConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore"
|
||||
)
|
||||
|
||||
|
||||
settings = MallSettings()
|
||||
|
||||
|
||||
class MallClient:
|
||||
"""Async client for Mall API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_url: Optional[str] = None,
|
||||
api_token: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
currency_code: Optional[str] = None,
|
||||
language_id: Optional[str] = None,
|
||||
source: Optional[str] = None
|
||||
):
|
||||
self.api_url = (api_url or settings.mall_api_url or "").rstrip("/")
|
||||
self.api_token = api_token or settings.mall_api_token or ""
|
||||
self.tenant_id = tenant_id or settings.mall_tenant_id
|
||||
self.currency_code = currency_code or settings.mall_currency_code
|
||||
self.language_id = language_id or settings.mall_language_id
|
||||
self.source = source or settings.mall_source
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client with default headers"""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.api_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Device-Type": "pc",
|
||||
"tenant-Id": self.tenant_id,
|
||||
"currency-code": self.currency_code,
|
||||
"language-id": self.language_id,
|
||||
"source": self.source,
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP client"""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
json: Optional[dict[str, Any]] = None,
|
||||
headers: Optional[dict[str, str]] = None
|
||||
) -> dict[str, Any]:
|
||||
"""Make API request and handle response
|
||||
|
||||
Args:
|
||||
method: HTTP method
|
||||
endpoint: API endpoint (e.g., "/mall/api/order/show")
|
||||
params: Query parameters
|
||||
json: JSON body
|
||||
headers: Additional headers
|
||||
|
||||
Returns:
|
||||
Response data
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
# Merge additional headers
|
||||
request_headers = {}
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=endpoint,
|
||||
params=params,
|
||||
json=json,
|
||||
headers=request_headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Mall API 返回格式: {"code": 200, "msg": "success", "result": {...}}
|
||||
# 检查 API 错误
|
||||
if data.get("code") != 200:
|
||||
raise Exception(f"API Error [{data.get('code')}]: {data.get('msg') or data.get('message')}")
|
||||
|
||||
# 返回 result 字段或整个 data
|
||||
return data.get("result", data)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
endpoint: str,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
"""GET request"""
|
||||
return await self.request("GET", endpoint, params=params, **kwargs)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
endpoint: str,
|
||||
json: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
"""POST request"""
|
||||
return await self.request("POST", endpoint, json=json, **kwargs)
|
||||
|
||||
# ============ Order APIs ============
|
||||
|
||||
async def get_order_by_id(
|
||||
self,
|
||||
order_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""Query order by order ID
|
||||
|
||||
根据订单号查询订单详情
|
||||
|
||||
Args:
|
||||
order_id: 订单号 (e.g., "202071324")
|
||||
|
||||
Returns:
|
||||
订单详情,包含订单号、状态、商品信息、金额、物流信息等
|
||||
Order details including order ID, status, items, amount, logistics info, etc.
|
||||
|
||||
Example:
|
||||
>>> client = MallClient()
|
||||
>>> order = await client.get_order_by_id("202071324")
|
||||
>>> print(order["order_id"])
|
||||
"""
|
||||
try:
|
||||
result = await self.get(
|
||||
"/mall/api/order/show",
|
||||
params={"orderId": order_id}
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise Exception(f"查询订单失败 (Query order failed): {str(e)}")
|
||||
|
||||
|
||||
# Global Mall client instance
|
||||
mall_client: Optional[MallClient] = None
|
||||
|
||||
|
||||
def get_mall_client() -> MallClient:
|
||||
"""Get or create global Mall client instance"""
|
||||
global mall_client
|
||||
if mall_client is None:
|
||||
mall_client = MallClient()
|
||||
return mall_client
|
||||
@@ -11,7 +11,9 @@ class StrapiSettings(BaseSettings):
|
||||
"""Strapi configuration"""
|
||||
strapi_api_url: str
|
||||
strapi_api_token: str
|
||||
|
||||
sync_on_startup: bool = True # Run initial sync on startup
|
||||
sync_interval_minutes: int = 60 # Sync interval in minutes
|
||||
|
||||
model_config = ConfigDict(env_file=".env")
|
||||
|
||||
|
||||
|
||||
161
mcp_servers/strapi_mcp/cache.py
Normal file
161
mcp_servers/strapi_mcp/cache.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Redis Cache for Strapi MCP Server
|
||||
"""
|
||||
import json
|
||||
import hashlib
|
||||
from typing import Any, Optional, Callable
|
||||
from redis import asyncio as aioredis
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class CacheSettings(BaseSettings):
|
||||
"""Cache configuration"""
|
||||
redis_host: str = "localhost"
|
||||
redis_port: int = 6379
|
||||
redis_password: Optional[str] = None
|
||||
redis_db: int = 1 # 使用不同的 DB 避免 key 冲突
|
||||
cache_ttl: int = 3600 # 默认缓存 1 小时
|
||||
|
||||
model_config = ConfigDict(env_file=".env")
|
||||
|
||||
|
||||
cache_settings = CacheSettings()
|
||||
|
||||
|
||||
class StrapiCache:
|
||||
"""Redis cache wrapper for Strapi responses"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
password: Optional[str] = None,
|
||||
db: Optional[int] = None,
|
||||
ttl: Optional[int] = None
|
||||
):
|
||||
self.host = host or cache_settings.redis_host
|
||||
self.port = port or cache_settings.redis_port
|
||||
self.password = password or cache_settings.redis_password
|
||||
self.db = db or cache_settings.redis_db
|
||||
self.ttl = ttl or cache_settings.cache_ttl
|
||||
self._redis: Optional[aioredis.Redis] = None
|
||||
|
||||
async def _get_redis(self) -> aioredis.Redis:
|
||||
"""Get or create Redis connection"""
|
||||
if self._redis is None:
|
||||
self._redis = aioredis.from_url(
|
||||
f"redis://{':' + self.password if self.password else ''}@{self.host}:{self.port}/{self.db}",
|
||||
encoding="utf-8",
|
||||
decode_responses=True
|
||||
)
|
||||
return self._redis
|
||||
|
||||
def _generate_key(self, category: str, locale: str, **kwargs) -> str:
|
||||
"""Generate cache key from parameters"""
|
||||
# 创建唯一 key
|
||||
key_parts = [category, locale]
|
||||
for k, v in sorted(kwargs.items()):
|
||||
key_parts.append(f"{k}:{v}")
|
||||
key_string = ":".join(key_parts)
|
||||
|
||||
# 使用 MD5 hash 缩短 key 长度
|
||||
return f"strapi:{hashlib.md5(key_string.encode()).hexdigest()}"
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
value = await redis.get(key)
|
||||
if value:
|
||||
return json.loads(value)
|
||||
except Exception:
|
||||
# Redis 不可用时降级,不影响业务
|
||||
pass
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
|
||||
"""Set value in cache"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
ttl = ttl or self.ttl
|
||||
await redis.setex(key, ttl, json.dumps(value, ensure_ascii=False))
|
||||
return True
|
||||
except Exception:
|
||||
# Redis 不可用时降级
|
||||
return False
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete value from cache"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
await redis.delete(key)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def clear_pattern(self, pattern: str) -> int:
|
||||
"""Clear all keys matching pattern"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
keys = await redis.keys(f"{pattern}*")
|
||||
if keys:
|
||||
await redis.delete(*keys)
|
||||
return len(keys)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
async def close(self):
|
||||
"""Close Redis connection"""
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
|
||||
|
||||
# 全局缓存实例
|
||||
cache = StrapiCache()
|
||||
|
||||
|
||||
async def cached_query(
|
||||
cache_key: str,
|
||||
query_func: Callable,
|
||||
ttl: Optional[int] = None
|
||||
) -> Any:
|
||||
"""Execute cached query
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
query_func: Async function to fetch data
|
||||
ttl: Cache TTL in seconds (overrides default)
|
||||
|
||||
Returns:
|
||||
Cached or fresh data
|
||||
"""
|
||||
# Try to get from cache
|
||||
cached_value = await cache.get(cache_key)
|
||||
if cached_value is not None:
|
||||
return cached_value
|
||||
|
||||
# Cache miss, execute query
|
||||
result = await query_func()
|
||||
|
||||
# Store in cache
|
||||
if result is not None:
|
||||
await cache.set(cache_key, result, ttl)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def clear_strapi_cache(pattern: Optional[str] = None) -> int:
|
||||
"""Clear Strapi cache
|
||||
|
||||
Args:
|
||||
pattern: Key pattern to clear (default: all strapi keys)
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
if pattern:
|
||||
return await cache.clear_pattern(f"strapi:{pattern}")
|
||||
else:
|
||||
return await cache.clear_pattern("strapi:")
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
HTTP Routes for Strapi MCP Server
|
||||
Provides direct HTTP access to knowledge base functions
|
||||
Provides direct HTTP access to knowledge base functions (with local cache)
|
||||
"""
|
||||
from typing import Optional, List
|
||||
import httpx
|
||||
@@ -11,6 +11,7 @@ from pydantic_settings import BaseSettings
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from config_loader import load_config, get_category_endpoint
|
||||
from knowledge_base import get_kb
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -18,6 +19,8 @@ class Settings(BaseSettings):
|
||||
strapi_api_url: str
|
||||
strapi_api_token: str = ""
|
||||
log_level: str = "INFO"
|
||||
sync_on_startup: bool = True # Run initial sync on startup
|
||||
sync_interval_minutes: int = 60 # Sync interval in minutes
|
||||
|
||||
model_config = ConfigDict(env_file=".env")
|
||||
|
||||
@@ -45,6 +48,16 @@ async def get_company_info_http(section: str = "contact", locale: str = "en"):
|
||||
locale: Language locale (default: en)
|
||||
Supported: en, nl, de, es, fr, it, tr
|
||||
"""
|
||||
# Try local knowledge base first
|
||||
kb = get_kb()
|
||||
try:
|
||||
local_result = kb.get_company_info(section, locale)
|
||||
if local_result["success"]:
|
||||
return local_result
|
||||
except Exception as e:
|
||||
print(f"Local KB error: {e}")
|
||||
|
||||
# Fallback to Strapi API
|
||||
try:
|
||||
# Map section names to API endpoints
|
||||
section_map = {
|
||||
@@ -96,6 +109,12 @@ async def get_company_info_http(section: str = "contact", locale: str = "en"):
|
||||
"content": profile.get("content")
|
||||
}
|
||||
|
||||
# Save to local cache for next time
|
||||
try:
|
||||
kb.save_company_info(section, locale, result_data)
|
||||
except Exception as e:
|
||||
print(f"Failed to save to local cache: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": result_data
|
||||
@@ -116,13 +135,23 @@ async def query_faq_http(
|
||||
locale: str = "en",
|
||||
limit: int = 10
|
||||
):
|
||||
"""Get FAQ by category - HTTP wrapper
|
||||
"""Get FAQ by category - HTTP wrapper (with local cache fallback)
|
||||
|
||||
Args:
|
||||
category: FAQ category (register, order, pre-order, payment, shipment, return, other)
|
||||
locale: Language locale (default: en)
|
||||
limit: Maximum results to return
|
||||
"""
|
||||
# Try local knowledge base first
|
||||
kb = get_kb()
|
||||
try:
|
||||
local_result = kb.query_faq(category, locale, limit)
|
||||
if local_result["count"] > 0:
|
||||
return local_result
|
||||
except Exception as e:
|
||||
print(f"Local KB error: {e}")
|
||||
|
||||
# Fallback to Strapi API (if local cache is empty)
|
||||
try:
|
||||
# 从配置文件获取端点
|
||||
if strapi_config:
|
||||
@@ -151,7 +180,8 @@ async def query_faq_http(
|
||||
"count": 0,
|
||||
"category": category,
|
||||
"locale": locale,
|
||||
"results": []
|
||||
"results": [],
|
||||
"_source": "strapi_api"
|
||||
}
|
||||
|
||||
# Handle different response formats
|
||||
@@ -178,7 +208,7 @@ async def query_faq_http(
|
||||
elif isinstance(item_data, list):
|
||||
faq_list = item_data
|
||||
|
||||
# Format results
|
||||
# Format results and save to local cache
|
||||
results = []
|
||||
for item in faq_list[:limit]:
|
||||
faq_item = {
|
||||
@@ -209,12 +239,19 @@ async def query_faq_http(
|
||||
if "question" in faq_item or "answer" in faq_item:
|
||||
results.append(faq_item)
|
||||
|
||||
# Save to local cache for next time
|
||||
try:
|
||||
kb.save_faq_batch(faq_list, category, locale)
|
||||
except Exception as e:
|
||||
print(f"Failed to save to local cache: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"count": len(results),
|
||||
"category": category,
|
||||
"locale": locale,
|
||||
"results": results
|
||||
"results": results,
|
||||
"_source": "strapi_api"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -222,7 +259,8 @@ async def query_faq_http(
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"category": category,
|
||||
"results": []
|
||||
"results": [],
|
||||
"_source": "error"
|
||||
}
|
||||
|
||||
|
||||
@@ -360,7 +398,16 @@ async def search_knowledge_base_http(query: str, locale: str = "en", limit: int
|
||||
locale: Language locale
|
||||
limit: Maximum results
|
||||
"""
|
||||
# Search FAQ across all categories
|
||||
# Try local knowledge base first using FTS
|
||||
kb = get_kb()
|
||||
try:
|
||||
local_result = kb.search_faq(query, locale, limit)
|
||||
if local_result["count"] > 0:
|
||||
return local_result
|
||||
except Exception as e:
|
||||
print(f"Local KB search error: {e}")
|
||||
|
||||
# Fallback to searching FAQ across all categories via Strapi API
|
||||
return await search_faq_http(query, locale, limit)
|
||||
|
||||
|
||||
@@ -371,6 +418,16 @@ async def get_policy_http(policy_type: str, locale: str = "en"):
|
||||
policy_type: Type of policy (return_policy, privacy_policy, etc.)
|
||||
locale: Language locale
|
||||
"""
|
||||
# Try local knowledge base first
|
||||
kb = get_kb()
|
||||
try:
|
||||
local_result = kb.get_policy(policy_type, locale)
|
||||
if local_result["success"]:
|
||||
return local_result
|
||||
except Exception as e:
|
||||
print(f"Local KB error: {e}")
|
||||
|
||||
# Fallback to Strapi API
|
||||
try:
|
||||
# Map policy types to endpoints
|
||||
policy_map = {
|
||||
@@ -404,6 +461,21 @@ async def get_policy_http(policy_type: str, locale: str = "en"):
|
||||
}
|
||||
|
||||
item = data["data"]
|
||||
|
||||
policy_data = {
|
||||
"title": item.get("title"),
|
||||
"summary": item.get("summary"),
|
||||
"content": item.get("content"),
|
||||
"version": item.get("version"),
|
||||
"effective_date": item.get("effective_date")
|
||||
}
|
||||
|
||||
# Save to local cache for next time
|
||||
try:
|
||||
kb.save_policy(policy_type, locale, policy_data)
|
||||
except Exception as e:
|
||||
print(f"Failed to save to local cache: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
|
||||
418
mcp_servers/strapi_mcp/knowledge_base.py
Normal file
418
mcp_servers/strapi_mcp/knowledge_base.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
Local Knowledge Base using SQLite
|
||||
|
||||
Stores FAQ, company info, and policies locally for fast access.
|
||||
Syncs with Strapi CMS periodically.
|
||||
"""
|
||||
import sqlite3
|
||||
import json
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
import httpx
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class KnowledgeBaseSettings(BaseSettings):
|
||||
"""Knowledge base configuration"""
|
||||
strapi_api_url: str
|
||||
strapi_api_token: str = ""
|
||||
db_path: str = "/data/faq.db"
|
||||
sync_interval: int = 3600 # Sync every hour
|
||||
sync_on_startup: bool = True # Run initial sync on startup
|
||||
sync_interval_minutes: int = 60 # Sync interval in minutes
|
||||
|
||||
model_config = ConfigDict(env_file=".env")
|
||||
|
||||
|
||||
settings = KnowledgeBaseSettings()
|
||||
|
||||
|
||||
class LocalKnowledgeBase:
|
||||
"""Local SQLite knowledge base"""
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
self.db_path = db_path or settings.db_path
|
||||
self._conn: Optional[sqlite3.Connection] = None
|
||||
|
||||
def _get_conn(self) -> sqlite3.Connection:
|
||||
"""Get database connection"""
|
||||
if self._conn is None:
|
||||
# Ensure directory exists
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._init_db()
|
||||
return self._conn
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database schema"""
|
||||
conn = self._get_conn()
|
||||
|
||||
# Create tables
|
||||
conn.executescript("""
|
||||
-- FAQ table
|
||||
CREATE TABLE IF NOT EXISTS faq (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
strapi_id TEXT,
|
||||
category TEXT NOT NULL,
|
||||
locale TEXT NOT NULL,
|
||||
question TEXT,
|
||||
answer TEXT,
|
||||
description TEXT,
|
||||
extra_data TEXT,
|
||||
synced_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(category, locale, strapi_id)
|
||||
);
|
||||
|
||||
-- Create indexes for FAQ
|
||||
CREATE INDEX IF NOT EXISTS idx_faq_category ON faq(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_faq_locale ON faq(locale);
|
||||
CREATE INDEX IF NOT EXISTS idx_faq_search ON faq(question, answer);
|
||||
|
||||
-- Full-text search
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS fts_faq USING fts5(
|
||||
question, answer, category, locale, content='faq'
|
||||
);
|
||||
|
||||
-- Trigger to update FTS
|
||||
CREATE TRIGGER IF NOT EXISTS fts_faq_insert AFTER INSERT ON faq BEGIN
|
||||
INSERT INTO fts_faq(rowid, question, answer, category, locale)
|
||||
VALUES (new.rowid, new.question, new.answer, new.category, new.locale);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS fts_faq_delete AFTER DELETE ON faq BEGIN
|
||||
INSERT INTO fts_faq(fts_faq, rowid, question, answer, category, locale)
|
||||
VALUES ('delete', old.rowid, old.question, old.answer, old.category, old.locale);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS fts_faq_update AFTER UPDATE ON faq BEGIN
|
||||
INSERT INTO fts_faq(fts_faq, rowid, question, answer, category, locale)
|
||||
VALUES ('delete', old.rowid, old.question, old.answer, old.category, old.locale);
|
||||
INSERT INTO fts_faq(rowid, question, answer, category, locale)
|
||||
VALUES (new.rowid, new.question, new.answer, new.category, new.locale);
|
||||
END;
|
||||
|
||||
-- Company info table
|
||||
CREATE TABLE IF NOT EXISTS company_info (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
section TEXT NOT NULL UNIQUE,
|
||||
locale TEXT NOT NULL,
|
||||
title TEXT,
|
||||
description TEXT,
|
||||
content TEXT,
|
||||
extra_data TEXT,
|
||||
synced_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(section, locale)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_company_section ON company_info(section);
|
||||
CREATE INDEX IF NOT EXISTS idx_company_locale ON company_info(locale);
|
||||
|
||||
-- Policy table
|
||||
CREATE TABLE IF NOT EXISTS policy (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
type TEXT NOT NULL,
|
||||
locale TEXT NOT NULL,
|
||||
title TEXT,
|
||||
summary TEXT,
|
||||
content TEXT,
|
||||
version TEXT,
|
||||
effective_date TEXT,
|
||||
synced_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(type, locale)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_policy_type ON policy(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_policy_locale ON policy(locale);
|
||||
|
||||
-- Sync status table
|
||||
CREATE TABLE IF NOT EXISTS sync_status (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
data_type TEXT NOT NULL,
|
||||
last_sync_at TIMESTAMP,
|
||||
status TEXT,
|
||||
error_message TEXT,
|
||||
items_count INTEGER
|
||||
);
|
||||
""")
|
||||
|
||||
# ============ FAQ Operations ============
|
||||
|
||||
def query_faq(
|
||||
self,
|
||||
category: str,
|
||||
locale: str,
|
||||
limit: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""Query FAQ from local database"""
|
||||
conn = self._get_conn()
|
||||
|
||||
# Query FAQ
|
||||
cursor = conn.execute(
|
||||
"""SELECT id, strapi_id, category, locale, question, answer, description, extra_data
|
||||
FROM faq
|
||||
WHERE category = ? AND locale = ?
|
||||
LIMIT ?""",
|
||||
(category, locale, limit)
|
||||
)
|
||||
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
item = {
|
||||
"id": row["strapi_id"],
|
||||
"category": row["category"],
|
||||
"locale": row["locale"],
|
||||
"question": row["question"],
|
||||
"answer": row["answer"]
|
||||
}
|
||||
if row["description"]:
|
||||
item["description"] = row["description"]
|
||||
if row["extra_data"]:
|
||||
item.update(json.loads(row["extra_data"]))
|
||||
results.append(item)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"count": len(results),
|
||||
"category": category,
|
||||
"locale": locale,
|
||||
"results": results,
|
||||
"_source": "local_cache"
|
||||
}
|
||||
|
||||
def search_faq(
|
||||
self,
|
||||
query: str,
|
||||
locale: str = "en",
|
||||
limit: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""Full-text search FAQ"""
|
||||
conn = self._get_conn()
|
||||
|
||||
# Use FTS for search
|
||||
cursor = conn.execute(
|
||||
"""SELECT fts_faq.question, fts_faq.answer, faq.category, faq.locale
|
||||
FROM fts_faq
|
||||
JOIN faq ON fts_faq.rowid = faq.id
|
||||
WHERE fts_faq MATCH ? AND faq.locale = ?
|
||||
LIMIT ?""",
|
||||
(query, locale, limit)
|
||||
)
|
||||
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
results.append({
|
||||
"question": row["question"],
|
||||
"answer": row["answer"],
|
||||
"category": row["category"],
|
||||
"locale": row["locale"]
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"count": len(results),
|
||||
"query": query,
|
||||
"locale": locale,
|
||||
"results": results,
|
||||
"_source": "local_cache"
|
||||
}
|
||||
|
||||
def save_faq_batch(self, faq_list: List[Dict[str, Any]], category: str, locale: str):
|
||||
"""Save batch of FAQ to database"""
|
||||
conn = self._get_conn()
|
||||
|
||||
count = 0
|
||||
for item in faq_list:
|
||||
try:
|
||||
# Extract fields
|
||||
question = item.get("question") or item.get("title") or item.get("content", "")
|
||||
answer = item.get("answer") or item.get("content") or ""
|
||||
description = item.get("description") or ""
|
||||
strapi_id = item.get("id", "")
|
||||
|
||||
# Extra data as JSON
|
||||
extra_data = json.dumps({
|
||||
k: v for k, v in item.items()
|
||||
if k not in ["id", "question", "answer", "title", "content", "description"]
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Insert or replace
|
||||
conn.execute(
|
||||
"""INSERT OR REPLACE INTO faq
|
||||
(strapi_id, category, locale, question, answer, description, extra_data, synced_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(strapi_id, category, locale, question, answer, description, extra_data, datetime.now().isoformat())
|
||||
)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
print(f"Error saving FAQ: {e}")
|
||||
|
||||
conn.commit()
|
||||
return count
|
||||
|
||||
# ============ Company Info Operations ============
|
||||
|
||||
def get_company_info(self, section: str, locale: str = "en") -> Dict[str, Any]:
|
||||
"""Get company info from local database"""
|
||||
conn = self._get_conn()
|
||||
|
||||
cursor = conn.execute(
|
||||
"""SELECT section, locale, title, description, content, extra_data
|
||||
FROM company_info
|
||||
WHERE section = ? AND locale = ?""",
|
||||
(section, locale)
|
||||
)
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Section '{section}' not found",
|
||||
"data": None,
|
||||
"_source": "local_cache"
|
||||
}
|
||||
|
||||
result_data = {
|
||||
"section": row["section"],
|
||||
"locale": row["locale"],
|
||||
"title": row["title"],
|
||||
"description": row["description"],
|
||||
"content": row["content"]
|
||||
}
|
||||
|
||||
if row["extra_data"]:
|
||||
result_data.update(json.loads(row["extra_data"]))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": result_data,
|
||||
"_source": "local_cache"
|
||||
}
|
||||
|
||||
def save_company_info(self, section: str, locale: str, data: Dict[str, Any]):
|
||||
"""Save company info to database"""
|
||||
conn = self._get_conn()
|
||||
|
||||
title = data.get("title") or data.get("section_title") or ""
|
||||
description = data.get("description") or ""
|
||||
content = data.get("content") or ""
|
||||
|
||||
extra_data = json.dumps({
|
||||
k: v for k, v in data.items()
|
||||
if k not in ["section", "title", "description", "content"]
|
||||
}, ensure_ascii=False)
|
||||
|
||||
conn.execute(
|
||||
"""INSERT OR REPLACE INTO company_info
|
||||
(section, locale, title, description, content, extra_data, synced_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
||||
(section, locale, title, description, content, extra_data, datetime.now().isoformat())
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
# ============ Policy Operations ============
|
||||
|
||||
def get_policy(self, policy_type: str, locale: str = "en") -> Dict[str, Any]:
|
||||
"""Get policy from local database"""
|
||||
conn = self._get_conn()
|
||||
|
||||
cursor = conn.execute(
|
||||
"""SELECT type, locale, title, summary, content, version, effective_date
|
||||
FROM policy
|
||||
WHERE type = ? AND locale = ?""",
|
||||
(policy_type, locale)
|
||||
)
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Policy '{policy_type}' not found",
|
||||
"data": None,
|
||||
"_source": "local_cache"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"type": row["type"],
|
||||
"locale": row["locale"],
|
||||
"title": row["title"],
|
||||
"summary": row["summary"],
|
||||
"content": row["content"],
|
||||
"version": row["version"],
|
||||
"effective_date": row["effective_date"]
|
||||
},
|
||||
"_source": "local_cache"
|
||||
}
|
||||
|
||||
def save_policy(self, policy_type: str, locale: str, data: Dict[str, Any]):
|
||||
"""Save policy to database"""
|
||||
conn = self._get_conn()
|
||||
|
||||
title = data.get("title") or ""
|
||||
summary = data.get("summary") or ""
|
||||
content = data.get("content") or ""
|
||||
version = data.get("version") or ""
|
||||
effective_date = data.get("effective_date") or ""
|
||||
|
||||
conn.execute(
|
||||
"""INSERT OR REPLACE INTO policy
|
||||
(type, locale, title, summary, content, version, effective_date, synced_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(policy_type, locale, title, summary, content, version, effective_date, datetime.now().isoformat())
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
# ============ Sync Status ============
|
||||
|
||||
def update_sync_status(self, data_type: str, status: str, items_count: int = 0, error: Optional[str] = None):
|
||||
"""Update sync status"""
|
||||
conn = self._get_conn()
|
||||
|
||||
conn.execute(
|
||||
"""INSERT INTO sync_status (data_type, last_sync_at, status, items_count, error_message)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(data_type, datetime.now().isoformat(), status, items_count, error)
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
def get_sync_status(self, data_type: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get sync status"""
|
||||
conn = self._get_conn()
|
||||
|
||||
if data_type:
|
||||
cursor = conn.execute(
|
||||
"""SELECT * FROM sync_status WHERE data_type = ? ORDER BY last_sync_at DESC LIMIT 1""",
|
||||
(data_type,)
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute(
|
||||
"""SELECT * FROM sync_status ORDER BY last_sync_at DESC LIMIT 10"""
|
||||
)
|
||||
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def close(self):
|
||||
"""Close database connection"""
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
|
||||
# Global knowledge base instance
|
||||
kb = LocalKnowledgeBase()
|
||||
|
||||
|
||||
def get_kb() -> LocalKnowledgeBase:
|
||||
"""Get global knowledge base instance"""
|
||||
return kb
|
||||
@@ -20,3 +20,9 @@ structlog>=24.1.0
|
||||
|
||||
# Configuration
|
||||
pyyaml>=6.0
|
||||
|
||||
# Cache
|
||||
redis>=5.0.0
|
||||
|
||||
# Scheduler
|
||||
apscheduler>=3.10.0
|
||||
|
||||
@@ -3,7 +3,9 @@ Strapi MCP Server - FAQ and Knowledge Base
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
# Add shared module to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -13,6 +15,7 @@ from pydantic_settings import BaseSettings
|
||||
from fastapi import Request
|
||||
from starlette.responses import JSONResponse
|
||||
import uvicorn
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@@ -23,7 +26,9 @@ class Settings(BaseSettings):
|
||||
strapi_api_url: str
|
||||
strapi_api_token: str
|
||||
log_level: str = "INFO"
|
||||
|
||||
sync_interval_minutes: int = 60 # Sync every 60 minutes
|
||||
sync_on_startup: bool = True # Run initial sync on startup
|
||||
|
||||
model_config = ConfigDict(env_file=".env")
|
||||
|
||||
|
||||
@@ -196,6 +201,55 @@ async def health_check() -> dict:
|
||||
}
|
||||
|
||||
|
||||
# ============ Sync Scheduler ============
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
|
||||
async def run_scheduled_sync():
|
||||
"""Run scheduled sync from Strapi to local knowledge base"""
|
||||
try:
|
||||
from sync import StrapiSyncer
|
||||
from knowledge_base import get_kb
|
||||
|
||||
kb = get_kb()
|
||||
syncer = StrapiSyncer(kb)
|
||||
|
||||
print(f"[{datetime.now()}] Starting scheduled sync...")
|
||||
result = await syncer.sync_all()
|
||||
|
||||
if result["success"]:
|
||||
print(f"[{datetime.now()}] Sync completed successfully")
|
||||
else:
|
||||
print(f"[{datetime.now()}] Sync failed: {result.get('error', 'Unknown error')}")
|
||||
except Exception as e:
|
||||
print(f"[{datetime.now()}] Sync error: {e}")
|
||||
|
||||
|
||||
async def run_initial_sync():
|
||||
"""Run initial sync on startup if enabled"""
|
||||
if settings.sync_on_startup:
|
||||
print("Running initial sync on startup...")
|
||||
await run_scheduled_sync()
|
||||
print("Initial sync completed")
|
||||
|
||||
|
||||
def start_scheduler():
|
||||
"""Start the background sync scheduler"""
|
||||
if settings.sync_interval_minutes > 0:
|
||||
scheduler.add_job(
|
||||
run_scheduled_sync,
|
||||
'interval',
|
||||
minutes=settings.sync_interval_minutes,
|
||||
id='strapi_sync',
|
||||
replace_existing=True
|
||||
)
|
||||
scheduler.start()
|
||||
print(f"Sync scheduler started (interval: {settings.sync_interval_minutes} minutes)")
|
||||
else:
|
||||
print("Sync scheduler disabled (interval set to 0)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Create FastAPI app from MCP
|
||||
@@ -252,9 +306,23 @@ if __name__ == "__main__":
|
||||
|
||||
# Add routes using the correct method
|
||||
from fastapi import FastAPI
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
# Lifespan context manager for startup/shutdown events
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup: start scheduler and run initial sync
|
||||
start_scheduler()
|
||||
if settings.sync_on_startup:
|
||||
print("Running initial sync on startup...")
|
||||
await run_scheduled_sync()
|
||||
print("Initial sync completed")
|
||||
yield
|
||||
# Shutdown: stop scheduler
|
||||
scheduler.shutdown()
|
||||
|
||||
# Create a wrapper FastAPI app with custom routes first
|
||||
app = FastAPI()
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Add custom routes BEFORE mounting mcp_app
|
||||
app.add_route("/health", health_check, methods=["GET"])
|
||||
|
||||
252
mcp_servers/strapi_mcp/sync.py
Normal file
252
mcp_servers/strapi_mcp/sync.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
Strapi to Local Knowledge Base Sync Script
|
||||
|
||||
Periodically syncs FAQ, company info, and policies from Strapi CMS to local SQLite database.
|
||||
"""
|
||||
import asyncio
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
from knowledge_base import LocalKnowledgeBase, settings
|
||||
from config_loader import load_config, get_category_endpoint
|
||||
|
||||
|
||||
class StrapiSyncer:
|
||||
"""Sync data from Strapi to local knowledge base"""
|
||||
|
||||
def __init__(self, kb: LocalKnowledgeBase):
|
||||
self.kb = kb
|
||||
self.api_url = settings.strapi_api_url.rstrip("/")
|
||||
self.api_token = settings.strapi_api_token
|
||||
|
||||
async def sync_all(self) -> Dict[str, Any]:
|
||||
"""Sync all data from Strapi"""
|
||||
results = {
|
||||
"success": True,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"details": {}
|
||||
}
|
||||
|
||||
try:
|
||||
# Load config
|
||||
try:
|
||||
config = load_config()
|
||||
except:
|
||||
config = None
|
||||
|
||||
# Sync FAQ categories
|
||||
categories = ["register", "order", "pre-order", "payment", "shipment", "return", "other"]
|
||||
if config:
|
||||
categories = list(config.faq_categories.keys())
|
||||
|
||||
faq_total = 0
|
||||
for category in categories:
|
||||
count = await self.sync_faq_category(category, config)
|
||||
faq_total += count
|
||||
results["details"][f"faq_{category}"] = count
|
||||
|
||||
results["details"]["faq_total"] = faq_total
|
||||
|
||||
# Sync company info
|
||||
company_sections = ["contact", "about", "service"]
|
||||
for section in company_sections:
|
||||
await self.sync_company_info(section)
|
||||
results["details"]["company_info"] = len(company_sections)
|
||||
|
||||
# Sync policies
|
||||
policy_types = ["return_policy", "privacy_policy", "terms_of_service", "shipping_policy", "payment_policy"]
|
||||
for policy_type in policy_types:
|
||||
await self.sync_policy(policy_type)
|
||||
results["details"]["policies"] = len(policy_types)
|
||||
|
||||
# Update sync status
|
||||
self.kb.update_sync_status("all", "success", faq_total)
|
||||
|
||||
print(f"✅ Sync completed: {faq_total} FAQs, {len(company_sections)} company sections, {len(policy_types)} policies")
|
||||
|
||||
except Exception as e:
|
||||
results["success"] = False
|
||||
results["error"] = str(e)
|
||||
self.kb.update_sync_status("all", "error", 0, str(e))
|
||||
print(f"❌ Sync failed: {e}")
|
||||
|
||||
return results
|
||||
|
||||
async def sync_faq_category(self, category: str, config=None) -> int:
|
||||
"""Sync FAQ category from Strapi"""
|
||||
try:
|
||||
# Get endpoint from config
|
||||
if config:
|
||||
endpoint = get_category_endpoint(category, config)
|
||||
else:
|
||||
endpoint = f"faq-{category}"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_token:
|
||||
headers["Authorization"] = f"Bearer {self.api_token}"
|
||||
|
||||
# Fetch from Strapi
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.api_url}/api/{endpoint}",
|
||||
params={"populate": "deep"},
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Extract FAQ items
|
||||
faq_list = []
|
||||
item_data = data.get("data", {})
|
||||
|
||||
if isinstance(item_data, dict):
|
||||
if item_data.get("content"):
|
||||
faq_list = item_data["content"]
|
||||
elif item_data.get("faqs"):
|
||||
faq_list = item_data["faqs"]
|
||||
elif item_data.get("questions"):
|
||||
faq_list = item_data["questions"]
|
||||
elif isinstance(item_data, list):
|
||||
faq_list = item_data
|
||||
|
||||
# Save to local database
|
||||
count = self.kb.save_faq_batch(faq_list, category, "en")
|
||||
|
||||
# Also sync other locales if available
|
||||
locales = ["nl", "de", "es", "fr", "it", "tr"]
|
||||
for locale in locales:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.api_url}/api/{endpoint}",
|
||||
params={"populate": "deep", "locale": locale},
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Extract and save
|
||||
faq_list_locale = []
|
||||
item_data_locale = data.get("data", {})
|
||||
|
||||
if isinstance(item_data_locale, dict):
|
||||
if item_data_locale.get("content"):
|
||||
faq_list_locale = item_data_locale["content"]
|
||||
elif item_data_locale.get("faqs"):
|
||||
faq_list_locale = item_data_locale["faqs"]
|
||||
|
||||
if faq_list_locale:
|
||||
self.kb.save_faq_batch(faq_list_locale, category, locale)
|
||||
count += len(faq_list_locale)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to sync {category} for locale {locale}: {e}")
|
||||
|
||||
print(f" ✓ Synced {count} FAQs for category '{category}'")
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to sync category '{category}': {e}")
|
||||
return 0
|
||||
|
||||
async def sync_company_info(self, section: str):
|
||||
"""Sync company info from Strapi"""
|
||||
try:
|
||||
section_map = {
|
||||
"contact": "info-contact",
|
||||
"about": "info-about",
|
||||
"service": "info-service",
|
||||
}
|
||||
|
||||
endpoint = section_map.get(section, f"info-{section}")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_token:
|
||||
headers["Authorization"] = f"Bearer {self.api_token}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.api_url}/api/{endpoint}",
|
||||
params={"populate": "deep"},
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
item = data.get("data", {})
|
||||
if item:
|
||||
# Extract data
|
||||
company_data = {
|
||||
"section": section,
|
||||
"title": item.get("title"),
|
||||
"description": item.get("description"),
|
||||
"content": item.get("content")
|
||||
}
|
||||
|
||||
# Handle profile info
|
||||
if item.get("yehwang_profile"):
|
||||
profile = item["yehwang_profile"]
|
||||
company_data["profile"] = {
|
||||
"title": profile.get("title"),
|
||||
"content": profile.get("content")
|
||||
}
|
||||
|
||||
self.kb.save_company_info(section, "en", company_data)
|
||||
print(f" ✓ Synced company info '{section}'")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to sync company info '{section}': {e}")
|
||||
|
||||
async def sync_policy(self, policy_type: str):
|
||||
"""Sync policy from Strapi"""
|
||||
try:
|
||||
policy_map = {
|
||||
"return_policy": "policy-return",
|
||||
"privacy_policy": "policy-privacy",
|
||||
"terms_of_service": "policy-terms",
|
||||
"shipping_policy": "policy-shipping",
|
||||
"payment_policy": "policy-payment",
|
||||
}
|
||||
|
||||
endpoint = policy_map.get(policy_type, f"policy-{policy_type}")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_token:
|
||||
headers["Authorization"] = f"Bearer {self.api_token}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.api_url}/api/{endpoint}",
|
||||
params={"populate": "deep"},
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
item = data.get("data", {})
|
||||
if item:
|
||||
policy_data = {
|
||||
"title": item.get("title"),
|
||||
"summary": item.get("summary"),
|
||||
"content": item.get("content"),
|
||||
"version": item.get("version"),
|
||||
"effective_date": item.get("effective_date")
|
||||
}
|
||||
|
||||
self.kb.save_policy(policy_type, "en", policy_data)
|
||||
print(f" ✓ Synced policy '{policy_type}'")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to sync policy '{policy_type}': {e}")
|
||||
|
||||
|
||||
async def run_sync(kb: LocalKnowledgeBase):
|
||||
"""Run sync process"""
|
||||
syncer = StrapiSyncer(kb)
|
||||
await syncer.sync_all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run sync
|
||||
kb_instance = LocalKnowledgeBase()
|
||||
asyncio.run(run_sync(kb_instance))
|
||||
Reference in New Issue
Block a user