"""
Regenerate docs/assets/early_exercise.svg.

Two-panel figure for the "Where the Model Breaks: Early Exercise" section
of docs/concepts/greek-validation.md.

- Left panel: European vs American put price as a function of spot.
  Gold band between the two curves is the early-exercise premium.
- Right panel: exercise boundary B(t) vs time-to-expiry. Below the
  boundary, immediate exercise is optimal.

Run from the docs repo root:
    python scripts/early_exercise.py
"""

from pathlib import Path

import numpy as np
from scipy.stats import norm

import matplotlib.pyplot as plt

import theme


# ----- parameters -----
K     = 100.0   # strike
r     = 0.05    # risk-free rate
sigma = 0.30    # vol
T     = 1.0     # time to expiry (years)
N     = 400     # binomial tree steps

OUT = Path(__file__).resolve().parents[1] / "docs" / "assets" / "early_exercise.svg"


# ----- European put (closed form) -----
def euro_put(S, K, r, sigma, T):
    S = np.asarray(S, dtype=float)
    sqrtT = np.sqrt(T)
    d1 = (np.log(S / K) + (r + 0.5 * sigma**2) * T) / (sigma * sqrtT)
    d2 = d1 - sigma * sqrtT
    return K * np.exp(-r * T) * norm.cdf(-d2) - S * norm.cdf(-d1)


# ----- American put via CRR tree, plus exercise boundary per step -----
def american_put_and_boundary(S0, K, r, sigma, T, N):
    """
    Return (price_at_S0, boundary_array) where boundary_array[i] is the
    exercise boundary at time step i (0 = now, N = expiry).

    Uses CRR recombining tree. The boundary is the largest spot at each
    step where intrinsic value >= continuation value.
    """
    dt = T / N
    u  = np.exp(sigma * np.sqrt(dt))
    d  = 1.0 / u
    disc = np.exp(-r * dt)
    p  = (np.exp(r * dt) - d) / (u - d)

    # Terminal payoff for all nodes at step N
    j = np.arange(N + 1)
    S_terminal = S0 * (u ** j) * (d ** (N - j))
    V = np.maximum(K - S_terminal, 0.0)

    boundary = np.full(N + 1, np.nan)
    boundary[N] = K

    for i in range(N - 1, -1, -1):
        j = np.arange(i + 1)
        S_nodes = S0 * (u ** j) * (d ** (i - j))
        cont    = disc * (p * V[1:i + 2] + (1.0 - p) * V[0:i + 1])
        intrin  = np.maximum(K - S_nodes, 0.0)
        exercise_mask = intrin >= cont
        V = np.where(exercise_mask, intrin, cont)

        # Boundary at step i = largest S where exercise is strictly optimal.
        # Restrict to ITM nodes (intrin > 0); otherwise deep OTM nodes with
        # intrin == cont == 0 would satisfy the mask and pick huge S values.
        itm_exercise = exercise_mask & (intrin > 0)
        if itm_exercise.any():
            boundary[i] = S_nodes[itm_exercise].max()

    return V[0], boundary


def american_put_grid(S_grid, K, r, sigma, T, N):
    """American put price at each S in S_grid (re-runs the tree)."""
    return np.array([american_put_and_boundary(S, K, r, sigma, T, N)[0]
                     for S in S_grid])


# ----- figure -----
def build_figure():
    theme.apply()

    fig, (ax_left, ax_right) = plt.subplots(
        1, 2, figsize=(11.75, 4.2), gridspec_kw={"wspace": 0.25}
    )

    # ----- left panel: European vs American put -----
    S_grid = np.linspace(60.0, 140.0, 81)
    euro_p = euro_put(S_grid, K, r, sigma, T)
    amer_p = american_put_grid(S_grid, K, r, sigma, T, N=200)
    intrinsic = np.maximum(K - S_grid, 0.0)

    ax_left.plot(S_grid, intrinsic, color=theme.TICK_LABEL, linewidth=1.0,
                 linestyle="--", label="Intrinsic  max(K-S, 0)")
    ax_left.plot(S_grid, euro_p, color=theme.BLUE, linewidth=2.0,
                 label="European put")
    ax_left.plot(S_grid, amer_p, color=theme.PURPLE, linewidth=2.0,
                 label="American put")
    ax_left.fill_between(S_grid, euro_p, amer_p,
                         color=theme.GOLD, alpha=0.20,
                         label="Early exercise premium")

    ax_left.set_xlabel("Spot price S")
    ax_left.set_ylabel("Put price")
    ax_left.set_title("European vs American put", color=theme.AXIS_LABEL)
    ax_left.grid(True)
    ax_left.legend(loc="upper right", fontsize=8)
    ax_left.set_xlim(S_grid.min(), S_grid.max())
    ax_left.set_ylim(bottom=0)

    # ----- right panel: exercise boundary B(t) -----
    _, boundary = american_put_and_boundary(K, K, r, sigma, T, N)
    # Plot boundary vs time-to-expiry (tau = T - t_step*dt)
    dt = T / N
    t_steps = np.arange(N + 1) * dt
    tau = T - t_steps   # time remaining until expiry

    # Sort by tau ascending for a clean curve
    order = np.argsort(tau)
    tau_sorted = tau[order]
    b_sorted   = boundary[order]
    mask = ~np.isnan(b_sorted)
    tau_sorted = tau_sorted[mask]
    b_sorted   = b_sorted[mask]

    ax_right.fill_between(tau_sorted, 0.0, b_sorted,
                          color=theme.RED, alpha=0.15,
                          label="Exercise optimal")
    ax_right.fill_between(tau_sorted, b_sorted, K * 1.05,
                          color=theme.GREEN, alpha=0.10,
                          label="Hold")
    ax_right.plot(tau_sorted, b_sorted, color=theme.GOLD, linewidth=2.0,
                  label="Boundary B(tau)")
    ax_right.axhline(K, color=theme.TICK_LABEL, linewidth=1.0,
                     linestyle="--", alpha=0.6)
    ax_right.text(T * 0.02, K + 0.5, f"K = {K:g}",
                  color=theme.TICK_LABEL, fontsize=8)

    ax_right.set_xlabel("Time to expiry (years)")
    ax_right.set_ylabel("Spot price")
    ax_right.set_title("Early-exercise boundary", color=theme.AXIS_LABEL)
    ax_right.grid(True)
    ax_right.legend(loc="lower right", fontsize=8)
    ax_right.set_xlim(0, T)
    ax_right.set_ylim(0, K * 1.05)

    for ax in (ax_left, ax_right):
        for spine in ax.spines.values():
            spine.set_edgecolor(theme.FRAME)

    fig.tight_layout()
    return fig


if __name__ == "__main__":
    fig = build_figure()
    OUT.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(OUT, format="svg", bbox_inches="tight")
    print(f"wrote {OUT}")
