195 lines
5.1 KiB
Python
195 lines
5.1 KiB
Python
|
|
"""
|
||
|
|
LLM Response Cache for FAQ and common queries
|
||
|
|
"""
|
||
|
|
import hashlib
|
||
|
|
import json
|
||
|
|
from typing import Any, Optional
|
||
|
|
from datetime import timedelta
|
||
|
|
|
||
|
|
from .cache import CacheManager
|
||
|
|
from .logger import get_logger
|
||
|
|
|
||
|
|
logger = get_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class ResponseCache:
|
||
|
|
"""Cache LLM responses for common queries"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
cache_manager: Optional[CacheManager] = None,
|
||
|
|
default_ttl: int = 3600 # 1 hour default
|
||
|
|
):
|
||
|
|
"""Initialize response cache
|
||
|
|
|
||
|
|
Args:
|
||
|
|
cache_manager: Cache manager instance
|
||
|
|
default_ttl: Default TTL in seconds for cached responses
|
||
|
|
"""
|
||
|
|
self.cache = cache_manager
|
||
|
|
self.default_ttl = default_ttl
|
||
|
|
|
||
|
|
def _generate_key(
|
||
|
|
self,
|
||
|
|
model: str,
|
||
|
|
messages: list[dict[str, str]],
|
||
|
|
temperature: float = 0.7,
|
||
|
|
**kwargs: Any
|
||
|
|
) -> str:
|
||
|
|
"""Generate cache key from request parameters
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model: Model name
|
||
|
|
messages: List of messages
|
||
|
|
temperature: Temperature parameter
|
||
|
|
**kwargs: Additional parameters
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Cache key string
|
||
|
|
"""
|
||
|
|
# Create a normalized representation of the request
|
||
|
|
cache_input = {
|
||
|
|
"model": model,
|
||
|
|
"messages": messages,
|
||
|
|
"temperature": temperature,
|
||
|
|
**{k: v for k, v in kwargs.items() if v is not None}
|
||
|
|
}
|
||
|
|
|
||
|
|
# Hash the input to create a short, unique key
|
||
|
|
cache_str = json.dumps(cache_input, sort_keys=True, ensure_ascii=False)
|
||
|
|
cache_hash = hashlib.sha256(cache_str.encode()).hexdigest()[:16]
|
||
|
|
|
||
|
|
return f"llm_response:{model}:{cache_hash}"
|
||
|
|
|
||
|
|
async def get(
|
||
|
|
self,
|
||
|
|
model: str,
|
||
|
|
messages: list[dict[str, str]],
|
||
|
|
temperature: float = 0.7,
|
||
|
|
**kwargs: Any
|
||
|
|
) -> Optional[str]:
|
||
|
|
"""Get cached response if available
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model: Model name
|
||
|
|
messages: List of messages
|
||
|
|
temperature: Temperature parameter
|
||
|
|
**kwargs: Additional parameters
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Cached response content or None
|
||
|
|
"""
|
||
|
|
if not self.cache:
|
||
|
|
return None
|
||
|
|
|
||
|
|
key = self._generate_key(model, messages, temperature, **kwargs)
|
||
|
|
cached = await self.cache.get(key)
|
||
|
|
|
||
|
|
if cached:
|
||
|
|
logger.info(
|
||
|
|
"Cache hit",
|
||
|
|
model=model,
|
||
|
|
key=key,
|
||
|
|
response_length=len(cached)
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
data = json.loads(cached)
|
||
|
|
return data.get("response")
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
logger.warning("Invalid cached data", key=key)
|
||
|
|
return None
|
||
|
|
|
||
|
|
logger.debug("Cache miss", model=model, key=key)
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def set(
|
||
|
|
self,
|
||
|
|
model: str,
|
||
|
|
messages: list[dict[str, str]],
|
||
|
|
response: str,
|
||
|
|
temperature: float = 0.7,
|
||
|
|
ttl: Optional[int] = None,
|
||
|
|
**kwargs: Any
|
||
|
|
) -> None:
|
||
|
|
"""Cache LLM response
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model: Model name
|
||
|
|
messages: List of messages
|
||
|
|
response: Response content to cache
|
||
|
|
temperature: Temperature parameter
|
||
|
|
ttl: Time-to-live in seconds
|
||
|
|
**kwargs: Additional parameters
|
||
|
|
"""
|
||
|
|
if not self.cache:
|
||
|
|
return
|
||
|
|
|
||
|
|
key = self._generate_key(model, messages, temperature, **kwargs)
|
||
|
|
ttl = ttl or self.default_ttl
|
||
|
|
|
||
|
|
# Store response with metadata
|
||
|
|
data = {
|
||
|
|
"response": response,
|
||
|
|
"model": model,
|
||
|
|
"response_length": len(response),
|
||
|
|
"temperature": temperature
|
||
|
|
}
|
||
|
|
|
||
|
|
await self.cache.set(
|
||
|
|
key,
|
||
|
|
json.dumps(data, ensure_ascii=False),
|
||
|
|
ttl=ttl
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
"Response cached",
|
||
|
|
model=model,
|
||
|
|
key=key,
|
||
|
|
response_length=len(response),
|
||
|
|
ttl=ttl
|
||
|
|
)
|
||
|
|
|
||
|
|
async def invalidate(self, pattern: str = "llm_response:*") -> int:
|
||
|
|
"""Invalidate cached responses matching pattern
|
||
|
|
|
||
|
|
Args:
|
||
|
|
pattern: Redis key pattern to match
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Number of keys deleted
|
||
|
|
"""
|
||
|
|
if not self.cache:
|
||
|
|
return 0
|
||
|
|
|
||
|
|
# This would need scan/delete operation
|
||
|
|
# For now, just log
|
||
|
|
logger.info("Cache invalidation requested", pattern=pattern)
|
||
|
|
return 0
|
||
|
|
|
||
|
|
def get_cache_stats(self) -> dict[str, Any]:
|
||
|
|
"""Get cache statistics
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary with cache stats
|
||
|
|
"""
|
||
|
|
return {
|
||
|
|
"enabled": self.cache is not None,
|
||
|
|
"default_ttl": self.default_ttl
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
# Global response cache instance
|
||
|
|
response_cache: Optional[ResponseCache] = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_response_cache() -> ResponseCache:
|
||
|
|
"""Get or create global response cache instance"""
|
||
|
|
global response_cache
|
||
|
|
if response_cache is None:
|
||
|
|
from .cache import get_cache_manager
|
||
|
|
response_cache = ResponseCache(
|
||
|
|
cache_manager=get_cache_manager(),
|
||
|
|
default_ttl=3600 # 1 hour
|
||
|
|
)
|
||
|
|
return response_cache
|