Source code for helix_ir.transform.expression

"""Expression DSL for building SQL expressions."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING

import pyarrow as pa

if TYPE_CHECKING:
    from helix_ir.types.core import HelixType


class Expression:
    """Base class for all expressions in the Helix IR query DSL."""

    def __init__(self) -> None:
        self._alias: str | None = None

    def alias(self, name: str) -> "Expression":
        """Assign an alias to this expression."""
        e = self._copy()
        e._alias = name
        return e

    def _copy(self) -> "Expression":
        raise NotImplementedError

    # Comparison operators
    def __eq__(self, other: Any) -> "BinaryExpr":  # type: ignore[override]
        return BinaryExpr(self, "=", _wrap(other))

    def __ne__(self, other: Any) -> "BinaryExpr":  # type: ignore[override]
        return BinaryExpr(self, "<>", _wrap(other))

    def __lt__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(self, "<", _wrap(other))

    def __le__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(self, "<=", _wrap(other))

    def __gt__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(self, ">", _wrap(other))

    def __ge__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(self, ">=", _wrap(other))

    # Arithmetic operators
    def __add__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(self, "+", _wrap(other))

    def __radd__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(_wrap(other), "+", self)

    def __sub__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(self, "-", _wrap(other))

    def __rsub__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(_wrap(other), "-", self)

    def __mul__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(self, "*", _wrap(other))

    def __rmul__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(_wrap(other), "*", self)

    def __truediv__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(self, "/", _wrap(other))

    def __rtruediv__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(_wrap(other), "/", self)

    # Logical operators
    def __and__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(self, "AND", _wrap(other))

    def __or__(self, other: Any) -> "BinaryExpr":
        return BinaryExpr(self, "OR", _wrap(other))

    def __invert__(self) -> "UnaryExpr":
        return UnaryExpr("NOT", self)

    # Null checks
    def is_null(self) -> "FunctionExpr":
        return FunctionExpr("IS NULL", [self], postfix=True)

    def is_not_null(self) -> "FunctionExpr":
        return FunctionExpr("IS NOT NULL", [self], postfix=True)

    # String methods
    def like(self, pattern: str) -> "BinaryExpr":
        return BinaryExpr(self, "LIKE", Literal(pattern))

    def ilike(self, pattern: str) -> "BinaryExpr":
        return BinaryExpr(self, "ILIKE", Literal(pattern))

    def contains(self, substring: str) -> "FunctionExpr":
        return FunctionExpr("CONTAINS", [self, Literal(substring)])

    def startswith(self, prefix: str) -> "BinaryExpr":
        return BinaryExpr(self, "LIKE", Literal(f"{prefix}%"))

    def endswith(self, suffix: str) -> "BinaryExpr":
        return BinaryExpr(self, "LIKE", Literal(f"%{suffix}"))

    def upper(self) -> "FunctionExpr":
        return FunctionExpr("UPPER", [self])

    def lower(self) -> "FunctionExpr":
        return FunctionExpr("LOWER", [self])

    def trim(self) -> "FunctionExpr":
        return FunctionExpr("TRIM", [self])

    def cast(self, sql_type: str) -> "CastExpr":
        return CastExpr(self, sql_type)

    def in_(self, values: list[Any]) -> "InExpr":
        return InExpr(self, [_wrap(v) for v in values])

    def between(self, low: Any, high: Any) -> "BetweenExpr":
        return BetweenExpr(self, _wrap(low), _wrap(high))

    # Aggregation
    def sum(self) -> "FunctionExpr":
        return FunctionExpr("SUM", [self])

    def avg(self) -> "FunctionExpr":
        return FunctionExpr("AVG", [self])

    def min(self) -> "FunctionExpr":
        return FunctionExpr("MIN", [self])

    def max(self) -> "FunctionExpr":
        return FunctionExpr("MAX", [self])

    def count(self) -> "FunctionExpr":
        return FunctionExpr("COUNT", [self])

    def count_distinct(self) -> "FunctionExpr":
        return FunctionExpr("COUNT_DISTINCT", [self])

    # Sorting helpers
    def asc(self) -> "SortExpr":
        return SortExpr(self, descending=False)

    def desc(self) -> "SortExpr":
        return SortExpr(self, descending=True)

    def to_sql(self, dialect: str = "duckdb") -> str:
        """Compile this expression to a SQL string."""
        raise NotImplementedError(f"{type(self).__name__} must implement to_sql()")

    def __repr__(self) -> str:
        try:
            return f"{type(self).__name__}({self.to_sql()})"
        except Exception:
            return f"{type(self).__name__}(?)"


class Column(Expression):
    """A reference to a column by name."""

    def __init__(self, name: str, table: str | None = None) -> None:
        super().__init__()
        self.name = name
        self.table = table

    def _copy(self) -> "Column":
        c = Column(self.name, self.table)
        c._alias = self._alias
        return c

    def to_sql(self, dialect: str = "duckdb") -> str:
        if self.table:
            base = f'"{self.table}"."{self.name}"'
        else:
            base = f'"{self.name}"'
        if self._alias:
            return f"{base} AS \"{self._alias}\""
        return base


class Literal(Expression):
    """A literal scalar value."""

    def __init__(self, value: Any) -> None:
        super().__init__()
        self.value = value

    def _copy(self) -> "Literal":
        l = Literal(self.value)
        l._alias = self._alias
        return l

    def to_sql(self, dialect: str = "duckdb") -> str:
        sql = _literal_to_sql(self.value)
        if self._alias:
            return f"{sql} AS \"{self._alias}\""
        return sql


class BinaryExpr(Expression):
    """A binary operation: left OP right."""

    def __init__(self, left: Expression, op: str, right: Expression) -> None:
        super().__init__()
        self.left = left
        self.op = op
        self.right = right

    def _copy(self) -> "BinaryExpr":
        e = BinaryExpr(self.left, self.op, self.right)
        e._alias = self._alias
        return e

    def to_sql(self, dialect: str = "duckdb") -> str:
        left_sql = self.left.to_sql(dialect)
        right_sql = self.right.to_sql(dialect)
        base = f"({left_sql} {self.op} {right_sql})"
        if self._alias:
            return f"{base} AS \"{self._alias}\""
        return base


class UnaryExpr(Expression):
    """A unary operation: OP expr."""

    def __init__(self, op: str, operand: Expression) -> None:
        super().__init__()
        self.op = op
        self.operand = operand

    def _copy(self) -> "UnaryExpr":
        e = UnaryExpr(self.op, self.operand)
        e._alias = self._alias
        return e

    def to_sql(self, dialect: str = "duckdb") -> str:
        base = f"({self.op} {self.operand.to_sql(dialect)})"
        if self._alias:
            return f"{base} AS \"{self._alias}\""
        return base


class FunctionExpr(Expression):
    """A function call or postfix operator."""

    def __init__(
        self,
        func_name: str,
        args: list[Expression],
        postfix: bool = False,
    ) -> None:
        super().__init__()
        self.func_name = func_name
        self.args = args
        self.postfix = postfix

    def _copy(self) -> "FunctionExpr":
        e = FunctionExpr(self.func_name, list(self.args), self.postfix)
        e._alias = self._alias
        return e

    def to_sql(self, dialect: str = "duckdb") -> str:
        if self.postfix:
            # e.g. "col IS NULL"
            arg_sql = self.args[0].to_sql(dialect) if self.args else ""
            base = f"({arg_sql} {self.func_name})"
        elif self.func_name == "COUNT_DISTINCT":
            arg_sql = self.args[0].to_sql(dialect) if self.args else "*"
            base = f"COUNT(DISTINCT {arg_sql})"
        else:
            args_sql = ", ".join(a.to_sql(dialect) for a in self.args)
            base = f"{self.func_name}({args_sql})"
        if self._alias:
            return f"{base} AS \"{self._alias}\""
        return base


class CastExpr(Expression):
    """A CAST expression."""

    def __init__(self, expr: Expression, sql_type: str) -> None:
        super().__init__()
        self.expr = expr
        self.sql_type = sql_type

    def _copy(self) -> "CastExpr":
        e = CastExpr(self.expr, self.sql_type)
        e._alias = self._alias
        return e

    def to_sql(self, dialect: str = "duckdb") -> str:
        base = f"CAST({self.expr.to_sql(dialect)} AS {self.sql_type})"
        if self._alias:
            return f"{base} AS \"{self._alias}\""
        return base


class InExpr(Expression):
    """A value IN (...) expression."""

    def __init__(self, expr: Expression, values: list[Expression]) -> None:
        super().__init__()
        self.expr = expr
        self.values = values

    def _copy(self) -> "InExpr":
        e = InExpr(self.expr, list(self.values))
        e._alias = self._alias
        return e

    def to_sql(self, dialect: str = "duckdb") -> str:
        vals = ", ".join(v.to_sql(dialect) for v in self.values)
        base = f"({self.expr.to_sql(dialect)} IN ({vals}))"
        if self._alias:
            return f"{base} AS \"{self._alias}\""
        return base


class BetweenExpr(Expression):
    """A BETWEEN expression."""

    def __init__(self, expr: Expression, low: Expression, high: Expression) -> None:
        super().__init__()
        self.expr = expr
        self.low = low
        self.high = high

    def _copy(self) -> "BetweenExpr":
        e = BetweenExpr(self.expr, self.low, self.high)
        e._alias = self._alias
        return e

    def to_sql(self, dialect: str = "duckdb") -> str:
        base = (
            f"({self.expr.to_sql(dialect)} BETWEEN "
            f"{self.low.to_sql(dialect)} AND {self.high.to_sql(dialect)})"
        )
        if self._alias:
            return f"{base} AS \"{self._alias}\""
        return base


class SortExpr:
    """A sort specification: expression + direction."""

    def __init__(self, expr: Expression, descending: bool = False) -> None:
        self.expr = expr
        self.descending = descending

    def to_sql(self, dialect: str = "duckdb") -> str:
        direction = "DESC" if self.descending else "ASC"
        return f"{self.expr.to_sql(dialect)} {direction}"


class StarExpr(Expression):
    """SELECT * expression."""

    def _copy(self) -> "StarExpr":
        return StarExpr()

    def to_sql(self, dialect: str = "duckdb") -> str:
        return "*"


# -------------------------------------------------------------------------
# Convenience constructors
# -------------------------------------------------------------------------


[docs] def col(name: str, table: str | None = None) -> Column: """Create a Column reference.""" return Column(name, table)
[docs] def lit(value: Any) -> Literal: """Create a Literal value.""" return Literal(value)
[docs] def star() -> StarExpr: """Create a SELECT * expression.""" return StarExpr()
def _wrap(value: Any) -> Expression: """Wrap a Python value in an Expression if it isn't one already.""" if isinstance(value, Expression): return value return Literal(value) def _literal_to_sql(value: Any) -> str: """Convert a Python value to a SQL literal string.""" if value is None: return "NULL" if isinstance(value, bool): return "TRUE" if value else "FALSE" if isinstance(value, (int, float)): return str(value) if isinstance(value, str): escaped = value.replace("'", "''") return f"'{escaped}'" import datetime if isinstance(value, datetime.datetime): return f"TIMESTAMP '{value.isoformat()}'" if isinstance(value, datetime.date): return f"DATE '{value.isoformat()}'" escaped = str(value).replace("'", "''") return f"'{escaped}'"