"""This module provides geoloc operations specific to the avhrr instrument.

In particular, it provides functions that allow matching gcp location in swath coordinates to reference positions, and
then minimise the distance to these positions by adjusting the time offset and attitude.
"""

import logging

import numpy as np
from pyproj import Geod

from pyorbital.geoloc import ScanGeometry, compute_pixels, get_lonlatalt

logger = logging.getLogger(__name__)
geod = Geod(ellps="WGS84")

def compute_avhrr_gcps_lonlatalt(gcps, max_scan_angle, rpy, start_time, tle) -> None:
    """Compute the longitute, latitude and altitude of given gcps (scanlines, columns of the swath).

    The gcps are arbitrary location in swath coordinates, for example (10.3, 7.7) for a gcp at line 10.3 in the swath,
    and column 7.7. This function returns the geographical coordinates of the gcps.

    The scanlines are relative to the pass scanline numbers, zero-based.
    """
    time_line_interval = 1/6
    time_row_interval = 25e-6

    fov_x = gcps[:, 1]
    fov_y = gcps[:, 0]

    scan_angles_across = (fov_x / 1023.5 - 1) * np.deg2rad(-max_scan_angle)
    scan_angles_along = np.zeros_like(scan_angles_across)
    scan_angles = np.vstack((scan_angles_across, scan_angles_along))
    time_offsets = np.array(fov_x * time_row_interval + fov_y * time_line_interval)
    geom = ScanGeometry(scan_angles, time_offsets)
    start_time = np.datetime64(start_time)
    s_times = geom.times(start_time)

    pixels_pos = compute_pixels(tle, geom, s_times, rpy)
    return get_lonlatalt(pixels_pos, s_times)


def estimate_time_and_attitude_deviations(gcps, ref_lons, ref_lats, start_time, tle, max_scan_angle):
    """Estimate time offset and attitude deviations from gcps.

    Provided reference longitudes and latitudes for the gcps, this function minimises the attitude and time offset
    needed to match the gcp coordinates to the reference coordinates.
    """
    from scipy.optimize import minimize

    original_distances = compute_gcp_distances_to_reference_lonlats((0, 0, 0, 0), gcps, start_time, tle, max_scan_angle,
                                                                    (ref_lons, ref_lats))
    original_median_distance = np.median(original_distances)
    logger.debug(f"GCP distances: median {original_median_distance}, std {np.std(original_distances)}")
    # we need to work in seconds*1e3 to avoid the nanosecond precision issue
    res = minimize(compute_gcp_accumulated_squared_distances_to_reference_lonlats,
                   x0=(0, 0, 0, 0),
                   args=(gcps, start_time, tle, max_scan_angle, (ref_lons, ref_lats)),
                   bounds=((-0.007, 0.007) , (-0.5, 0.5), (-0.5, 0.5), (-0.5, 0.5)))
    if not res.success:
        raise RuntimeError("Time and attitude estimation did not converge")
    time_diff, roll, pitch, yaw = res.x * [1e3, 1, 1, 1]
    logger.debug(f"Estimated time difference to {time_diff} seconds, "
                 f"attitude to {np.rad2deg(roll)}, {np.rad2deg(pitch)}, {np.rad2deg(yaw)} degrees")
    distances = compute_gcp_distances_to_reference_lonlats(res.x, gcps, start_time, tle, max_scan_angle,
                                                           (ref_lons, ref_lats))

    minimized_median_distance = np.median(distances)
    logger.debug(f"Remaining GCP distances: median {minimized_median_distance}, std {np.std(distances)}")

    return time_diff, (roll, pitch, yaw), (original_distances, distances)


def estimate_time_offset(gcps, ref_lons, ref_lats, start_time, tle, max_scan_angle):
    """Estimate time offset from gcps.

    Provided reference longitudes and latitudes for the gcps, this function minimises the time offset
    needed to match the gcp coordinates to the reference coordinates.
    """
    from scipy.optimize import minimize

    original_distances = compute_gcp_distances_to_reference_lonlats((0, 0, 0, 0), gcps, start_time, tle, max_scan_angle,
                                                                    (ref_lons, ref_lats))
    original_median_distance = np.median(original_distances)
    logger.debug(f"GCP distances: median {original_median_distance}, std {np.std(original_distances)}")

    def gcp_distance_for_time(time):
        dist = compute_gcp_accumulated_squared_distances_to_reference_lonlats((time[0], 0, 0, 0), gcps, start_time, tle,
                                                                              max_scan_angle, (ref_lons, ref_lats))
        return dist

    # we need to work in seconds*1e3 to avoid the nanosecond precision issue
    res = minimize(gcp_distance_for_time,
                   x0=(0,),
                   bounds=((-0.03, 0.03),),
                   options=dict(ftol=1e-1),
                   )
    if not res.success:
        raise RuntimeError("Time offset estimation did not converge")
    time_diff, = res.x * [1e3,]
    logger.debug(f"Estimated time difference to {time_diff} seconds")
    distances = compute_gcp_distances_to_reference_lonlats((res.x[0], 0, 0, 0), gcps, start_time, tle, max_scan_angle,
                                                           (ref_lons, ref_lats))

    minimized_median_distance = np.median(distances)
    logger.debug(f"Remaining GCP distances: median {minimized_median_distance}, std {np.std(distances)}")

    return time_diff, (original_distances, distances)


def compute_gcp_accumulated_squared_distances_to_reference_lonlats(
        variables, gcps, start_time, tle, max_scan_angle, refs):
    """Compute the summed squared distance fot gcps to reference lonlats.

    Given the gcps (in swath coordinates) along with attitude and time offset, compute the sum of squared distances to
    the reference lons and lats of the gcps.
    """
    distances = compute_gcp_distances_to_reference_lonlats(variables, gcps, start_time, tle, max_scan_angle, refs)
    return np.sum(distances**2)


def compute_gcp_distances_to_reference_lonlats(variables, gcps, start_time, tle, max_scan_angle, refs):
    """Compute the gcp distances to references lonlats."""
    time_diff, roll, pitch, yaw = variables
    # we need to work in seconds*1e3 to avoid the nanosecond precision issue
    time = np.datetime64(start_time) + np.timedelta64(int(time_diff * 1e12), "ns")
    lons, lats, _ = compute_avhrr_gcps_lonlatalt(gcps, max_scan_angle, (roll, pitch, yaw), time, tle)
    valid = np.isfinite(lons)
    lons = lons[valid]
    lats = lats[valid]
    ref_lons, ref_lats = refs
    ref_lons = np.array(ref_lons)[valid]
    ref_lats = np.array(ref_lats)[valid]
    _, _, distances = geod.inv(ref_lons, ref_lats, lons, lats)
    return distances
