Source code for helix_ir.test.generator

"""Automatic test generation from schema metadata."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable

import pyarrow as pa

from helix_ir.schema.path import Path
from helix_ir.schema.schema import Schema
from helix_ir.types.core import HelixType


[docs] @dataclass class Test: """A generated data quality test.""" name: str path: Path kind: str # 'not_null', 'type_check', 'cardinality', 'range', 'regex', 'pii' description: str severity: str = "error" # 'error', 'warning', 'info' sql_template: str | None = None metadata: dict[str, Any] = field(default_factory=dict)
[docs] @dataclass class TestResult: """Result of running a test.""" test: Test passed: bool message: str = "" actual_value: Any = None expected_value: Any = None
[docs] def generate_tests( schema: Schema, sensitivity: float = 1.5, ) -> list[Test]: """Generate data quality tests based on schema metadata. Args: schema: The schema to generate tests for. sensitivity: Multiplier for expected cardinality deviations. Higher = more permissive (fewer false positives). Returns: A list of Test objects. """ tests: list[Test] = [] for path, ht in schema.walk(): path_str = str(path) # NOT NULL test for non-nullable fields if ht.null_ratio == 0.0 and ht.sample_count > 0: tests.append( Test( name=f"not_null__{path_str}", path=path, kind="not_null", description=f"Column '{path_str}' must not contain NULL values", severity="error", sql_template=f"SELECT COUNT(*) FROM {{table}} WHERE \"{path_str}\" IS NULL", ) ) # NULL ratio test for nullable fields elif ht.null_ratio > 0.0 and ht.sample_count > 0: max_null_ratio = min(1.0, ht.null_ratio * sensitivity) tests.append( Test( name=f"null_ratio__{path_str}", path=path, kind="null_ratio", description=( f"Column '{path_str}' null ratio should not exceed " f"{max_null_ratio:.2%}" ), severity="warning", sql_template=( f"SELECT CAST(COUNT(*) FILTER (WHERE \"{path_str}\" IS NULL) AS FLOAT) " f"/ COUNT(*) FROM {{table}}" ), metadata={"expected_max_null_ratio": max_null_ratio}, ) ) # Type check (for string columns with semantic types) if ht.semantic and ht.semantic not in ("json_blob",) and not ht.semantic.startswith("union:"): tests.append( Test( name=f"semantic__{path_str}__{ht.semantic}", path=path, kind="semantic", description=f"Column '{path_str}' should match semantic type '{ht.semantic}'", severity="warning", metadata={"semantic": ht.semantic}, ) ) # Cardinality test for enum fields if ht.semantic == "enum" and ht.cardinality_estimate is not None: max_card = int(ht.cardinality_estimate * sensitivity) tests.append( Test( name=f"cardinality__{path_str}", path=path, kind="cardinality", description=( f"Column '{path_str}' cardinality should not exceed {max_card}" ), severity="warning", sql_template=f"SELECT COUNT(DISTINCT \"{path_str}\") FROM {{table}}", metadata={"expected_max_cardinality": max_card}, ) ) # PII test if ht.pii_class: tests.append( Test( name=f"pii__{path_str}__{ht.pii_class}", path=path, kind="pii", description=( f"Column '{path_str}' is classified as PII ({ht.pii_class}). " "Ensure proper masking/access controls." ), severity="info", metadata={"pii_class": ht.pii_class}, ) ) # Range tests for numeric fields if ( pa.types.is_integer(ht.arrow_type) or pa.types.is_floating(ht.arrow_type) ) and ht.min_value is not None and ht.max_value is not None: tests.append( Test( name=f"range__{path_str}", path=path, kind="range", description=( f"Column '{path_str}' values should be between " f"{ht.min_value} and {ht.max_value}" ), severity="warning", sql_template=( f"SELECT COUNT(*) FROM {{table}} WHERE " f"\"{path_str}\" < {ht.min_value} OR \"{path_str}\" > {ht.max_value}" ), metadata={"min": ht.min_value, "max": ht.max_value}, ) ) return tests
[docs] def run_test(test: Test, data: list[dict[str, Any]]) -> TestResult: """Run a test against in-memory data.""" try: if test.kind == "not_null": path_str = str(test.path) nulls = sum(1 for row in data if row.get(path_str) is None) if nulls > 0: return TestResult( test=test, passed=False, message=f"Found {nulls} null values", actual_value=nulls, expected_value=0, ) return TestResult(test=test, passed=True, message="No null values found") elif test.kind == "cardinality": path_str = str(test.path) values = {row.get(path_str) for row in data if row.get(path_str) is not None} actual = len(values) expected_max = test.metadata.get("expected_max_cardinality", float("inf")) if actual > expected_max: return TestResult( test=test, passed=False, message=f"Cardinality {actual} exceeds expected max {expected_max}", actual_value=actual, expected_value=expected_max, ) return TestResult(test=test, passed=True) # Default: pass (complex tests require SQL execution) return TestResult(test=test, passed=True, message="Test requires SQL execution") except Exception as e: return TestResult(test=test, passed=False, message=str(e))