"""Schema class — the central immutable data structure of Helix IR."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Iterable
import pyarrow as pa
from helix_ir.exceptions import PathNotFoundError
from helix_ir.schema.path import Path, PathSegment
from helix_ir.types.core import HelixType
[docs]
@dataclass(frozen=True)
class Schema:
"""Immutable schema descriptor.
Fields are stored as a tuple of (name, HelixType) pairs to preserve order
while remaining hashable.
"""
name: str
fields: tuple[tuple[str, HelixType], ...] = field(default_factory=tuple)
# -------------------------------------------------------------------------
# Field access
# -------------------------------------------------------------------------
[docs]
def field(self, name: str) -> HelixType:
"""Return the HelixType for a top-level field by name."""
for fname, ftype in self.fields:
if fname == name:
return ftype
raise PathNotFoundError(f"Field {name!r} not found in schema {self.name!r}")
[docs]
def get_field(self, name: str) -> HelixType | None:
"""Return the HelixType for a top-level field, or None if not found."""
for fname, ftype in self.fields:
if fname == name:
return ftype
return None
[docs]
def field_names(self) -> list[str]:
"""Return a list of top-level field names."""
return [name for name, _ in self.fields]
[docs]
def path(self, path: str) -> HelixType:
"""Return the HelixType at the given dotted path.
Example: schema.path("customer.address.city")
"""
p = Path.parse(path)
return self._resolve_path(p)
def _resolve_path(self, p: Path) -> HelixType:
"""Recursively resolve a Path within this schema."""
if not p.segments:
raise PathNotFoundError("Empty path")
seg = p.segments[0]
rest = Path(segments=p.segments[1:])
if seg.kind == "array_element":
raise PathNotFoundError("Cannot start a path with an array element segment")
current = self.field(seg.name) # type: ignore[arg-type]
# Walk remaining segments
for seg in rest.segments:
if seg.kind == "array_element":
if not pa.types.is_list(current.arrow_type):
raise PathNotFoundError(
f"Expected list type at {p}, got {current.arrow_type}"
)
current = HelixType(arrow_type=current.arrow_type.value_type)
else:
if not pa.types.is_struct(current.arrow_type):
raise PathNotFoundError(
f"Expected struct type, got {current.arrow_type}"
)
arrow_struct = current.arrow_type
try:
idx = arrow_struct.get_field_index(seg.name) # type: ignore[attr-defined]
except Exception:
idx = -1
if idx < 0:
raise PathNotFoundError(
f"Field {seg.name!r} not found in struct"
)
current = HelixType(arrow_type=arrow_struct.field(idx).type)
return current
# -------------------------------------------------------------------------
# Walking
# -------------------------------------------------------------------------
[docs]
def walk(self) -> Iterable[tuple[Path, HelixType]]:
"""Yield (path, HelixType) for every leaf field in the schema."""
yield from self._walk_fields(Path.ROOT, self.fields)
def _walk_fields(
self, base: Path, fields: tuple[tuple[str, HelixType], ...]
) -> Iterable[tuple[Path, HelixType]]:
for name, ht in fields:
p = base.append(name)
yield p, ht
if pa.types.is_struct(ht.arrow_type):
nested = _struct_to_fields(ht.arrow_type)
yield from self._walk_struct(p, nested)
elif pa.types.is_list(ht.arrow_type):
elem_type = ht.arrow_type.value_type
elem_ht = HelixType(arrow_type=elem_type)
elem_path = p.array_element()
yield elem_path, elem_ht
if pa.types.is_struct(elem_type):
nested = _struct_to_fields(elem_type)
yield from self._walk_struct(elem_path, nested)
def _walk_struct(
self, base: Path, fields: list[tuple[str, HelixType]]
) -> Iterable[tuple[Path, HelixType]]:
for name, ht in fields:
p = base.append(name)
yield p, ht
if pa.types.is_struct(ht.arrow_type):
nested = _struct_to_fields(ht.arrow_type)
yield from self._walk_struct(p, nested)
elif pa.types.is_list(ht.arrow_type):
elem_type = ht.arrow_type.value_type
elem_ht = HelixType(arrow_type=elem_type)
elem_path = p.array_element()
yield elem_path, elem_ht
if pa.types.is_struct(elem_type):
nested = _struct_to_fields(elem_type)
yield from self._walk_struct(elem_path, nested)
[docs]
def walk_arrays(self) -> Iterable[tuple[Path, HelixType]]:
"""Yield (path, HelixType) for every array field in the schema."""
for p, ht in self.walk():
if pa.types.is_list(ht.arrow_type):
yield p, ht
# -------------------------------------------------------------------------
# Arrow interop
# -------------------------------------------------------------------------
[docs]
def to_arrow(self) -> pa.Schema:
"""Convert this schema to a PyArrow Schema."""
from helix_ir.types.arrow_interop import helix_schema_to_arrow
return helix_schema_to_arrow(list(self.fields))
[docs]
@classmethod
def from_arrow(cls, name: str, schema: pa.Schema) -> "Schema":
"""Create a Schema from a PyArrow Schema."""
from helix_ir.types.arrow_interop import arrow_schema_to_helix
pairs = arrow_schema_to_helix(schema)
return cls(name=name, fields=tuple(pairs))
# -------------------------------------------------------------------------
# JSON serialization
# -------------------------------------------------------------------------
[docs]
def to_json(self) -> dict[str, Any]:
"""Serialize this schema to a JSON-compatible dict."""
from helix_ir.schema.serialization import helix_type_to_json
return {
"name": self.name,
"fields": [
{"name": fname, "type": helix_type_to_json(ftype)}
for fname, ftype in self.fields
],
}
[docs]
@classmethod
def from_json(cls, data: dict[str, Any]) -> "Schema":
"""Deserialize a Schema from a JSON-compatible dict."""
from helix_ir.schema.serialization import helix_type_from_json
fields = tuple(
(f["name"], helix_type_from_json(f["type"])) for f in data["fields"]
)
return cls(name=data["name"], fields=fields)
# -------------------------------------------------------------------------
# Mutation (returns new Schema)
# -------------------------------------------------------------------------
[docs]
def add_field(self, name: str, ftype: HelixType) -> "Schema":
"""Return a new Schema with the given field added (or replaced)."""
new_fields = tuple(
(n, t) for n, t in self.fields if n != name
) + ((name, ftype),)
return Schema(name=self.name, fields=new_fields)
[docs]
def drop_field(self, name: str) -> "Schema":
"""Return a new Schema with the given field removed."""
return Schema(
name=self.name,
fields=tuple((n, t) for n, t in self.fields if n != name),
)
[docs]
def rename(self, new_name: str) -> "Schema":
"""Return a new Schema with a different name."""
return Schema(name=new_name, fields=self.fields)
def __len__(self) -> int:
return len(self.fields)
def __contains__(self, name: str) -> bool:
return any(n == name for n, _ in self.fields)
def __repr__(self) -> str:
field_strs = ", ".join(f"{n}: {t.arrow_type}" for n, t in self.fields)
return f"Schema({self.name!r}, [{field_strs}])"
def _struct_to_fields(t: pa.StructType) -> list[tuple[str, HelixType]]:
"""Convert a pa.StructType to a list of (name, HelixType) pairs."""
return [
(t.field(i).name, HelixType(arrow_type=t.field(i).type))
for i in range(t.num_fields)
]