""" Redis cache management for conversation context """ import json from typing import Any, Optional from datetime import timedelta import redis.asyncio as redis from config import settings, get_redis_url from .logger import get_logger logger = get_logger(__name__) class CacheManager: """Redis cache manager for conversation context""" def __init__(self, redis_url: Optional[str] = None): """Initialize cache manager Args: redis_url: Redis connection URL, defaults to settings """ self._redis_url = redis_url or get_redis_url() self._client: Optional[redis.Redis] = None async def connect(self) -> None: """Connect to Redis""" if self._client is None: self._client = redis.from_url( self._redis_url, encoding="utf-8", decode_responses=True ) logger.info("Connected to Redis") async def disconnect(self) -> None: """Disconnect from Redis""" if self._client: await self._client.close() self._client = None logger.info("Disconnected from Redis") async def _ensure_connected(self) -> redis.Redis: """Ensure Redis connection is established""" if self._client is None: await self.connect() return self._client # ============ Conversation Context ============ def _context_key(self, conversation_id: str) -> str: """Generate Redis key for conversation context""" return f"conversation:{conversation_id}" async def get_context(self, conversation_id: str) -> Optional[dict[str, Any]]: """Get conversation context Args: conversation_id: Unique conversation identifier Returns: Context dictionary or None if not found """ client = await self._ensure_connected() key = self._context_key(conversation_id) data = await client.get(key) if data: logger.debug("Context retrieved", conversation_id=conversation_id) return json.loads(data) return None async def set_context( self, conversation_id: str, context: dict[str, Any], ttl: Optional[int] = None ) -> None: """Set conversation context Args: conversation_id: Unique conversation identifier context: Context dictionary ttl: Time-to-live in seconds, defaults to settings """ client = await self._ensure_connected() key = self._context_key(conversation_id) ttl = ttl or settings.conversation_timeout await client.setex( key, timedelta(seconds=ttl), json.dumps(context, ensure_ascii=False) ) logger.debug("Context saved", conversation_id=conversation_id, ttl=ttl) async def update_context( self, conversation_id: str, updates: dict[str, Any] ) -> dict[str, Any]: """Update conversation context with new values Args: conversation_id: Unique conversation identifier updates: Dictionary of updates to merge Returns: Updated context dictionary """ context = await self.get_context(conversation_id) or {} context.update(updates) await self.set_context(conversation_id, context) return context async def delete_context(self, conversation_id: str) -> bool: """Delete conversation context Args: conversation_id: Unique conversation identifier Returns: True if deleted, False if not found """ client = await self._ensure_connected() key = self._context_key(conversation_id) result = await client.delete(key) if result: logger.debug("Context deleted", conversation_id=conversation_id) return bool(result) # ============ Message History ============ def _messages_key(self, conversation_id: str) -> str: """Generate Redis key for message history""" return f"messages:{conversation_id}" async def add_message( self, conversation_id: str, role: str, content: str, max_messages: int = 20 ) -> None: """Add message to conversation history Args: conversation_id: Unique conversation identifier role: Message role (user/assistant/system) content: Message content max_messages: Maximum messages to keep """ client = await self._ensure_connected() key = self._messages_key(conversation_id) message = json.dumps({ "role": role, "content": content }, ensure_ascii=False) # Add to list and trim await client.rpush(key, message) await client.ltrim(key, -max_messages, -1) await client.expire(key, settings.conversation_timeout) logger.debug( "Message added", conversation_id=conversation_id, role=role ) async def get_messages( self, conversation_id: str, limit: int = 20 ) -> list[dict[str, str]]: """Get conversation message history Args: conversation_id: Unique conversation identifier limit: Maximum messages to retrieve Returns: List of message dictionaries """ client = await self._ensure_connected() key = self._messages_key(conversation_id) messages = await client.lrange(key, -limit, -1) return [json.loads(m) for m in messages] async def clear_messages(self, conversation_id: str) -> bool: """Clear conversation message history Args: conversation_id: Unique conversation identifier Returns: True if cleared, False if not found """ client = await self._ensure_connected() key = self._messages_key(conversation_id) result = await client.delete(key) return bool(result) # ============ Generic Cache Operations ============ async def get(self, key: str) -> Optional[str]: """Get value from cache""" client = await self._ensure_connected() return await client.get(key) async def set( self, key: str, value: str, ttl: Optional[int] = None ) -> None: """Set value in cache""" client = await self._ensure_connected() if ttl: await client.setex(key, timedelta(seconds=ttl), value) else: await client.set(key, value) async def delete(self, key: str) -> bool: """Delete key from cache""" client = await self._ensure_connected() return bool(await client.delete(key)) # Global cache manager instance cache_manager: Optional[CacheManager] = None def get_cache_manager() -> CacheManager: """Get or create global cache manager instance""" global cache_manager if cache_manager is None: cache_manager = CacheManager() return cache_manager