package org.lsst.ccs.messaging.util;

import java.util.*;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.lsst.ccs.utilities.logging.Logger;

/**
 * Executor that uses {@code String} keys associated with submitted tasks to sequence their execution.
 * Tasks submitted with overlapping key sets are guaranteed to execute sequentially in order of submission.
 * Tasks that do not have common keys may run concurrently, up to the specified maximum number of threads.
 *
 * @author onoprien
 */
public class KeyQueueExecutor {

// -- Fields : -----------------------------------------------------------------
    
    private final String name; 
    private final int maxThreads; // limit on the number of threads
    
    // Guarded by instance monitor (begin)
    
    private int threads = 0;
    private final HashSet<String> blockedKeys = new HashSet<>();
    private final LinkedList<Task> queue = new LinkedList<>();
    private long taskID = 0L;
    
    private final HashMap<String,Integer> keyCounts = new HashMap<>();
    
    // Guarded by instance monitor (end)
    
    private final ThreadPoolExecutor exec;  // caching executor with unlimited number of threads and no queue, used internally
    private Logger logger; // should be set before first task, or left at {@code null} (no logging)

// -- Life cycle : -------------------------------------------------------------
    
    /**
     * Constructs the executor.
     * 
     * @param name Name of this executor, used for naming threads.
     * @param maxThreads Maximum number of task-executing threads that can run concurrently.
     */
    public KeyQueueExecutor(String name, int maxThreads) {
        this.name = name;
        this.maxThreads = maxThreads;
        
        ThreadFactory threadFactory = new ThreadFactory() {
            private final ThreadFactory delegate = Executors.defaultThreadFactory();
            @Override
            public Thread newThread(Runnable r) {
                Thread thread = delegate.newThread(r);
                thread.setDaemon(true);
                thread.setName(name);
                thread.setUncaughtExceptionHandler((t, x) -> {
                    if (logger != null) {
                        logger.error("Exception thrown by KeyQueueExecutor worker thread. This should not happen.", x);
                    }
                });
                return thread;
            }
        };
        exec = new ThreadPoolExecutor(1, Integer.MAX_VALUE, 70L, TimeUnit.SECONDS, new SynchronousQueue<>(), threadFactory);
    }

    /**
     * Sets a logger for messages on abnormal conditions.
     * This method should be called before submitting the first task.
     * If this method is never called, nothing is logged.
     * 
     * @param logger Logger for abnormal conditions warnings.
     */
    synchronized public void setLogger(Logger logger) {
        this.logger = logger;
    }
    
    /**
     * Initiates an orderly shutdown in which previously submitted
     * tasks are executed, but no new tasks will be accepted.
     */
    synchronized public void shutdown() {
        taskID = -1L;
    }
    
    /**
     * Shuts down this executor.
     * Tasks that have not started execution are discarded.
     */
    synchronized public void shutdownNow() {
        taskID = -1L;
        exec.shutdownNow();
        queue.clear();
    }
    
    /**
     * Blocks until all tasks have completed execution after a shutdown request, or the
     * timeout occurs, or the current thread is interrupted, whichever happens first.
     *
     * @param timeout the maximum time to wait
     * @param unit the time unit of the timeout argument
     * @return {@code true} if this executor terminated and
     *         {@code false} if the timeout elapsed before termination
     * @throws InterruptedException if interrupted while waiting
     */
    public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
        return exec.awaitTermination(timeout, unit);
    }
    

// -- Getters : ----------------------------------------------------------------

    /**
     * Returns the name of this executor.
     * @return Executor name.
     */
    public String getName() {
        return name;
    }


// -- Submitting tasks : -------------------------------------------------------
    
    /**
     * Submits a Runnable for execution.
     * This method submits a task and returns without waiting for its completion.
     * Tasks submitted with overlapping key sets are guaranteed to execute sequentially in order of submission.
     * Tasks that do not have common keys may run concurrently, up to the specified maximum number of threads.
     * <p>
     * This class is thread-safe, its methods can be called on any thread. 
     * Visibility effects: all actions of the current thread prior to calling this method "happen
     * before" actions on the thread executing the task. Actions of a task 
     * "happen before" actions of another task submitted later with overlapping key set.
     * 
     * @param command Runnable to execute.
     * @param keys Keys
     * @throws RejectedExecutionException if this task cannot be accepted for execution.
     * @throws NullPointerException if command is null.
     */
    synchronized public void execute(Runnable command, String... keys) {
        if (taskID < 0L || exec.isShutdown()) throw new RejectedExecutionException("KeyQueueExecutor"+ name +" is shut down");
        Task task = new Task(command, keys);
        if (threads == maxThreads || task.isBlocked(keyCounts.keySet())) {
            queue.addLast(task);
        } else {
            threads++;
            blockedKeys.addAll(task.getKeys());
            exec.execute(new MasterTask(task));
        }
        for (String key : keys) {
            keyCounts.compute(key, (k, v) -> v == null ? 1 : v + 1);
        }
    }
    
    /**
     * Called on a worker thread after executing a task.
     * 
     * @param prevTask Task that just finished executing.
     * @return Next task to be executed, or {@code null} if this thread should stop processing tasks.
     */
    synchronized private Task finishTask(Task prevTask) {
        removeTaskKeys(prevTask);
        List<String> prevKeys = prevTask.getKeys();
        Task out = null;
        if (threads == maxThreads) {
            out = fetchFirstNonBlockedTask();
        } else {
            switch (prevKeys.size()) {
                case 0:
                    threads--;
                    break;
                case 1: // FIXME: can be optimized
                    out = fetchFirstNonBlockedTask();
                    break;
                default:
                    int n = maxThreads - threads;
                    HashSet<String> block = new HashSet<>(blockedKeys);
                    List<Task> nextTasks = new ArrayList<>(n);
                    Iterator<Task> it = queue.iterator();
                    while (n > 0 && it.hasNext()) {
                        Task task = it.next();
                        if (!task.isBlocked(block)) {
                            nextTasks.add(task);
                            it.remove();
                            blockedKeys.addAll(task.getKeys());
                            n--;
                        }
                        block.addAll(task.getKeys());
                    }
                    if (nextTasks.isEmpty()) {
                        threads--;
                    } else {
                        for (int i=1; i<nextTasks.size(); i++) {
                            threads++;
                            exec.execute(new MasterTask(nextTasks.get(i)));
                        }
                        out = nextTasks.get(0);
                    }
            }
        }
        if (taskID < 0L && queue.isEmpty()) {
            exec.shutdown();
        }
        return out;
    }
    
    private void removeTaskKeys(Task task) {
        List<String> keys = task.getKeys();
        switch (keys.size()) {
            case 0:
                break;
            case 1:
                String key = keys.get(0);
                blockedKeys.remove(key);
                keyCounts.compute(key, (k, v) -> {
                    int count = v - 1;
                    return count == 0 ? null : count;
                });
                break;
            default:
                blockedKeys.removeAll(keys);
                for (String kk : keys) {
                    keyCounts.compute(kk, (k, v) -> {
                        int count = v - 1;
                        return count == 0 ? null : count;
                    });
                }
        }
    }
    
    private Task fetchFirstNonBlockedTask() {
        HashSet<String> block = new HashSet<>(blockedKeys);
        Iterator<Task> it = queue.iterator();
        while (it.hasNext()) {
            Task task = it.next();
            if (task.isBlocked(block)) {
                block.addAll(task.getKeys());
            } else {
                it.remove();
                blockedKeys.addAll(task.getKeys());
                return task;
            }
        }
        threads--;
        return null;
    }
    
// -- Task class : -------------------------------------------------------------
    
    /**
     * Represents a task submitted to {@code MultiQueueExecutor}.
     * Note: constructor should run while holding {@code KeyQueueExecutor.this} monitor.
     */
    private class Task implements Runnable {
        
        private final Runnable runnable;
        private final String[] keys;
        private final long id;
        
        Task(Runnable runnable, String... key) {
            this.keys = key;
            this.runnable = runnable;
            id = taskID++;
        }
        
        List<String> getKeys() {
            return Arrays.asList(keys);
        }
        
        boolean isBlocked(Set<String> block) {
            switch (keys.length) {
                case 0:
                    return false;
                case 1:
                    return block.contains(keys[0]);
                default:
                    for (String key : keys) {
                        if (block.contains(key)) return true;
                    }
                    return false;
            }
        }

        @Override
        public void run() {
            runnable.run();
        }

        @Override
        public String toString() {
            return "Task "+ id +" keys: "+ Arrays.deepToString(keys);
        }
        
    }
    
    /**
     * Runs on each of the {@code MultiQueueExecutor} threads and executes {@code Task}s.
     * Note: constructor should run while holding {@code KeyQueueExecutor.this} monitor.
     */
    private class MasterTask implements Runnable {
        
        private Task current;
        
        MasterTask(Task firstTask) {
            current = firstTask;
        }

        @Override
        public void run() {
            try {
                while (current != null) {
                    Thread.currentThread().setName(name + "_" + String.join("_", current.getKeys()));
                    try {
                        current.run();
                    } catch (Throwable x) {
                        if (logger != null) {
                            logger.error("Exception thrown by a task submitted to KeyQueueExecutor " + name, x);
                        }
                    } finally {
                        try {
                            Thread.currentThread().setName(name);
                        } catch (Throwable x) {
                        }
                    }
                    current = finishTask(current);
                }
            } catch (Throwable t) {
                if (logger != null) {
                    logger.error("Exception in MasterTask of KeyQueueExecutor " + name, t);
                }
            }
        }
        
    }
    
}
