Files

246 lines
7.3 KiB
Python
Raw Permalink Normal View History

"""
Redis cache management for conversation context
"""
import json
from typing import Any, Optional
from datetime import timedelta
import redis.asyncio as redis
from config import settings, get_redis_url
from .logger import get_logger
logger = get_logger(__name__)
class CacheManager:
"""Redis cache manager for conversation context"""
def __init__(self, redis_url: Optional[str] = None):
"""Initialize cache manager
Args:
redis_url: Redis connection URL, defaults to settings
"""
self._redis_url = redis_url or get_redis_url()
self._client: Optional[redis.Redis] = None
async def connect(self) -> None:
"""Connect to Redis"""
if self._client is None:
self._client = redis.from_url(
self._redis_url,
encoding="utf-8",
decode_responses=True
)
logger.info("Connected to Redis")
async def disconnect(self) -> None:
"""Disconnect from Redis"""
if self._client:
await self._client.close()
self._client = None
logger.info("Disconnected from Redis")
async def _ensure_connected(self) -> redis.Redis:
"""Ensure Redis connection is established"""
if self._client is None:
await self.connect()
return self._client
# ============ Conversation Context ============
def _context_key(self, conversation_id: str) -> str:
"""Generate Redis key for conversation context"""
return f"conversation:{conversation_id}"
async def get_context(self, conversation_id: str) -> Optional[dict[str, Any]]:
"""Get conversation context
Args:
conversation_id: Unique conversation identifier
Returns:
Context dictionary or None if not found
"""
client = await self._ensure_connected()
key = self._context_key(conversation_id)
data = await client.get(key)
if data:
logger.debug("Context retrieved", conversation_id=conversation_id)
return json.loads(data)
return None
async def set_context(
self,
conversation_id: str,
context: dict[str, Any],
ttl: Optional[int] = None
) -> None:
"""Set conversation context
Args:
conversation_id: Unique conversation identifier
context: Context dictionary
ttl: Time-to-live in seconds, defaults to settings
"""
client = await self._ensure_connected()
key = self._context_key(conversation_id)
ttl = ttl or settings.conversation_timeout
await client.setex(
key,
timedelta(seconds=ttl),
json.dumps(context, ensure_ascii=False)
)
logger.debug("Context saved", conversation_id=conversation_id, ttl=ttl)
async def update_context(
self,
conversation_id: str,
updates: dict[str, Any]
) -> dict[str, Any]:
"""Update conversation context with new values
Args:
conversation_id: Unique conversation identifier
updates: Dictionary of updates to merge
Returns:
Updated context dictionary
"""
context = await self.get_context(conversation_id) or {}
context.update(updates)
await self.set_context(conversation_id, context)
return context
async def delete_context(self, conversation_id: str) -> bool:
"""Delete conversation context
Args:
conversation_id: Unique conversation identifier
Returns:
True if deleted, False if not found
"""
client = await self._ensure_connected()
key = self._context_key(conversation_id)
result = await client.delete(key)
if result:
logger.debug("Context deleted", conversation_id=conversation_id)
return bool(result)
# ============ Message History ============
def _messages_key(self, conversation_id: str) -> str:
"""Generate Redis key for message history"""
return f"messages:{conversation_id}"
async def add_message(
self,
conversation_id: str,
role: str,
content: str,
max_messages: int = 20
) -> None:
"""Add message to conversation history
Args:
conversation_id: Unique conversation identifier
role: Message role (user/assistant/system)
content: Message content
max_messages: Maximum messages to keep
"""
client = await self._ensure_connected()
key = self._messages_key(conversation_id)
message = json.dumps({
"role": role,
"content": content
}, ensure_ascii=False)
# Add to list and trim
await client.rpush(key, message)
await client.ltrim(key, -max_messages, -1)
await client.expire(key, settings.conversation_timeout)
logger.debug(
"Message added",
conversation_id=conversation_id,
role=role
)
async def get_messages(
self,
conversation_id: str,
limit: int = 20
) -> list[dict[str, str]]:
"""Get conversation message history
Args:
conversation_id: Unique conversation identifier
limit: Maximum messages to retrieve
Returns:
List of message dictionaries
"""
client = await self._ensure_connected()
key = self._messages_key(conversation_id)
messages = await client.lrange(key, -limit, -1)
return [json.loads(m) for m in messages]
async def clear_messages(self, conversation_id: str) -> bool:
"""Clear conversation message history
Args:
conversation_id: Unique conversation identifier
Returns:
True if cleared, False if not found
"""
client = await self._ensure_connected()
key = self._messages_key(conversation_id)
result = await client.delete(key)
return bool(result)
# ============ Generic Cache Operations ============
async def get(self, key: str) -> Optional[str]:
"""Get value from cache"""
client = await self._ensure_connected()
return await client.get(key)
async def set(
self,
key: str,
value: str,
ttl: Optional[int] = None
) -> None:
"""Set value in cache"""
client = await self._ensure_connected()
if ttl:
await client.setex(key, timedelta(seconds=ttl), value)
else:
await client.set(key, value)
async def delete(self, key: str) -> bool:
"""Delete key from cache"""
client = await self._ensure_connected()
return bool(await client.delete(key))
# Global cache manager instance
cache_manager: Optional[CacheManager] = None
def get_cache_manager() -> CacheManager:
"""Get or create global cache manager instance"""
global cache_manager
if cache_manager is None:
cache_manager = CacheManager()
return cache_manager