package org.lsst.ccs.messaging.util;

import java.util.*;
import java.util.concurrent.CopyOnWriteArrayList;
import org.lsst.ccs.bus.definition.Bus;
import org.lsst.ccs.utilities.logging.Logger;

/**
 * Base class to facilitate implementing {@link Dispatcher}.
 * If we need statistic-based actions or MBean access, this is the place to addReport it.
 *
 * @author onoprien
 */
abstract public class AbstractDispatcher implements Dispatcher {

// -- Fields : -----------------------------------------------------------------
    
    static private final Logger LOGGER = Logger.getLogger("org.lsst.ccs.messaging.Dispatcher");
    static private final int[] DEF_PERIODS = new int[0];  // {0,60}; // status monitoring is now disabled by default
    static protected final Timer timer = new Timer("Dispatcher timer", true);
    
    private final int[] periods; // monitoring periods (seconds, 0 means lifetime); empty if no monitoring
    private final Status statusAccumulator; // mutable Status instance used to accumulate statistics; {@code null} in no monitoring
    private volatile Status statusOut;  // latest available Status; {@code null} in no monitoring
    private final Task[] statusTasks; // status updating periodic tasks; each creates a new Status snapshot and assigns it to {@code statusOut}
    private final CopyOnWriteArrayList<StatusListener> listeners = new CopyOnWriteArrayList<>();

// -- Life cycle : -------------------------------------------------------------
    
    protected AbstractDispatcher(String... args) {
        
        int[] pp = getIntArrayArg("periods", args);
        if (pp == null) {
            periods = DEF_PERIODS;
        } else {
            Arrays.sort(pp);
            periods = pp;
        }
        
        if (periods.length == 0) {
            statusAccumulator = null;
            statusTasks = new Task[0];
        } else {
            statusAccumulator = new Status(periods, true);
            ArrayList<Task> tasks = new ArrayList<>();
            for (int pid = 0; pid < periods.length; pid++) {
                int seconds = periods[pid];
                if (seconds > 0) {
                    boolean needsNewTask = true;
                    for (Task task : tasks) {
                        if (seconds % task.getPeriod() == 0) {
                            task.add(pid);
                            needsNewTask = false;
                            break;
                        }
                    }
                    if (needsNewTask) {
                        Task t = new Task(pid);
                        tasks.add(t);
                    }
                }
            }
            statusTasks = tasks.toArray(new Task[tasks.size()]);
            statusOut = new Status(periods, true);
        }
    }

    @Override
    public void initialize() {
        for (Task task : statusTasks) {
            timer.scheduleAtFixedRate(task, 0L, task.getPeriod()*1000L);
        }
    }

    @Override
    public void shutdown() {
        for (Task task : statusTasks) {
            task.cancel();
        }
    }
    
    
// -- Statistics : -------------------------------------------------------------
    
    protected void report(Runnable run, long taskID, boolean incoming, Bus bus, Order order, long duration, Stage... stages) {
        if (run instanceof Dispatcher.Task) {
            ((Dispatcher.Task)run).stageEnded(stages);
        }
        if (statusAccumulator != null) {
            statusAccumulator.addReport(taskID, incoming, bus, order, duration, stages);
        }
    }
    
    protected void updateStatistics(ArrayList<Integer> ids) {
        if (periods[0] == 0) ids.add(0);
        Status out = statusAccumulator.compute(ids, statusOut);
        statusOut = out;
        notifyListeners(out);
    }

    @Override
    public Status getStatus() {
        return statusOut;
    }    
    
    static protected class Status implements Dispatcher.Status {
        
        private final int[] periods;
        private final Bucket[] buckets;
        private final int cIncoming, cBus, nPeriods;
        
        protected Status(int[] periods, boolean createBuckets) {
            this.periods = periods;
            nPeriods = periods.length;
            cBus = Order.values().length * nPeriods;
            cIncoming = Bus.values().length * cBus;
            buckets = new Bucket[2 * cIncoming];
            if (createBuckets) {
                for (int i=0; i < 2 * cIncoming; i++) {
                    buckets[i] = new Bucket();
                }
            }
        }

        @Override
        public long getTime(boolean incoming, Bus bus, Order order, Stage stage, Stat stat, int periodID) {
            Bin bin = getBin(incoming, bus, order, stage, periodID);
            switch (stat) {
                case MAX:
                    return bin.timeMax;
                case AVERAGE:
                    return bin.timeSum;
                default:
                    return 0L;
            }
        }

        @Override
        public long getCompletedTasks(boolean incoming, Bus bus, Order order, Stage stage, int periodID) {
            return getBin(incoming, bus, order, stage, periodID).n;
        }
        
        protected final int getBucketIndex(boolean incoming, Bus bus, Order order, int periodID) {
            return (incoming ? 0 : 1)*cIncoming + bus.ordinal()*cBus + order.ordinal()*nPeriods + periodID ;
        }
        
        protected final Bucket getBucket(boolean incoming, Bus bus, Order order, int periodID) {
            return buckets[getBucketIndex(incoming, bus, order, periodID)];
        }
        
        protected Bin getBin(boolean incoming, Bus bus, Order order, Stage stage, int periodID) {
            return getBucket(incoming, bus, order, periodID).bins.get(stage);
        }
        
        protected void addReport(long taskID, boolean incoming, Bus bus, Order order, long duration, Stage... stages) {
            int i0 = getBucketIndex(incoming, bus, order, 0);
            synchronized (buckets[i0]) {
                for (int i=i0; i<i0+nPeriods; i++) {
                    Bucket bucket = buckets[i];
                    for (Stage stage : stages) {
                        Bin bin = bucket.bins.get(stage);
                        bin.n++;
                        bin.timeSum += duration;
                        if (duration > bin.timeMax) bin.timeMax = duration;
                    }
                }
            }
        }
        
        protected Status compute(List<Integer> ids, Status previous) {
            Status out = new Status(previous.periods, false);
            for (int i0 = 0; i0 < buckets.length; i0 += nPeriods) {
                synchronized (buckets[i0]) {
                    for (int i : ids) {
                        out.buckets[i0 + i] = buckets[i0 + i].compute();
                    }
                }
            }
            for (int i=0; i < buckets.length; i++) {
                if (out.buckets[i] == null) {
                    out.buckets[i] = previous.buckets[i];
                }
            }
            return out;
        }
        
    }
    
    static protected class Bucket {
        
        protected EnumMap<Stage,Bin> bins;
        
        protected Bucket() {
            init();
        }
        
        protected Bucket(Bucket other) {
            bins = other.bins;
        }
        
        protected void init() {
            bins = new EnumMap<>(Stage.class);
            for (Stage stage : Stage.values()) {
                bins.put(stage, new Bin());
            }            
        }
        
        protected Bucket compute() {
            Bucket out = new Bucket(this);
            for (Bin bin : out.bins.values()) {
                if (bin.n > 0) {
                    bin.timeSum /= bin.n;
                }
            }
            init();
            return out;
        }
    } 
    
    static protected class Bin {
        protected long n;
        protected long timeSum;
        protected long timeMax;
    } 

// -- Handling listeners : -----------------------------------------------------
    
    @Override    
    public void addStatusListener(StatusListener listener) {
        listeners.add(listener);
    }
    
    @Override
    public void removeStatusListener(StatusListener listener) {
        listeners.remove(listener);
    }
    
    private void notifyListeners(Status status) {
        for (StatusListener listener : listeners) {
            try {
                listener.statusChanged(status);
            } catch (Exception x) {
                Logger logger = getLogger();
                if (logger != null) {
                    getLogger().warn("Error notifying " + listener +" of "+ getClass().getSimpleName() +" status change.", x);
                }
            }
        }
    }

// -- Utility methods : --------------------------------------------------------
    
    protected Logger getLogger() {
        return LOGGER;
    }
    
    protected final Integer getIntegerArg(String key, String[] args) {
        for (String arg : args) {
            String[] ss = arg.split("=");
            if (ss.length == 2 && ss[0].equals(key)) {
                return Integer.parseInt(ss[1]);
            }
        }
        return null;
    }
    
    protected final int[] getIntArrayArg(String key, String[] args) {
        for (String arg : args) {
            String[] ss = arg.split("=");
            if (ss.length == 2 && ss[0].equals(key)) {
                ss = ss[1].split(":");
                int[] out = new int[ss.length];
                for (int i=0; i<ss.length; i++) {
                    out[i] = Integer.parseInt(ss[i]);
                }
                return out;
            }
        }
        return null;
    }
    
    
// -- Local classes : ----------------------------------------------------------
    
    private class Task extends TimerTask {
        
        private int[] periodIDs;
        private int[] count;
        
        Task(int periodID) {
            periodIDs = new int[] {periodID};
            count = new int[] {1};
        }

        @Override
        public void run() {
            int n = count.length;
            ArrayList<Integer> id = new ArrayList<>(n);
            for (int i=0; i<n; i++) {
                if (--count[i] == 0) {
                    id.add(periodIDs[i]);
                    count[i] = periods[periodIDs[i]]/getPeriod();
                }
            }
            updateStatistics(id);
        }
        
        int getPeriod() {
            return periods[periodIDs[0]];
        }
        
        void add(int periodID) {
            int n = periodIDs.length;
            periodIDs = Arrays.copyOf(periodIDs, n+1);
            periodIDs[n] = periodID;
            count = Arrays.copyOf(periodIDs, n+1);
            int period = periods[periodID];
            count[n] = period / getPeriod();
        }
        
    }
    
}
