#!/usr/bin/env python3
"""Export DOA spectrum data to JSON for the web visualizer."""

import numpy as np
from scipy.fft import fft
import os
import json

# Import from parent
import sys
sys.path.insert(0, '..')

import constants
from constants import create_rotated_uca, CHANNEL_ORDER, C, RADIUS, ARRAY_ROTATION, NUM_ELEMENTS
from doa_py.algorithm.music_based import music


def extract_tone_phases(X, fft_size=8192, n_snapshots=100):
    """Extract phases at the tone frequency using short-time FFT."""
    n_samples = X.shape[1]
    fft_full = fft(X[0, :fft_size])
    tone_bin = np.argmax(np.abs(fft_full[1:fft_size//2])) + 1

    hop = fft_size // 2
    phases = []
    for start in range(0, n_samples - fft_size, hop):
        if len(phases) >= n_snapshots:
            break
        segment = X[:, start:start + fft_size]
        fft_seg = fft(segment, axis=1)
        phase = np.angle(fft_seg[:, tone_bin])
        phases.append(phase)

    return np.array(phases).T  # Shape: (n_channels, n_snapshots)


def compute_phase_metrics(X):
    """Compute phase difference metrics for a measurement."""
    tone_phases = extract_tone_phases(X, n_snapshots=200)

    # Phase differences relative to channel 0
    phase_diffs = {}
    phase_stds = {}

    for ch in range(1, 4):
        diff = tone_phases[ch] - tone_phases[0]
        diff = np.angle(np.exp(1j * diff))  # Wrap to [-pi, pi]
        mean_diff = np.rad2deg(np.mean(diff))
        std_diff = np.rad2deg(np.std(diff))
        phase_diffs[f'ch{ch}_ch0'] = round(mean_diff, 1)
        phase_stds[f'ch{ch}_ch0'] = round(std_diff, 2)

    # Adjacent channel differences (for circular array)
    for i in range(4):
        j = (i + 1) % 4
        diff = tone_phases[j] - tone_phases[i]
        diff = np.angle(np.exp(1j * diff))
        mean_diff = np.rad2deg(np.mean(diff))
        phase_diffs[f'ch{j}_ch{i}'] = round(mean_diff, 1)

    # Phase coherence (low std = good coherence)
    avg_std = np.mean([phase_stds[k] for k in phase_stds])

    return {
        'phase_diffs': phase_diffs,
        'phase_stds': phase_stds,
        'coherence': round(avg_std, 2)  # Lower is better
    }


def load_cs16(filepath):
    raw = np.fromfile(filepath, dtype=np.int16)
    return raw[0::2] + 1j * raw[1::2]


def load_measurement(base_dir, angle):
    data_dir = os.path.join(base_dir, f"{angle}deg")
    channels = []
    for i in range(4):
        filepath = os.path.join(data_dir, f"channel-{i}.cs16")
        channels.append(load_cs16(filepath))
    raw_data = np.vstack(channels)
    return raw_data[CHANNEL_ORDER]


def extract_tone_snapshots(X, n_snapshots=2048, fft_size=1024):
    n_samples = X.shape[1]
    fft_full = fft(X[0])
    peak_idx = np.argmax(np.abs(fft_full[1:n_samples//2])) + 1
    norm_freq = peak_idx / n_samples
    tone_bin = int(norm_freq * fft_size)

    snapshots = []
    hop = fft_size // 2
    for start in range(0, n_samples - fft_size, hop):
        if len(snapshots) >= n_snapshots:
            break
        segment = X[:, start:start + fft_size]
        fft_seg = fft(segment, axis=1)
        snapshots.append(fft_seg[:, tone_bin])
    return np.array(snapshots).T


def circular_error(est, true):
    err = est - true
    while err > 180: err -= 360
    while err < -180: err += 360
    return err


def export_dataset(name, freq_hz, data_dir, step):
    """Export a single dataset to JSON format."""
    print(f"\nExporting {name}...")

    if not os.path.exists(data_dir):
        print(f"  Skipping - directory not found: {data_dir}")
        return None

    uca = create_rotated_uca()
    angle_grids = np.arange(-180, 180, 1)
    wavelength = C / freq_hz

    # Get available angles
    angles = sorted([int(d.replace('deg', ''))
                    for d in os.listdir(data_dir)
                    if d.endswith('deg')])

    print(f"  Found {len(angles)} angles")

    results = []
    for idx, true_angle in enumerate(angles):
        try:
            X = load_measurement(data_dir, true_angle)
            snapshots = extract_tone_snapshots(X, n_snapshots=2048)

            spectrum = music(
                received_data=snapshots,
                num_signal=1,
                array=uca,
                signal_fre=freq_hz,
                angle_grids=angle_grids,
                unit="deg"
            )

            est = int(angle_grids[np.argmax(spectrum)])
            true_adj = true_angle if true_angle <= 180 else true_angle - 360
            err = circular_error(est, true_adj)

            # Normalize spectrum like doa_py: dB then shift so min=0, max=1
            spectrum_db = 10 * np.log10(spectrum / np.max(spectrum) + 1e-10)
            # Shift so minimum is 0
            spectrum_shifted = spectrum_db - np.min(spectrum_db)
            # Normalize to 0-1 range
            spectrum_norm = spectrum_shifted / np.max(spectrum_shifted)

            # Downsample spectrum for smaller JSON (every 2 degrees)
            spectrum_sparse = spectrum_norm[::2].tolist()
            angles_sparse = angle_grids[::2].tolist()

            # Compute phase metrics
            phase_metrics = compute_phase_metrics(X)

            results.append({
                'true_angle': true_angle,
                'estimate': est,
                'error': int(err),
                'spectrum': spectrum_sparse,
                'angles': angles_sparse,
                'phase_diffs': phase_metrics['phase_diffs'],
                'phase_stds': phase_metrics['phase_stds'],
                'coherence': phase_metrics['coherence']
            })

            if (idx + 1) % 10 == 0:
                print(f"  Processed {idx + 1}/{len(angles)}")

        except Exception as e:
            print(f"  Error at {true_angle}°: {e}")

    # Compute dataset-wide statistics
    errors = [m['error'] for m in results]
    coherences = [m['coherence'] for m in results]

    return {
        'name': name,
        'freq_hz': freq_hz,
        'freq_ghz': freq_hz / 1e9,
        'wavelength_mm': wavelength * 1000,
        'n_angles': len(results),
        # Configuration
        'config': {
            'channel_order': CHANNEL_ORDER,
            'array_rotation': ARRAY_ROTATION,
            'radius_mm': RADIUS * 1000,
            'num_elements': NUM_ELEMENTS,
            'radius_wavelengths': round(RADIUS / wavelength, 3)
        },
        # Dataset statistics
        'stats': {
            'mean_error': round(np.mean(np.abs(errors)), 1),
            'max_error': int(np.max(np.abs(errors))),
            'rmse': round(np.sqrt(np.mean(np.array(errors)**2)), 1),
            'mean_coherence': round(np.mean(coherences), 2),
            'good_angles': sum(1 for e in errors if abs(e) < 10),
            'moderate_angles': sum(1 for e in errors if 10 <= abs(e) < 25),
            'poor_angles': sum(1 for e in errors if abs(e) >= 25)
        },
        'measurements': results
    }


def main():
    # Datasets to export - all available frequencies
    datasets = {
        # 10-degree increment datasets
        '1200MHz_10deg': {'freq_hz': 1.2e9, 'data_dir': '../data/1200MHz, 0dB, 10deg increments, outside', 'step': 10},
        '1500MHz_10deg': {'freq_hz': 1.5e9, 'data_dir': '../data/1500MHz, 0dB, 10deg increments, outside', 'step': 10},
        '2000MHz_10deg': {'freq_hz': 2.0e9, 'data_dir': '../data/2000MHz, 0dB, 10deg increments, outside', 'step': 10},
        '2400MHz_10deg': {'freq_hz': 2.4e9, 'data_dir': '../data/2400MHz, 0dB, 10deg increments, outside', 'step': 10},
        '3000MHz_10deg': {'freq_hz': 3.0e9, 'data_dir': '../data/3000MHz, 0dB, 10deg increments, outside', 'step': 10},
        '4000MHz_10deg': {'freq_hz': 4.0e9, 'data_dir': '../data/4000MHz, 0dB, 10deg increments, outside', 'step': 10},
        '5000MHz_10deg': {'freq_hz': 5.0e9, 'data_dir': '../data/5000MHz, 0dB, 10deg increments, outside', 'step': 10},
        '5700MHz_10deg': {'freq_hz': 5.7e9, 'data_dir': '../data/5700MHz, 0dB, 10deg increments, outside', 'step': 10},
        '5800MHz_10deg': {'freq_hz': 5.8e9, 'data_dir': '../data/5800MHz, 0dB, 10deg increments, outside', 'step': 10},
        '6000MHz_10deg': {'freq_hz': 6.0e9, 'data_dir': '../data/6000MHz, 0dB, 10deg increments, outside', 'step': 10},
        # 1-degree increment datasets
        '1200MHz_1deg':  {'freq_hz': 1.2e9, 'data_dir': '../data/1200MHz, 0dB, 1deg increments, outside', 'step': 1},
        '5700MHz_1deg':  {'freq_hz': 5.7e9, 'data_dir': '../data/5700MHz, 0dB, 1deg increments, outside', 'step': 1},
    }

    all_data = {}

    for name, config in datasets.items():
        result = export_dataset(name, **config)
        if result:
            all_data[name] = result

    # Write to JSON
    output_file = 'doa_data.json'
    with open(output_file, 'w') as f:
        json.dump(all_data, f)

    # Get file size
    size_mb = os.path.getsize(output_file) / (1024 * 1024)
    print(f"\nExported to {output_file} ({size_mb:.1f} MB)")
    print(f"Datasets: {list(all_data.keys())}")


if __name__ == '__main__':
    main()
