from __future__ import annotations

from datetime import datetime
from typing import Optional

from flask_login import UserMixin
from sqlalchemy import ForeignKey, Text, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from werkzeug.security import check_password_hash, generate_password_hash

from .extensions import db
from .security import decrypt_text, encrypt_text, utcnow_naive


class AdminUser(UserMixin, db.Model):
    __tablename__ = "admin_users"

    id: Mapped[int] = mapped_column(primary_key=True)
    username: Mapped[str] = mapped_column(db.String(80), unique=True, nullable=False, index=True)
    password_hash: Mapped[str] = mapped_column(db.String(255), nullable=False)
    created_at: Mapped[datetime] = mapped_column(default=utcnow_naive, nullable=False)
    last_login_at: Mapped[Optional[datetime]] = mapped_column(nullable=True)

    batches: Mapped[list[DeliveryBatch]] = relationship(back_populates="created_by")

    def set_password(self, password: str) -> None:
        self.password_hash = generate_password_hash(password)

    def check_password(self, password: str) -> bool:
        return check_password_hash(self.password_hash, password)


class DeliveryBatch(db.Model):
    __tablename__ = "delivery_batches"

    id: Mapped[int] = mapped_column(primary_key=True)
    name: Mapped[str] = mapped_column(db.String(120), nullable=False, index=True)
    description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
    is_active: Mapped[bool] = mapped_column(default=True, nullable=False)
    created_at: Mapped[datetime] = mapped_column(default=utcnow_naive, nullable=False)
    updated_at: Mapped[datetime] = mapped_column(default=utcnow_naive, onupdate=utcnow_naive, nullable=False)
    created_by_id: Mapped[Optional[int]] = mapped_column(ForeignKey("admin_users.id"), nullable=True)

    created_by: Mapped[Optional[AdminUser]] = relationship(back_populates="batches")
    credentials: Mapped[list[BatchCredential]] = relationship(
        back_populates="batch", cascade="all, delete-orphan", order_by="BatchCredential.sort_order"
    )
    access_keys: Mapped[list[AccessKey]] = relationship(
        back_populates="batch", cascade="all, delete-orphan", order_by="desc(AccessKey.created_at)"
    )
    logs: Mapped[list[AccessLog]] = relationship(back_populates="batch", cascade="all, delete-orphan")

    @property
    def credential_count(self) -> int:
        return len(self.credentials)


class BatchCredential(db.Model):
    __tablename__ = "batch_credentials"

    id: Mapped[int] = mapped_column(primary_key=True)
    batch_id: Mapped[int] = mapped_column(ForeignKey("delivery_batches.id"), nullable=False, index=True)
    label: Mapped[Optional[str]] = mapped_column(db.String(120), nullable=True)
    encrypted_value: Mapped[str] = mapped_column(Text, nullable=False)
    sort_order: Mapped[int] = mapped_column(default=0, nullable=False)
    created_at: Mapped[datetime] = mapped_column(default=utcnow_naive, nullable=False)

    batch: Mapped[DeliveryBatch] = relationship(back_populates="credentials")

    @property
    def value(self) -> str:
        return decrypt_text(self.encrypted_value)

    @value.setter
    def value(self, raw_value: str) -> None:
        self.encrypted_value = encrypt_text(raw_value)


class AccessKey(db.Model):
    __tablename__ = "access_keys"

    id: Mapped[int] = mapped_column(primary_key=True)
    batch_id: Mapped[int] = mapped_column(ForeignKey("delivery_batches.id"), nullable=False, index=True)
    key_hash: Mapped[str] = mapped_column(db.String(64), unique=True, nullable=False, index=True)
    key_hint: Mapped[str] = mapped_column(db.String(32), nullable=False)
    notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
    is_active: Mapped[bool] = mapped_column(default=True, nullable=False)
    created_at: Mapped[datetime] = mapped_column(default=utcnow_naive, nullable=False)
    expires_at: Mapped[Optional[datetime]] = mapped_column(nullable=True)
    revoked_at: Mapped[Optional[datetime]] = mapped_column(nullable=True)
    last_used_at: Mapped[Optional[datetime]] = mapped_column(nullable=True)
    view_count: Mapped[int] = mapped_column(default=0, nullable=False)
    max_views: Mapped[Optional[int]] = mapped_column(nullable=True)

    batch: Mapped[DeliveryBatch] = relationship(back_populates="access_keys")
    logs: Mapped[list[AccessLog]] = relationship(back_populates="access_key")

    @property
    def is_expired(self) -> bool:
        return self.expires_at is not None and self.expires_at <= utcnow_naive()

    @property
    def is_exhausted(self) -> bool:
        return self.max_views is not None and self.view_count >= self.max_views

    @property
    def is_usable(self) -> bool:
        return self.is_active and self.revoked_at is None and not self.is_expired and not self.is_exhausted

    @property
    def status(self) -> str:
        if not self.is_active or self.revoked_at is not None:
            return "revoked"
        if self.is_expired:
            return "expired"
        if self.is_exhausted:
            return "used up"
        return "active"


class AccessLog(db.Model):
    __tablename__ = "access_logs"

    id: Mapped[int] = mapped_column(primary_key=True)
    batch_id: Mapped[Optional[int]] = mapped_column(ForeignKey("delivery_batches.id"), nullable=True, index=True)
    access_key_id: Mapped[Optional[int]] = mapped_column(ForeignKey("access_keys.id"), nullable=True, index=True)
    event_type: Mapped[str] = mapped_column(db.String(40), nullable=False, index=True)
    detail: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
    ip_address: Mapped[Optional[str]] = mapped_column(db.String(64), nullable=True)
    user_agent: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
    created_at: Mapped[datetime] = mapped_column(default=utcnow_naive, nullable=False, index=True)

    batch: Mapped[Optional[DeliveryBatch]] = relationship(back_populates="logs")
    access_key: Mapped[Optional[AccessKey]] = relationship(back_populates="logs")


class DashboardStats:
    def __init__(self, total_batches: int, active_keys: int, total_credentials: int, total_logs: int):
        self.total_batches = total_batches
        self.active_keys = active_keys
        self.total_credentials = total_credentials
        self.total_logs = total_logs


def build_dashboard_stats() -> DashboardStats:
    total_batches = db.session.query(func.count(DeliveryBatch.id)).scalar() or 0
    active_keys = (
        db.session.query(func.count(AccessKey.id))
        .filter(AccessKey.is_active.is_(True), AccessKey.revoked_at.is_(None))
        .scalar()
        or 0
    )
    total_credentials = db.session.query(func.count(BatchCredential.id)).scalar() or 0
    total_logs = db.session.query(func.count(AccessLog.id)).scalar() or 0
    return DashboardStats(total_batches, active_keys, total_credentials, total_logs)
