package org.lsst.ccs.subsystem.imagehandling;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Deque;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import nom.tam.util.BufferedFile;
import nom.tam.fits.BinaryTable;
import nom.tam.fits.BinaryTableHDU;
import nom.tam.fits.FitsException;
import nom.tam.fits.Header;
import nom.tam.fits.HeaderCardException;

/**
 * Code for adding additional HDUs to generate FITS files for shutter motion
 * profiles
 */
public class ShutterMotionProfileHandler {

    private static final ObjectMapper MAPPER = new ObjectMapper();

    // Motion profile header fields: map "HIERARCH suffix" -> JSON dot path under motionProfile
    private static final LinkedHashMap<String, String> HEADER_FIELD_MAP = new LinkedHashMap<>();

    static {
        HEADER_FIELD_MAP.put("STARTTIME TAI ISOT", "startTime.tai.isot");
        HEADER_FIELD_MAP.put("STARTTIME TAI MJD", "startTime.tai.mjd");
        HEADER_FIELD_MAP.put("STARTPOSITION", "startPosition");
        HEADER_FIELD_MAP.put("TARGETPOSITION", "targetPosition");
        HEADER_FIELD_MAP.put("ENDPOSITION", "endPosition");
        HEADER_FIELD_MAP.put("TARGETDURATION", "targetDuration");
        HEADER_FIELD_MAP.put("ACTIONDURATION", "actionDuration");
        HEADER_FIELD_MAP.put("SIDE", "side");
    }

    private static final String HIERARCH_PREFIX = "HIERARCH SHUTTER ";

    static BinaryTableHDU createHDU(String json) throws JsonProcessingException, FitsException {
        JsonNode doc = MAPPER.readTree(json);
        String version = findVersion(doc);
        //System.out.println("version=" + version);
        JsonNode profile = safeNode(doc, "motionProfile");
        boolean open = profile.path("isOpen").asBoolean(true);
        JsonNode fitResults = safeNode(profile, "fitResults");
        BinaryTableHDU hdu = buildCombinedTableHDU(profile, fitResults, open ? "OPEN" : "CLOSE", version);
        return hdu;
    }

    static void appendHDUsToFitsFile(File file, BinaryTableHDU openHDU, BinaryTableHDU closeHDU) throws IOException, FitsException {
        try (BufferedFile bf2 = new BufferedFile(file, "rw")) {
            bf2.seek(bf2.length());
            openHDU.write(bf2);
            closeHDU.write(bf2);
        }
    }

    // ------------------------ Build one table HDU ------------------------
    private static BinaryTableHDU buildCombinedTableHDU(JsonNode profile, JsonNode fitResults, String extname, String version) throws FitsException {
        // Collect rows from encodeSamples + hallTransitions
        List<Row> rows = new ArrayList<>();
        for (JsonNode s : safeArray(profile, "encodeSamples")) {
            rows.add(Row.fromEncode(s, version));
        }
        for (JsonNode h : safeArray(profile, "hallTransitions")) {
            rows.add(Row.fromHall(h, version));
        }

        // Build column arrays
        final int n = rows.size();
        String[] TYPE = new String[n];
        String[] TIME_TAI = new String[n];
        double[] TIME_MJD = new double[n];
        double[] POSITION = new double[n];
        String[] SENSORID = new String[n];
        boolean[] ISON = new boolean[n];

        int maxType = 4, maxTai = 1, maxSid = 1; // minimum widths
        for (int i = 0; i < n; i++) {
            Row r = rows.get(i);
            TYPE[i] = r.TYPE;
            TIME_TAI[i] = r.TIME_TAI;
            TIME_MJD[i] = r.TIME_MJD;
            POSITION[i] = r.POSITION;
            SENSORID[i] = r.SENSORID == -1 ? "" : String.valueOf(r.SENSORID);
            ISON[i] = r.ISON;

            maxType = Math.max(maxType, TYPE[i] == null ? 0 : TYPE[i].length());
            maxTai = Math.max(maxTai, TIME_TAI[i] == null ? 0 : TIME_TAI[i].length());
            maxSid = Math.max(maxSid, SENSORID[i] == null ? 0 : SENSORID[i].length());
        }

        // Create BinaryTable with typed columns
        Object[] dataCols = new Object[]{
            TYPE,
            TIME_TAI,
            TIME_MJD,
            POSITION,
            SENSORID,
            ISON
        };
        int naxis1 = 8 * 2 + 1 + maxType + maxTai + maxSid;
        //System.out.println(naxis1);
        BinaryTable bt = new BinaryTable(dataCols);
        BinaryTableHDU hdu = new BinaryTableHDU(new Header(), bt);
        final Header header = hdu.getHeader();
        header.addValue("XTENSION", "BINTABLE", "binary table extension");
        header.addValue("BITPIX", 8, "array data type");
        header.addValue("NAXIS", 2, "number of array dimensions");
        header.addValue("NAXIS1", naxis1, "number of array dimensions");
        header.addValue("NAXIS2", n, "number of array dimensions");

        header.addValue("PCOUNT", 0, "number of group parameters");
        header.addValue("GCOUNT", 1, "number of groups");

        // Set column metadata: TTYPEn and TUNITn (comments on TTYPE/TUNIT)
        header.addValue("TFIELDS", 6, "number of table fields");

        setColumnMeta(hdu, 1, "TYPE", null, "Row type: 'ENCODE' (encoder sample) or 'HALL' (hall transition)");
        setColumnMeta(hdu, 2, "TIME_TAI", "ISOT", "ISO-8601 time (TAI) as provided in JSON");
        setColumnMeta(hdu, 3, "TIME_MJD", "d", "Modified Julian Date (TAI) as provided in JSON");
        setColumnMeta(hdu, 4, "POSITION", "mm", "Shutter position");
        setColumnMeta(hdu, 5, "SENSORID", null, "Hall sensor identifier (empty for ENCODE rows)");
        setColumnMeta(hdu, 6, "ISON", null, "Hall state: True if sensor ON; False otherwise");

        setTFORM(header, 1, "A" + Math.max(maxType, 1)); // TYPE (char[n])
        setTFORM(header, 2, "A" + Math.max(maxTai, 1)); // TIME_TAI (char[n])
        setTFORM(header, 3, "D");                        // TIME_MJD (float64)
        setTFORM(header, 4, "D");                        // POSITION (float64)
        setTFORM(header, 5, "A" + Math.max(maxSid, 1)); // SENSORID (char[n])
        setTFORM(header, 6, "L");                        // ISON (logical)

        header.addValue("EXTNAME", extname, "extension name");

        Map<String, String> headerFieldMap = HEADER_FIELD_MAP;
        if ("1.0".equals(version)) {
            headerFieldMap = new LinkedHashMap(headerFieldMap);
            headerFieldMap.put("STARTTIME TAI ISOT", "startTime.tai");
            headerFieldMap.put("STARTTIME TAI MJD", "startTime.mjd");
        }

        // Add HIERARCH headers from motionProfile
        addMotionProfileHeaders(header, profile, headerFieldMap);

        // Add fitResults headers (flattened)
        addFitResultsHeaders(header, fitResults);

        // (Optional) You can set display widths via TFORM if you like; not required.
        return hdu;
    }

    /**
     * Set (or overwrite) TFORMi card safely.
     */
    private static void setTFORM(Header hdr, int colIndex1Based, String tform) throws HeaderCardException {
        String key = "TFORM" + colIndex1Based;
        if (hdr.findCard(key) != null) {
            hdr.findCard(key).setValue(tform);
        } else {
            hdr.addValue(key, tform, "column data format");
        }
    }

    // ------------------------ Header helpers ------------------------
    private static void addMotionProfileHeaders(Header hdr, JsonNode profile, Map<String, String> headerFieldMap) throws HeaderCardException {
        for (Map.Entry<String, String> e : headerFieldMap.entrySet()) {
            String suffix = e.getKey();        // e.g., "STARTPOSITION"
            String jsonPath = e.getValue();      // e.g., "startPosition"
            JsonNode vNode = getByPath(profile, jsonPath);
            String key = HIERARCH_PREFIX + suffix;
            if (vNode != null && !vNode.isMissingNode() && !vNode.isNull()) {
                Object value = jsonNodeToScalar(vNode);
                addHierarch(hdr, key, value, null);
            } else {
                // keep the card with empty value (to mirror python script behavior)
                addHierarch(hdr, key, "", null);
            }
        }
    }

    private static void addFitResultsHeaders(Header hdr, JsonNode fitResults) throws HeaderCardException {
        if (fitResults == null || fitResults.isMissingNode() || fitResults.isNull()) {
            return;
        }

        List<FlatEntry> flat = new ArrayList<>();
        flatten("", fitResults, flat);

        int fallbackIdx = 0;
        for (FlatEntry fe : flat) {
            // Drop leading 'fitResults.' if present; build suffix words in UPPERCASE separated by space
            String path = fe.path.startsWith("fitResults.") ? fe.path.substring("fitResults.".length()) : fe.path;
            String suffix = suffixFromPath(path); // e.g., "HALLSENSORFIT MODELSTARTTIME"
            String key = HIERARCH_PREFIX + suffix;

            // try to add; on failure fallback to FITRES<n> / NAME
            try {
                addHierarch(hdr, key, fe.value, null);
            } catch (HeaderCardException ex) {
                fallbackIdx++;
                String kVal = HIERARCH_PREFIX + ("FITRES" + fallbackIdx);
                String kName = HIERARCH_PREFIX + ("FITRES" + fallbackIdx + " NAME");
                addHierarch(hdr, kVal, fe.value, null);
                addHierarch(hdr, kName, fe.path, null);
            }
        }
    }

    private static void addHierarch(Header hdr, String keyword, Object value, String comment) throws HeaderCardException {
        // nom.tam handles HIERARCH when FitsFactory.setUseHierarch(true)
        keyword = keyword.replaceAll(" ", ".");
        if (value == null) {
            value = "";
        }
        if (value instanceof Double aDouble) {
            hdr.addValue(keyword, aDouble, comment);
        } else if (value instanceof Boolean aBoolean) {
            hdr.addValue(keyword, aBoolean, comment);
        } else {
            //System.out.println(keyword);
            hdr.addValue(keyword, String.valueOf(value), comment);
        }
    }

    private static String suffixFromPath(String path) {
        // Split on non-alphanumerics, uppercase, join with single spaces
        String[] parts = path.split("[^A-Za-z0-9]+");
        StringBuilder sb = new StringBuilder();
        for (String p : parts) {
            if (p.isEmpty()) {
                continue;
            }
            if (sb.length() > 0) {
                sb.append(' ');
            }
            sb.append(p.toUpperCase());
        }
        return sb.toString();
    }

    private static void setColumnMeta(BinaryTableHDU hdu, int colIndex1Based, String name, String unit, String ttypeComment) throws HeaderCardException {
        Header hdr = hdu.getHeader();
        // TTYPEi
        String ttypeKey = "TTYPE" + colIndex1Based;
        if (hdr.findCard(ttypeKey) != null) {
            hdr.findCard(ttypeKey).setValue(name);
            hdr.findCard(ttypeKey).setComment(ttypeComment);
        } else {
            hdr.addValue(ttypeKey, name, ttypeComment);
        }
        // TUNITi
        if (unit != null && !unit.isEmpty()) {
            String tunitKey = "TUNIT" + colIndex1Based;
            if (hdr.findCard(tunitKey) != null) {
                hdr.findCard(tunitKey).setValue(unit);
                hdr.findCard(tunitKey).setComment(unitCommentFor(unit, name));
            } else {
                hdr.addValue(tunitKey, unit, unitCommentFor(unit, name));
            }
        }
    }

    private static String unitCommentFor(String unit, String colName) {
        return switch (unit) {
            case "ISOT" ->
                "Time scale indicator for " + colName + " ('ISOT')";
            case "d" ->
                "Days";
            case "mm" ->
                "Millimetres";
            default ->
                null;
        };
    }

    // ------------------------ JSON helpers ------------------------
    private static JsonNode safeNode(JsonNode root, String field) {
        return root != null && root.has(field) ? root.get(field) : MAPPER.createObjectNode();
        // empty object if missing
    }

    private static List<JsonNode> safeArray(JsonNode root, String field) {
        if (root != null && root.has(field) && root.get(field).isArray()) {
            List<JsonNode> list = new ArrayList<>();
            root.get(field).forEach(list::add);
            return list;
        }
        return Collections.emptyList();
    }

    private static JsonNode getByPath(JsonNode node, String dotPath) {
        if (node == null) {
            return null;
        }
        String[] parts = dotPath.split("\\.");
        JsonNode cur = node;
        for (String p : parts) {
            if (cur == null || cur.isMissingNode()) {
                return null;
            }
            cur = cur.get(p);
        }
        return cur;
    }

    private static Object jsonNodeToScalar(JsonNode node) {
        if (node == null || node.isNull()) {
            return "";
        }
        if (node.isTextual()) {
            return node.asText();
        }
        if (node.isInt() || node.isLong()) {
            return node.asLong();
        }
        if (node.isFloat() || node.isDouble() || node.isBigDecimal()) {
            return node.asDouble();
        }
        if (node.isBoolean()) {
            return node.asBoolean();
        }
        // fallback to compact JSON
        return node.toString();
    }

    private static String findObsid(JsonNode doc) {
        // DFS search for typical OBSID keys, case-insensitive
        Set<String> keys = new HashSet<>(Arrays.asList("obsid", "observationid", "obs_id", "obsidstr", "obsId"));
        Deque<JsonNode> stack = new ArrayDeque<>();
        stack.push(doc);
        while (!stack.isEmpty()) {
            JsonNode cur = stack.pop();
            if (cur.isObject()) {
                Iterator<Map.Entry<String, JsonNode>> it = cur.fields();
                while (it.hasNext()) {
                    Map.Entry<String, JsonNode> e = it.next();
                    if (keys.contains(e.getKey().toLowerCase(Locale.ROOT))) {
                        return e.getValue().asText();
                    }
                    JsonNode v = e.getValue();
                    if (v.isContainerNode()) {
                        stack.push(v);
                    }
                }
            } else if (cur.isArray()) {
                cur.forEach(stack::push);
            }
        }
        return null;
    }

    private static String findVersion(JsonNode doc) {
        // DFS search for typical OBSID keys, case-insensitive
        Set<String> keys = new HashSet<>(Arrays.asList("version"));
        Deque<JsonNode> stack = new ArrayDeque<>();
        stack.push(doc);
        while (!stack.isEmpty()) {
            JsonNode cur = stack.pop();
            if (cur.isObject()) {
                Iterator<Map.Entry<String, JsonNode>> it = cur.fields();
                while (it.hasNext()) {
                    Map.Entry<String, JsonNode> e = it.next();
                    if (keys.contains(e.getKey().toLowerCase(Locale.ROOT))) {
                        return e.getValue().asText();
                    }
                    JsonNode v = e.getValue();
                    if (v.isContainerNode()) {
                        stack.push(v);
                    }
                }
            } else if (cur.isArray()) {
                cur.forEach(stack::push);
            }
        }
        return null;
    }

    // ------------------------ Flatten helper ------------------------
    private static class FlatEntry {

        final String path;
        final Object value;

        FlatEntry(String path, Object value) {
            this.path = path;
            this.value = value;
        }
    }

    private static void flatten(String prefix, JsonNode node, List<FlatEntry> out) {
        if (node == null || node.isNull()) {
            return;
        }
        if (node.isObject()) {
            Iterator<Map.Entry<String, JsonNode>> it = node.fields();
            while (it.hasNext()) {
                Map.Entry<String, JsonNode> e = it.next();
                String path = prefix.isEmpty() ? e.getKey() : prefix + "." + e.getKey();
                flatten(path, e.getValue(), out);
            }
        } else if (node.isArray()) {
            // We follow the Python writer: do NOT expand arrays into fitResults
            // (arrays were not used in fitResults in the writer; if present, store JSON string)
            out.add(new FlatEntry(prefix, node.toString()));
        } else {
            out.add(new FlatEntry(prefix, jsonNodeToScalar(node)));
        }
    }

    // ------------------------ Row model ------------------------
    private static class Row {

        final String TYPE;     // ENCODE | HALL
        final String TIME_TAI; // ISOT string
        final double TIME_MJD; // days (as provided)
        final double POSITION; // mm
        final int SENSORID; // -1 for ENCODE
        final boolean ISON;    // false for ENCODE

        Row(String type, String tai, double mjd, double position, Integer sensorId, boolean isOn) {
            this.TYPE = type == null ? "" : type;
            this.TIME_TAI = tai == null ? "" : tai;
            this.TIME_MJD = Double.isNaN(mjd) ? Double.NaN : mjd;
            this.POSITION = Double.isNaN(position) ? Double.NaN : position;
            this.SENSORID = sensorId == null ? -1 : sensorId;
            this.ISON = isOn;
        }

        static Row fromEncode(JsonNode s, String version) {
            String taiPath = "tai.isot";
            String mjdPath = "tai.mjd";
            if ("1.0".contains(version)) {
                taiPath = "time.tai";
                mjdPath = "time.mjd";
            }
            String tai = optText(getByPath(s, taiPath), "");
            Double mjd = optDouble(getByPath(s, mjdPath));
            Double pos = optDouble(s.get("position"));
            return new Row("ENCODE", tai, mjd != null ? mjd : Double.NaN, pos != null ? pos : Double.NaN, -1, false);
        }

        static Row fromHall(JsonNode h, String version) {
            String taiPath = "tai.isot";
            String mjdPath = "tai.mjd";
            if ("1.0".contains(version)) {
                taiPath = "time.tai";
                mjdPath = "time.mjd";
            }
            String tai = optText(getByPath(h, taiPath), "");
            Double mjd = optDouble(getByPath(h, mjdPath));
            Double pos = optDouble(h.get("position"));
            int sensorId = optInteger(h.get("sensorId"));
            boolean isOn = optBool(h.get("isOn"), false);
            return new Row("HALL", tai, mjd != null ? mjd : Double.NaN, pos != null ? pos : Double.NaN, sensorId, isOn);
        }

        private static String optText(JsonNode n, String def) {
            return (n != null && n.isTextual()) ? n.asText() : def;
        }

        private static Double optDouble(JsonNode n) {
            return (n != null && n.isNumber()) ? n.asDouble() : null;
        }

        private static Integer optInteger(JsonNode n) {
            return (n != null && n.isNumber()) ? n.asInt() : null;
        }

        private static boolean optBool(JsonNode n, boolean def) {
            return (n != null && n.isBoolean()) ? n.asBoolean() : def;
        }
    }
}
