Source code for sparqlmodel.stores.http_common

"""Shared SPARQL HTTP helpers for sync and async endpoint stores."""

from __future__ import annotations

import asyncio
import base64
import re
import time
from collections.abc import Iterable, Iterator, Mapping
from typing import Any, Literal, cast
from urllib.parse import parse_qsl, urlencode, urljoin, urlparse, urlunparse

import httpx
from triplemodel import Store, load_graph

from sparqlmodel.exceptions import ConfigurationError, QueryError
from sparqlmodel.graph import _subject_pattern
from sparqlmodel.rdf_n3 import triple_to_n3
from sparqlmodel.types import IRI, expand_iri

QueryMethod = Literal["post", "get"]
_VALID_QUERY_METHODS = frozenset({"post", "get"})
_RETRYABLE_STATUS_CODES = frozenset({502, 503, 504})
_MAX_BACKOFF_SECONDS = 30.0

MirrorMode = Literal["writer", "remote_authoritative"]

_VALID_MIRROR_MODES = frozenset({"writer", "remote_authoritative"})


def validate_mirror_mode(mirror_mode: str) -> MirrorMode:
    if mirror_mode not in _VALID_MIRROR_MODES:
        raise ValueError(
            f"mirror_mode must be 'writer' or 'remote_authoritative', got {mirror_mode!r}"
        )
    return cast(MirrorMode, mirror_mode)


def validate_query_method(query_method: str) -> QueryMethod:
    if query_method not in _VALID_QUERY_METHODS:
        raise ValueError(f"query_method must be 'post' or 'get', got {query_method!r}")
    return cast(QueryMethod, query_method)


def validate_http_resilience(
    *,
    max_retries: int,
    retry_backoff: float,
    max_triples_per_update: int,
) -> None:
    if max_retries < 0:
        raise ValueError(f"max_retries must be >= 0, got {max_retries}")
    if retry_backoff < 0:
        raise ValueError(f"retry_backoff must be >= 0, got {retry_backoff}")
    if max_triples_per_update < 1:
        raise ValueError(f"max_triples_per_update must be >= 1, got {max_triples_per_update}")


def is_retryable_status(status_code: int) -> bool:
    return status_code in _RETRYABLE_STATUS_CODES


def is_retryable_exception(exc: BaseException) -> bool:
    if isinstance(exc, httpx.HTTPStatusError):
        return is_retryable_status(exc.response.status_code)
    return isinstance(exc, (httpx.ConnectError, httpx.TimeoutException, httpx.NetworkError))


def _backoff_seconds(retry_backoff: float, attempt: int) -> float:
    return min(retry_backoff * (2**attempt), _MAX_BACKOFF_SECONDS)


def _chunk_triples(graph: Store, max_triples: int) -> list[list[tuple[Any, Any, Any]]]:
    triples = list(graph)
    if not triples:
        return []
    chunks: list[list[tuple[Any, Any, Any]]] = []
    for index in range(0, len(triples), max_triples):
        chunks.append(triples[index : index + max_triples])
    return chunks


[docs] def iter_graph_chunks(graph: Store, max_triples: int) -> Iterator[Store]: """Yield sub-stores with at most ``max_triples`` triples each.""" for batch in _chunk_triples(graph, max_triples): chunk = Store() for triple in batch: chunk.add(triple) yield chunk
[docs] def build_update_chunks( remove: Store | None, add: Store | None, max_triples: int, ) -> list[str]: """Build ordered SPARQL UPDATE strings: all DELETE chunks, then all INSERT chunks.""" chunks: list[str] = [] if remove is not None and len(remove): for graph_chunk in iter_graph_chunks(remove, max_triples): update = graph_to_delete_data(graph_chunk) if update: chunks.append(update) if add is not None and len(add): for graph_chunk in iter_graph_chunks(add, max_triples): update = graph_to_insert_data(graph_chunk) if update: chunks.append(update) return chunks
[docs] def append_query_params(url: str, **params: str) -> str: """Append query parameters to ``url``, preserving any existing query string.""" parsed = urlparse(url) query_items = list(parse_qsl(parsed.query, keep_blank_values=True)) query_items.extend(params.items()) return urlunparse(parsed._replace(query=urlencode(query_items)))
[docs] def expand_subject_iris( iris: Iterable[str | IRI], prefixes: dict[str, str], ) -> list[str]: """Expand compact IRIs to absolute form for CONSTRUCT VALUES and mirror sync.""" unique_raw = list(dict.fromkeys(str(i) for i in iris)) expanded: list[str] = [] for raw in unique_raw: try: expanded.append(expand_iri(raw, prefixes)) except ConfigurationError as exc: raise QueryError(f"Invalid IRI for CONSTRUCT: {exc}") from exc return expanded
[docs] def request_with_retry( client: httpx.Client, method: str, url: str, *, operation: str, max_retries: int, retry_backoff: float, **kwargs: Any, ) -> httpx.Response: """Execute an HTTP request with retries on transient failures.""" last_exc: BaseException | None = None for attempt in range(max_retries + 1): try: response = client.request(method.upper(), url, **kwargs) if is_retryable_status(response.status_code): if attempt >= max_retries: response.raise_for_status() response.close() time.sleep(_backoff_seconds(retry_backoff, attempt)) continue response.raise_for_status() return response except httpx.HTTPError as exc: last_exc = exc if not is_retryable_exception(exc) or attempt >= max_retries: raise QueryError(f"{operation} failed: {exc}") from exc time.sleep(_backoff_seconds(retry_backoff, attempt)) raise QueryError(f"{operation} failed: {last_exc}") from last_exc # pragma: no cover
[docs] async def async_request_with_retry( client: httpx.AsyncClient, method: str, url: str, *, operation: str, max_retries: int, retry_backoff: float, **kwargs: Any, ) -> httpx.Response: """Execute an async HTTP request with retries on transient failures.""" last_exc: BaseException | None = None for attempt in range(max_retries + 1): try: response = await client.request(method.upper(), url, **kwargs) if is_retryable_status(response.status_code): if attempt >= max_retries: response.raise_for_status() await response.aclose() await asyncio.sleep(_backoff_seconds(retry_backoff, attempt)) continue response.raise_for_status() return response except httpx.HTTPError as exc: last_exc = exc if not is_retryable_exception(exc) or attempt >= max_retries: raise QueryError(f"{operation} failed: {exc}") from exc await asyncio.sleep(_backoff_seconds(retry_backoff, attempt)) raise QueryError(f"{operation} failed: {last_exc}") from last_exc # pragma: no cover
[docs] def execute_select( client: httpx.Client, url: str, sparql: str, *, query_method: QueryMethod, max_retries: int, retry_backoff: float, ) -> httpx.Response: """Run a remote SELECT using GET or POST.""" if query_method == "get": query_url = append_query_params(url, query=sparql) return request_with_retry( client, "GET", query_url, operation="SPARQL query", max_retries=max_retries, retry_backoff=retry_backoff, headers={"Accept": "application/sparql-results+json"}, ) return request_with_retry( client, "POST", url, operation="SPARQL query", max_retries=max_retries, retry_backoff=retry_backoff, content=sparql.encode("utf-8"), headers=SPARQL_QUERY_HEADERS, )
[docs] async def async_execute_select( client: httpx.AsyncClient, url: str, sparql: str, *, query_method: QueryMethod, max_retries: int, retry_backoff: float, ) -> httpx.Response: """Run a remote SELECT using GET or POST (async).""" if query_method == "get": query_url = append_query_params(url, query=sparql) return await async_request_with_retry( client, "GET", query_url, operation="SPARQL query", max_retries=max_retries, retry_backoff=retry_backoff, headers={"Accept": "application/sparql-results+json"}, ) return await async_request_with_retry( client, "POST", url, operation="SPARQL query", max_retries=max_retries, retry_backoff=retry_backoff, content=sparql.encode("utf-8"), headers=SPARQL_QUERY_HEADERS, )
[docs] def remove_mirror_subjects( graph: Store, iris: Iterable[str | IRI], prefixes: dict[str, str], ) -> None: """Remove all mirror triples whose subject is one of ``iris``.""" unique = list(dict.fromkeys(str(i) for i in iris)) for iri in unique: subject = _subject_pattern(iri, prefixes) for triple in list(graph.triples((subject, None, None))): graph.remove(triple)
[docs] def apply_construct_to_mirror( graph: Store, remote: Store | None, *, subjects: Iterable[str | IRI], prefixes: dict[str, str], ) -> None: """Replace mirror triples for ``subjects`` with triples from ``remote`` (may be empty).""" remove_mirror_subjects(graph, subjects, prefixes) if remote is not None: for triple in remote: graph.add(triple)
def graph_to_insert_data(graph: Store) -> str: if len(graph) == 0: return "" lines = [f" {triple_to_n3(s, p, o)} ." for s, p, o in graph] return "INSERT DATA {\n" + "\n".join(lines) + "\n}" def graph_to_delete_data(graph: Store) -> str: if len(graph) == 0: return "" lines = [f" {triple_to_n3(s, p, o)} ." for s, p, o in graph] return "DELETE DATA {\n" + "\n".join(lines) + "\n}"
[docs] def sparql_url(endpoint: str) -> str: """Normalize a SPARQL endpoint URL, preserving an existing query string.""" base, sep, query = endpoint.partition("?") if base.endswith("/sparql") or base.endswith("/query") or base.endswith("/update"): path = base else: path = urljoin(base.rstrip("/") + "/", "sparql") if sep: return f"{path}?{query}" return path
_SPARQL_QUERY_KIND = re.compile( r"\b(SELECT|ASK|CONSTRUCT|DESCRIBE|INSERT|DELETE)\b", re.IGNORECASE, ) _BLOCK_COMMENT_RE = re.compile(r"/\*.*?\*/", re.DOTALL) def _strip_block_comments(text: str) -> str: return _BLOCK_COMMENT_RE.sub("", text)
[docs] def is_select_query(sparql: str) -> bool: """Return True when ``sparql`` appears to be a SPARQL SELECT (not ASK/CONSTRUCT/DESCRIBE).""" text = _strip_block_comments(sparql) while True: stripped = text.lstrip() if not stripped: return False if stripped.startswith("#"): newline = stripped.find("\n") if newline == -1: return False text = stripped[newline + 1 :] continue upper = stripped.upper() if (upper.startswith("PREFIX ") or upper.startswith("BASE ")) and "\n" in stripped: newline = stripped.find("\n") text = stripped[newline + 1 :] continue match = _SPARQL_QUERY_KIND.search(stripped) if match is None: return False return match.group(1).upper() == "SELECT"
def build_request_headers( *, headers: Mapping[str, str] | None = None, auth: tuple[str, str] | None = None, bearer_token: str | None = None, ) -> dict[str, str]: req_headers = dict(headers or {}) if bearer_token: req_headers["Authorization"] = f"Bearer {bearer_token}" if auth is not None: user, password = auth token = base64.b64encode(f"{user}:{password}".encode()).decode("ascii") req_headers["Authorization"] = f"Basic {token}" return req_headers SPARQL_QUERY_HEADERS = { "Content-Type": "application/sparql-query", "Accept": "application/sparql-results+json", } SPARQL_UPDATE_HEADERS = { "Content-Type": "application/sparql-update", "Accept": "*/*", } GSP_ACCEPT_HEADERS = { "Accept": "text/turtle, application/n-triples;q=0.9, */*;q=0.1", }
[docs] def default_graph_store_url(sparql_endpoint: str) -> str | None: """Heuristic Fuseki GSP URL from a SPARQL endpoint (``.../sparql`` → ``.../data``).""" base, sep, query = sparql_endpoint.partition("?") path = base.rstrip("/") if path.endswith("/sparql"): gsp = path[: -len("/sparql")] + "/data" return f"{gsp}?{query}" if sep else gsp return None
[docs] def replace_mirror_from_graph(target: Store, remote: Store | None) -> None: """Replace all triples in ``target`` with triples from ``remote`` (empty clears mirror).""" for triple in list(target): target.remove(triple) if remote is not None: for triple in remote: target.add(triple)
def _gsp_format_from_content_type(content_type: str | None) -> str: if not content_type: return "turtle" media = content_type.split(";", 1)[0].strip().lower() if media in ("text/turtle", "application/x-turtle"): return "turtle" if media in ("application/n-triples", "text/plain"): return "nt" if media in ("application/trig", "application/x-trig"): return "trig" return "turtle"
[docs] def parse_gsp_response(content: bytes, content_type: str | None) -> Store: """Parse a Graph Store HTTP GET body into a :class:`~triplemodel.Store`.""" if not content.strip(): return Store() fmt = _gsp_format_from_content_type(content_type) try: return load_graph(data=content, format=fmt) except Exception as exc: raise QueryError(f"Failed to parse Graph Store response: {exc}") from exc
[docs] def parse_construct_response(content: bytes, content_type: str | None) -> Store: """Parse a SPARQL CONSTRUCT response body into a :class:`~triplemodel.Store`.""" if not content.strip(): return Store() fmt = _gsp_format_from_content_type(content_type) try: return load_graph(data=content, format=fmt) except Exception as exc: raise QueryError(f"Failed to parse CONSTRUCT response: {exc}") from exc
[docs] def fetch_graph_store( client: httpx.Client, url: str, *, max_retries: int, retry_backoff: float, ) -> tuple[bytes, str | None]: """GET an RDF graph via Graph Store HTTP; return body and ``Content-Type``.""" response = request_with_retry( client, "GET", url, operation="Graph Store GET", max_retries=max_retries, retry_backoff=retry_backoff, headers=GSP_ACCEPT_HEADERS, ) content_type = response.headers.get("Content-Type") return response.content, content_type
[docs] async def async_fetch_graph_store( client: httpx.AsyncClient, url: str, *, max_retries: int, retry_backoff: float, ) -> tuple[bytes, str | None]: """Async GET for Graph Store HTTP.""" response = await async_request_with_retry( client, "GET", url, operation="Graph Store GET", max_retries=max_retries, retry_backoff=retry_backoff, headers=GSP_ACCEPT_HEADERS, ) content_type = response.headers.get("Content-Type") return response.content, content_type