from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any

from ..sources import Source
from .config import SOURCE_TABLE_SEPARATOR
from .utils import get_schema, truncate_iterable, truncate_string


@dataclass
class Column:
    """Schema for a column with its description."""

    name: str
    description: str | None = None
    metadata: dict[str, Any] = field(default_factory=dict)


@dataclass
class VectorMetadata:
    """Schema for vector lookup data for a single table."""

    table_slug: str  # Combined source_name and table_name with separator
    similarity: float
    description: str | None = None
    base_sql: str | None = None
    columns: list[Column] = field(default_factory=list)
    metadata: dict[str, Any] = field(default_factory=dict)


@dataclass
class VectorMetaset:
    """Schema container for vector data for multiple tables_metadata."""

    query: str
    vector_metadata_map: dict[str, VectorMetadata]
    selected_columns: dict[str, list[str]] = field(default_factory=dict)

    def _generate_context(self, truncate: bool | None = None) -> str:
        """
        Generate formatted text representation of the context.

        Args:
            truncate: Controls truncation behavior.
                    None: Show all tables (max_context)
                    False: Filter by selected_columns without truncation (selected_context)
                    True: Filter by selected_columns with truncation (min_context)
        """
        context = "Below are the relevant tables and columns to use:\n\n"

        # Use selected tables if specified, otherwise use all tables
        tables_to_show = self.selected_columns or self.vector_metadata_map.keys()

        for table_slug in self.vector_metadata_map.keys():
            # Skip tables not in the selected list if using selected_context or min_context
            if truncate is not None and table_slug not in tables_to_show:
                continue

            vector_metadata = self.vector_metadata_map[table_slug]
            context += f"\n\n{table_slug!r} Similarity: ({vector_metadata.similarity:.3f})\n"

            if vector_metadata.description:
                context += f"Description: {vector_metadata.description}\n"

            if vector_metadata.base_sql:
                context += f"Base SQL: {vector_metadata.base_sql}\n"

            max_length = 20
            cols_to_show = vector_metadata.columns
            if truncate is not None:
                cols_to_show = [col for col in cols_to_show if col.name in self.selected_columns.get(table_slug, [])]

            show_ellipsis = False
            if truncate:
                cols_to_show, original_indices, show_ellipsis = truncate_iterable(cols_to_show, max_length)
            else:
                cols_to_show = list(cols_to_show)
                original_indices = list(range(len(cols_to_show)))
                show_ellipsis = False

            for i, (col, orig_idx) in enumerate(zip(cols_to_show, original_indices)):
                if show_ellipsis and i == len(cols_to_show) // 2:
                    context += "...\n"

                col_name = truncate_string(col.name) if truncate else col.name
                context += f"{orig_idx}. {col_name!r}"
                if col.description:
                    col_desc = (
                        truncate_string(col.description, max_length=100)
                        if truncate
                        else col.description
                    )
                    context += f": {col_desc}"
                context += "\n"
        return context

    @property
    def max_context(self) -> str:
        """Generate formatted text representation of the context."""
        return self._generate_context(truncate=None)

    @property
    def selected_context(self) -> str:
        """Generate formatted text representation of the context with selected tables cols"""
        return self._generate_context(truncate=False)

    @property
    def min_context(self) -> str:
        """Generate formatted text representation of the context with selected tables cols and truncated strings"""
        return self._generate_context(truncate=True)

    def __str__(self) -> str:
        """String representation is the formatted context."""
        return self.selected_context


@dataclass
class SQLMetadata:
    """Schema for SQL schema data for a single table."""

    table_slug: str
    schema: dict[str, Any]
    base_sql: str | None = None
    view_definition: str | None = None
    metadata: dict[str, Any] = field(default_factory=dict)


@dataclass
class SQLMetaset:
    """Schema container for SQL data for multiple tables_metadata that builds on vector context."""

    vector_metaset: VectorMetaset
    sql_metadata_map: dict[str, SQLMetadata]

    def _generate_context(self, truncate: bool | None = None) -> str:
        """
        Generate formatted context with both vector and SQL data.

        Args:
            truncate: Controls truncation behavior.
                      None: Show all tables (max_context)
                      False: Filter by selected_columns without truncation (selected_context)
                      True: Filter by selected_columns with truncation (min_context)

        Returns:
            Formatted context string
        """
        context = "Below are the relevant tables and columns to use:\n\n"

        vector_metaset = self.vector_metaset
        tables_to_show = vector_metaset.selected_columns or vector_metaset.vector_metadata_map.keys()

        for table_slug in self.sql_metadata_map.keys():
            # Skip tables not in selected list for sub/min context
            if truncate is not None and table_slug not in tables_to_show:
                continue

            vector_metadata = vector_metaset.vector_metadata_map.get(table_slug)
            if not vector_metadata:
                continue

            context += f"{table_slug!r} Similarity: ({vector_metadata.similarity:.3f})\n"

            if vector_metadata.description:
                desc = truncate_string(vector_metadata.description, max_length=100) if truncate else vector_metadata.description
                context += f"Description: {desc}\n"

            sql_data: SQLMetadata = self.sql_metadata_map.get(table_slug)
            if sql_data:
                base_sql = truncate_string(sql_data.base_sql, max_length=200) if truncate else sql_data.base_sql
                context += f"Base SQL: {base_sql}\n"

                # Get the count from schema
                if sql_data.schema.get("__len__"):
                    context += f"Row count: {len(sql_data.schema)}\n"

            max_length = 20
            cols_to_show = vector_metadata.columns
            if truncate is not None and vector_metaset.selected_columns:
                cols_to_show = [col for col in cols_to_show if col.name in vector_metaset.selected_columns.get(table_slug, [])]

            original_indices = []
            show_ellipsis = False
            if truncate:
                cols_to_show, original_indices, show_ellipsis = truncate_iterable(cols_to_show, max_length)
            else:
                cols_to_show = list(cols_to_show)
                original_indices = list(range(len(cols_to_show)))
                show_ellipsis = False

            for i, (col, orig_idx) in enumerate(zip(cols_to_show, original_indices)):
                if show_ellipsis and i == len(cols_to_show) // 2:
                    context += "...\n"

                schema_data = None
                if sql_data and col.name in sql_data.schema:
                    schema_data = sql_data.schema[col.name]
                    if truncate is not None and schema_data == "<null>":
                        continue

                # Get column name with optional truncation
                col_name = truncate_string(col.name) if truncate else col.name
                context += f"{orig_idx}. {col_name!r}"

                # Get column description with optional truncation
                if col.description:
                    col_desc = truncate_string(col.description, max_length=100) if truncate else col.description
                    context += f": {col_desc}"

                # Add schema info for the column if available
                if schema_data:
                    if truncate:
                        schema_data = truncate_string(str(schema_data), max_length=50)
                    context += f" `{schema_data}`"

                context += "\n"

        return context

    @property
    def max_context(self) -> str:
        """Generate comprehensive formatted context with both vector and SQL data."""
        return self._generate_context(truncate=None)

    @property
    def selected_context(self) -> str:
        """Generate context with selected tables and columns, without truncation."""
        return self._generate_context(truncate=False)

    @property
    def min_context(self) -> str:
        """Generate context with selected tables and columns, with truncation."""
        return self._generate_context(truncate=True)

    @property
    def query(self) -> str:
        """Get the original query that generated this context."""
        return self.vector_metaset.query

    def __str__(self) -> str:
        """String representation is the formatted context."""
        return self.selected_context

@dataclass
class PreviousState:
    """Schema for previous state data."""

    query: str
    selected_columns: dict[str, list[str]] = field(default_factory=dict)

    @property
    def max_context(self) -> str:
        """Generate formatted text representation of the previous state."""
        context = f"Previous query: {self.query}\n\n"
        if self.selected_columns:
            for table_slug, cols in self.selected_columns.items():
                context += f"Table: {table_slug}\n"
                for col in cols:
                    context += f"- {col!r}\n"
        return context

    def __str__(self) -> str:
        """String representation is the formatted context."""
        return self.max_context


async def get_metaset(sources: dict[str, Source], tables: list[str]) -> SQLMetaset:
    """
    Get the metaset for the given sources and tables.

    Parameters
    ----------
    sources: dict[str, Source]
        The sources to get the metaset for.
    tables: list[str]
        The tables to get the metaset for.

    Returns
    -------
    metaset: SQLMetaset
        The metaset for the given sources and tables.
    """
    tables_info, tables_metadata = {}, {}
    for table_slug in tables:
        if SOURCE_TABLE_SEPARATOR in table_slug:
            source_name, table_name = table_slug.split(SOURCE_TABLE_SEPARATOR)
        elif len(sources) > 1:
            raise ValueError(
                f"Cannot resolve table {table_slug} without providing "
                "the source, when multiple sources are provided. Ensure "
                f"that you qualify the table name as follows:\n\n"
                "    <source>{SOURCE_TABLE_SEPARATOR}<table>"
            )
        else:
            source_name = next(iter(sources))
            table_name = table_slug
        source = sources[source_name]
        schema = await get_schema(source, table_name, include_count=True)
        tables_info[table_name] = SQLMetadata(
            table_slug=table_slug,
            schema=schema,
            base_sql=source.get_sql_expr(source.normalize_table(table_name)),
            view_definition=None,
        )
        metadata = source.get_metadata(table_name)
        tables_metadata[table_name] = VectorMetadata(
            table_slug=table_slug,
            similarity=1,
            description=metadata['description']
        )
    vector_metaset = VectorMetaset(vector_metadata_map=tables_metadata, query=None)
    return SQLMetaset(
        vector_metaset=vector_metaset,
        sql_metadata_map=tables_info,
    )
