/*
 * Decompiled with CFR 0.152.
 */
package org.renjin.pipeliner;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Set;
import org.renjin.pipeliner.fusion.FusedNode;
import org.renjin.pipeliner.fusion.LoopKernelCache;
import org.renjin.pipeliner.fusion.LoopKernels;
import org.renjin.pipeliner.node.CallNode;
import org.renjin.pipeliner.node.DataNode;
import org.renjin.pipeliner.node.DeferredNode;
import org.renjin.pipeliner.node.FunctionNode;
import org.renjin.pipeliner.node.OutputNode;
import org.renjin.pipeliner.optimize.Optimizers;
import org.renjin.primitives.ni.DeferredNativeCall;
import org.renjin.primitives.ni.NativeOutputVector;
import org.renjin.primitives.vector.DeferredComputation;
import org.renjin.repackaged.guava.base.Preconditions;
import org.renjin.repackaged.guava.collect.HashMultimap;
import org.renjin.repackaged.guava.collect.Lists;
import org.renjin.repackaged.guava.collect.Maps;
import org.renjin.repackaged.guava.collect.Multimap;
import org.renjin.repackaged.guava.collect.Sets;
import org.renjin.sexp.Vector;

public class DeferredGraph {
    private List<DeferredNode> rootNodes = new ArrayList<DeferredNode>();
    private List<DeferredNode> nodes = Lists.newArrayList();
    private IdentityHashMap<Vector, DeferredNode> vectorMap = Maps.newIdentityHashMap();
    private IdentityHashMap<DeferredNativeCall, CallNode> callMap = Maps.newIdentityHashMap();
    private Multimap<String, FunctionNode> computationIndex = HashMultimap.create();

    public DeferredGraph(DeferredNativeCall call2) {
        this.addRoot(call2);
    }

    public DeferredGraph(Vector root) {
        this.addRoot(root);
    }

    public DeferredGraph() {
    }

    public void optimize(LoopKernelCache loopKernelCache) {
        Optimizers optimizers = new Optimizers();
        optimizers.optimize(this);
        this.fuse(loopKernelCache);
    }

    public void fuse(LoopKernelCache loopKernelCache) {
        Set<DeferredNode> visited = Sets.newIdentityHashSet();
        ArrayList<DeferredNode> toCheck = new ArrayList<DeferredNode>(this.rootNodes);
        for (DeferredNode rootNode : toCheck) {
            this.fuse(loopKernelCache, visited, rootNode);
        }
    }

    private void fuse(LoopKernelCache loopKernelCache, Set<DeferredNode> visited, DeferredNode node) {
        FusedNode fused;
        if (visited.add(node)) {
            for (DeferredNode operand : node.getOperands()) {
                this.fuse(loopKernelCache, visited, operand);
            }
        }
        if ((fused = this.tryFuse(node)) != null) {
            fused.startCompilation(loopKernelCache);
            this.replaceNode(node, fused);
        }
    }

    private FusedNode tryFuse(DeferredNode root) {
        if (LoopKernels.INSTANCE.supports(root)) {
            return new FusedNode((FunctionNode)root);
        }
        return null;
    }

    void addRoot(Vector root) {
        DeferredNode rootNode = this.addNode(root);
        this.rootNodes.add(rootNode);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private DeferredNode addNode(Vector vector2) {
        DeferredNode node = this.vectorMap.get(vector2);
        if (node != null) {
            return node;
        }
        if (!vector2.isDeferred()) return this.addDataNode(vector2);
        if (vector2 instanceof NativeOutputVector) {
            return this.addOutputNode(vector2);
        }
        if (!(vector2 instanceof DeferredComputation)) throw new UnsupportedOperationException("deferred: " + vector2.getClass().getName());
        return this.addComputeNode((DeferredComputation)vector2);
    }

    private DataNode addDataNode(Vector vector2) {
        DataNode dataNode = new DataNode(vector2);
        this.vectorMap.put(vector2, dataNode);
        this.nodes.add(dataNode);
        return dataNode;
    }

    private DeferredNode addComputeNode(DeferredComputation vector2) {
        Vector[] operands = vector2.getOperands();
        DeferredNode[] children = new DeferredNode[operands.length];
        for (int i = 0; i < operands.length; ++i) {
            children[i] = this.addNode(operands[i]);
        }
        if (this.computationIndex.containsKey(vector2.getComputationName())) {
            for (FunctionNode existingNode : this.computationIndex.get(vector2.getComputationName())) {
                if (!this.equivalent(children, existingNode.getOperands())) continue;
                return existingNode;
            }
        }
        FunctionNode newNode = new FunctionNode(vector2);
        newNode.addInputs(children);
        this.nodes.add(newNode);
        this.vectorMap.put(vector2, newNode);
        this.computationIndex.put(vector2.getComputationName(), newNode);
        return newNode;
    }

    private boolean equivalent(DeferredNode[] a, List<DeferredNode> b) {
        if (a.length != b.size()) {
            return false;
        }
        for (int i = 0; i < a.length; ++i) {
            if (this.equivalent(a[i], b.get(i))) continue;
            return false;
        }
        return true;
    }

    private boolean equivalent(DeferredNode a, DeferredNode b) {
        if (a == b) {
            return true;
        }
        if (a instanceof DataNode) {
            return ((DataNode)a).equivalent(b);
        }
        return false;
    }

    private DeferredNode addOutputNode(Vector vector2) {
        OutputNode node = new OutputNode((NativeOutputVector)vector2);
        this.vectorMap.put(vector2, node);
        this.nodes.add(node);
        this.addCallChild(node, ((NativeOutputVector)vector2).getCall());
        return node;
    }

    private CallNode addNode(DeferredNativeCall call2) {
        CallNode node = this.callMap.get(call2);
        if (node != null) {
            return node;
        }
        node = new CallNode(call2);
        this.nodes.add(node);
        this.callMap.put(call2, node);
        this.addChildren(node, call2.getOperands());
        return node;
    }

    private void addCallChild(DeferredNode parentNode, DeferredNativeCall call2) {
        CallNode callNode = this.addNode(call2);
        parentNode.addInput(callNode);
        callNode.addOutput(parentNode);
    }

    private void addRoot(DeferredNativeCall call2) {
        CallNode rootNode = new CallNode(call2);
        this.rootNodes.add(rootNode);
        this.nodes.add(rootNode);
        this.addChildren(rootNode, call2.getOperands());
    }

    private void addChildren(DeferredNode parent2, Vector[] operands) {
        for (Vector operand : operands) {
            DeferredNode node = this.addNode(operand);
            parent2.addInput(node);
            node.addOutput(parent2);
        }
    }

    public void dumpGraph() {
        try {
            File tempFile = File.createTempFile("deferred", ".dot");
            PrintWriter writer = new PrintWriter(tempFile);
            this.printGraph(writer);
            writer.close();
            System.out.println("Dumping compute graph to " + tempFile.getAbsolutePath());
        }
        catch (IOException iOException) {
            // empty catch block
        }
    }

    public void printGraph(PrintWriter writer) {
        Set<DeferredNode> nodes = Sets.newIdentityHashSet();
        ArrayDeque<DeferredNode> workingList = new ArrayDeque<DeferredNode>(this.rootNodes);
        while (!workingList.isEmpty()) {
            DeferredNode node = workingList.poll();
            if (!nodes.add(node)) continue;
            workingList.addAll(node.getOperands());
        }
        writer.println("digraph G {");
        this.printEdges(writer, nodes);
        this.printNodes(writer, nodes);
        writer.println("}");
        writer.flush();
    }

    private void printEdges(PrintWriter writer, Set<DeferredNode> nodes) {
        for (DeferredNode node : nodes) {
            for (DeferredNode operand : node.getOperands()) {
                writer.println(operand.getDebugId() + " -> " + node.getDebugId());
            }
        }
    }

    private void printNodes(PrintWriter writer, Set<DeferredNode> nodes) {
        for (DeferredNode node : nodes) {
            writer.println(node.getDebugId() + " [ label=\"" + node.getDebugLabel() + "\",  " + "shape=\"" + node.getShape().name().toLowerCase() + "\"]");
        }
    }

    public List<DeferredNode> getRoots() {
        return this.rootNodes;
    }

    public Vector getRootResult(int rootIndex) {
        return this.rootNodes.get(rootIndex).getVector();
    }

    public DeferredNode getRoot() {
        Preconditions.checkState(this.rootNodes.size() == 1);
        return this.rootNodes.get(0);
    }

    public List<DeferredNode> getNodes() {
        return this.nodes;
    }

    public void replaceNode(DeferredNode toReplace, DeferredNode replacementNode) {
        this.nodes.remove(toReplace);
        if (!this.nodes.contains(replacementNode)) {
            this.nodes.add(replacementNode);
        }
        if (this.rootNodes.remove(toReplace)) {
            this.rootNodes.add(replacementNode);
        }
        for (DeferredNode operand : toReplace.getOperands()) {
            operand.removeUse(toReplace);
        }
        for (DeferredNode node : toReplace.getUses()) {
            node.replaceOperand(toReplace, replacementNode);
        }
    }
}

