from datetime import date, timedelta
from sqlmodel import Session, select
from app.db import engine, init_db
from app.models import Tenant, User, Control, Assessment, Answer, Risk, Severity, Status, Policy, Supplier, Questionnaire, Role
from app.security import hash_password

CONTROL_SEED = [
    ("AC-1", "Access Control", "Access Control Policy", "Limit system access to authorized users and document account management procedures.", "3.1.1", "Level 2", True),
    ("AC-2", "Access Control", "Account Management", "Review accounts, remove terminated users, and approve privileged access.", "3.1.2", "Level 2", True),
    ("IA-1", "Identification and Authentication", "MFA", "Use MFA for privileged and remote access where supported.", "3.5.3", "Level 2", True),
    ("IR-1", "Incident Response", "Incident Response Plan", "Establish incident handling, reporting, and lessons-learned procedures.", "3.6.1", "Level 2", True),
    ("CP-1", "Contingency Planning", "Backups", "Perform, protect, and test backups for critical systems.", "3.8.9", "Level 2", True),
    ("RA-1", "Risk Assessment", "Risk Register", "Identify, document, and track cybersecurity risks and remediation.", "3.11.1", "Level 2", False),
    ("CM-1", "Configuration Management", "Secure Configuration", "Baseline and manage changes to systems.", "3.4.1", "Level 2", False),
    ("AT-1", "Awareness and Training", "Security Awareness", "Train users on cybersecurity risks and responsibilities.", "3.2.1", "Level 1", True),
]
POLICIES = ["Access control policy", "Password/MFA policy", "Acceptable use policy", "Incident response policy", "Backup policy", "Vendor management policy", "Remote access policy", "Data classification policy"]


def seed():
    init_db()
    with Session(engine) as s:
        for c in CONTROL_SEED:
            if not s.exec(select(Control).where(Control.code == c[0])).first():
                s.add(Control(code=c[0], family=c[1], title=c[2], requirement=c[3], nist_800_171=c[4], cmmc_level=c[5], insurance_control=c[6]))
        s.commit()
        tenant = s.exec(select(Tenant).where(Tenant.name == "Acme Manufacturing Demo")).first()
        if not tenant:
            tenant = Tenant(name="Acme Manufacturing Demo", industry="Aerospace Supplier")
            s.add(tenant); s.commit(); s.refresh(tenant)
        admin = s.exec(select(User).where(User.email == "admin@secureflow.test")).first()
        if not admin:
            admin = User(tenant_id=tenant.id, email="admin@secureflow.test", full_name="Demo Admin", role=Role.owner, password_hash=hash_password("SecureFlow123!"))
            s.add(admin); s.commit(); s.refresh(admin)
        controls = s.exec(select(Control)).all()
        for i, c in enumerate(controls):
            if not s.exec(select(Assessment).where(Assessment.tenant_id == tenant.id, Assessment.control_id == c.id)).first():
                s.add(Assessment(tenant_id=tenant.id, control_id=c.id, answer=[Answer.yes, Answer.partial, Answer.no][i % 3], notes="Seeded readiness assessment item", owner_id=admin.id, due_date=date.today() + timedelta(days=30 + i), risk_level=[Severity.low, Severity.medium, Severity.high][i % 3], remediation_status=Status.in_progress if i % 3 else Status.completed))
        for p in POLICIES:
            if not s.exec(select(Policy).where(Policy.tenant_id == tenant.id, Policy.title == p)).first():
                content = f"# {p}\n\nThis starter policy defines responsibilities, review cadence, evidence expectations, and management approval. Tailor before production use."
                s.add(Policy(tenant_id=tenant.id, title=p, category=p.split()[0], content=content, owner_id=admin.id, next_review_date=date.today() + timedelta(days=365)))
        if not s.exec(select(Risk).where(Risk.tenant_id == tenant.id)).first():
            s.add(Risk(tenant_id=tenant.id, title="MFA coverage gap", description="Some remote access paths do not yet enforce MFA.", control_id=3, severity=Severity.high, likelihood=4, owner_id=admin.id, due_date=date.today() + timedelta(days=45), status=Status.in_progress, remediation_plan="Inventory remote access, enforce MFA, attach screenshots and access control policy evidence."))
        if not s.exec(select(Supplier).where(Supplier.tenant_id == tenant.id)).first():
            s.add(Supplier(tenant_id=tenant.id, name="Precision Parts Vendor", criticality=Severity.high, contact_email="security@example.com", risk_score=72, review_date=date.today() + timedelta(days=90), notes="Critical supplier awaiting questionnaire."))
        if not s.exec(select(Questionnaire).where(Questionnaire.tenant_id == tenant.id)).first():
            s.add(Questionnaire(tenant_id=tenant.id, customer_name="Prime Aerospace Customer", title="2026 Cyber Questionnaire", questions=[{"question": "Do you enforce MFA for remote access?", "status": "Draft"}]))
        s.commit()
        print("Seed complete. Demo login: admin@secureflow.test / SecureFlow123!")


if __name__ == "__main__":
    seed()
