#!/usr/bin/env julia

using ArgParse
using LinearAlgebra
using TOML

include(joinpath(@__DIR__, "src", "ABEE.jl"))
import .ABEE

function parse_cmdline()
    s = ArgParseSettings()
    @add_arg_table s begin
        "--omega1"
            help = "Mode 1 frequency (omega_1)."
            arg_type = Float64
            default = 1.0
        "--omega2"
            help = "Mode 2 frequency (omega_2)."
            arg_type = Float64
            default = 1.2
        "--m1"
            help = "Mode 1 mass."
            arg_type = Float64
            default = 1.0
        "--m2"
            help = "Mode 2 mass."
            arg_type = Float64
            default = 1.0
        "--kxx"
            help = "x-x coupling strength."
            arg_type = Float64
            default = 0.1
        "--kpp"
            help = "p-p coupling strength (optional)."
            arg_type = Float64
            default = 0.0
        "--g"
            help = "RWA coupling strength (only if --coupling RWA)."
            arg_type = Float64
            default = 0.0
        "--coupling"
            help = "Inter-mode coupling: none, xx, pp, xxpp, RWA."
            arg_type = String
            default = "none"
        "--gamma1"
            help = "Bath damping gamma_1 (ladder convention)."
            arg_type = Float64
            default = 0.02
        "--gamma2"
            help = "Bath damping gamma_2 (ladder convention)."
            arg_type = Float64
            default = 0.02
        "--gamma-convention"
            help = "Gamma convention: physical or snu (applies to both unless overridden)."
            arg_type = String
            default = "snu"
        "--gamma-convention1"
            help = "Gamma convention for bath 1: physical or snu (overrides --gamma-convention)."
            arg_type = String
            default = ""
        "--gamma-convention2"
            help = "Gamma convention for bath 2: physical or snu (overrides --gamma-convention)."
            arg_type = String
            default = ""
        "--T1"
            help = "Bath 1 temperature."
            arg_type = Float64
            default = 0.5
        "--T2"
            help = "Bath 2 temperature."
            arg_type = Float64
            default = 0.5
        "--bath-coupling1"
            help = "Bath 1 coupling: x or p."
            arg_type = String
            default = "x"
        "--bath-coupling2"
            help = "Bath 2 coupling: x or p."
            arg_type = String
            default = "x"
        "--omega-D"
            help = "Drude cutoff omega_D (required)."
            arg_type = Float64
            default = NaN
        "--eps"
            help = "Small regulator for invGR."
            arg_type = Float64
            default = 1e-10
        "--method"
            help = "Integration method (quadgk only)."
            arg_type = String
            default = "quadgk"
        "--quadgk-rtol"
            help = "Relative tolerance for quadgk."
            arg_type = Float64
            default = 1e-10
        "--quadgk-atol"
            help = "Absolute tolerance for quadgk."
            arg_type = Float64
            default = 1e-12
        "--quadgk-order"
            help = "Quadrature order for quadgk."
            arg_type = Int
            default = 9
        "--quadgk-maxevals"
            help = "Max evals for quadgk (0 = unlimited)."
            arg_type = Int
            default = 0
        "--no-quadgk-auto-points"
            help = "Disable auto quadgk points around system frequencies."
            action = :store_true
        "--quadgk-point-span"
            help = "Relative span around auto points (e.g. 0.05 = ±5%)."
            arg_type = Float64
            default = 0.05
        "--quadgk-point-steps"
            help = "Number of points on each side within span."
            arg_type = Int
            default = 2
        "--quadgk-extra-points"
            help = "Comma-separated extra quadgk points (e.g. \"0.1,0.2,-0.3\")."
            arg_type = String
            default = ""
        "--out"
            help = "Output base path (without extension)."
            arg_type = String
            default = "two_mode_ellipse"
        "--no-physicality-check"
            help = "Skip physicality checks."
            action = :store_true
    end
    return ArgParse.parse_args(s)
end

function parse_coupling(s::String)
    u = lowercase(strip(s))
    if u == "x"
        return :x
    elseif u == "p"
        return :p
    end
    error("Invalid bath coupling: $s (use x or p)")
end

function parse_gamma_convention(s::String)
    u = lowercase(strip(s))
    if u == "physical"
        return :physical
    elseif u == "snu"
        return :snu
    end
    error("Invalid gamma convention: $s (use physical or snu)")
end

function parse_interaction(s::String)
    u = uppercase(strip(s))
    if u == "NONE"
        return :none
    elseif u == "XX"
        return :xx
    elseif u == "PP"
        return :pp
    elseif u == "XXPP"
        return :xxpp
    elseif u == "RWA"
        return :RWA
    end
    error("Invalid coupling: $s (use none, xx, pp, xxpp, RWA)")
end

function mat_to_vecs(M::AbstractMatrix{<:Real})
    return [ [M[i,j] for j in 1:size(M,2)] for i in 1:size(M,1) ]
end

function parse_float_list(s::String)
    s = strip(s)
    if isempty(s)
        return Float64[]
    end
    vals = Float64[]
    for part in split(s, ",")
        p = strip(part)
        if !isempty(p)
            push!(vals, parse(Float64, p))
        end
    end
    return vals
end

function auto_quadgk_points(sys::ABEE.SystemParams, omega1::Float64, omega2::Float64;
                            span::Float64, steps::Int, extra::Vector{Float64})
    pts = Float64[0.0, omega1, -omega1, omega2, -omega2]

    # Add approximate normal-mode frequencies from the undamped Hamiltonian.
    K = ABEE.Kmatrix(sys)
    J = ABEE.symplecticJ(2)
    eigs = eigvals(J * K)
    freqs = unique(abs.(imag.(eigs)))
    for f in freqs
        if isfinite(f) && f > 1e-12
            push!(pts, f, -f)
        end
    end

    append!(pts, extra)

    if steps > 0 && span > 0
        base = copy(pts)
        for p in base
            if p == 0.0
                continue
            end
            for k in 1:steps
                δ = span * k / steps
                push!(pts, p * (1 - δ))
                push!(pts, p * (1 + δ))
            end
        end
    end

    return sort(unique(pts))
end

function main()
    args = parse_cmdline()

    omega1 = args["omega1"]
    omega2 = args["omega2"]
    m1 = args["m1"]
    m2 = args["m2"]
    kxx = args["kxx"]
    kpp = args["kpp"]
    g = args["g"]
    coupling = parse_interaction(args["coupling"])
    gamma1 = args["gamma1"]
    gamma2 = args["gamma2"]
    T1 = args["T1"]
    T2 = args["T2"]
    bath1_c = parse_coupling(args["bath-coupling1"])
    bath2_c = parse_coupling(args["bath-coupling2"])
    gamma_conv_default = parse_gamma_convention(args["gamma-convention"])
    gamma_conv1 = isempty(strip(args["gamma-convention1"])) ? gamma_conv_default :
        parse_gamma_convention(args["gamma-convention1"])
    gamma_conv2 = isempty(strip(args["gamma-convention2"])) ? gamma_conv_default :
        parse_gamma_convention(args["gamma-convention2"])
    omega_D = args["omega-D"]
    eps = args["eps"]
    method = Symbol(lowercase(args["method"]))
    quadgk_rtol = args["quadgk-rtol"]
    quadgk_atol = args["quadgk-atol"]
    quadgk_order = args["quadgk-order"]
    quadgk_maxevals = args["quadgk-maxevals"]
    quadgk_auto = !args["no-quadgk-auto-points"]
    quadgk_span = args["quadgk-point-span"]
    quadgk_steps = args["quadgk-point-steps"]
    quadgk_extra = parse_float_list(args["quadgk-extra-points"])
    out_base = args["out"]
    check_phys = !args["no-physicality-check"]
    if method != :quadgk
        error("--method must be quadgk (trapz has been removed).")
    end

    # Apply coupling choice.
    if coupling == :none
        kxx = 0.0
        kpp = 0.0
        g = 0.0
    elseif coupling == :pp
        kxx = 0.0
        g = 0.0
    elseif coupling == :xx
        kpp = 0.0
        g = 0.0
    elseif coupling == :xxpp
        g = 0.0
    elseif coupling == :RWA
        kxx = 0.0
        kpp = 0.0
    end

    sys = ABEE.SystemParams(m1, m2, omega1, omega2, coupling, g, kxx, kpp, nothing)
    if !isfinite(omega_D) || omega_D <= 0
        error("--omega-D must be > 0")
    end
    bath1 = ABEE.BathParams(γ=gamma1, T=T1, coupling=bath1_c, λ=nothing,
                            omega_D=omega_D, gamma_convention=gamma_conv1)
    bath2 = ABEE.BathParams(γ=gamma2, T=T2, coupling=bath2_c, λ=nothing,
                            omega_D=omega_D, gamma_convention=gamma_conv2)
    integ = ABEE.IntegrationParams(omega_D, 2001, eps)

    quadgk_points = Float64[]
    if method == :quadgk
        quadgk_points = quadgk_auto ? auto_quadgk_points(sys, omega1, omega2;
                                                         span=quadgk_span,
                                                         steps=quadgk_steps,
                                                         extra=quadgk_extra) : quadgk_extra
    end

    V = ABEE.steady_state_covariance(sys, bath1, bath2, integ;
                                     method=method,
                                     quadgk_rtol=quadgk_rtol,
                                     quadgk_atol=quadgk_atol,
                                     quadgk_points=quadgk_points,
                                     quadgk_order=quadgk_order,
                                     quadgk_maxevals=quadgk_maxevals,
                                     check_physicality=check_phys,
                                     throw_on_unphysical=false)

    # Single-mode blocks
    V1 = Matrix{Float64}(V[1:2, 1:2])
    V2 = Matrix{Float64}(V[3:4, 3:4])

    # Standard deviations in x/p (SNU)
    stds1 = ABEE.mode_stds(V; mode=1)
    sigma_x1 = stds1[1]
    sigma_p1 = stds1[2]
    stds2 = ABEE.mode_stds(V; mode=2)
    sigma_x2 = stds2[1]
    sigma_p2 = stds2[2]

    # Principal axes of the ellipse (mode 1)
    evals1, evecs1 = eigen(Symmetric(V1))
    order1 = sortperm(evals1)
    evals1 = evals1[order1]
    evecs1 = evecs1[:, order1]
    lambda_min1 = evals1[1]
    lambda_max1 = evals1[2]
    v_min1 = evecs1[:, 1]
    angle1 = atan(v_min1[2], v_min1[1])

    # Principal axes of the ellipse (mode 2)
    evals2, evecs2 = eigen(Symmetric(V2))
    order2 = sortperm(evals2)
    evals2 = evals2[order2]
    evecs2 = evecs2[:, order2]
    lambda_min2 = evals2[1]
    lambda_max2 = evals2[2]
    v_min2 = evecs2[:, 1]
    angle2 = atan(v_min2[2], v_min2[1])

    # Symplectic eigenvalues for each single mode
    nu1 = sqrt(max(det(V1), 0.0))
    nu2 = sqrt(max(det(V2), 0.0))

    # Squeezing in dB relative to vacuum
    sq1 = ABEE.squeezing_dB(V, sys; mode=1)
    sq2 = ABEE.squeezing_dB(V, sys; mode=2)

    # Two-mode entanglement diagnostics
    logneg = ABEE.log_negativity(V)
    nupt = minimum(ABEE.symplectic_eigenvalues(ABEE.partial_transpose(V)))

    params = Dict(
        "omega1" => omega1,
        "omega2" => omega2,
        "m1" => m1,
        "m2" => m2,
        "kxx" => kxx,
        "kpp" => kpp,
        "g" => g,
        "coupling" => String(coupling),
        "gamma1" => gamma1,
        "gamma2" => gamma2,
        "T1" => T1,
        "T2" => T2,
        "bath_coupling1" => String(bath1_c),
        "bath_coupling2" => String(bath2_c),
        "omega_D" => omega_D,
        "eps" => eps,
        "method" => String(method),
        "quadgk_rtol" => quadgk_rtol,
        "quadgk_atol" => quadgk_atol,
        "quadgk_order" => quadgk_order,
        "quadgk_maxevals" => quadgk_maxevals,
        "quadgk_auto_points" => quadgk_auto,
        "quadgk_point_span" => quadgk_span,
        "quadgk_point_steps" => quadgk_steps,
        "quadgk_extra_points" => quadgk_extra
    )

    results = Dict(
        # Backward-compatible mode-1 keys
        "sigma_x" => sigma_x1,
        "sigma_p" => sigma_p1,
        "lambda_min" => lambda_min1,
        "lambda_max" => lambda_max1,
        "angle_rad" => angle1,
        "angle_deg" => angle1 * 180 / pi,
        "symplectic_nu" => nu1,
        "squeezing_x_dB" => sq1.sx_dB,
        "squeezing_p_dB" => sq1.sp_dB,
        # Explicit mode-1 and mode-2 fields
        "sigma_x_mode1" => sigma_x1,
        "sigma_p_mode1" => sigma_p1,
        "lambda_min_mode1" => lambda_min1,
        "lambda_max_mode1" => lambda_max1,
        "angle_rad_mode1" => angle1,
        "angle_deg_mode1" => angle1 * 180 / pi,
        "symplectic_nu_mode1" => nu1,
        "squeezing_x_dB_mode1" => sq1.sx_dB,
        "squeezing_p_dB_mode1" => sq1.sp_dB,
        "sigma_x_mode2" => sigma_x2,
        "sigma_p_mode2" => sigma_p2,
        "lambda_min_mode2" => lambda_min2,
        "lambda_max_mode2" => lambda_max2,
        "angle_rad_mode2" => angle2,
        "angle_deg_mode2" => angle2 * 180 / pi,
        "symplectic_nu_mode2" => nu2,
        "squeezing_x_dB_mode2" => sq2.sx_dB,
        "squeezing_p_dB_mode2" => sq2.sp_dB,
        # Two-mode diagnostics
        "nu_pt_min" => nupt,
        "log_negativity" => logneg
    )

    out = Dict(
        "params" => params,
        "results" => results,
        "covariance" => mat_to_vecs(V),
        "mode_covariance" => mat_to_vecs(V1),
        "mode1_covariance" => mat_to_vecs(V1),
        "mode2_covariance" => mat_to_vecs(V2)
    )

    toml_path = out_base * ".toml"
    open(toml_path, "w") do io
        TOML.print(io, out)
    end

    println("Wrote: $toml_path")
    println("Mode 1 results:")
    println("  sigma_x = $(sigma_x1)")
    println("  sigma_p = $(sigma_p1)")
    println("  lambda_min = $(lambda_min1)")
    println("  lambda_max = $(lambda_max1)")
    println("  angle = $(angle1) rad ($(angle1*180/pi) deg)")
    println("  symplectic_nu = $(nu1)")
    println("  squeezing_x_dB = $(sq1.sx_dB)")
    println("  squeezing_p_dB = $(sq1.sp_dB)")
    println("Mode 2 results:")
    println("  sigma_x = $(sigma_x2)")
    println("  sigma_p = $(sigma_p2)")
    println("  lambda_min = $(lambda_min2)")
    println("  lambda_max = $(lambda_max2)")
    println("  angle = $(angle2) rad ($(angle2*180/pi) deg)")
    println("  symplectic_nu = $(nu2)")
    println("  squeezing_x_dB = $(sq2.sx_dB)")
    println("  squeezing_p_dB = $(sq2.sp_dB)")
    println("Two-mode entanglement diagnostics:")
    println("  nu_pt_min = $(nupt)")
    println("  log_negativity = $(logneg)")
end

main()
