Source code for sqlspec.storage.pipeline

"""Storage pipeline scaffolding for driver-aware storage bridge."""

from collections import deque
from functools import partial
from pathlib import Path
from time import perf_counter, time
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, TypeAlias, cast
from uuid import uuid4

from mypy_extensions import mypyc_attr
from typing_extensions import NotRequired, TypedDict

from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.storage._arrow_payload import decode_arrow_payload, encode_arrow_payload
from sqlspec.storage.errors import execute_async_storage_operation, execute_sync_storage_operation
from sqlspec.storage.registry import StorageRegistry, storage_registry
from sqlspec.utils.serializers import get_serializer_metrics, serialize_collection, to_json
from sqlspec.utils.sync_tools import async_
from sqlspec.utils.type_guards import supports_async_delete, supports_async_read_bytes, supports_async_write_bytes

if TYPE_CHECKING:
    from collections.abc import AsyncIterator, Iterator

    from sqlspec.protocols import ObjectStoreProtocol
    from sqlspec.typing import ArrowTable


__all__ = (
    "AsyncStoragePipeline",
    "PartitionStrategyConfig",
    "StagedArtifact",
    "StorageBridgeJob",
    "StorageCapabilities",
    "StorageDestination",
    "StorageDiagnostics",
    "StorageFormat",
    "StorageLoadRequest",
    "StorageTelemetry",
    "SyncStoragePipeline",
    "create_storage_bridge_job",
    "get_recent_storage_events",
    "get_storage_bridge_diagnostics",
    "get_storage_bridge_metrics",
    "record_storage_diagnostic_event",
    "reset_storage_bridge_events",
    "reset_storage_bridge_metrics",
)

StorageFormat = Literal["jsonl", "json", "parquet", "arrow-ipc", "csv"]
StorageDestination: TypeAlias = str | Path
StorageDiagnostics: TypeAlias = dict[str, float]


[docs] class StorageCapabilities(TypedDict): """Runtime-evaluated driver storage capabilities.""" arrow_export_enabled: bool arrow_import_enabled: bool parquet_export_enabled: bool parquet_import_enabled: bool requires_staging_for_load: bool staging_protocols: "list[str]" partition_strategies: "list[str]" default_storage_profile: NotRequired[str | None]
[docs] class PartitionStrategyConfig(TypedDict, total=False): """Configuration for partition fan-out strategies.""" kind: str partitions: int rows_per_chunk: int manifest_path: str
[docs] class StorageLoadRequest(TypedDict): """Request describing a staging allocation.""" partition_id: str destination_uri: str ttl_seconds: int correlation_id: str source_uri: NotRequired[str]
[docs] class StagedArtifact(TypedDict): """Metadata describing a staged artifact managed by the pipeline.""" partition_id: str uri: str cleanup_token: str ttl_seconds: int expires_at: float correlation_id: str
[docs] class StorageTelemetry(TypedDict, total=False): """Telemetry payload for storage bridge operations.""" destination: str bytes_processed: int rows_processed: int partitions_created: int duration_s: float format: str extra: "dict[str, object]" backend: str correlation_id: str config: str bind_key: str
[docs] class StorageBridgeJob(NamedTuple): """Handle representing a storage bridge operation.""" job_id: str status: str telemetry: StorageTelemetry
class _StorageBridgeMetrics: __slots__ = ("bytes_written", "partitions_created") def __init__(self) -> None: self.bytes_written = 0 self.partitions_created = 0 def record_bytes(self, count: int) -> None: self.bytes_written += max(count, 0) def record_partitions(self, count: int) -> None: self.partitions_created += max(count, 0) def snapshot(self) -> "dict[str, int]": return { "storage_bridge.bytes_written": self.bytes_written, "storage_bridge.partitions_created": self.partitions_created, } def reset(self) -> None: self.bytes_written = 0 self.partitions_created = 0 _METRICS = _StorageBridgeMetrics() _RECENT_STORAGE_EVENTS: "deque[StorageTelemetry]" = deque(maxlen=25) _EMPTY_STORAGE_OPTIONS: dict[str, Any] = {} def _resolve_pipeline_storage_options( default_options: "dict[str, Any]", storage_options: "dict[str, Any] | None" ) -> "dict[str, Any]": return default_options if storage_options is None else storage_options def _extract_csv_write_options(storage_options: "dict[str, Any]") -> "dict[str, Any] | None": return cast("dict[str, Any] | None", storage_options.get("write_options")) def _get_csv_write_options( format_choice: StorageFormat, resolved_options: "dict[str, Any]", default_options: "dict[str, Any]", default_write_options: "dict[str, Any] | None", ) -> "dict[str, Any] | None": if format_choice != "csv": return None if resolved_options is default_options: return default_write_options return _extract_csv_write_options(resolved_options)
[docs] def get_storage_bridge_metrics() -> "dict[str, int]": """Return aggregated storage bridge metrics.""" return _METRICS.snapshot()
[docs] def reset_storage_bridge_metrics() -> None: """Reset aggregated storage bridge metrics.""" _METRICS.reset()
def record_storage_diagnostic_event(telemetry: StorageTelemetry) -> None: """Record telemetry for inclusion in diagnostics snapshots.""" _RECENT_STORAGE_EVENTS.append(cast("StorageTelemetry", dict(telemetry))) def get_recent_storage_events() -> "list[StorageTelemetry]": """Return recent storage telemetry events (most recent first).""" return [cast("StorageTelemetry", dict(entry)) for entry in _RECENT_STORAGE_EVENTS] def reset_storage_bridge_events() -> None: """Clear recorded storage telemetry events.""" _RECENT_STORAGE_EVENTS.clear()
[docs] def create_storage_bridge_job(status: str, telemetry: StorageTelemetry) -> StorageBridgeJob: """Create a storage bridge job handle with a unique identifier.""" job = StorageBridgeJob(job_id=str(uuid4()), status=status, telemetry=telemetry) record_storage_diagnostic_event(job.telemetry) return job
[docs] def get_storage_bridge_diagnostics() -> "StorageDiagnostics": """Return aggregated storage bridge + serializer cache metrics.""" diagnostics: dict[str, float] = {key: float(value) for key, value in get_storage_bridge_metrics().items()} serializer_metrics = get_serializer_metrics() for key, value in serializer_metrics.items(): diagnostics[f"serializer.{key}"] = float(value) return diagnostics
def _encode_row_payload(rows: "list[Any]", format_hint: StorageFormat) -> bytes: if format_hint == "json": data = to_json(rows, as_bytes=True) if isinstance(data, bytes): return data return data.encode() buffer = bytearray() for row in rows: buffer.extend(to_json(row, as_bytes=True)) buffer.extend(b"\n") return bytes(buffer) def _encode_arrow_payload( table: "ArrowTable", format_choice: StorageFormat, *, compression: str | None, write_options: "dict[str, Any] | None" = None, ) -> bytes: return encode_arrow_payload(table, format_choice, compression=compression, write_options=write_options) def _delete_backend_sync(backend: "ObjectStoreProtocol", path: str, *, backend_name: str) -> None: execute_sync_storage_operation( partial(backend.delete_sync, path), backend=backend_name, operation="delete", path=path ) def _write_backend_sync(backend: "ObjectStoreProtocol", path: str, payload: bytes, *, backend_name: str) -> None: execute_sync_storage_operation( partial(backend.write_bytes_sync, path, payload), backend=backend_name, operation="write_bytes", path=path ) def _read_backend_sync(backend: "ObjectStoreProtocol", path: str, *, backend_name: str) -> bytes: return execute_sync_storage_operation( partial(backend.read_bytes_sync, path), backend=backend_name, operation="read_bytes", path=path ) def _decode_arrow_payload(payload: bytes, format_choice: StorageFormat) -> "ArrowTable": return decode_arrow_payload(payload, format_choice) def _resolve_alias_destination( registry: StorageRegistry, destination: str, backend_options: "dict[str, Any]" ) -> "tuple[ObjectStoreProtocol, str, str] | None": if not destination.startswith("alias://"): return None payload = destination.removeprefix("alias://") alias_name, _, relative_path = payload.partition("/") alias = alias_name.strip() if not alias: msg = "Alias destinations must include a registry alias before the path component" raise ImproperConfigurationError(msg) path_segment = relative_path.strip() if not path_segment: msg = "Alias destinations must include an object path after the alias name" raise ImproperConfigurationError(msg) backend = registry.get(alias, **backend_options) return backend, path_segment.lstrip("/"), backend.backend_type def _normalize_path_for_backend(destination: str) -> str: if destination.startswith("file://"): return destination.removeprefix("file://") if "://" in destination: _, remainder = destination.split("://", 1) return remainder.lstrip("/") return destination def _resolve_storage_backend( registry: StorageRegistry, destination: StorageDestination, backend_options: "dict[str, Any] | None" ) -> "tuple[ObjectStoreProtocol, str, str]": destination_str = destination.as_posix() if isinstance(destination, Path) else str(destination) options = _EMPTY_STORAGE_OPTIONS if backend_options is None else backend_options alias_resolution = _resolve_alias_destination(registry, destination_str, options) if alias_resolution is not None: return alias_resolution backend = registry.get(destination_str, **options) normalized_path = _normalize_path_for_backend(destination_str) return backend, normalized_path, backend.backend_type def _make_resolved_backend_cache_key( destination: StorageDestination, backend_options: "dict[str, Any] | None" ) -> "str | None": if backend_options: return None return destination.as_posix() if isinstance(destination, Path) else str(destination)
[docs] @mypyc_attr(allow_interpreted_subclasses=True) class SyncStoragePipeline: """Pipeline coordinating storage registry operations and telemetry.""" __slots__ = ("_csv_write_options", "_resolved_backend_cache", "_storage_options", "registry")
[docs] def __init__( self, *, registry: StorageRegistry | None = None, storage_options: "dict[str, Any] | None" = None ) -> None: self.registry = registry or storage_registry self._resolved_backend_cache: dict[str, tuple[ObjectStoreProtocol, str, str]] = {} self._storage_options = _EMPTY_STORAGE_OPTIONS if storage_options is None else storage_options self._csv_write_options = _extract_csv_write_options(self._storage_options)
[docs] def clear_cache(self) -> None: """Clear cached storage backend resolutions for this pipeline instance.""" self._resolved_backend_cache.clear()
def _resolve_backend( self, destination: StorageDestination, backend_options: "dict[str, Any] | None" ) -> "tuple[ObjectStoreProtocol, str, str]": """Resolve storage backend and normalized path for a destination.""" cache_key = _make_resolved_backend_cache_key(destination, backend_options) if cache_key is None: return _resolve_storage_backend(self.registry, destination, backend_options) cached = self._resolved_backend_cache.get(cache_key) if cached is not None: return cached resolved = _resolve_storage_backend(self.registry, destination, backend_options) self._resolved_backend_cache[cache_key] = resolved return resolved
[docs] def write_rows( self, rows: "list[dict[str, Any]]", destination: StorageDestination, *, format_hint: StorageFormat | None = None, storage_options: "dict[str, Any] | None" = None, ) -> StorageTelemetry: """Write dictionary rows to storage using cached serializers.""" serialized = serialize_collection(rows) format_choice = format_hint or "jsonl" payload = _encode_row_payload(serialized, format_choice) resolved_options = _resolve_pipeline_storage_options(self._storage_options, storage_options) return self._write_bytes( payload, destination, rows=len(serialized), format_label=format_choice, storage_options=resolved_options )
[docs] def write_arrow( self, table: "ArrowTable", destination: StorageDestination, *, format_hint: StorageFormat | None = None, storage_options: "dict[str, Any] | None" = None, compression: str | None = None, ) -> StorageTelemetry: """Write an Arrow table to storage using zero-copy buffers.""" format_choice = format_hint or "parquet" resolved_options = _resolve_pipeline_storage_options(self._storage_options, storage_options) format_write_options = _get_csv_write_options( format_choice, resolved_options, self._storage_options, self._csv_write_options ) payload = _encode_arrow_payload( table, format_choice, compression=compression, write_options=format_write_options ) return self._write_bytes( payload, destination, rows=int(table.num_rows), format_label=format_choice, storage_options=resolved_options )
[docs] def read_arrow( self, source: StorageDestination, *, file_format: StorageFormat, storage_options: "dict[str, Any] | None" = None ) -> "tuple[ArrowTable, StorageTelemetry]": """Read an artifact from storage and decode it into an Arrow table.""" backend, path, backend_name = self._resolve_backend(source, storage_options) payload = _read_backend_sync(backend, path, backend_name=backend_name) table = _decode_arrow_payload(payload, file_format) rows_processed = int(table.num_rows) telemetry: StorageTelemetry = { "destination": path, "bytes_processed": len(payload), "rows_processed": rows_processed, "format": file_format, "backend": backend_name, } return table, telemetry
[docs] def stream_read( self, source: StorageDestination, *, chunk_size: int | None = None, storage_options: "dict[str, Any] | None" = None, ) -> "Iterator[bytes]": """Stream bytes from an artifact.""" backend, path, _backend_name = self._resolve_backend(source, storage_options) return backend.stream_read_sync(path, chunk_size=chunk_size)
[docs] def allocate_staging_artifacts(self, requests: "list[StorageLoadRequest]") -> "list[StagedArtifact]": """Allocate staging metadata for upcoming loads.""" artifacts: list[StagedArtifact] = [] now = time() for request in requests: ttl = max(request["ttl_seconds"], 0) cleanup_token = f"{request['correlation_id']}::{request['partition_id']}" artifacts.append({ "partition_id": request["partition_id"], "uri": request["destination_uri"], "cleanup_token": cleanup_token, "ttl_seconds": ttl, "expires_at": now + ttl if ttl else now, "correlation_id": request["correlation_id"], }) if artifacts: _METRICS.record_partitions(len(artifacts)) return artifacts
[docs] def cleanup_staging_artifacts(self, artifacts: "list[StagedArtifact]", *, ignore_errors: bool = True) -> None: """Delete staged artifacts best-effort.""" for artifact in artifacts: backend, path, backend_name = self._resolve_backend(artifact["uri"], None) try: _delete_backend_sync(backend, path, backend_name=backend_name) except Exception: if not ignore_errors: raise
def _write_bytes( self, payload: bytes, destination: StorageDestination, *, rows: int, format_label: str, storage_options: "dict[str, Any]", ) -> StorageTelemetry: backend, path, backend_name = self._resolve_backend(destination, storage_options) start = perf_counter() _write_backend_sync(backend, path, payload, backend_name=backend_name) elapsed = perf_counter() - start bytes_written = len(payload) _METRICS.record_bytes(bytes_written) telemetry: StorageTelemetry = { "destination": path, "bytes_processed": bytes_written, "rows_processed": rows, "duration_s": elapsed, "format": format_label, "backend": backend_name, } return telemetry
[docs] @mypyc_attr(allow_interpreted_subclasses=True) class AsyncStoragePipeline: """Async variant of the storage pipeline leveraging async-capable backends when available.""" __slots__ = ("_csv_write_options", "_resolved_backend_cache", "_storage_options", "registry")
[docs] def __init__( self, *, registry: StorageRegistry | None = None, storage_options: "dict[str, Any] | None" = None ) -> None: self.registry = registry or storage_registry self._resolved_backend_cache: dict[str, tuple[ObjectStoreProtocol, str, str]] = {} self._storage_options = _EMPTY_STORAGE_OPTIONS if storage_options is None else storage_options self._csv_write_options = _extract_csv_write_options(self._storage_options)
[docs] def clear_cache(self) -> None: """Clear cached storage backend resolutions for this pipeline instance.""" self._resolved_backend_cache.clear()
def _resolve_backend( self, destination: StorageDestination, backend_options: "dict[str, Any] | None" ) -> "tuple[ObjectStoreProtocol, str, str]": """Resolve storage backend and normalized path for a destination.""" cache_key = _make_resolved_backend_cache_key(destination, backend_options) if cache_key is None: return _resolve_storage_backend(self.registry, destination, backend_options) cached = self._resolved_backend_cache.get(cache_key) if cached is not None: return cached resolved = _resolve_storage_backend(self.registry, destination, backend_options) self._resolved_backend_cache[cache_key] = resolved return resolved async def write_rows( self, rows: "list[dict[str, Any]]", destination: StorageDestination, *, format_hint: StorageFormat | None = None, storage_options: "dict[str, Any] | None" = None, ) -> StorageTelemetry: serialized = serialize_collection(rows) format_choice = format_hint or "jsonl" payload = await async_(_encode_row_payload)(serialized, format_choice) resolved_options = _resolve_pipeline_storage_options(self._storage_options, storage_options) return await self._write_bytes_async( payload, destination, rows=len(serialized), format_label=format_choice, storage_options=resolved_options ) async def write_arrow( self, table: "ArrowTable", destination: StorageDestination, *, format_hint: StorageFormat | None = None, storage_options: "dict[str, Any] | None" = None, compression: str | None = None, ) -> StorageTelemetry: format_choice = format_hint or "parquet" resolved_options = _resolve_pipeline_storage_options(self._storage_options, storage_options) format_write_options = _get_csv_write_options( format_choice, resolved_options, self._storage_options, self._csv_write_options ) payload = await async_(_encode_arrow_payload)( table, format_choice, compression=compression, write_options=format_write_options ) return await self._write_bytes_async( payload, destination, rows=int(table.num_rows), format_label=format_choice, storage_options=resolved_options ) async def cleanup_staging_artifacts(self, artifacts: "list[StagedArtifact]", *, ignore_errors: bool = True) -> None: for artifact in artifacts: backend, path, backend_name = self._resolve_backend(artifact["uri"], None) if supports_async_delete(backend): try: await execute_async_storage_operation( partial(backend.delete_async, path), backend=backend_name, operation="delete", path=path ) except Exception: if not ignore_errors: raise continue try: await async_(_delete_backend_sync)(backend=backend, path=path, backend_name=backend_name) except Exception: if not ignore_errors: raise async def _write_bytes_async( self, payload: bytes, destination: StorageDestination, *, rows: int, format_label: str, storage_options: "dict[str, Any]", ) -> StorageTelemetry: backend, path, backend_name = self._resolve_backend(destination, storage_options) start = perf_counter() if supports_async_write_bytes(backend): await execute_async_storage_operation( partial(backend.write_bytes_async, path, payload), backend=backend_name, operation="write_bytes", path=path, ) else: await async_(_write_backend_sync)(backend=backend, path=path, payload=payload, backend_name=backend_name) elapsed = perf_counter() - start bytes_written = len(payload) _METRICS.record_bytes(bytes_written) telemetry: StorageTelemetry = { "destination": path, "bytes_processed": bytes_written, "rows_processed": rows, "duration_s": elapsed, "format": format_label, "backend": backend_name, } return telemetry async def read_arrow_async( self, source: StorageDestination, *, file_format: StorageFormat, storage_options: "dict[str, Any] | None" = None ) -> "tuple[ArrowTable, StorageTelemetry]": backend, path, backend_name = self._resolve_backend(source, storage_options) if supports_async_read_bytes(backend): payload = await execute_async_storage_operation( partial(backend.read_bytes_async, path), backend=backend_name, operation="read_bytes", path=path ) else: payload = await async_(_read_backend_sync)(backend=backend, path=path, backend_name=backend_name) table = await async_(_decode_arrow_payload)(payload, file_format) rows_processed = int(table.num_rows) telemetry: StorageTelemetry = { "destination": path, "bytes_processed": len(payload), "rows_processed": rows_processed, "format": file_format, "backend": backend_name, } return table, telemetry
[docs] async def stream_read_async( self, source: StorageDestination, *, chunk_size: int | None = None, storage_options: "dict[str, Any] | None" = None, ) -> "AsyncIterator[bytes]": """Stream bytes from an artifact asynchronously.""" backend, path, _backend_name = self._resolve_backend(source, storage_options) return await backend.stream_read_async(path, chunk_size=chunk_size)