import logging
from collections.abc import Iterable
from contextlib import suppress
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, NoReturn, cast, overload
from litestar.di import Provide
from litestar.exceptions import NotFoundException
from litestar.middleware import DefineMiddleware
from litestar.plugins import CLIPlugin, InitPluginProtocol, OpenAPISchemaPlugin
from sqlspec.base import SQLSpec
from sqlspec.config import (
AsyncConfigT,
AsyncDatabaseConfig,
DatabaseConfigProtocol,
DriverT,
NoPoolAsyncConfig,
NoPoolSyncConfig,
SyncConfigT,
SyncDatabaseConfig,
)
from sqlspec.core._pagination import OffsetPagination
from sqlspec.core.sqlcommenter import SQLCommenterContext
from sqlspec.exceptions import ImproperConfigurationError, NotFoundError
from sqlspec.extensions.litestar._utils import (
delete_sqlspec_scope_state,
get_sqlspec_scope_state,
set_sqlspec_scope_state,
)
from sqlspec.extensions.litestar.handlers import (
autocommit_handler_maker,
connection_provider_maker,
lifespan_handler_maker,
manual_handler_maker,
pool_provider_maker,
session_provider_maker,
)
from sqlspec.typing import NUMPY_INSTALLED, ConnectionT, PoolT, SchemaT
from sqlspec.utils.correlation import CorrelationContext
from sqlspec.utils.logging import get_logger, log_with_context
from sqlspec.utils.serializers import DEFAULT_TYPE_ENCODERS, numpy_array_dec_hook
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractAsyncContextManager
from litestar import Litestar, Request
from litestar._openapi.schema_generation.schema import SchemaCreator
from litestar.config.app import AppConfig
from litestar.datastructures.state import State
from litestar.openapi.spec import Schema
from litestar.types import ASGIApp, BeforeMessageSendHookHandler, Receive, Scope, Send
from litestar.typing import FieldDefinition
from rich_click import Group
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
from sqlspec.loader import SQLFileLoader
logger = get_logger("sqlspec.extensions.litestar")
CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"]
DEFAULT_COMMIT_MODE: CommitMode = "manual"
DEFAULT_CONNECTION_KEY = "db_connection"
DEFAULT_POOL_KEY = "db_pool"
DEFAULT_SESSION_KEY = "db_session"
DEFAULT_CORRELATION_HEADER = "x-request-id"
TRACE_CONTEXT_FALLBACK_HEADERS: tuple[str, ...] = (
DEFAULT_CORRELATION_HEADER,
"x-correlation-id",
"traceparent",
"x-cloud-trace-context",
"grpc-trace-bin",
"x-amzn-trace-id",
"x-b3-traceid",
"x-client-trace-id",
)
CORRELATION_STATE_KEY = "sqlspec_correlation_id"
_LITESTAR_NUMPY_ARRAY_TYPE: type[Any] | None = None
__all__ = (
"CORRELATION_STATE_KEY",
"DEFAULT_COMMIT_MODE",
"DEFAULT_CONNECTION_KEY",
"DEFAULT_CORRELATION_HEADER",
"DEFAULT_POOL_KEY",
"DEFAULT_SESSION_KEY",
"TRACE_CONTEXT_FALLBACK_HEADERS",
"CommitMode",
"CorrelationMiddleware",
"PluginConfigState",
"SQLSpecPlugin",
"_OffsetPaginationSchemaPlugin",
"not_found_error_handler",
)
def not_found_error_handler(_request: "Request[Any, Any, Any]", exc: NotFoundError) -> NoReturn:
"""Translate :class:`sqlspec.exceptions.NotFoundError` into Litestar's HTTP 404.
Re-raised as :class:`litestar.exceptions.NotFoundException` so the standard
Litestar exception-handler chain renders it (including any RFC 7807 handler
the user has registered) and the OpenAPI 404 schema stays consistent.
"""
detail = str(exc) or "Not Found"
raise NotFoundException(detail=detail) from exc
class CorrelationMiddleware:
__slots__ = ("_app", "_headers")
def __init__(self, app: "ASGIApp", *, headers: tuple[str, ...]) -> None:
self._app = app
self._headers = headers
async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None:
scope_type = scope.get("type")
if str(scope_type) != "http" or not self._headers:
await self._app(scope, receive, send)
return
header_value: str | None = None
raw_headers = scope.get("headers") or []
for header in self._headers:
for name, value in raw_headers:
if name.decode().lower() == header:
header_value = value.decode()
break
if header_value:
break
if not header_value:
header_value = CorrelationContext.generate()
previous_correlation_id = CorrelationContext.get()
CorrelationContext.set(header_value)
set_sqlspec_scope_state(scope, CORRELATION_STATE_KEY, header_value)
try:
await self._app(scope, receive, send)
finally:
with suppress(KeyError):
delete_sqlspec_scope_state(scope, CORRELATION_STATE_KEY)
CorrelationContext.set(previous_correlation_id)
@dataclass
class PluginConfigState:
"""Internal state for each database configuration."""
config: "DatabaseConfigProtocol[Any, Any, Any]"
connection_key: str
pool_key: str
session_key: str
commit_mode: CommitMode
extra_commit_statuses: "set[int] | None"
extra_rollback_statuses: "set[int] | None"
enable_correlation_middleware: bool
correlation_header: str
enable_sqlcommenter_middleware: bool
correlation_headers: tuple[str, ...] = field(init=False)
disable_di: bool
connection_provider: "Callable[[State, Scope], AsyncGenerator[Any, None]]" = field(init=False)
pool_provider: "Callable[[State, Scope], Any]" = field(init=False)
session_provider: "Callable[..., AsyncGenerator[Any, None]]" = field(init=False)
before_send_handler: "BeforeMessageSendHookHandler" = field(init=False)
lifespan_handler: "Callable[[Litestar], AbstractAsyncContextManager[None]]" = field(init=False)
annotation: "type[DatabaseConfigProtocol[Any, Any, Any]]" = field(init=False)
[docs]
class SQLSpecPlugin(InitPluginProtocol, CLIPlugin):
"""Litestar plugin for SQLSpec database integration.
Automatically configures NumPy array serialization when NumPy is installed,
enabling seamless bidirectional conversion between NumPy arrays and JSON
for vector embedding workflows.
Session Table Migrations:
The Litestar extension includes migrations for creating session storage tables.
To include these migrations in your database migration workflow, add 'litestar'
to the include_extensions list in your migration configuration.
Example:
config = AsyncpgConfig(
connection_config={"dsn": "postgresql://localhost/db"},
extension_config={
"litestar": {
"session_table": "custom_sessions" # Optional custom table name
}
},
migration_config={
"script_location": "migrations",
"include_extensions": ["litestar"], # Simple string list only
}
)
The session table migration will automatically use the appropriate column types
for your database dialect (JSONB for PostgreSQL, JSON for MySQL, TEXT for SQLite).
Extension migrations use the ext_litestar_ prefix (e.g., ext_litestar_0001) to
prevent version conflicts with application migrations.
"""
__slots__ = ("_correlation_headers", "_enable_sqlcommenter_middleware", "_plugin_configs", "_sqlspec")
[docs]
def __init__(self, sqlspec: SQLSpec, *, loader: "SQLFileLoader | None" = None) -> None:
"""Initialize SQLSpec plugin.
Args:
sqlspec: Pre-configured SQLSpec instance with registered database configs.
loader: Optional SQL file loader instance (SQLSpec may already have one).
"""
self._sqlspec = sqlspec
self._plugin_configs: list[PluginConfigState] = []
for cfg in self._sqlspec.configs.values():
config_union = cast(
"SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]",
cfg,
)
settings = self._extract_litestar_settings(config_union)
state = self._create_config_state(config_union, settings)
self._plugin_configs.append(state)
correlation_headers: list[str] = []
enable_sqlcommenter = False
for state in self._plugin_configs:
if state.enable_sqlcommenter_middleware and state.config.statement_config.enable_sqlcommenter:
enable_sqlcommenter = True
if not state.enable_correlation_middleware:
continue
for header in state.correlation_headers:
if header not in correlation_headers:
correlation_headers.append(header)
self._correlation_headers = tuple(correlation_headers)
self._enable_sqlcommenter_middleware = enable_sqlcommenter
log_with_context(
logger,
logging.DEBUG,
"extension.init",
framework="litestar",
stage="init",
config_count=len(self._plugin_configs),
correlation_headers=len(self._correlation_headers),
)
def _extract_litestar_settings(
self,
config: "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]",
) -> "dict[str, Any]":
"""Extract Litestar settings from config.extension_config."""
litestar_config = config.extension_config.get("litestar", {})
connection_key = litestar_config.get("connection_key", DEFAULT_CONNECTION_KEY)
pool_key = litestar_config.get("pool_key", DEFAULT_POOL_KEY)
session_key = litestar_config.get("session_key", DEFAULT_SESSION_KEY)
commit_mode = litestar_config.get("commit_mode", DEFAULT_COMMIT_MODE)
if not config.supports_connection_pooling and pool_key == DEFAULT_POOL_KEY:
pool_key = f"_{DEFAULT_POOL_KEY}_{id(config)}"
correlation_header = str(litestar_config.get("correlation_header", DEFAULT_CORRELATION_HEADER)).lower()
configured_headers = _normalize_header_list(litestar_config.get("correlation_headers"))
auto_trace_headers = bool(litestar_config.get("auto_trace_headers", True))
return {
"connection_key": connection_key,
"pool_key": pool_key,
"session_key": session_key,
"commit_mode": commit_mode,
"extra_commit_statuses": litestar_config.get("extra_commit_statuses"),
"extra_rollback_statuses": litestar_config.get("extra_rollback_statuses"),
"enable_correlation_middleware": litestar_config.get("enable_correlation_middleware", True),
"correlation_header": correlation_header,
"correlation_headers": _build_correlation_headers(
primary=correlation_header, configured=configured_headers, auto_trace_headers=auto_trace_headers
),
"disable_di": litestar_config.get("disable_di", False),
"enable_sqlcommenter_middleware": litestar_config.get("enable_sqlcommenter_middleware", True),
}
def _create_config_state(
self,
config: "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]",
settings: "dict[str, Any]",
) -> PluginConfigState:
"""Create plugin state with handlers for the given configuration."""
state = PluginConfigState(
config=config,
connection_key=settings["connection_key"],
pool_key=settings["pool_key"],
session_key=settings["session_key"],
commit_mode=settings["commit_mode"],
extra_commit_statuses=settings.get("extra_commit_statuses"),
extra_rollback_statuses=settings.get("extra_rollback_statuses"),
enable_correlation_middleware=settings["enable_correlation_middleware"],
correlation_header=settings["correlation_header"],
enable_sqlcommenter_middleware=settings["enable_sqlcommenter_middleware"],
disable_di=settings["disable_di"],
)
state.correlation_headers = tuple(settings["correlation_headers"])
if not state.disable_di:
self._setup_handlers(state)
return state
def _setup_handlers(self, state: PluginConfigState) -> None:
"""Setup handlers for the plugin state."""
connection_key = state.connection_key
pool_key = state.pool_key
commit_mode = state.commit_mode
config = state.config
is_async = config.is_async
state.connection_provider = connection_provider_maker(config, pool_key, connection_key)
state.pool_provider = pool_provider_maker(config, pool_key)
state.session_provider = session_provider_maker(config, connection_key)
state.lifespan_handler = lifespan_handler_maker(config, pool_key)
if commit_mode == "manual":
state.before_send_handler = manual_handler_maker(connection_key, is_async)
else:
commit_on_redirect = commit_mode == "autocommit_include_redirect"
state.before_send_handler = autocommit_handler_maker(
connection_key, is_async, commit_on_redirect, state.extra_commit_statuses, state.extra_rollback_statuses
)
@property
def config(
self,
) -> "list[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]":
"""Return the plugin configurations.
Returns:
List of database configurations.
"""
return [
cast(
"SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]",
state.config,
)
for state in self._plugin_configs
]
[docs]
def on_cli_init(self, cli: "Group") -> None:
"""Configure CLI commands for SQLSpec database operations.
Args:
cli: The Click command group to add commands to.
"""
from sqlspec.extensions.litestar.cli import database_group
cli.add_command(database_group)
[docs]
def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
"""Configure Litestar application with SQLSpec database integration.
Automatically registers NumPy array serialization when NumPy is installed.
Args:
app_config: The Litestar application configuration instance.
Returns:
The updated application configuration instance.
"""
self._validate_dependency_keys()
def store_sqlspec_in_state() -> None:
app_config.state.sqlspec = self
app_config.on_startup.append(store_sqlspec_in_state)
app_config.signature_types.extend([SQLSpec, DatabaseConfigProtocol, SyncConfigT, AsyncConfigT])
signature_namespace = {"ConnectionT": ConnectionT, "PoolT": PoolT, "DriverT": DriverT, "SchemaT": SchemaT}
for state in self._plugin_configs:
state.annotation = type(state.config)
app_config.signature_types.append(state.annotation)
app_config.signature_types.append(state.config.connection_type)
app_config.signature_types.append(state.config.driver_type)
signature_namespace.update(state.config.get_signature_namespace())
if not state.disable_di:
app_config.before_send.append(state.before_send_handler)
app_config.lifespan.append(state.lifespan_handler)
app_config.dependencies.update({
state.connection_key: Provide(state.connection_provider),
state.pool_key: Provide(state.pool_provider),
state.session_key: Provide(state.session_provider),
})
if signature_namespace:
app_config.signature_namespace.update(signature_namespace)
if NUMPY_INSTALLED and (ndarray_type := _get_litestar_numpy_array_type()) is not None:
import numpy as np
numpy_namespace = _NumpySignatureNamespace(np, ndarray_type)
app_config.signature_namespace.update({
"np": numpy_namespace,
"numpy": numpy_namespace,
"ndarray": ndarray_type,
})
if not any(isinstance(p, _OffsetPaginationSchemaPlugin) for p in app_config.plugins):
app_config.plugins.append(_OffsetPaginationSchemaPlugin())
if app_config.exception_handlers is None:
app_config.exception_handlers = {}
app_config.exception_handlers.setdefault(NotFoundError, not_found_error_handler)
# Inject sqlspec's DEFAULT_TYPE_ENCODERS into Litestar's response serializer
# (user-supplied encoders win on conflict). Litestar's per-handler
# resolve_type_encoders() merges these with route/controller/router-level
# overrides automatically — no bidirectional thread needed.
app_config.type_encoders = {**DEFAULT_TYPE_ENCODERS, **(app_config.type_encoders or {})}
sqlspec_decoders = _build_litestar_type_decoders()
if sqlspec_decoders:
app_config.type_decoders = [*(app_config.type_decoders or []), *sqlspec_decoders]
if self._correlation_headers:
middleware = DefineMiddleware(CorrelationMiddleware, headers=self._correlation_headers)
existing_middleware = list(app_config.middleware or [])
existing_middleware.append(middleware)
app_config.middleware = existing_middleware
if self._enable_sqlcommenter_middleware:
sc_middleware = DefineMiddleware(SQLCommenterMiddleware)
existing_middleware = list(app_config.middleware or [])
existing_middleware.append(sc_middleware)
app_config.middleware = existing_middleware
log_with_context(
logger,
logging.DEBUG,
"extension.init",
framework="litestar",
stage="configured",
config_count=len(self._plugin_configs),
correlation_headers=len(self._correlation_headers),
numpy_enabled=bool(NUMPY_INSTALLED),
)
return app_config
[docs]
def get_annotations(
self,
) -> "list[type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]]":
"""Return the list of annotations.
Returns:
List of annotations.
"""
return [
cast(
"type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]",
state.annotation,
)
for state in self._plugin_configs
]
[docs]
def get_annotation(
self,
key: "str | SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any] | type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]",
) -> "type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]":
"""Return the annotation for the given configuration.
Args:
key: The configuration instance or key to lookup.
Raises:
KeyError: If no configuration is found for the given key.
Returns:
The annotation for the configuration.
"""
for state in self._plugin_configs:
if key in {state.config, state.annotation} or key in {state.connection_key, state.pool_key}:
return cast(
"type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]",
state.annotation,
)
msg = f"No configuration found for {key}"
raise KeyError(msg)
@overload
def get_config(
self, name: "type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any]]"
) -> "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any]": ...
@overload
def get_config(
self, name: "type[AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]"
) -> "AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]": ...
@overload
def get_config(
self, name: str
) -> "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]": ...
[docs]
def get_config(
self, name: "type[DatabaseConfigProtocol[Any, Any, Any]] | str | Any"
) -> "DatabaseConfigProtocol[Any, Any, Any]":
"""Get a configuration instance by name.
Args:
name: The configuration identifier.
Raises:
KeyError: If no configuration is found for the given name.
Returns:
The configuration instance for the specified name.
"""
if isinstance(name, str):
for state in self._plugin_configs:
if name in {state.connection_key, state.pool_key, state.session_key}:
return cast("DatabaseConfigProtocol[Any, Any, Any]", state.config) # type: ignore[redundant-cast]
for state in self._plugin_configs:
if name in {state.config, state.annotation}:
return cast("DatabaseConfigProtocol[Any, Any, Any]", state.config) # type: ignore[redundant-cast]
msg = f"No database configuration found for name '{name}'. Available keys: {self._get_available_keys()}"
raise KeyError(msg)
def _ensure_connection_sync(self, plugin_state: PluginConfigState, state: "State", scope: "Scope") -> Any:
"""Ensure a connection exists in scope, creating one from the pool if needed (sync)."""
connection = get_sqlspec_scope_state(scope, plugin_state.connection_key)
if connection is not None:
return connection
pool = state.get(plugin_state.pool_key)
if pool is None:
self._raise_missing_connection(plugin_state.connection_key)
cm = plugin_state.config.provide_connection(pool)
connection = cm.__enter__() # type: ignore[union-attr]
set_sqlspec_scope_state(scope, plugin_state.connection_key, connection)
return connection
async def _ensure_connection_async(self, plugin_state: PluginConfigState, state: "State", scope: "Scope") -> Any:
"""Ensure a connection exists in scope, creating one from the pool if needed (async)."""
connection = get_sqlspec_scope_state(scope, plugin_state.connection_key)
if connection is not None:
return connection
pool = state.get(plugin_state.pool_key)
if pool is None:
self._raise_missing_connection(plugin_state.connection_key)
cm = plugin_state.config.provide_connection(pool)
connection = await cm.__aenter__() # type: ignore[union-attr]
set_sqlspec_scope_state(scope, plugin_state.connection_key, connection)
return connection
def _create_session(
self, plugin_state: PluginConfigState, connection: Any, scope: "Scope"
) -> "SyncDriverAdapterBase | AsyncDriverAdapterBase":
"""Create a session from a connection and store it in scope."""
session_scope_key = f"{plugin_state.session_key}_instance"
session = get_sqlspec_scope_state(scope, session_scope_key)
if session is not None:
return cast("SyncDriverAdapterBase | AsyncDriverAdapterBase", session)
session = plugin_state.config.driver_type(
connection=connection,
statement_config=plugin_state.config.statement_config,
driver_features=plugin_state.config.driver_features,
)
set_sqlspec_scope_state(scope, session_scope_key, session)
return cast("SyncDriverAdapterBase | AsyncDriverAdapterBase", session)
@overload
def provide_request_session(
self,
key: "SyncDatabaseConfig[Any, Any, DriverT] | NoPoolSyncConfig[Any, DriverT] | type[SyncDatabaseConfig[Any, Any, DriverT] | NoPoolSyncConfig[Any, DriverT]]",
state: "State",
scope: "Scope",
) -> "DriverT": ...
@overload
def provide_request_session(
self,
key: "AsyncDatabaseConfig[Any, Any, DriverT] | NoPoolAsyncConfig[Any, DriverT] | type[AsyncDatabaseConfig[Any, Any, DriverT] | NoPoolAsyncConfig[Any, DriverT]]",
state: "State",
scope: "Scope",
) -> "DriverT": ...
@overload
def provide_request_session(
self, key: str, state: "State", scope: "Scope"
) -> "SyncDriverAdapterBase | AsyncDriverAdapterBase": ...
[docs]
def provide_request_session(
self,
key: "str | SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any] | type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]",
state: "State",
scope: "Scope",
) -> "SyncDriverAdapterBase | AsyncDriverAdapterBase":
"""Provide a database session for the specified configuration key from request scope.
This method requires the connection to already exist in scope (e.g., from DI injection).
For on-demand connection creation, use ``provide_request_session_sync`` or
``provide_request_session_async`` instead.
Args:
key: The configuration identifier (same as get_config).
state: The Litestar application State object.
scope: The ASGI scope containing the request context.
Returns:
A driver session instance for the specified database configuration.
"""
plugin_state = self._get_plugin_state(key)
connection = get_sqlspec_scope_state(scope, plugin_state.connection_key)
if connection is None:
self._raise_missing_connection(plugin_state.connection_key)
return self._create_session(plugin_state, connection, scope)
@overload
def provide_request_session_sync(
self,
key: "SyncDatabaseConfig[Any, Any, DriverT] | NoPoolSyncConfig[Any, DriverT]",
state: "State",
scope: "Scope",
) -> "DriverT": ...
@overload
def provide_request_session_sync(
self,
key: "type[SyncDatabaseConfig[Any, Any, DriverT] | NoPoolSyncConfig[Any, DriverT]]",
state: "State",
scope: "Scope",
) -> "DriverT": ...
@overload
def provide_request_session_sync(self, key: str, state: "State", scope: "Scope") -> "SyncDriverAdapterBase": ...
[docs]
def provide_request_session_sync(
self,
key: "str | SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any]]",
state: "State",
scope: "Scope",
) -> "SyncDriverAdapterBase | Any":
"""Provide a sync database session for the specified configuration key from request scope.
If no connection exists in scope, one will be created from the pool and stored
in scope for reuse. The connection will be cleaned up by the before_send handler.
For async configurations, use ``provide_request_session_async`` instead.
Args:
key: The configuration identifier (same as get_config).
state: The Litestar application State object.
scope: The ASGI scope containing the request context.
Returns:
A sync driver session instance for the specified database configuration.
"""
plugin_state = self._get_plugin_state(key)
connection = self._ensure_connection_sync(plugin_state, state, scope)
return cast("SyncDriverAdapterBase", self._create_session(plugin_state, connection, scope))
@overload
async def provide_request_session_async(
self,
key: "AsyncDatabaseConfig[Any, Any, DriverT] | NoPoolAsyncConfig[Any, DriverT]",
state: "State",
scope: "Scope",
) -> "DriverT": ...
@overload
async def provide_request_session_async(
self,
key: "type[AsyncDatabaseConfig[Any, Any, DriverT] | NoPoolAsyncConfig[Any, DriverT]]",
state: "State",
scope: "Scope",
) -> "DriverT": ...
@overload
async def provide_request_session_async(
self, key: str, state: "State", scope: "Scope"
) -> "AsyncDriverAdapterBase": ...
[docs]
async def provide_request_session_async(
self,
key: "str | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any] | type[AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]",
state: "State",
scope: "Scope",
) -> "AsyncDriverAdapterBase | Any":
"""Provide an async database session for the specified configuration key from request scope.
If no connection exists in scope, one will be created from the pool and stored
in scope for reuse. The connection will be cleaned up by the before_send handler.
For sync configurations, use ``provide_request_session`` instead.
Args:
key: The configuration identifier (same as get_config).
state: The Litestar application State object.
scope: The ASGI scope containing the request context.
Returns:
An async driver session instance for the specified database configuration.
"""
plugin_state = self._get_plugin_state(key)
connection = await self._ensure_connection_async(plugin_state, state, scope)
return cast("AsyncDriverAdapterBase", self._create_session(plugin_state, connection, scope))
@overload
def provide_request_connection(
self,
key: "SyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolSyncConfig[ConnectionT, Any] | AsyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolAsyncConfig[ConnectionT, Any]",
state: "State",
scope: "Scope",
) -> "ConnectionT": ...
@overload
def provide_request_connection(
self,
key: "type[SyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolSyncConfig[ConnectionT, Any] | AsyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolAsyncConfig[ConnectionT, Any]]",
state: "State",
scope: "Scope",
) -> "ConnectionT": ...
@overload
def provide_request_connection(self, key: str, state: "State", scope: "Scope") -> Any: ...
[docs]
def provide_request_connection(
self,
key: "str | SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any] | type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]",
state: "State",
scope: "Scope",
) -> Any:
"""Provide a database connection for the specified configuration key from request scope.
This method requires the connection to already exist in scope (e.g., from DI injection).
For on-demand connection creation, use ``provide_request_connection_sync`` or
``provide_request_connection_async`` instead.
Args:
key: The configuration identifier (same as get_config).
state: The Litestar application State object.
scope: The ASGI scope containing the request context.
Returns:
A database connection instance for the specified database configuration.
"""
plugin_state = self._get_plugin_state(key)
connection = get_sqlspec_scope_state(scope, plugin_state.connection_key)
if connection is None:
self._raise_missing_connection(plugin_state.connection_key)
return connection
@overload
def provide_request_connection_sync(
self,
key: "SyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolSyncConfig[ConnectionT, Any]",
state: "State",
scope: "Scope",
) -> "ConnectionT": ...
@overload
def provide_request_connection_sync(
self,
key: "type[SyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolSyncConfig[ConnectionT, Any]]",
state: "State",
scope: "Scope",
) -> "ConnectionT": ...
@overload
def provide_request_connection_sync(self, key: str, state: "State", scope: "Scope") -> Any: ...
[docs]
def provide_request_connection_sync(
self,
key: "str | SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any]]",
state: "State",
scope: "Scope",
) -> Any:
"""Provide a sync database connection for the specified configuration key from request scope.
If no connection exists in scope, one will be created from the pool and stored
in scope for reuse. The connection will be cleaned up by the before_send handler.
For async configurations, use ``provide_request_connection_async`` instead.
Args:
key: The configuration identifier (same as get_config).
state: The Litestar application State object.
scope: The ASGI scope containing the request context.
Returns:
A database connection instance for the specified database configuration.
"""
plugin_state = self._get_plugin_state(key)
return self._ensure_connection_sync(plugin_state, state, scope)
@overload
async def provide_request_connection_async(
self,
key: "AsyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolAsyncConfig[ConnectionT, Any]",
state: "State",
scope: "Scope",
) -> "ConnectionT": ...
@overload
async def provide_request_connection_async(
self,
key: "type[AsyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolAsyncConfig[ConnectionT, Any]]",
state: "State",
scope: "Scope",
) -> "ConnectionT": ...
@overload
async def provide_request_connection_async(self, key: str, state: "State", scope: "Scope") -> Any: ...
[docs]
async def provide_request_connection_async(
self,
key: "str | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any] | type[AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]",
state: "State",
scope: "Scope",
) -> Any:
"""Provide an async database connection for the specified configuration key from request scope.
If no connection exists in scope, one will be created from the pool and stored
in scope for reuse. The connection will be cleaned up by the before_send handler.
For sync configurations, use ``provide_request_connection`` instead.
Args:
key: The configuration identifier (same as get_config).
state: The Litestar application State object.
scope: The ASGI scope containing the request context.
Returns:
A database connection instance for the specified database configuration.
"""
plugin_state = self._get_plugin_state(key)
return await self._ensure_connection_async(plugin_state, state, scope)
def _get_plugin_state(
self, key: "str | DatabaseConfigProtocol[Any, Any, Any] | type[DatabaseConfigProtocol[Any, Any, Any]]"
) -> PluginConfigState:
"""Get plugin state for a configuration by key."""
if isinstance(key, str):
for state in self._plugin_configs:
if key in {state.connection_key, state.pool_key, state.session_key}:
return state
for state in self._plugin_configs:
if key in {state.config, state.annotation}:
return state
self._raise_config_not_found(key)
return None
def _get_available_keys(self) -> "list[str]":
"""Get a list of all available configuration keys for error messages."""
keys = []
for state in self._plugin_configs:
keys.extend([state.connection_key, state.pool_key, state.session_key])
return keys
def _validate_dependency_keys(self) -> None:
"""Validate that connection and pool keys are unique across configurations."""
connection_keys = [state.connection_key for state in self._plugin_configs]
pool_keys = [state.pool_key for state in self._plugin_configs]
if len(set(connection_keys)) != len(connection_keys):
self._raise_duplicate_connection_keys()
if len(set(pool_keys)) != len(pool_keys):
self._raise_duplicate_pool_keys()
def _raise_missing_connection(self, connection_key: str) -> None:
"""Raise error when connection is not found in scope."""
msg = f"No database connection found in scope for key '{connection_key}'. "
msg += "Ensure the connection dependency is properly configured and available."
raise ImproperConfigurationError(detail=msg)
def _raise_config_not_found(self, key: Any) -> NoReturn:
"""Raise error when configuration is not found."""
msg = f"No database configuration found for name '{key}'. Available keys: {self._get_available_keys()}"
raise KeyError(msg)
def _raise_duplicate_connection_keys(self) -> None:
"""Raise error when connection keys are not unique."""
msg = "When using multiple database configuration, each configuration must have a unique `connection_key`."
raise ImproperConfigurationError(detail=msg)
def _raise_duplicate_pool_keys(self) -> None:
"""Raise error when pool keys are not unique."""
msg = "When using multiple database configuration, each configuration must have a unique `pool_key`."
raise ImproperConfigurationError(detail=msg)
class SQLCommenterMiddleware:
"""ASGI middleware that populates SQLCommenterContext with Litestar request attributes.
Extracts route, controller, and action from the Litestar scope and sets them
in :class:`~sqlspec.extensions.sqlcommenter.SQLCommenterContext` for the
duration of the request.
"""
__slots__ = ("app",)
def __init__(self, app: "ASGIApp") -> None:
self.app = app
async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None:
if str(scope.get("type")) != "http":
await self.app(scope, receive, send)
return
attrs: dict[str, str] = {"route": scope.get("path", ""), "framework": "litestar"}
handler = scope.get("route_handler")
if handler is not None:
fn = getattr(handler, "fn", None)
if fn is not None:
attrs["action"] = getattr(fn, "__name__", "")
owner = getattr(handler, "owner", None)
if owner is not None:
attrs["controller"] = getattr(owner, "__name__", "")
previous = SQLCommenterContext.get()
SQLCommenterContext.set(attrs)
try:
await self.app(scope, receive, send)
finally:
SQLCommenterContext.set(previous)
class _OffsetPaginationSchemaPlugin(OpenAPISchemaPlugin):
"""OpenAPI schema plugin expanding OffsetPagination[T] into a concrete schema.
Defense-in-depth for sqlspec.core.filters.OffsetPagination. The msgspec.Struct
conversion already lets Litestar's default generator produce a correct schema;
this plugin guarantees the shape even if future Litestar or msgspec changes
break auto-detection.
"""
@staticmethod
def is_plugin_supported_type(value: Any) -> bool:
origin = getattr(value, "__origin__", value)
return isinstance(origin, type) and issubclass(origin, OffsetPagination)
def to_openapi_schema(self, field_definition: "FieldDefinition", schema_creator: "SchemaCreator") -> "Schema":
from litestar.openapi.spec import OpenAPIType, Schema
from litestar.typing import FieldDefinition
inner_type: Any = Any
inner_args = getattr(field_definition, "inner_types", ())
if inner_args:
inner_type = inner_args[0].annotation
item_schema = schema_creator.for_field_definition(FieldDefinition.from_annotation(inner_type))
return Schema(
type=OpenAPIType.OBJECT,
properties={
"items": Schema(type=OpenAPIType.ARRAY, items=item_schema),
"limit": Schema(type=OpenAPIType.INTEGER),
"offset": Schema(type=OpenAPIType.INTEGER),
"total": Schema(type=OpenAPIType.INTEGER),
},
required=["items", "limit", "offset", "total"],
)
class _NumpySignatureNamespace:
"""Proxy ``numpy`` module namespace that maps ``ndarray`` to SQLSpec's decoder-safe subclass."""
__slots__ = ("_numpy", "ndarray")
def __init__(self, numpy_module: Any, ndarray_type: type[Any]) -> None:
self._numpy = numpy_module
self.ndarray = ndarray_type
def __getattr__(self, name: str) -> Any:
return getattr(self._numpy, name)
def _normalize_header_list(headers: Any) -> list[str]:
if headers is None:
return []
if isinstance(headers, str):
return [headers.lower()]
if isinstance(headers, Iterable):
normalized: list[str] = []
for header in headers:
if not isinstance(header, str):
msg = "litestar correlation headers must be strings"
raise ImproperConfigurationError(msg)
normalized.append(header.lower())
return normalized
msg = "litestar correlation_headers must be a string or iterable of strings"
raise ImproperConfigurationError(msg)
def _dedupe_headers(headers: Iterable[str]) -> list[str]:
seen: set[str] = set()
ordered: list[str] = []
for header in headers:
lowered = header.lower()
if lowered in seen or not lowered:
continue
seen.add(lowered)
ordered.append(lowered)
return ordered
def _build_correlation_headers(*, primary: str, configured: list[str], auto_trace_headers: bool) -> tuple[str, ...]:
header_order: list[str] = [primary.lower()]
header_order.extend(configured)
if auto_trace_headers:
header_order.extend(TRACE_CONTEXT_FALLBACK_HEADERS)
return tuple(_dedupe_headers(header_order))
def _get_litestar_numpy_array_type() -> type[Any] | None:
"""Return a mutable ndarray subclass Litestar can attach a decoder to."""
if not NUMPY_INSTALLED:
return None
global _LITESTAR_NUMPY_ARRAY_TYPE
if _LITESTAR_NUMPY_ARRAY_TYPE is None:
import numpy as np
class SQLSpecNumpyArray(np.ndarray):
pass
_LITESTAR_NUMPY_ARRAY_TYPE = SQLSpecNumpyArray
return _LITESTAR_NUMPY_ARRAY_TYPE
def _litestar_numpy_array_predicate(target_type: Any) -> bool:
ndarray_type = _get_litestar_numpy_array_type()
return ndarray_type is not None and target_type is ndarray_type
def _litestar_numpy_array_dec_hook(target_type: type[Any], value: Any) -> Any:
decoded = numpy_array_dec_hook(target_type, value)
if NUMPY_INSTALLED:
import numpy as np
if isinstance(decoded, np.ndarray) and isinstance(target_type, type) and not isinstance(decoded, target_type):
return decoded.view(target_type)
return decoded
def _build_litestar_type_decoders() -> "list[tuple[Callable[[Any], bool], Callable[[type, Any], Any]]]":
"""Build the Litestar-specific ``type_decoders`` list.
Decoders are predicate-tuples consumed by Litestar's request-body parsing,
not part of sqlspec's serializer registry — so they live here rather than
in :data:`sqlspec.utils.serializers.DEFAULT_TYPE_ENCODERS`.
"""
decoders: list[tuple[Callable[[Any], bool], Callable[[type, Any], Any]]] = []
if NUMPY_INSTALLED:
decoders.append((_litestar_numpy_array_predicate, _litestar_numpy_array_dec_hook))
with suppress(ImportError):
import uuid_utils # pyright: ignore[reportMissingImports]
decoders.append((lambda t: t is uuid_utils.UUID, lambda t, v: t(str(v))))
return decoders