Source code for sparqlmodel.async_session

"""Async ORM unit of work over a graph store (:class:`AsyncSPARQLSession`)."""

from __future__ import annotations

from types import TracebackType
from typing import Any

from triplemodel import Store
from typing_extensions import Self

from sparqlmodel import session_core
from sparqlmodel.async_query import AsyncQuery
from sparqlmodel.graph import (
    cascade_subjects_for_removal,
    owned_triples_for_subjects,
    triples_to_graph,
)
from sparqlmodel.model import SPARQLModel
from sparqlmodel.rdf_bridge import model_to_graph
from sparqlmodel.session_state import SessionState, identity_key, identity_key_for_iri
from sparqlmodel.stores.async_base import AsyncStoreProtocol
from sparqlmodel.stores.async_memory import AsyncMemoryStore
from sparqlmodel.types import IRI, NamespaceRegistry


[docs] class AsyncSPARQLSession: """Async ORM session: CRUD, queries, and graph sync with the backing store. Use as an async context manager:: async with AsyncSPARQLSession(store=AsyncHttpStore(endpoint)) as session: await session.put(model) Same identity map, hydration, and cascade semantics as :class:`~sparqlmodel.session.SPARQLSession`. Not safe to share across concurrent asyncio tasks; use one session per task. """
[docs] def __init__( self, store: AsyncStoreProtocol | None = None, *, prefixes: dict[str, str] | None = None, autoflush: bool = True, close_on_exit: bool = True, rollback_on_error: bool = True, ) -> None: self._store: AsyncStoreProtocol = store or AsyncMemoryStore(prefixes=prefixes) store_prefixes = getattr(self._store, "namespaces", None) store_pfx = store_prefixes.prefixes if store_prefixes else {} merged_prefixes = {**store_pfx, **(prefixes or {})} self._namespaces = NamespaceRegistry(merged_prefixes) self._namespaces.bind(self._store.graph) self._state = SessionState() self.autoflush = autoflush self.close_on_exit = close_on_exit self.rollback_on_error = rollback_on_error self._closed = False
@property def store(self) -> AsyncStoreProtocol: return self._store @property def namespaces(self) -> NamespaceRegistry: return self._namespaces @property def graph(self) -> Store: return self._store.graph def _check_open(self) -> None: if self._closed: raise RuntimeError(session_core.CLOSED_ASYNC_SESSION_MSG)
[docs] async def flush(self) -> None: """Write all pending models queued with ``put(..., flush=False)``.""" self._check_open() pending = list(self._state.pending) index = 0 try: while index < len(pending): await session_core.put_impl_async(self._store, self._state, pending[index]) index += 1 except Exception: self._state.clear_pending() for model in pending[index:]: self._state.add_pending(model) raise self._state.clear_pending()
[docs] async def rollback_pending(self) -> None: """Discard pending models without writing to the store.""" self._check_open() self._state.clear_pending()
[docs] async def close(self) -> None: """Close the backing store when it implements ``aclose()``.""" if self._closed: return if self._state.pending: n = len(self._state.pending) raise RuntimeError( f"Cannot close AsyncSPARQLSession with {n} pending put(s); " "call flush() or rollback_pending() first" ) aclose = getattr(self._store, "aclose", None) if callable(aclose): await aclose() self._closed = True
async def __aenter__(self) -> Self: self._check_open() return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: flush_err: BaseException | None = None try: if exc_type is None: if self._state.pending: try: await self.flush() except Exception as exc: flush_err = exc elif self.rollback_on_error: await self.rollback_pending() finally: if self.close_on_exit: try: await self.close() except RuntimeError as close_err: if flush_err is not None: raise flush_err from close_err if exc_type is None or self.rollback_on_error: raise if flush_err is not None: raise flush_err
[docs] async def expire(self, model_cls: type[SPARQLModel], iri: str | IRI) -> None: """Remove a resource from the identity map and hydration cache.""" self._check_open() session_core.expire_impl(self._state, model_cls, iri)
[docs] async def expunge(self, model: SPARQLModel) -> None: """Detach ``model`` from the identity map and hydration cache (store unchanged).""" self._check_open() session_core.expunge_impl(self._state, model)
[docs] async def expunge_all(self) -> None: """Clear the identity map and hydration cache; pending ``put`` queue is kept.""" self._check_open() session_core.expunge_all_impl(self._state)
[docs] async def refresh(self, model: SPARQLModel, *, depth: int = 0) -> SPARQLModel: """Reload ``model`` from the store at ``depth`` (updates cached instance when present).""" self._check_open() await self._maybe_autoflush() return await session_core.refresh_impl_async(self._state, self._store, model, depth=depth)
[docs] async def merge(self, model: SPARQLModel) -> SPARQLModel: """Return the session instance for ``model``'s identity key (no store write).""" self._check_open() await self._maybe_autoflush() return session_core.merge_impl(self._state, model)
async def _maybe_autoflush(self) -> None: if self.autoflush and self._state.pending: await self.flush()
[docs] async def add(self, model: SPARQLModel) -> SPARQLModel: """Insert model triples into the store (no delete).""" self._check_open() await self._maybe_autoflush() model.ensure_id() id_key = identity_key(model) self._state.remove_pending_for(type(model), id_key[1]) session_core.check_stale_add(self._store.graph, model) g = model_to_graph(model) await self._store.update_graph(add=g) self._state.set_identity(model) session_core._register_embedded_identities(self._state, model) session_core.invalidate_cascade_keys(self._state, self._store.graph, model, for_put=False) return model
[docs] async def put(self, model: SPARQLModel, *, flush: bool = True) -> SPARQLModel: """Upsert model and cascaded embedded resources.""" self._check_open() if flush: await self._maybe_autoflush() return await session_core.put_impl_async(self._store, self._state, model) model.ensure_id() assert model.id is not None session_core.invalidate_cascade_keys(self._state, self._store.graph, model, for_put=True) key = identity_key_for_iri(type(model), model.id) self._state.evict_identity_prefix(key[0], key[1]) self._state.add_pending(model) self._state.invalidate_hydration_for_iri(key[1]) return model
[docs] async def delete(self, model: SPARQLModel) -> None: """Remove owned triples for the model and cascaded embedded resources.""" self._check_open() await self._maybe_autoflush() model.ensure_id() subjects = cascade_subjects_for_removal(model, self._store.graph, for_put=False) session_core.remove_pending_for_subjects(self._state, subjects) session_core.invalidate_cascade_keys(self._state, self._store.graph, model, for_put=False) remove_g = triples_to_graph(owned_triples_for_subjects(subjects, self._store.graph)) if len(remove_g): await self._store.update_graph(remove=remove_g) self._state.expire_model(model)
[docs] async def get( self, model_cls: type[SPARQLModel], iri: str | IRI, *, depth: int = 0, ) -> SPARQLModel | None: """Load a model by IRI with optional relationship depth.""" self._check_open() await self._maybe_autoflush() return await session_core.get_impl_async( self._state, self._store, model_cls, iri, depth=depth )
[docs] async def hydrate_bindings( self, model_cls: type[SPARQLModel], bindings: list[dict[str, Any]], *, depth: int = 0, polymorphic: bool = False, ) -> list[SPARQLModel]: """Hydrate query results with identity map and session cache.""" self._check_open() await self._maybe_autoflush() async def get_fn( mcls: type[SPARQLModel], subject_iri: str | IRI, *, depth: int = 0, ) -> SPARQLModel | None: return await session_core.get_impl_async( self._state, self._store, mcls, subject_iri, depth=depth, polymorphic=polymorphic, ) return await session_core.hydrate_bindings_impl_async( self._state, self._store, model_cls, bindings, depth=depth, get_fn=get_fn, )
[docs] def query(self, model_cls: type[SPARQLModel]) -> AsyncQuery: """Start a fluent async query for the given model class.""" self._check_open() return AsyncQuery(self, model_cls)
[docs] async def execute(self, sparql: str) -> list[dict[str, Any]]: """Execute raw SPARQL SELECT.""" self._check_open() await self._maybe_autoflush() if not session_core.sparql_has_prefix_declarations(sparql): prefix_block = self._namespaces.sparql_prefixes() if prefix_block: sparql = f"{prefix_block}\n\n{sparql}" return await self._store.query(sparql)