Source code for cligram.state_manager

"""State management module for presistent application data."""

import asyncio
import copy
import hashlib
import logging
import os
import shutil
from abc import ABC, abstractmethod
from argparse import ArgumentError
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional, TypeVar, Union, overload

import aiofiles

from . import utils

logger: logging.Logger = logging.getLogger(__name__)


[docs] class State(ABC): """Abstract base class for state storage. Defines interface for loading, exporting, and tracking changes to state data. """ suffix: str = ".state" """File suffix for state files"""
[docs] @staticmethod def parse(content: str) -> Any: """Parse state input contents. Subclasses may override this method to implement custom parsing logic. The default implementation returns the content as-is. Args: content: Content of the state file Returns: Parsed state data, could be passed to load() method """ return content
[docs] @abstractmethod def load(self, data: Any) -> None: """Load state data. Args: data: Data to load into state """ pass
[docs] @abstractmethod def export(self) -> Any: """Export current state data. Returns: Current state data """ pass
[docs] @abstractmethod def changed(self) -> bool: """Check if state changed since last reset. Returns: True if state has changed, False otherwise """ pass
[docs] @abstractmethod def set_changed(self, changed: bool) -> None: """Set the changed status of the state. This method is used internally to mark the state as changed or unchanged. Args: changed: True to mark state as changed, False to mark as unchanged """ pass
StateT = TypeVar("StateT", bound=State)
[docs] class JsonState(State): """JSON-based state implementation.""" data: Dict[str, Any] """Current state data""" schema: Optional[Dict[str, Any]] """Optional schema for data validation""" suffix: str = ".json" """File suffix for JSON state files""" def __init__(self): """Initialize JSON state. Initializes data with default values and sets up change tracking. If a schema is defined, it will be used to validate data structure If subclasses define _default_data or schema attributes, they will be used instead of empty defaults. """ self._default_data: Dict[str, Any] = getattr(self, "_default_data", {}) self.data: Dict[str, Any] = copy.deepcopy(self._default_data) self.schema: Optional[Dict[str, Any]] = getattr(self, "schema", None) self.set_changed(False)
[docs] @staticmethod def parse(content: str) -> Dict[str, Any] | list: """Parse JSON state contents. Args: content: Content of the state file Returns: Parsed state data Raises: ValueError: If file content is invalid """ try: data = utils.json.loads(content) except utils.json.JSONDecodeError as e: raise ValueError(f"Invalid JSON content: {e}") from e if not isinstance(data, (dict, list)): raise ValueError("Invalid JSON content: expected dict or list at top level") return data
[docs] def load(self, data: Optional[Union[Dict[str, Any], str]]) -> None: """Load state data from JSON string or dictionary. Args: data: Data to load into state, either as a dict or JSON string Raises: RuntimeError: If there are unsaved changes """ if self.changed(): raise RuntimeError("Cannot load state with unsaved changes") if data is None: return if not isinstance(data, dict): raise ValueError( f"Invalid data format: expected dict, got {type(data).__name__}" ) # merge with default data data = copy.deepcopy(self._default_data) | data self.ensure_schema(data) self.data = data self.set_changed(False)
[docs] def export(self) -> str: """Export current state data. Returns: Current state data """ self.ensure_schema(self.data) normalized = JsonState._sets_to_lists(self.data) json_data = utils.json.dumps(normalized, sort_keys=True) return json_data
[docs] def changed(self) -> bool: """Check if state changed since last reset. Returns: True if state has changed, False otherwise """ return self.get_hash() != self._last_hash
[docs] def set_changed(self, changed: bool) -> None: """Set the changed status of the state. This method is used internally to mark the state as changed or unchanged. Args: changed: True to mark state as changed, False to mark as unchanged """ if not changed: self._last_hash = self.get_hash() else: self._last_hash = ""
[docs] def ensure_schema(self, data: Dict[str, Any]) -> None: """Ensure current data matches schema. Args: data: Data to validate Raises: ValueError: If data does not match schema """ if self.schema: if not self.verify_structure(data, self.schema): raise ValueError("State data does not match schema")
[docs] def get_hash(self) -> str: """Get a hash of the current state data.""" m = hashlib.sha256() json_data = JsonState._sets_to_lists(self.data) m.update(utils.json.dumps(json_data, sort_keys=True).encode("utf-8")) return m.hexdigest()
[docs] @classmethod def verify_structure( cls, data: Any, schema: Dict[str, Any], path: str = "" ) -> bool: """Recursively verify that data matches the provided schema. Args: data: Data to verify schema: Schema definition path: Current data path for error messages Returns: True if data matches schema, False otherwise """ if isinstance(schema, dict): if not isinstance(data, dict): logger.warning( f"Structure mismatch at {path or 'root'}: expected dict, got {type(data).__name__}" ) return False for key, subschema in schema.items(): if key not in data: logger.warning(f"Missing key '{key}' at {path or 'root'}") return False if not cls.verify_structure( data[key], subschema, path=f"{path}.{key}" if path else key ): return False return True elif isinstance(schema, list): if not isinstance(data, list): logger.warning( f"Structure mismatch at {path or 'root'}: expected list, got {type(data).__name__}" ) return False if schema: subschema = schema[0] for idx, item in enumerate(data): if not cls.verify_structure(item, subschema, path=f"{path}[{idx}]"): return False return True elif isinstance(schema, type): if not isinstance(data, schema): logger.warning( f"Type mismatch at {path or 'root'}: expected {schema.__name__}, got {type(data).__name__}" ) return False return True else: logger.warning(f"Unknown schema type at {path or 'root'}: {schema}") return False
@staticmethod def _sets_to_lists(data: Any) -> Any: """Convert set objects to lists for JSON serialization. Args: data: Data structure containing sets Returns: Data structure with sets converted to lists """ if isinstance(data, dict): return {k: JsonState._sets_to_lists(v) for k, v in data.items()} elif isinstance(data, set): return list(data) elif isinstance(data, list): return [JsonState._sets_to_lists(item) for item in data] return data
[docs] class StateManager: """Manages persistent application state and handles file-based storage operations. Provides methods to register state types, load/save states, and perform backups/restores of state data. """ _registered_states: Dict[str, type[State]] = {} """Class-level registry of state types""" def __init__(self, data_dir: str | Path, backup_dir: Optional[str | Path] = None): """Initialize the state manager. Args: data_dir: Directory to store state files backup_dir: Directory to store backups (optional) """ self.data_dir = Path(data_dir).resolve() """Directory for state files""" self.data_dir.mkdir(parents=True, exist_ok=True) self.backup_dir = ( Path(backup_dir).resolve() if backup_dir else self.data_dir / "backup" ) """Directory for backup files""" self._need_backup = False """Flag indicating if state changes need backup""" self.lock = asyncio.Lock() """Lock for synchronizing state operations""" self.states: Dict[str, State] = {} """Registry of state handlers by name""" # Initialize all registered states for name, state_class in self._registered_states.items(): self.states[name] = state_class() def _get_state_path(self, name: str) -> Path: """Get full path for state file.""" suffix = self.states[name].suffix return self.data_dir / f"{name}{suffix}"
[docs] @classmethod def register(cls, name: str, state_class: type[State]) -> None: """Register a new state type. This method does not affect existing StateManager instances. Args: name: Name of the state state_class: State class to register Raises: TypeError: If state_class is not a State subclass ArgumentError: If name is already registered """ if not isinstance(state_class, type) or not issubclass(state_class, State): raise TypeError("State must be a subclass of State") if name in cls._registered_states: raise ArgumentError(None, f"State '{name}' is already registered") cls._registered_states[name] = state_class
@overload def get(self, name: str) -> State: ... @overload def get(self, name: str, expected_type: type[StateT]) -> StateT: ...
[docs] def get(self, name: str, expected_type: Optional[type[State]] = None) -> State: """Get registered state by name. Args: name: Name of the state expected_type: Optional expected type of the state Returns: Registered state instance Raises: KeyError: If state is not registered TypeError: If state type does not match expected_type """ if name not in self.states: raise KeyError(f"State '{name}' is not registered") state = self.states[name] if expected_type and not isinstance(state, expected_type): raise TypeError( f"State '{name}' is not of expected type {expected_type.__name__}" ) return state
[docs] async def load(self): """Load all states from disk.""" logger.info("Loading state...") for name, state in self.states.items(): filepath = self._get_state_path(name) if not filepath.exists(): continue async with aiofiles.open(filepath, "r", encoding="utf-8") as f: content = await f.read() data = state.parse(content) state.load(data) logger.info(f"Loaded {name} state")
[docs] async def save(self): """Save changed states to disk.""" changed = False async with self.lock: logger.info("Saving state...") for name, state in self.states.items(): if not state.changed(): continue data = state.export() filepath = self._get_state_path(name) await self._atomic_save(filepath, data) state.set_changed(False) changed = True logger.debug(f"Saved {name} state") if changed: logger.info("All states saved") self._need_backup = True else: logger.debug("No changes detected")
async def _atomic_save(self, path: str | Path, data: str): """Atomically save state data to disk. Args: path: Path to state file data: State data to save """ filepath = Path(path) temp_path = filepath.with_suffix(".tmp") async with aiofiles.open(temp_path, "w", encoding="utf-8") as f: await f.write(data) os.replace(temp_path, filepath)
[docs] def backup(self): """Create backup of all registered states.""" if not self._need_backup: logger.debug("No changes detected, skipping backup") return if not self.backup_dir: raise ValueError("No backup directory configured") if self.backup_dir.is_file(): raise ValueError("Invalid backup directory") logger.info("Creating backup of all states...") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_path = self.backup_dir / timestamp backup_path.mkdir(parents=True, exist_ok=True) for name, state in self.states.items(): if state.changed(): raise RuntimeError(f"Cannot backup state with unsaved changes: {name}") src = self._get_state_path(name) if src.exists(): dest = backup_path / f"{name}{state.suffix}" shutil.copy2(src, dest) logger.debug(f"Backed up {name}") logger.info(f"All states backed up to {backup_path}") self._need_backup = False
[docs] def restore(self, timestamp: Optional[str] = None): """Restore all states from backup. Restore just copies state files from the backup, overwriting current state files. But it does not load them into current StateManager. It's highly recommended to create a new StateManager instance after restore. Args: timestamp: Timestamp of the backup to restore (format: YYYYMMDD_HHMMSS """ if not self.backup_dir: raise ValueError("No backup directory configured") if self.backup_dir.is_file(): raise ValueError("Invalid backup directory") if not self.backup_dir.exists(): raise ValueError("Backup directory does not exist") backup_base = self.backup_dir if not timestamp: backups = [p for p in backup_base.iterdir() if p.is_dir()] if not backups: raise ValueError("No backups found") backup_path = max(backups, key=lambda p: p.name) else: backup_path = backup_base / timestamp if not backup_path.exists(): raise ValueError(f"Backup {timestamp} does not exist") logger.info(f"Restoring states from {backup_path.name}...") restored = 0 for name, state in self.states.items(): if state.changed(): raise RuntimeError(f"Cannot restore state with unsaved changes: {name}") backup = backup_path / f"{name}{state.suffix}" target = self._get_state_path(name) if backup.exists(): shutil.copy2(backup, target) restored += 1 logger.debug(f"Restored {name}") if restored > 0: logger.info(f"Restored {restored} states from {backup_path.name}") else: logger.warning("No states restored")