#!/usr/bin/env python3

import json
import os
import threading
import subprocess
import sys
from datetime import datetime
from pathlib import Path

import numpy as np
import tkinter as tk
from tkinter import ttk, scrolledtext, messagebox, filedialog

try:
    import tomllib  # Python 3.11+
except ImportError:  # pragma: no cover - fallback for older Python
    try:
        import toml as tomllib  # type: ignore
    except Exception as exc:  # pragma: no cover
        raise SystemExit("Missing dependency: toml. Install with `python -m pip install toml`.") from exc

try:
    import matplotlib
    matplotlib.use("TkAgg")
    import matplotlib.pyplot as plt
    from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
    from matplotlib import rcParams
except Exception as exc:  # pragma: no cover
    raise SystemExit("Missing dependency: matplotlib. Install with `python -m pip install matplotlib`.") from exc

# Global matplotlib style
rcParams["text.usetex"] = False
rcParams["font.family"] = "serif"
rcParams["font.serif"] = ["Computer Modern Roman", "CMU Serif", "Computer Modern", "DejaVu Serif"]
rcParams["mathtext.fontset"] = "cm"
rcParams["xtick.direction"] = "in"
rcParams["ytick.direction"] = "in"
rcParams["xtick.top"] = True
rcParams["xtick.bottom"] = True
rcParams["ytick.left"] = True
rcParams["ytick.right"] = True


ROOT = Path(__file__).resolve().parent
JULIA_SCRIPT = ROOT / "two_mode_ellipse_data.jl"
STATE_PATH = ROOT / ".gui_state.json"


def load_toml(path: Path):
    with path.open("rb") as f:
        return tomllib.load(f)


def ellipse_points(cov2, npts=400):
    cov = np.array(cov2, dtype=float)
    vals, vecs = np.linalg.eigh(cov)
    vals = np.maximum(vals, 0.0)
    axes = np.sqrt(vals)
    t = np.linspace(0.0, 2.0 * np.pi, npts)
    circle = np.vstack((axes[0] * np.cos(t), axes[1] * np.sin(t)))
    return vecs @ circle, axes, vecs, vals


def mode_block(cov4, mode):
    i0 = 2 * (mode - 1)
    return cov4[i0 : i0 + 2, i0 : i0 + 2]


def render_axes(ax, cov2, title, extra_lines=None):
    pts, axes, vecs, vals = ellipse_points(cov2)
    ax.clear()
    ax.plot(pts[0, :], pts[1, :], color="C0", lw=2, label="1-sigma ellipse")
    ax.axhline(0, color="0.7", lw=1)
    ax.axvline(0, color="0.7", lw=1)

    sigma_x = float(np.sqrt(max(cov2[0, 0], 0.0)))
    sigma_p = float(np.sqrt(max(cov2[1, 1], 0.0)))
    ax.plot([0, sigma_x], [0, 0], color="C3", lw=2, label=r"$\sigma_x$")
    ax.plot([0, 0], [0, sigma_p], color="C2", lw=2, label=r"$\sigma_p$")

    v_min = vecs[:, 0] * axes[0]
    v_max = vecs[:, 1] * axes[1]
    ax.plot([0, v_min[0]], [0, v_min[1]], color="C4", lw=1.5, ls="--", label="principal axes")
    ax.plot([0, v_max[0]], [0, v_max[1]], color="C4", lw=1.5, ls="--")

    lam_min = float(vals[0])
    lam_max = float(vals[1])
    angle = float(np.degrees(np.arctan2(vecs[1, 0], vecs[0, 0])))
    nu = float(np.sqrt(max(np.linalg.det(cov2), 0.0)))

    lines = [
        f"$\\sigma_x$ = {sigma_x:.4g}",
        f"$\\sigma_p$ = {sigma_p:.4g}",
        f"$\\lambda_{{\\min}}$ = {lam_min:.4g}",
        f"$\\lambda_{{\\max}}$ = {lam_max:.4g}",
        f"$\\theta$ = {angle:.2f} deg",
        f"$\\nu$ = {nu:.4g}",
    ]
    if extra_lines:
        lines.extend(extra_lines)
    text = "\n".join(lines)
    ax.text(
        0.02,
        0.98,
        text,
        transform=ax.transAxes,
        ha="left",
        va="top",
        fontsize=9,
        bbox=dict(facecolor="white", edgecolor="none", boxstyle="square,pad=0.4", alpha=0.6),
    )

    ax.set_title(title)
    ax.set_xlabel(r"$x$ (SNU)")
    ax.set_ylabel(r"$p$ (SNU)")
    return pts


class App(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("ABEE Two-Mode Ellipse GUI")
        self.geometry("1200x700")
        self.last_plot_path = None
        self.header_text = None

        self._build_ui()
        self._load_last_session()

    def _build_ui(self):
        container = ttk.Frame(self)
        container.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)

        left_container = ttk.Frame(container)
        left_container.pack(side=tk.LEFT, fill=tk.Y)

        # Scrollable left pane for parameters.
        left_canvas = tk.Canvas(left_container, borderwidth=0, highlightthickness=0)
        left_scroll = ttk.Scrollbar(left_container, orient="vertical", command=left_canvas.yview)
        left_canvas.configure(yscrollcommand=left_scroll.set)

        left_scroll.pack(side=tk.RIGHT, fill=tk.Y)
        left_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        left = ttk.Frame(left_canvas)
        left_window = left_canvas.create_window((0, 0), window=left, anchor="nw")

        def _sync_scroll_region(event):
            left_canvas.configure(scrollregion=left_canvas.bbox("all"))

        def _sync_width(event):
            # Keep the inner frame width in sync with the canvas width.
            left_canvas.itemconfigure(left_window, width=event.width)

        left.bind("<Configure>", _sync_scroll_region)
        left_canvas.bind("<Configure>", _sync_width)

        def _on_mousewheel(event):
            # Windows/Mac: event.delta, Linux uses Button-4/5.
            if event.delta:
                left_canvas.yview_scroll(int(-1 * (event.delta / 120)), "units")
            else:
                left_canvas.yview_scroll(1 if event.num == 5 else -1, "units")

        left_canvas.bind_all("<MouseWheel>", _on_mousewheel)
        left_canvas.bind_all("<Button-4>", _on_mousewheel)
        left_canvas.bind_all("<Button-5>", _on_mousewheel)

        right = ttk.Frame(container)
        right.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)

        self.fields = {}
        self.vars = {}

        def add_row(label, key, default="", width=16):
            row = ttk.Frame(left)
            row.pack(fill=tk.X, pady=2)
            ttk.Label(row, text=label, width=18).pack(side=tk.LEFT)
            ent = ttk.Entry(row, width=width)
            ent.insert(0, str(default))
            ent.pack(side=tk.LEFT, fill=tk.X, expand=True)
            self.fields[key] = ent

        def add_combo(label, key, values, default):
            row = ttk.Frame(left)
            row.pack(fill=tk.X, pady=2)
            ttk.Label(row, text=label, width=18).pack(side=tk.LEFT)
            var = tk.StringVar()
            combo = ttk.Combobox(row, textvariable=var, values=values, width=14, state="readonly")
            combo.set(default)
            combo.pack(side=tk.LEFT, fill=tk.X, expand=True)
            self.fields[key] = combo
            self.vars[key] = var

        ttk.Label(left, text="System Parameters", font=("TkDefaultFont", 10, "bold")).pack(anchor="w")
        add_combo("coupling", "coupling", ["none", "xx", "pp", "xxpp", "RWA"], "none")
        add_row("g (RWA)", "g", "0.0")
        add_row("kxx", "kxx", "0.0")
        add_row("kpp", "kpp", "0.0")
        add_row("omega1", "omega1", "1.0")
        add_row("omega2", "omega2", "1.0")
        add_row("m1", "m1", "1.0")
        add_row("m2", "m2", "1.0")

        ttk.Separator(left).pack(fill=tk.X, pady=6)
        ttk.Label(left, text="Bath Parameters", font=("TkDefaultFont", 10, "bold")).pack(anchor="w")
        add_row("gamma1", "gamma1", "0.01")
        add_row("gamma2", "gamma2", "0.01")
        add_row("T1", "T1", "0.5")
        add_row("T2", "T2", "0.5")
        add_combo("bath1 coupling", "bath1", ["x", "p"], "x")
        add_combo("bath2 coupling", "bath2", ["x", "p"], "p")
        add_combo("gamma convention", "gamma_conv", ["snu", "physical"], "snu")
        add_row("omega_D", "omega_D", "5000")

        ttk.Separator(left).pack(fill=tk.X, pady=6)
        ttk.Label(left, text="Integrator", font=("TkDefaultFont", 10, "bold")).pack(anchor="w")
        add_row("eps", "eps", "1e-10")
        add_row("quadgk rtol", "rtol", "1e-10")
        add_row("quadgk atol", "atol", "1e-12")
        add_row("quadgk order", "order", "9")
        add_row("quadgk maxevals", "maxevals", "0")
        add_row("extra points", "extra_points", "auto")

        ttk.Separator(left).pack(fill=tk.X, pady=6)
        ttk.Label(left, text="Run", font=("TkDefaultFont", 10, "bold")).pack(anchor="w")
        add_row("Julia exe", "julia", "julia")
        add_row("Output base", "out", f"two_mode_gui_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
        add_row("PDF output", "pdf_out", "")
        self.save_pdf_var = tk.BooleanVar(value=False)
        save_pdf_chk = ttk.Checkbutton(left, text="Auto-save PDF", variable=self.save_pdf_var)
        save_pdf_chk.pack(anchor="w", pady=2)
        save_pdf_btn = ttk.Button(left, text="Save PDF Now", command=self.save_pdf_now)
        save_pdf_btn.pack(fill=tk.X, pady=2)
        load_toml_btn = ttk.Button(left, text="Load TOML", command=self.load_toml_now)
        load_toml_btn.pack(fill=tk.X, pady=2)

        self.run_btn = ttk.Button(left, text="Start", command=self.run)
        self.run_btn.pack(fill=tk.X, pady=6)

        self.log = scrolledtext.ScrolledText(left, width=42, height=10)
        self.log.pack(fill=tk.BOTH, expand=True)

        fig, axes = plt.subplots(1, 2, figsize=(8, 4))
        self.fig = fig
        self.axes = axes
        self.canvas = FigureCanvasTkAgg(fig, master=right)
        self.canvas.draw()
        self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

    def append_log(self, text):
        self.log.insert(tk.END, text + "\n")
        self.log.see(tk.END)

    def run(self):
        if not JULIA_SCRIPT.exists():
            messagebox.showerror("Error", f"Julia script not found: {JULIA_SCRIPT}")
            return

        self.run_btn.config(state=tk.DISABLED)
        self.append_log("Running Julia...")

        thread = threading.Thread(target=self._run_job, daemon=True)
        thread.start()

    def _run_job(self):
        cmd = self._build_command()
        cmd_str = " ".join(cmd)
        result = subprocess.run(cmd, capture_output=True, text=True, cwd=str(ROOT))
        out_base = self.fields["out"].get().strip()
        self.after(
            0,
            self._handle_run_result,
            result.returncode,
            result.stdout,
            result.stderr,
            out_base,
            cmd_str,
        )

    def _handle_run_result(self, returncode, stdout, stderr, out_base, cmd_str):
        self.append_log("Command: " + cmd_str)
        if stdout:
            self.append_log(stdout.strip())
        if stderr:
            self.append_log(stderr.strip())
        if returncode != 0:
            self.append_log(f"Julia failed with code {returncode}")
            self.run_btn.config(state=tk.NORMAL)
            return
        toml_path = Path(out_base).with_suffix(".toml")
        if not toml_path.is_absolute():
            toml_path = ROOT / toml_path
        if not toml_path.exists():
            self.append_log(f"Output TOML not found: {toml_path}")
            self.run_btn.config(state=tk.NORMAL)
            return
        self._render_plot(toml_path)
        self.run_btn.config(state=tk.NORMAL)

    def _build_command(self):
        def val(key):
            return self.fields[key].get().strip()

        out_base = val("out")
        if not out_base:
            out_base = "two_mode_gui"

        cmd = [
            val("julia"),
            "--project",
            str(JULIA_SCRIPT),
            "--coupling",
            val("coupling"),
            "--g",
            val("g"),
            "--kxx",
            val("kxx"),
            "--kpp",
            val("kpp"),
            "--omega1",
            val("omega1"),
            "--omega2",
            val("omega2"),
            "--m1",
            val("m1"),
            "--m2",
            val("m2"),
            "--gamma1",
            val("gamma1"),
            "--gamma2",
            val("gamma2"),
            "--T1",
            val("T1"),
            "--T2",
            val("T2"),
            "--bath-coupling1",
            val("bath1"),
            "--bath-coupling2",
            val("bath2"),
            "--gamma-convention",
            val("gamma_conv"),
            "--omega-D",
            val("omega_D"),
            "--method",
            "quadgk",
            "--quadgk-rtol",
            val("rtol"),
            "--quadgk-atol",
            val("atol"),
            "--quadgk-order",
            val("order"),
            "--quadgk-maxevals",
            val("maxevals"),
            "--eps",
            val("eps"),
            "--out",
            out_base,
        ]
        extra = val("extra_points")
        if extra and extra.lower() not in {"auto", "none"}:
            cmd.extend(["--quadgk-extra-points", extra])
        return cmd

    def _render_plot(self, toml_path: Path):
        data = load_toml(toml_path)
        cov4 = np.array(data.get("covariance", []), dtype=float)
        if cov4.shape != (4, 4):
            self.append_log("TOML missing 4x4 covariance.")
            return
        cov1 = mode_block(cov4, 1)
        cov2 = mode_block(cov4, 2)

        params = data.get("params", {})
        results = data.get("results", {})

        omega1 = params.get("omega1", float("nan"))
        omega2 = params.get("omega2", float("nan"))
        gamma1 = params.get("gamma1", float("nan"))
        gamma2 = params.get("gamma2", float("nan"))
        T1 = params.get("T1", float("nan"))
        T2 = params.get("T2", float("nan"))
        bath1 = params.get("bath_coupling1", "?")
        bath2 = params.get("bath_coupling2", "?")
        omega_D = params.get("omega_D", float("nan"))
        logneg = results.get("log_negativity", float("nan"))
        nu_pt_min = results.get("nu_pt_min", float("nan"))

        pts1 = render_axes(
            self.axes[0],
            cov1,
            f"$\\omega_{{1}}$={omega1}",
            extra_lines=[f"bath: ${bath1}$, $\\gamma$={gamma1}, $T$={T1}", f"$\\omega_D$={omega_D}"],
        )
        pts2 = render_axes(
            self.axes[1],
            cov2,
            f"$\\omega_{{2}}$={omega2}",
            extra_lines=[f"bath: ${bath2}$, $\\gamma$={gamma2}, $T$={T2}", f"$\\omega_D$={omega_D}"],
        )

        all_pts = np.hstack((pts1, pts2))
        max_extent = np.max(np.abs(all_pts))
        if not np.isfinite(max_extent) or max_extent == 0:
            max_extent = 1.0
        lim = 1.1 * max_extent
        for ax in self.axes:
            ax.set_aspect("equal", adjustable="box")
            ax.set_xlim(-lim, lim)
            ax.set_ylim(-lim, lim)
            ax.legend(loc="lower right", fontsize=8, framealpha=0.6, fancybox=False, edgecolor="none")

        title_parts = []
        if np.isfinite(logneg):
            title_parts.append(f"$E_{{\\mathcal{{N}}}}$ = {logneg:.4g}")
        if np.isfinite(nu_pt_min):
            title_parts.append(f"$\\tilde{{\\nu}}_{{\\min}}$ = {nu_pt_min:.4g}")
        title_text = "   ".join(title_parts)

        # Layout with a fixed top margin so the header never overlaps the axes.
        header_top = 0.92
        self.fig.tight_layout(rect=[0.0, 0.0, 1.0, header_top])
        if self.header_text is not None:
            try:
                self.header_text.remove()
            except Exception:
                pass
            self.header_text = None
        if title_text:
            self.header_text = self.fig.text(0.5, header_top + 0.01, title_text, ha="center", va="bottom")
        self.canvas.draw()
        self.last_plot_path = toml_path
        self._save_last_session(toml_path)
        self.append_log(f"Plotted: {toml_path}")

        if self.save_pdf_var.get():
            pdf_out = self.fields["pdf_out"].get().strip()
            if not pdf_out:
                pdf_out = str(toml_path.with_name(toml_path.stem + "_both.pdf"))
            pdf_path = Path(pdf_out)
            if not pdf_path.is_absolute():
                pdf_path = ROOT / pdf_path
            self.fig.savefig(pdf_path, dpi=200)
            self.append_log(f"Saved PDF: {pdf_path}")

    def save_pdf_now(self):
        if self.last_plot_path is None:
            messagebox.showinfo("Save PDF", "No plot yet. Run a calculation first.")
            return
        pdf_out = self.fields["pdf_out"].get().strip()
        if not pdf_out:
            pdf_out = str(self.last_plot_path.with_name(self.last_plot_path.stem + "_both.pdf"))
        pdf_path = Path(pdf_out)
        if not pdf_path.is_absolute():
            pdf_path = ROOT / pdf_path
        self.fig.savefig(pdf_path, dpi=200)
        self.append_log(f"Saved PDF: {pdf_path}")

    def load_toml_now(self):
        file_path = filedialog.askopenfilename(
            title="Select TOML file",
            initialdir=str(ROOT),
            filetypes=[("TOML files", "*.toml"), ("All files", "*.*")],
        )
        if not file_path:
            return
        toml_path = Path(file_path)
        try:
            self._render_plot(toml_path)
        except Exception as exc:
            messagebox.showerror("Load TOML", f"Failed to load TOML: {exc}")

    def _save_last_session(self, toml_path: Path):
        try:
            payload = {"last_toml": str(toml_path)}
            STATE_PATH.write_text(json.dumps(payload), encoding="utf-8")
        except Exception:
            pass

    def _load_last_session(self):
        if not STATE_PATH.exists():
            return
        try:
            payload = json.loads(STATE_PATH.read_text(encoding="utf-8"))
            last = payload.get("last_toml", "")
            if not last:
                return
            toml_path = Path(last)
            if not toml_path.is_absolute():
                toml_path = ROOT / toml_path
            if toml_path.exists():
                self._render_plot(toml_path)
        except Exception:
            return


if __name__ == "__main__":
    app = App()
    app.mainloop()
