/*
 * Decompiled with CFR 0.152.
 */
package org.renjin.pipeliner.fusion;

import java.lang.reflect.Method;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import org.renjin.eval.EvalException;
import org.renjin.pipeliner.fusion.LoopKernelCache;
import org.renjin.pipeliner.fusion.LoopKernels;
import org.renjin.pipeliner.fusion.kernel.CompiledKernel;
import org.renjin.pipeliner.fusion.kernel.LoopKernel;
import org.renjin.pipeliner.fusion.node.BinaryVectorOpNode;
import org.renjin.pipeliner.fusion.node.DistanceMatrixNode;
import org.renjin.pipeliner.fusion.node.DoubleArrayNode;
import org.renjin.pipeliner.fusion.node.IntArrayNode;
import org.renjin.pipeliner.fusion.node.IntBufferNode;
import org.renjin.pipeliner.fusion.node.IntSeqNode;
import org.renjin.pipeliner.fusion.node.LoopNode;
import org.renjin.pipeliner.fusion.node.RepeatingNode;
import org.renjin.pipeliner.fusion.node.TransposeNode;
import org.renjin.pipeliner.fusion.node.UnaryVectorOpNode;
import org.renjin.pipeliner.fusion.node.VirtualVectorNode;
import org.renjin.pipeliner.node.DeferredNode;
import org.renjin.pipeliner.node.FunctionNode;
import org.renjin.pipeliner.node.NodeShape;
import org.renjin.primitives.sequence.IntSequence;
import org.renjin.primitives.vector.MemoizedComputation;
import org.renjin.repackaged.asm.Type;
import org.renjin.sexp.DoubleArrayVector;
import org.renjin.sexp.IntArrayVector;
import org.renjin.sexp.IntBufferVector;
import org.renjin.sexp.LogicalArrayVector;
import org.renjin.sexp.Vector;

public class FusedNode
extends DeferredNode
implements Runnable {
    private LoopKernel kernel;
    private LoopNode[] kernelOperands;
    private MemoizedComputation memoizedComputation;
    private DoubleArrayVector resultVector;
    private Future<CompiledKernel> compiledKernel;

    public FusedNode(FunctionNode node) {
        this.kernel = LoopKernels.INSTANCE.get(node);
        this.kernelOperands = new LoopNode[node.getOperands().size()];
        this.memoizedComputation = (MemoizedComputation)node.getVector();
        for (int i = 0; i < this.kernelOperands.length; ++i) {
            this.kernelOperands[i] = this.addLoopNode(node.getOperand(i));
        }
    }

    private LoopNode addLoopNode(DeferredNode node) {
        if (node instanceof FusedNode) {
            int inputIndex = this.addInput(node);
            node.addOutput(this);
            return new DoubleArrayNode(inputIndex, Type.getType(DoubleArrayVector.class));
        }
        if (node instanceof FunctionNode) {
            Method binaryOperator;
            Method unaryOperator;
            FunctionNode computation = (FunctionNode)node;
            String name = computation.getComputationName();
            if (name.equals("dist")) {
                return new DistanceMatrixNode(this.addLoopNode(computation.getOperand(0)));
            }
            if (name.equals("rep")) {
                return new RepeatingNode(this.addLoopNode(node.getOperand(0)), this.addLoopNode(node.getOperand(1)));
            }
            if (name.equals("t")) {
                return new TransposeNode(this.addLoopNode(node.getOperand(0)), this.addLoopNode(node.getOperand(1)));
            }
            int arity = node.getOperands().size();
            if (arity == 1 && (unaryOperator = UnaryVectorOpNode.findMethod(node.getVector())) != null) {
                return new UnaryVectorOpNode(name, unaryOperator, this.addLoopNode(node.getOperand(0)));
            }
            if (arity == 2 && (binaryOperator = BinaryVectorOpNode.findMethod(node.getVector())) != null) {
                return new BinaryVectorOpNode(name, binaryOperator, this.addLoopNode(node.getOperand(0)), this.addLoopNode(node.getOperand(1)));
            }
        }
        return this.addLoopInput(node);
    }

    private LoopNode addLoopInput(DeferredNode node) {
        int inputIndex = this.addInput(node);
        node.addOutput(this);
        if (node.getVector() instanceof IntBufferVector) {
            return new IntBufferNode(inputIndex);
        }
        if (node.getVector() instanceof IntSequence) {
            return new IntSeqNode(inputIndex);
        }
        if (node.getVector() instanceof DoubleArrayVector) {
            return new DoubleArrayNode(inputIndex, node.getResultVectorType());
        }
        if (node.getVector() instanceof IntArrayVector) {
            return new IntArrayNode(inputIndex, node.getResultVectorType());
        }
        if (node.getVector() instanceof LogicalArrayVector) {
            return new IntArrayNode(inputIndex, node.getResultVectorType());
        }
        return new VirtualVectorNode(inputIndex, node.getVector());
    }

    @Override
    public String getDebugLabel() {
        return this.kernel.debugLabel(this.kernelOperands);
    }

    @Override
    public NodeShape getShape() {
        return NodeShape.ELLIPSE;
    }

    @Override
    public Type getResultVectorType() {
        return Type.getType(DoubleArrayVector.class);
    }

    public void startCompilation(LoopKernelCache loopKernelCache) {
        this.compiledKernel = loopKernelCache.get(this.kernel, this.kernelOperands);
    }

    @Override
    public void run() {
        CompiledKernel kernel;
        try {
            kernel = this.compiledKernel.get();
        }
        catch (InterruptedException | ExecutionException e) {
            throw new EvalException("Exception compiling kernel", e);
        }
        Vector[] vectorOperands = new Vector[this.getOperands().size()];
        for (int i = 0; i < vectorOperands.length; ++i) {
            vectorOperands[i] = this.getOperand(i).getVector();
        }
        double[] result = kernel.compute(vectorOperands);
        this.resultVector = DoubleArrayVector.unsafe(result, this.memoizedComputation.getAttributes());
        this.memoizedComputation.setResult(this.resultVector);
    }

    @Override
    public DoubleArrayVector getVector() {
        if (this.resultVector == null) {
            throw new IllegalStateException("Not computed yet.");
        }
        return this.resultVector;
    }
}

