package org.lsst.ccs.subsystem.ocsbridge.sim;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.stream.DoubleStream;
import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
import org.apache.commons.math3.fitting.SimpleCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoint;
import org.lsst.ccs.subsystem.ocsbridge.events.ShutterMotionProfileFitResult;
import org.lsst.ccs.subsystem.shutter.status.MotionDone;

/**
 * A simple wrapper around code to perform fits to a shutter motion profile.
 * Results can be returned as either a json formatted map, or as
 * ShutterMotionProfileFitResults
 *
 * @author tonyj
 */
public class ShutterMotionProfileFitter {

    private final ShutterMotionProfileFitResult hallSensorFit;
    private final ShutterMotionProfileFitResult motorEncoderFit;
    private final double aveMidPoint;
    private final double midPointFromZeroAccelerationHall;
    private final double midPointFromNewtonMethodHall;
    private final double midPointFromZeroAccelerationMotorEncoder;
    private final double midPointFromNewtonMethodMotorEncoder;
    private double[] hallFitParametersPhysical;
    private double[] motorEncoderFitParametersPhysical;
    
    public static CompletableFuture<ShutterMotionProfileFitter> fit(Executor executor, MotionDone md) {
       CompletableFuture<ShutterMotionProfileFitter> result = new CompletableFuture<>();
       executor.execute(() -> {
           try {
               result.complete(new ShutterMotionProfileFitter(md));
           } catch (Throwable t) {
               result.completeExceptionally(t);
           }
       });
       return result;
    }
    
    private ShutterMotionProfileFitter(MotionDone md) {
        hallSensorFit = doHallSensorFit(md);
        motorEncoderFit = doMotorEncoderFit(md);
        aveMidPoint = (md.startPosition() + md.endPosition()) / 2.0;
        midPointFromZeroAccelerationHall = midPointCalculator(aveMidPoint, hallFitParametersPhysical)[0];
        midPointFromNewtonMethodHall = midPointCalculator(aveMidPoint, hallFitParametersPhysical)[1];
        midPointFromZeroAccelerationMotorEncoder = midPointCalculator(aveMidPoint, motorEncoderFitParametersPhysical)[0];
        midPointFromNewtonMethodMotorEncoder = midPointCalculator(aveMidPoint, motorEncoderFitParametersPhysical)[1];
    }

    public double getAveMidPoint() {
        return aveMidPoint;
    }

    public double getMidPointFromZeroAccelerationHall() {
        return midPointFromZeroAccelerationHall;
    }

    public double getMidPointFromNewtonMethodHall() {
        return midPointFromNewtonMethodHall;
    }

    public double getMidPointFromZeroAccelerationMotorEncoder() {
        return midPointFromZeroAccelerationMotorEncoder;
    }

    public double getMidPointFromNewtonMethodMotorEncoder() {
        return midPointFromNewtonMethodMotorEncoder;
    }

    public double[] getHallFitParametersPhysical() {
        return hallFitParametersPhysical;
    }

    public double[] getMotorEncoderFitParametersPhysical() {
        return motorEncoderFitParametersPhysical;
    }

    Map<String, Object> getJsonResult() {
        Map<String, Object> fits = new LinkedHashMap<>();
        fits.put("Model", "ThreeJerksModelv1");
        fits.put("hallSensorFit", hallSensorFit.getNamedFitParameters());
        fits.put("motorEncoderFit", motorEncoderFit.getNamedFitParameters());
        return fits;
    }

    private ShutterMotionProfileFitResult doMotorEncoderFit(MotionDone md) {
        long startTime = md.startTime().getTAIInstant().toEpochMilli();
        double actualDuration = md.actualDuration().getSeconds() + md.actualDuration().getNano() * 1e-9;
        double endPosition = md.endPosition();

        double[] times_physical = DoubleStream.concat(md.encoderSamples().stream().mapToDouble(ht -> (ht.getTime().getTAIInstant().toEpochMilli() - startTime) / 1000.0), DoubleStream.of(actualDuration)).toArray();
        double[] positions_physical = DoubleStream.concat(md.encoderSamples().stream().mapToDouble(ht -> ht.getPosition()), DoubleStream.of(endPosition)).toArray();
        motorEncoderFitParametersPhysical = doFit(times_physical, md, endPosition, actualDuration, positions_physical);
        return new ShutterMotionProfileFitResult(doFit(times_physical, md, endPosition, actualDuration, positions_physical));
    }

    private ShutterMotionProfileFitResult doHallSensorFit(MotionDone md) {
        // Convert times to seconds since startTime
        long startTime = md.startTime().getTAIInstant().toEpochMilli();
        double actualDuration = md.actualDuration().getSeconds() + md.actualDuration().getNano() * 1e-9;
        double endPosition = md.endPosition();

        double[] times_physical = DoubleStream.concat(md.hallTransitions().stream().mapToDouble(ht -> (ht.getTime().getTAIInstant().toEpochMilli() - startTime) / 1000.0), DoubleStream.of(actualDuration)).toArray();
        double[] positions_physical = DoubleStream.concat(md.hallTransitions().stream().mapToDouble(ht -> ht.getPosition()), DoubleStream.of(endPosition)).toArray();
        hallFitParametersPhysical = doFit(times_physical, md, endPosition, actualDuration, positions_physical);
        return new ShutterMotionProfileFitResult(doFit(times_physical, md, endPosition, actualDuration, positions_physical));
    }

    private double[] doFit(double[] times_physical, MotionDone md, double endPosition, double actualDuration, double[] positions_physical) {
        // normalize the data to [0,1]
        int dataNumber = times_physical.length;
        double[] times_scaled = new double[dataNumber];
        double[] positions_scaled = new double[dataNumber];

        double startPosition = md.startPosition();
        String direction = "x";
        if (startPosition < endPosition) {
            direction = "+";
        } else if (startPosition > endPosition) {
            direction = "-";
        }

        for (int i = 0; i < dataNumber; i++) {
            times_scaled[i] = times_physical[i] / actualDuration;
            if (direction.equals("+")) {
                positions_scaled[i] = (positions_physical[i] - startPosition) / 750;
            } else if (direction.equals("-")) {
                positions_scaled[i] = (startPosition - positions_physical[i]) / 750;
            }
        }

        final List<WeightedObservedPoint> points = new ArrayList<>();
        SixParameterModel model = new SixParameterModel();

        // Try a better example using SimpleCurveFitter and ParametricUnivariateFunction
//        double t0p = 0.00;
//        double t1p = 0.225;
//        double t2p = 0.675;
//        double j0p = 30000;
//        double j1p = -30000;
//        double j2p = 30000;
        double t0s = 0.00 / actualDuration;
        double t1s = 0.225 / actualDuration;
        double t2s = 0.675 / actualDuration;
        double j0s = 30000 * Math.pow(actualDuration, 3) / 750;
        double j1s = -30000 * Math.pow(actualDuration, 3) / 750;
        double j2s = 30000 * Math.pow(actualDuration, 3) / 750;
        double[] startPoint_scaled = {t0s, t1s, t2s, j0s, j1s, j2s};

//        System.out.println("====================================== Start point");
//        System.out.printf("HallTransition: %s %s %s %s %s %s\n", "scaled time", "positions", "model;", "physical time", "positions", "model");
        for (int i = 0; i < dataNumber; i++) {
//            double p_scaled = model.value(times_scaled[i], t0s, t1s, t2s, j0s, j1s, j2s);
//            double p_physical = model.value(times_physical[i], t0p, t1p, t2p, j0p, j1p, j2p);
//            System.out.printf("HallTransition: %g %g %g; %g %g %g\n", times_scaled[i], positions_scaled[i], p_scaled,
//                    times_physical[i], positions_physical[i], p_physical);
            points.add(new WeightedObservedPoint(1.0, times_scaled[i], positions_scaled[i]));
        }
//        System.out.printf("grad: %s\n", Arrays.toString(model.gradient(0.5, startPoint_scaled)));

        int maxIter = 1000;
        SimpleCurveFitter curveFitter = SimpleCurveFitter.create(model, startPoint_scaled).withMaxIterations(maxIter);
        double[] fit_scaled = curveFitter.fit(points);
        double[] fit_physical = new double[fit_scaled.length];
        fit_physical[0] = fit_scaled[0] * actualDuration;
        fit_physical[1] = fit_scaled[1] * actualDuration;
        fit_physical[2] = fit_scaled[2] * actualDuration;
        fit_physical[3] = fit_scaled[3] * 750 / Math.pow(actualDuration, 3);
        fit_physical[4] = fit_scaled[4] * 750 / Math.pow(actualDuration, 3);
        fit_physical[5] = fit_scaled[5] * 750 / Math.pow(actualDuration, 3);

//        System.out.println("Fit result: " + Arrays.toString(fit_physical));
//        System.out.println("====================================== After fit");
//        System.out.printf("HallTransition: %s %s %s %s %s %s\n", "scaled time", "positions", "model;", "physical time", "positions", "model");
//        for (int i = 0; i < times_physical.length; i++) {
//            double p_scaled = model.value(times_scaled[i], fit_scaled);
//            double p_physical = model.value(times_physical[i], fit_physical);
//            System.out.printf("HallTransition: %g %g %g; %g %g %g\n", times_scaled[i], positions_scaled[i], p_scaled,
//                    times_physical[i], positions_physical[i], p_physical);
//        }
        return fit_physical;
    }

    ShutterMotionProfileFitResult getMotorEncoderFit() {
        return this.motorEncoderFit;
    }

    ShutterMotionProfileFitResult getHallSensorFit() {
        return this.hallSensorFit;
    }

    private static class SixParameterModel implements ParametricUnivariateFunction {

        @Override
        public double value(double t, double... parameters) {
            // Value of function from https://www.overleaf.com/project/637813358c20c4b04c7ff190
            // Probably wrong (seems like t0 is not really used??)
            double t0 = parameters[0];
            double t1 = parameters[1];
            double t2 = parameters[2];
            double j0 = parameters[3];
            double j1 = parameters[4];
            double j2 = parameters[5];

            t = t - t0;
            if (t < t1) {
                return j0 * t * t * t / 6;
            } else {
                double A1 = (j0 - j1) * t1;
                double V1 = (j0 - j1) * t1 * t1 / 2 - A1 * t1;
                double S1 = (j0 - j1) * t1 * t1 * t1 / 6 - A1 * t1 * t1 / 2 - V1 * t1;
                if (t < t2) {
                    return j1 * t * t * t / 6 + A1 * t * t / 2 + V1 * t + S1;
                } else {
                    double A2 = A1 + (j1 - j2) * t2;
                    double V2 = (j1 - j2) * t2 * t2 / 2 + (A1 - A2) * t2 + V1;
                    double S2 = (j1 - j2) * t2 * t2 * t2 / 6 + (A1 - A2) * t2 * t2 / 2 + (V1 - V2) * t2 + S1;
                    return j2 * t * t * t / 6 + A2 * t * t / 2 + V2 * t + S2;
                }
            }
        }

        @Override
        public double[] gradient(double t, double... parameters) {
            // Extremely naive gradient computation. Would be better to have analytical gradients
            double delta = 0.001;
            double[] result = new double[parameters.length];
            for (int i = 0; i < parameters.length; i++) {
                double original = parameters[i];
                parameters[i] = original - delta;
                double v1 = value(t, parameters);
                parameters[i] = original + delta;
                double v2 = value(t, parameters);
                parameters[i] = original;
                result[i] = (v2 - v1) / (2 * delta);
            }
            return result;
        }
    }

    // The code below is taken with thanks from https://github.com/shuang92/StandaloneJavaShutterFitter/blob/main/src/main/java/org/lsst/ccs/standalone/shutter/fitter/StandaloneJavaShutterFitter.java
    // the Author is Luang Shuang of SLAC- this code returns the time at the midpoint of motion by both setting the acceleration
    // equal to zero and using the appropriate jerk*t as the acceleration, and also by solving the cubic equation using Newton's method
    // incorporated by Farrukh, Tony, July 28th 2024.
    private static double[] midPointCalculator(double Smid, double... parameters) {

        double[] midTime = new double[2];

        double t1 = parameters[1];
        double j0 = parameters[3];
        double j1 = parameters[4];

        double A1 = (j0 - j1) * t1;
        double V1 = (j0 - j1) * t1 * t1 / 2 - A1 * t1;
        double S1 = (j0 - j1) * t1 * t1 * t1 / 6 - A1 * t1 * t1 / 2 - V1 * t1;

        // the variable t_max_v is the time at mid point calculated by setting the velocity equal to zero - 
        // thanks to Luang Shuang 
        double t_max_v = -A1 / j1;

        // coefficients of a cubic equation: a*t^3 + b*t^2 + c*t + d = 0
        double a = j1 / 6;
        double b = A1 / 2;
        double c = V1;
        double d = S1 - Smid;

        b = b / a;
        c = c / a;
        d = d / a;
        a = 1;

        double t_old = t_max_v;
        double f_t = a * t_old * t_old * t_old + b * t_old * t_old + c * t_old + d;
        double fp_t = 3 * a * t_old * t_old + 2 * b * t_old + c;
        double t_new = t_old - 0.2 * f_t / fp_t;

        while (Math.abs(t_new - t_old) > 1e-5) {

            t_old = t_new;
            f_t = a * t_old * t_old * t_old + b * t_old * t_old + c * t_old + d;
            fp_t = 3 * a * t_old * t_old + 2 * b * t_old + c;
            t_new = t_old - 0.2 * f_t / fp_t;
        }

        midTime[0] = t_max_v;
        midTime[1] = t_new;

        return midTime;
    }
}
