import asyncio
import base64
import io
import os
import tarfile
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Union
import aiofiles
from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from .. import exceptions
class CompressionType(Enum):
"""Supported compression types."""
NONE = ""
GZIP = "gz"
BZIP2 = "bz2"
XZ = "xz"
class FileType(Enum):
"""File types in archive."""
FILE = "file"
DIRECTORY = "directory"
SYMLINK = "symlink"
[docs]
@dataclass(frozen=True, slots=True)
class ArchiveEntry:
"""Represents an entry in the archive.
This is an immutable value object that represents a file, directory, or symlink
in an archive. Content hash is computed lazily on first access and cached.
"""
name: str
"""Entry name/path in the archive."""
size: int
"""Size of the entry in bytes."""
file_type: FileType
"""Type of the entry."""
mode: int
"""File mode/permissions."""
mtime: datetime
"""Last modification time."""
_content: Optional[bytes] = None
"""Content of the file entry. Should only be set for FILE type."""
uid: int = 0
"""User ID of the entry."""
gid: int = 0
"""Group ID of the entry."""
_uname: str = ""
"""User name of the entry (cached)."""
_gname: str = ""
"""Group name of the entry (cached)."""
pax_headers: Dict[str, str] = field(default_factory=dict)
"""PAX extended attributes/headers."""
_content_hash_cache: Optional[bytes] = field(default=None, init=False)
"""Cached SHA-256 hash of content."""
def __post_init__(self):
"""Validate entry and pre-compute expensive operations."""
# Validation
if self.file_type == FileType.FILE and self._content is None:
raise ValueError("File entries must have content.")
if self.file_type != FileType.FILE and self._content is not None:
raise ValueError("Only file entries can have content.")
# Pre-compute and cache content hash for files to avoid recomputation
# This is critical for __hash__ and __eq__ performance
if self._content is not None:
digest = hashes.Hash(hashes.SHA256())
digest.update(self._content)
# Use object.__setattr__ because dataclass is frozen
object.__setattr__(self, "_content_hash_cache", digest.finalize())
# freeze pax_headers to prevent modification
object.__setattr__(self, "pax_headers", dict(self.pax_headers))
@property
def content(self) -> Optional[bytes]:
"""Get content bytes.
Returns:
File content or None for non-file entries
"""
return self._content
@property
def content_hash(self) -> Optional[bytes]:
"""Get SHA-256 hash of the content.
Hash is computed once and cached for performance.
Returns:
SHA-256 hash bytes or None for non-file entries
"""
return self._content_hash_cache
@property
def uname(self) -> str:
"""Get user name (lazy computation with fallback).
Returns:
User name or empty string if unavailable
"""
if self._uname:
return self._uname
try:
import pwd
return pwd.getpwuid(self.uid).pw_name # type: ignore
except (ImportError, KeyError, OSError):
return ""
@property
def gname(self) -> str:
"""Get group name (lazy computation with fallback).
Returns:
Group name or empty string if unavailable
"""
if self._gname:
return self._gname
try:
import grp
return grp.getgrgid(self.gid).gr_name # type: ignore
except (ImportError, KeyError, OSError):
return ""
[docs]
@classmethod
def from_tar_member(
cls, member: tarfile.TarInfo, content: Optional[bytes] = None
) -> "ArchiveEntry":
"""Create ArchiveEntry from TarInfo.
Args:
member: TarInfo object
content: File content bytes (required for files)
Returns:
New ArchiveEntry instance
"""
if member.isdir():
file_type = FileType.DIRECTORY
elif member.issym() or member.islnk():
file_type = FileType.SYMLINK
else:
file_type = FileType.FILE
# Extract PAX headers if available
pax_headers = {}
if hasattr(member, "pax_headers") and member.pax_headers:
pax_headers = dict(member.pax_headers)
return cls(
name=member.name,
size=member.size,
file_type=file_type,
mode=member.mode,
mtime=datetime.fromtimestamp(member.mtime),
_content=content,
uid=member.uid,
gid=member.gid,
_uname=member.uname,
_gname=member.gname,
pax_headers=pax_headers,
)
[docs]
@classmethod
async def from_file(
cls,
file_path: Path,
arcname: Optional[str] = None,
pax_headers: Optional[dict] = None,
) -> "ArchiveEntry":
"""Create ArchiveEntry from file system path.
Args:
file_path: Path to file or directory
arcname: Name to use in archive (defaults to file name)
pax_headers: Optional PAX headers to include
Returns:
New ArchiveEntry instance
"""
stat = file_path.stat()
if file_path.is_dir():
file_type = FileType.DIRECTORY
content = None
size = 0
else:
file_type = FileType.FILE
async with aiofiles.open(file_path, "rb") as f:
content = await f.read()
size = len(content)
return cls(
name=arcname or file_path.name,
size=size,
file_type=file_type,
mode=stat.st_mode,
mtime=datetime.fromtimestamp(stat.st_mtime),
_content=content,
uid=getattr(stat, "st_uid", 0),
gid=getattr(stat, "st_gid", 0),
pax_headers=pax_headers or {},
)
[docs]
def to_tar_info(self) -> tarfile.TarInfo:
"""Convert to TarInfo for writing to tar archive.
Returns:
TarInfo object ready for tar.addfile()
"""
info = tarfile.TarInfo(name=self.name)
info.size = self.size
info.mode = self.mode
info.mtime = int(self.mtime.timestamp())
info.uid = self.uid
info.gid = self.gid
info.uname = self.uname
info.gname = self.gname
# Set PAX headers if available
if self.pax_headers:
info.pax_headers = self.pax_headers.copy()
if self.file_type == FileType.DIRECTORY:
info.type = tarfile.DIRTYPE
elif self.file_type == FileType.SYMLINK:
info.type = tarfile.SYMTYPE
else:
info.type = tarfile.REGTYPE
return info
[docs]
def to_dict(self) -> dict:
"""Convert to dictionary representation.
Returns:
Dictionary with entry metadata (excludes content)
"""
return {
"name": self.name,
"size": self.size,
"type": self.file_type.value,
"mode": self.mode,
"mtime": self.mtime.isoformat(),
"uid": self.uid,
"gid": self.gid,
"uname": self.uname,
"gname": self.gname,
"content_hash": self.content_hash.hex() if self.content_hash else None,
"pax_headers": self.pax_headers,
}
def __hash__(self) -> int:
"""Compute hash for use in sets and dicts.
Uses cached content hash for performance. Does not include
uname/gname or mtime as they can vary by system/creation time.
Returns:
Hash value
"""
# Convert pax_headers dict to tuple of sorted items for hashing
pax_tuple = tuple(sorted(self.pax_headers.items())) if self.pax_headers else ()
return hash(
(
self.name,
self.size,
self.file_type,
self._content_hash_cache, # Use cached hash
self.mode,
# mtime excluded from hash
self.uid,
self.gid,
pax_tuple,
)
)
def __eq__(self, other: object) -> bool:
"""Check equality with another ArchiveEntry.
Uses cached content hash for performance. Does not compare
uname/gname or mtime as they can vary by system/creation time.
Args:
other: Object to compare with
Returns:
True if entries are equal
"""
if not isinstance(other, ArchiveEntry):
return NotImplemented
# Fast path: check identity
if self is other:
return True
return (
self.name == other.name
and self.size == other.size
and self.file_type == other.file_type
and self._content_hash_cache == other._content_hash_cache # Cached
and self.mode == other.mode
# mtime excluded from comparison
and self.uid == other.uid
and self.gid == other.gid
and self.pax_headers == other.pax_headers
)
def __repr__(self) -> str:
"""String representation for debugging."""
return (
f"ArchiveEntry(name={self.name!r}, "
f"type={self.file_type.value}, "
f"size={self.size}, "
f"mode=0o{self.mode:o})"
)
[docs]
class Archive:
"""High-performance async archive with encryption and streaming support.
The archive data is held in memory by default for fast operations.
Decryption happens only on load, encryption only on export.
Examples:
```
# Create archive from directory and save to file
async with Archive.from_directory("my_folder", password="secret") as archive:
await archive.write("archive.tar.gz")
# Load archive from file and extract
async with await Archive.load("archive.tar.gz", password="secret") as archive:
await archive.extract("output_folder")
# Create archive from bytes
data = b"..." # Encrypted archive bytes
async with await Archive.from_bytes(data, password="secret") as archive:
file_content = await archive.get_file("document.txt")
```
"""
# Maximum archive size in memory (50 MB)
MAX_SIZE = 50 * 1024 * 1024
def __init__(
self,
password: Optional[str] = None,
compression: Union[str, CompressionType] = CompressionType.GZIP,
salt: Optional[bytes] = None,
):
# Handle both string and enum for compression
if isinstance(compression, str):
try:
self.compression = CompressionType(compression)
except ValueError:
raise exceptions.InvalidCompressionTypeError(
f"Invalid compression type: {compression}"
)
else:
self.compression = compression
self._cipher: Optional[Fernet] = None
self._salt = salt
self._entries: Dict[str, ArchiveEntry] = {} # In-memory storage
self._password = (
password if password else "password"
) # To ensure data scrambling
if self._password:
self._cipher = self._create_cipher()
[docs]
@classmethod
async def load(
cls,
path: Union[str, Path],
**kwargs,
) -> "Archive":
"""Load archive from file.
Args:
path: Path to archive file
**kwargs: Additional arguments for Archive constructor
Returns:
Archive instance with loaded data
"""
archive = cls(**kwargs)
await archive.read(path)
return archive
[docs]
@classmethod
async def from_bytes(
cls,
data: bytes,
**kwargs,
) -> "Archive":
"""Load archive from bytes.
Args:
data: Archive data bytes
**kwargs: Additional arguments for Archive constructor
Returns:
Archive instance with loaded data
"""
archive = cls(**kwargs)
await archive._load_from_bytes(data)
return archive
[docs]
@classmethod
async def from_base64(
cls,
b64_string: str,
**kwargs,
) -> "Archive":
"""Load archive from base64 string.
Args:
b64_string: Base64 encoded archive
**kwargs: Additional arguments for Archive constructor
Returns:
Archive instance with loaded data
"""
loop = asyncio.get_event_loop()
# Handle both str and bytes input
b64_bytes = (
b64_string.encode("utf-8") if isinstance(b64_string, str) else b64_string
)
data = await loop.run_in_executor(None, base64.b64decode, b64_bytes)
return await cls.from_bytes(data, **kwargs)
[docs]
@classmethod
async def from_directory(
cls,
directory: Union[str, Path],
**kwargs,
) -> "Archive":
"""Create archive from directory.
Args:
directory: Directory to archive
**kwargs: Additional arguments for Archive constructor
Returns:
Archive instance with directory contents
"""
archive = cls(**kwargs)
await archive.add_directory(directory)
return archive
def _create_cipher(self) -> Fernet:
"""Create Fernet cipher from password."""
if not self._password:
raise ValueError("Password is required for encryption/decryption.")
if self._salt is None:
self._salt = os.urandom(16)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=self._salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(self._password.encode()))
return Fernet(key)
[docs]
def get_salt(self) -> Optional[bytes]:
"""Get the salt used for encryption."""
return self._salt
def _get_tar_mode(self, mode: str) -> str:
"""Get tarfile mode with compression."""
if self.compression == CompressionType.NONE:
return mode
return f"{mode}:{self.compression.value}"
def _get_tar_format(self) -> int:
"""Get tar format (PAX by default)."""
return tarfile.PAX_FORMAT
def _check_size_limit(self, additional_size: int = 0) -> None:
"""Check if adding data would exceed size limit."""
current_size = sum(entry.size for entry in self._entries.values())
if current_size + additional_size > self.MAX_SIZE:
raise exceptions.SizeLimitExceededError(
f"Archive size would exceed {self.MAX_SIZE / (1024 * 1024):.1f} MB limit. "
f"Current: {current_size / (1024 * 1024):.1f} MB, "
f"Adding: {additional_size / (1024 * 1024):.1f} MB"
)
async def _load_from_bytes(self, data: bytes) -> None:
"""Load archive from bytes."""
# Decrypt if needed
if self._password:
try:
loop = asyncio.get_event_loop()
# Extract salt from the first 16 bytes
if len(data) < 16:
raise exceptions.InvalidArchiveError(
"Encrypted data is too short to contain salt."
)
stored_salt = data[:16]
encrypted_data = data[16:]
# Use the stored salt to recreate the cipher
self._salt = stored_salt
self._cipher = self._create_cipher()
# Fernet expects base64 encoded bytes
encrypted_data = base64.urlsafe_b64encode(encrypted_data)
data = await loop.run_in_executor(
None, self._cipher.decrypt, encrypted_data
)
except InvalidToken:
raise exceptions.InvalidPasswordError(
"Incorrect password for archive decryption."
)
# Parse tar archive
loop = asyncio.get_event_loop()
self._entries = await loop.run_in_executor(
None,
self._parse_tar_to_entries,
data,
)
# Check size after parsing (uncompressed content size)
total_size = sum(entry.size for entry in self._entries.values())
if total_size > self.MAX_SIZE:
raise exceptions.SizeLimitExceededError(
f"Archive content size {total_size / (1024 * 1024):.1f} MB exceeds {self.MAX_SIZE / (1024 * 1024):.1f} MB limit"
)
def _parse_tar_to_entries(self, data: bytes) -> Dict[str, ArchiveEntry]:
"""Parse tar archive to entries."""
buffer = io.BytesIO(data)
entries = {}
tar: tarfile.TarFile
try:
mode = self._get_tar_mode("r")
with tarfile.open(fileobj=buffer, mode=mode) as tar: # type: ignore
for member in tar.getmembers():
content = None
if member.isfile():
file_obj = tar.extractfile(member)
if file_obj:
content = file_obj.read()
entry = ArchiveEntry.from_tar_member(member, content)
entries[entry.name] = entry
except tarfile.TarError as e:
raise exceptions.InvalidArchiveError(f"Failed to parse archive: {e}") from e
return entries
def _build_tar_from_entries(self) -> bytes:
"""Build tar archive from entries."""
buffer = io.BytesIO()
tar: tarfile.TarFile
mode = self._get_tar_mode("w")
with tarfile.open(
fileobj=buffer, mode=mode, format=self._get_tar_format()
) as tar: # type: ignore
for entry in self._entries.values():
tar_info = entry.to_tar_info()
if entry.file_type == FileType.FILE and entry.content:
tar.addfile(tar_info, io.BytesIO(entry.content))
else:
tar.addfile(tar_info)
return buffer.getvalue()
[docs]
async def read(self, path: Union[str, Path]) -> None:
"""Read archive from file.
Args:
path: Path to archive file
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Archive file not found: {path}")
async with aiofiles.open(path, "rb") as f:
data = await f.read()
await self._load_from_bytes(data)
[docs]
async def write(self, path: Union[str, Path]) -> int:
"""Write archive to file.
Args:
path: Output path
Returns:
Number of bytes written
"""
if not self._entries:
raise ValueError("Archive is empty. Nothing to write.")
data = await self.to_bytes()
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(path, "wb") as f:
return await f.write(data)
[docs]
async def to_bytes(self) -> bytes:
"""Get archive as bytes.
Returns:
Archive bytes
"""
if not self._entries:
raise exceptions.EmptyArchiveError("Archive is empty.")
# Build tar archive
loop = asyncio.get_event_loop()
data = await loop.run_in_executor(None, self._build_tar_from_entries)
# Encrypt if needed
if self._cipher:
# Fernet returns encrypted bytes (base64 encoded)
encrypted = await loop.run_in_executor(None, self._cipher.encrypt, data)
encrypted = base64.urlsafe_b64decode(encrypted)
# Prepend salt to encrypted data
if self._salt is None:
raise ValueError("Salt must be set for encryption.")
data = self._salt + encrypted
return data
[docs]
async def to_base64(self) -> str:
"""Get archive as base64 string.
Returns:
Base64 encoded archive
"""
data = await self.to_bytes()
loop = asyncio.get_event_loop()
encoded = await loop.run_in_executor(None, base64.b64encode, data)
return encoded.decode("utf-8")
[docs]
async def add_file(
self,
file_path: Union[str, Path],
arcname: Optional[str] = None,
pax_headers: Optional[Dict[str, str]] = None,
) -> ArchiveEntry:
"""Add file to archive.
Args:
file_path: Path to file
arcname: Name in archive (defaults to filename)
pax_headers: Optional PAX extended attributes
Returns:
Created ArchiveEntry
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
if not file_path.is_file():
raise exceptions.InvalidPathError(f"Path is not a file: {file_path}")
name = arcname or file_path.name
entry = await ArchiveEntry.from_file(
file_path, name, pax_headers=pax_headers or {}
)
# Check size limit
self._check_size_limit(entry.size)
self._entries[name] = entry
return entry
[docs]
async def add_directory(
self,
dir_path: Union[str, Path],
arcname: Optional[str] = None,
pax_headers: Optional[Dict[str, str]] = None,
) -> List[ArchiveEntry]:
"""Add directory to archive.
Args:
dir_path: Path to directory
arcname: Name in archive (defaults to directory name)
pax_headers: Optional PAX extended attributes
Returns:
List of created ArchiveEntry objects
"""
dir_path = Path(dir_path)
if not dir_path.exists():
raise FileNotFoundError(f"Directory not found: {dir_path}")
if not dir_path.is_dir():
raise NotADirectoryError(f"Path is not a directory: {dir_path}")
base_name = arcname or dir_path.name
entries = []
# Add directory entry itself
dir_entry = await ArchiveEntry.from_file(dir_path, base_name)
self._entries[base_name] = dir_entry
entries.append(dir_entry)
# Add all files recursively
for src_file in dir_path.rglob("*"):
rel_path = src_file.relative_to(dir_path)
arc_path = f"{base_name}/{rel_path}".replace("\\", "/")
entry = await ArchiveEntry.from_file(
src_file, arc_path, pax_headers=pax_headers or {}
)
self._check_size_limit(entry.size)
self._entries[arc_path] = entry
entries.append(entry)
return entries
[docs]
def add_bytes(
self,
name: str,
data: bytes,
mode: int = 0o644,
pax_headers: Optional[Dict[str, str]] = None,
) -> ArchiveEntry:
"""Add file from bytes to archive.
Args:
name: Name in archive
data: File content
mode: File mode (default: 0o644)
pax_headers: Optional PAX extended attributes
Returns:
Created ArchiveEntry
"""
self._check_size_limit(len(data))
entry = ArchiveEntry(
name=name,
size=len(data),
file_type=FileType.FILE,
mode=mode,
mtime=datetime.now(),
_content=data,
pax_headers=pax_headers or {},
)
self._entries[name] = entry
return entry
[docs]
def get_entry(self, name: str) -> ArchiveEntry:
"""Get archive entry by name.
Args:
name: Entry name
Returns:
ArchiveEntry
"""
if name not in self._entries:
raise FileNotFoundError(f"Entry not found in archive: {name}")
return self._entries[name]
[docs]
def get_file(self, name: str) -> bytes:
"""Get content of a file from archive.
Args:
name: Path of file within archive
Returns:
File content as bytes
"""
entry = self.get_entry(name)
if entry.file_type != FileType.FILE:
raise ValueError(f"Entry is not a file: {name}")
if entry.content is None:
raise ValueError(f"File has no content: {name}")
return entry.content
[docs]
def has_file(self, name: str) -> bool:
"""Check if archive contains a file."""
return name in self._entries
[docs]
def remove_file(self, name: str) -> ArchiveEntry:
"""Remove file from archive.
Args:
name: Entry name
Returns:
Removed ArchiveEntry
"""
if name not in self._entries:
raise FileNotFoundError(f"Entry not found in archive: {name}")
return self._entries.pop(name)
[docs]
def list_files(self) -> List[ArchiveEntry]:
"""List all entries in archive.
Returns:
List of ArchiveEntry objects
"""
return list(self._entries.values())
[docs]
def list_file_names(self) -> List[str]:
"""List all entry names in archive.
Returns:
List of entry names
"""
return list(self._entries.keys())
[docs]
def get_size(self) -> int:
"""Get total size of archive contents in bytes."""
return sum(entry.size for entry in self._entries.values())
[docs]
def get_file_count(self) -> int:
"""Get number of files in archive."""
return sum(
1 for entry in self._entries.values() if entry.file_type == FileType.FILE
)
[docs]
def clear(self) -> None:
"""Clear archive contents."""
self._entries.clear()
[docs]
def is_empty(self) -> bool:
"""Check if archive is empty."""
return len(self._entries) == 0
[docs]
def close(self) -> None:
"""Close the archive and release resources."""
self.clear()
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.close()
def __iter__(self) -> Iterator[ArchiveEntry]:
"""Iterator over archive entries."""
return iter(self._entries.values())
def __hash__(self):
"""Hash of the archive based on its contents."""
# entries provide their hashes
# Include compression in hash
return hash((frozenset(self._entries.items()), self.compression))
def __eq__(self, other: object) -> bool:
"""Check equality with another Archive.
Args:
other: Object to compare with
Returns:
True if archives are equal
"""
if not isinstance(other, Archive):
return NotImplemented
# Fast path: check identity
if self is other:
return True
# Compare entries and compression
# Note: We don't compare password/salt for default instances
# since salt is randomly generated
return self._entries == other._entries and self.compression == other.compression
async def __aenter__(self):
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
self.close()