package org.lsst.ccs.messaging;

import java.io.Serializable;
import java.text.DateFormat;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Predicate;

import org.lsst.ccs.bus.messages.BusMessage;
import org.lsst.ccs.bus.messages.CommandAck;
import org.lsst.ccs.bus.messages.CommandNack;
import org.lsst.ccs.bus.messages.CommandRequest;
import org.lsst.ccs.bus.messages.CommandResult;
import org.lsst.ccs.bus.messages.EmbeddedObjectDeserializationException;
import org.lsst.ccs.bus.messages.StatusMessage;
import org.lsst.ccs.utilities.logging.Logger;

/**
 * Utility class to synchronously or asynchronously invoke or listen for events
 * on the buses.
 *
 * Synchronous command execution will wait a given time before throwing a
 * {@code TimeoutException}. The timeout used when throwing an exception depends
 * on the method invoked, the ConcurrentMessagingUtils configuration and the
 * command invoked.
 * If a method that takes a timeout is invoked, that timeout is used, otherwise
 * a cascade approach is used to determine the timeout among the following:
 *
 * <ul>
 * <li>The default timeout set on the ConcurrentMessagingUtils object, either at
 * construction time or by invoking the method {@link #setDefaultTimeout(Duration) setDefaultTimeout}
 * method.</li>
 * <li>The timeout value that comes from the annotated {@code Command} that is invoked</li>
 * <li>The timeout returned as part of the {@code CommandAck} when the invoked
 * {@code Command} is accepted</li>
 * </ul>
 *
 * Of the above possible timeouts, the last one defined out of the presented order
 * is picked and used.
 *
 * If none of the above timeouts is defined, an Exception will be thrown.
 *
 * @author The LSST CCS Team
 */
public final class ConcurrentMessagingUtils {

    private static final Logger log = Logger.getLogger("org.lsst.ccs.messaging");
    private final AgentMessagingLayer agentMessagingLayer;
    private static final Object NULL = new Object();
    private volatile Duration defaultTimeout;

    /**
     * ConsurrentMessagingUtils constructor.
     *
     * @param agentMessagingLayer The {@code AgentMessagingLayer} used to send messages on the buses.
     */
    public ConcurrentMessagingUtils(AgentMessagingLayer agentMessagingLayer) {
        this(agentMessagingLayer, Duration.ofSeconds(5));
    }

    /**
     * Build a ConsurrentMessagingUtils object with a default timeout.
     *
     * @param agentMessagingLayer The {@code AgentMessagingLayer} used to send messages on the buses.
     * @param defaultTimeout The default timeout used when synchronous commands are invoked.
     */
    public ConcurrentMessagingUtils(AgentMessagingLayer agentMessagingLayer, Duration defaultTimeout) {
        this.agentMessagingLayer = agentMessagingLayer;
        this.defaultTimeout = defaultTimeout;
    }

    /**
     * Set the default timeout for this ConcurrentMessagingUtils object.
     * The default timeout will be used to determine the timeout to be used
     * when a synchronous method without a timeout is invoked.
     *
     * @param defaultTimeout The default timeout.
     */
    public void setDefaultTimeout(Duration defaultTimeout) {
        this.defaultTimeout = defaultTimeout;
    }

    /**
     * Get the default timeout for this ConcurrentMessagingUtils object.
     *
     * @return The default timeout Duration.
     */
    public Duration getDefaultTimeout() {
        return defaultTimeout;
    }

    /**
     * Send synchronously a command on the Buses without a timeout.
     * The timeout used will be the last one defined out of the following:
     * <ul>
     * <li>default timeout</li>
     * <li>timeout defined in the {@code Command} annotation</li>
     * <li>timeout returned as part of the {@code CommandAck}</li>
     * </ul>
     * If none of the above is defined, an exception will be thrown.
     *
     * @param command
     * The CommandRequest object to be sent on the buses.
     * @return The reply of the CommandRequest.
     * @throws Exception
     * If an exception was fired by the remote execution of the
     * command or the timeout expired or no valid timeout could
     * be found.
     *
     */
    public Object sendSynchronousCommand(CommandRequest command) throws Exception {
        return invokeIt(false, command, defaultTimeout, false);
    }

    /**
     * Send a command on the Buses and wait for the reply within the provided
     * Duration timeout. If the command does not yield a reply within
     * the provided timeout a TimeoutException will be thrown.
     *
     * @param command
     * The CommandRequest object to be sent on the buses.
     * @param timeout
     * The Duration timeout. If the reply is not received within
     * the timeout a TimeoutException will be thrown.
     * @return The reply of the CommandRequest.
     * @throws Exception
     * If an exception was fired by the remote execution of the
     * command or the timeout expired.
     *
     */
    public Object sendSynchronousCommand(CommandRequest command, Duration timeout) throws Exception {
        return invokeIt(false, command, timeout, true);
    }

    /**
     * Send a command on the buses and immediately return a Future that will
     * asynchronously listen for the command reply.
     *
     * @param command
     * The CommandRequest object to be sent on the buses.
     * @return A Future on the reply of the command execution. The future will
     * also contain any possible exception thrown during the command
     * execution.
     *
     */
    public Future<Object> sendAsynchronousCommand(CommandRequest command) {
        LinkedCommandOriginator commandOriginator = new LinkedCommandOriginator(
                false, agentMessagingLayer);
        LinkedFuture<Object> linkedFuture = new LinkedFuture(commandOriginator,
                false);
        linkedFuture.init();
        agentMessagingLayer.sendCommandRequest(command, commandOriginator);
        return linkedFuture;
    }

    /**
     * Send a CommandRequest on the buses and synchronously wait for the Ack to
     * come back within the provided Duration timeout.
     *
     * @param command
     * The CommandRequest object to be sent on the buses.
     * @param timeout
     * Duration timeout. If the {@code CommandAck} is not received within the
     * timeout a TimeoutException will be thrown.
     * @return The CommandAck for the CommandRequest.
     * @throws Exception
     * If an exception was fired by the remote execution of the
     * command or the timeout expired.
     *
     */
    public Object getAckForCommand(CommandRequest command, Duration timeout) throws Exception {
        return invokeIt(true, command, timeout, true);
    }

    /**
     * Invokes the command synchronously.
     * 
     * @param ackOnly If {@code true}, waits for ACK and returns it.
     * @param command The command to invoke.
     * @param timeout Time to wait for the result before throwing {@code TimeoutException}.
     * @param isTimeoutUserProvided {@code True} if the specified timeout is not the default timeout.
     * @return The result sent back by the command target.
     * @throws TimeoutException if the result is not received within the specified timeout.
     * @throws Exception If the result received from the command target is an {@code Exception}, it is thrown rather than returned.
     */
    private Object invokeIt(boolean ackOnly, CommandRequest command, Duration timeout, boolean isTimeoutUserProvided) throws Exception {
        
        long timeoutMillis = 0L;
        TimeUnit timeoutUnits = TimeUnit.MILLISECONDS;
       
        if ( timeout == null ) {
            if ( isTimeoutUserProvided ) {
                throw new IllegalArgumentException("Provided timeout cannot be null");
            } else {
                timeoutMillis = -1L;                
            }
        } else {
            timeoutMillis = timeout.toMillis();
        }
        
        LinkedCommandOriginator commandOriginator = new LinkedCommandOriginator(ackOnly, agentMessagingLayer);
        SynchLinkedFuture<Object> linkedFuture = new SynchLinkedFuture(commandOriginator, false, isTimeoutUserProvided);
        linkedFuture.init();
        commandOriginator.addEvent("Sending command "+ command.getBasicCommand().getCommand());
        agentMessagingLayer.sendCommandRequest(command, commandOriginator);
        commandOriginator.addEvent("Command sent.");

// Timeout diagnostics: (commented out)
//        try {
            Object res = linkedFuture.get(timeoutMillis, timeoutUnits);
            log.debug(commandOriginator.getTrace());            
            if (res instanceof Exception) {
                throw (Exception) res;
            }
            return res;
//        } catch (TimeoutException x) {
//            commandOriginator.addEvent("Timing out.");
//            throw (TimeoutException) new TimeoutException(x.getMessage() +"\n"+ commandOriginator.getTrace()).initCause(x);
//        }
    }

    /**
     * Get a Future on a StatusBusMessage. The content of the Future is filled
     * when the first StatusBusMessage that satisfies the
     * ScriptingStatusBusMessageFilter is received. When the Future is exercised
     * it will return the first ScriptingStatusBusMessage or an
     * ExecutionException will be thrown when the timeout is reached.
     *
     * @param filter
     * ScriptingStatusBusMessageFilter The message filter
     * @param timeout
     * Duration timeout, after which a
     * ScriptiongTimeoutException is thrown. This timeout is from the
     * time the method is invoked. If this timeout is reached an
     * exception will be thrown if/when the Future is exercised.
     * @return A Future on a ScriptingStatusBusMessage.
     */
    public Future<StatusMessage> startListeningForStatusBusMessage(
            Predicate<BusMessage<? extends Serializable, ?>> filter, Duration timeout) {

        LinkedStatusBusListener innerListener = new LinkedStatusBusListener(
                filter, timeout.toMillis(), this.agentMessagingLayer);
        LinkedFuture future = new LinkedFuture<>(innerListener, true);
        future.init();
        this.agentMessagingLayer
                .addStatusMessageListener(innerListener, filter);

        return future;
    }

    /**
     * Get a Future on a StatusBusMessage. The content of the Future is filled
     * when the first StatusBusMessage that satisfies the
     * ScriptingStatusBusMessageFilter is received. When the Future is exercised
     * it will return the first ScriptingStatusBusMessage.
     *
     * @param f
     * ScriptingStatusBusMessageFilter The message filter
     * @return A Future on a ScriptingStatusBusMessage.
     */
    public Future<StatusMessage> startListeningForStatusBusMessage(Predicate<BusMessage<? extends Serializable, ?>> f) {
        return startListeningForStatusBusMessage(f, Duration.ofMillis(-1));
    }

    class LinkedStatusBusListener extends LinkedTask<StatusMessage> implements StatusMessageListener {

        private final Predicate<BusMessage<? extends Serializable, ?>> filter;
        private final Timer timeoutTimer = new Timer("LinkedStatusBusListener");
        private boolean cleanedUp = false;
        private final AgentMessagingLayer agentMessagingLayer;
        private final long timeout;

        LinkedStatusBusListener(Predicate<BusMessage<? extends Serializable, ?>> filter, long timeout, AgentMessagingLayer agentMessagingLayer) {
            this.filter = filter;
            this.agentMessagingLayer = agentMessagingLayer;
            this.timeout = timeout;
        }

        @Override
        public void start() {
            if (timeout > 0) {
                timeoutTimer.schedule(new TimerTask() {
                    @Override
                    public void run() {
                        cancel();
                        TimeoutException ex = new TimeoutException(
                                "Timeout listening for filtered events "
                                + filter.toString());
                        getLinkedFuture().addToQueue(ex);
                    }
                }, timeout);
            }

        }

        @Override
        public Duration getTaskInternalTimeout(long tout) {
            return null;
        }

        @Override
        public void stop() {
            cancel();
        }

        @Override
        public void cancel() {
            if (!cleanedUp) {
                agentMessagingLayer.removeStatusMessageListener(this);
                cleanedUp = true;
            }
        }

        @Override
        public void onStatusMessage(StatusMessage bm) {
            if (!getLinkedFuture().isDone()) {
                timeoutTimer.cancel();
                getLinkedFuture().addToQueue(bm);
            }
        }

    }

    private class LinkedCommandOriginator extends LinkedTask<Object> implements CommandOriginator {

        private final boolean getAckOnly;
        private Duration timeout;
        private boolean gotAck = false;
        private final Object ackLock = new Object();

        LinkedCommandOriginator(boolean ackOnly, AgentMessagingLayer agentMessagingLayer) {
            getAckOnly = ackOnly;
            addEvent("Constructed command originator.");
        }

        @Override
        public void cancel() {
        }

        @Override
        public void start() {
        }

        @Override
        public void stop() {
        }

        @Override
        public Duration getTaskInternalTimeout(long tout) {
            addEvent("Asking for custom timeout, wait for "+ tout +" ms.");
            long deadline = System.currentTimeMillis() + tout;
            try {
                synchronized (ackLock) {
                    while (!gotAck && tout > 0L) {
                        ackLock.wait(tout);
                        tout = deadline - System.currentTimeMillis();
                    }
                }
                addEvent("Obtained custom timeout: "+ (timeout == null ? "none" : (timeout.getSeconds() +" seconds.")));
                return timeout;
            } catch (InterruptedException ie) {
                addEvent("Wait for custom timeout is interrupted");
                throw new RuntimeException("Interrupted while waiting for ACK ", ie);
            }
        }

        @Override
        public void processNack(CommandNack nack) {
            addEvent("Received NACK.");
            CommandRejectedException rejection = new CommandRejectedException(nack);
            getLinkedFuture().addToQueue(rejection);
            synchronized (ackLock) {
                gotAck = true;
                ackLock.notifyAll();
            }
       }

        @Override
        public void processResult(CommandResult result) {
            addEvent("Received result: "+ result.getEncodedData());
            if (getAckOnly) {
                return;
            }
            Object resultContent;
            //Try to collect the actual result : if that's not possible, return
            // the deserialization exception
            try {
                resultContent = result.getResult();
            } catch (EmbeddedObjectDeserializationException e) {
                String message = e.getMessage() +"\nORIGINAL ENCODED DATA:\n"+result.getEncodedData();
                resultContent = new EmbeddedObjectDeserializationException(message, e);
            }
            getLinkedFuture().addToQueue(resultContent);
        }

        @Override
        public void processAck(CommandAck ack) {
            addEvent("Received ACK.");
            synchronized(ackLock) {
                timeout = ack.getTimeout();
                gotAck = true;
                ackLock.notifyAll();
            }
            if (getAckOnly) {
                getLinkedFuture().addToQueue(ack);
            }
        }
        
        // Diagnostics:
        
        private final ArrayList<Event> trace = new ArrayList<>();
        
        final void addEvent(Object event) {
            synchronized(trace) {
                trace.add(new Event(event));
            }
        }
        
        private class Event {
            final long time;
            final Object content;
            Event(Object content) {
                time = System.currentTimeMillis();
                this.content = content;
            }
            public long getTime() {
                return time;
            }
            public Object getContent() {
                return content;
            }
        }
        
        String getTrace() {
            DateFormat form = DateFormat.getTimeInstance(DateFormat.MEDIUM);
            StringBuilder sb = new StringBuilder("Command trace:\n");
            long prev = 0L;
            synchronized(trace) {
                for (Event e : trace) {
                    sb.append(form.format(new Date(e.time))).append(" ").append(e.content);
                    if (prev > 0L) {
                        sb.append(" ( + ").append(e.time - prev).append(" ms)");
                    }
                    prev = e.time;
                    sb.append("\n");
                }
            }
            return sb.toString();
        }
    
    }

    abstract class LinkedTask<T> {

        LinkedFuture<T> future = null;

        public abstract void cancel();

        public abstract void start();

        public abstract void stop();

        public abstract Duration getTaskInternalTimeout(long timeout);

        void setLinkedFuture(LinkedFuture<T> future) {
            this.future = future;
            start();
        }

        LinkedFuture<T> getLinkedFuture() {
            return future;
        }

    }

    class LinkedFuture<T extends Object> implements Future<T> {

        private final LinkedTransferQueue<Object> queue = new LinkedTransferQueue<>();
        protected final LinkedTask<T> task;
        private boolean isCancelled = false;
        private final boolean throwException;
        
        private boolean initialized = false;
        private final Object initLock = new Object();

        LinkedFuture(LinkedTask<T> task, boolean throwException) {
            this.task = task;
            this.throwException = throwException;
        }

        protected void init() {
            synchronized (initLock) {
                if (initialized) {
                    throw new RuntimeException("LinkedFuture must be initialized only once");
                }
                initialized = true;
            }
            task.setLinkedFuture(this);
        }

        @Override
        public boolean isCancelled() {
            synchronized (initLock) {
                if ( ! initialized ) throw new RuntimeException("LinkedFuture must be initialized first");
            }
            return isCancelled;
        }

        @Override
        public boolean isDone() {
            synchronized (initLock) {
                if ( ! initialized ) throw new RuntimeException("LinkedFuture must be initialized first");
            }
            return !queue.isEmpty();
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            synchronized (initLock) {
                if ( ! initialized ) throw new RuntimeException("LinkedFuture must be initialized first");
                if (isCancelled) {
                    return false;
                } else {
                    isCancelled = true;
                }
            }
            task.cancel();
            return true;
        }

        @Override
        public T get() throws InterruptedException, ExecutionException {
            synchronized (initLock) {
                if ( ! initialized ) throw new RuntimeException("LinkedFuture must be initialized first");
            }
            return processReply(queue.take());
        }

        @Override
        public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
            synchronized (initLock) {
                if ( ! initialized ) throw new RuntimeException("LinkedFuture must be initialized first");
            }
            long time = timeout > 0L ? timeout : 1L;
            Object reply = queue.poll(time, unit);
            if (reply == null) {
                throw new TimeoutException("Could not get reply within the specified timeout of " + timeout + " " + unit.toString().toLowerCase());
            }
            return processReply(reply);

        }

        private T processReply(Object reply) throws InterruptedException, ExecutionException {
            if (reply instanceof Exception && throwException) {
                throw new ExecutionException("Execution Exception", (Exception) reply);
            }
            return reply != NULL ? (T) reply : null;
        }

        void addToQueue(Object obj) {
            synchronized (initLock) {
                if ( ! initialized ) throw new RuntimeException("LinkedFuture must be initialized first");
            }
            if (obj == null) {
                obj = NULL;
            }
            queue.offer(obj);
            task.stop();
        }

    }

    class SynchLinkedFuture<T extends Object> extends LinkedFuture {

        private final boolean isTimeoutUserProvided;
        
        SynchLinkedFuture(LinkedTask<T> task, boolean throwException, boolean isTimeoutUserProvided) {
            super(task, throwException);
            this.isTimeoutUserProvided = isTimeoutUserProvided;
        }

        @Override
        public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
            long pollTimeout;
            if (isTimeoutUserProvided) {
                pollTimeout = timeout;
            } else {
                long start = System.currentTimeMillis();
                Duration internalTaskDuration = task.getTaskInternalTimeout(unit.toMillis(timeout));
                if ( internalTaskDuration != null ) {
                    timeout = internalTaskDuration.toMillis();
                    unit = TimeUnit.MILLISECONDS;
                }
                pollTimeout = timeout - (System.currentTimeMillis() - start);
            }
            try {
                return (T) super.get(pollTimeout, unit);
            } catch (TimeoutException x) {
                throw (TimeoutException) new TimeoutException("Timed out after "+ TimeUnit.MILLISECONDS.toSeconds(timeout) +" seconds.").initCause(x);
            }
        }
    }

}
