/*
 * Decompiled with CFR 0.152.
 */
package ghidra.feature.vt.api.correlator.program;

import generic.DominantPair;
import generic.lsh.vector.LSHCosineVectorAccum;
import generic.lsh.vector.LSHVector;
import generic.lsh.vector.VectorCompare;
import ghidra.feature.vt.api.main.VTAssociation;
import ghidra.feature.vt.api.main.VTAssociationStatus;
import ghidra.feature.vt.api.main.VTAssociationType;
import ghidra.feature.vt.api.main.VTMatch;
import ghidra.feature.vt.api.main.VTMatchInfo;
import ghidra.feature.vt.api.main.VTMatchSet;
import ghidra.feature.vt.api.main.VTScore;
import ghidra.feature.vt.api.main.VTSession;
import ghidra.feature.vt.api.util.VTAbstractProgramCorrelator;
import ghidra.framework.options.ToolOptions;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressSetView;
import ghidra.program.model.listing.CodeUnit;
import ghidra.program.model.listing.CodeUnitIterator;
import ghidra.program.model.listing.Data;
import ghidra.program.model.listing.Function;
import ghidra.program.model.listing.FunctionManager;
import ghidra.program.model.listing.Instruction;
import ghidra.program.model.listing.Listing;
import ghidra.program.model.listing.Program;
import ghidra.program.model.symbol.Reference;
import ghidra.program.model.symbol.ReferenceIterator;
import ghidra.program.model.symbol.ReferenceManager;
import ghidra.util.datastruct.Counter;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.collections4.map.LazyMap;

public abstract class VTAbstractReferenceProgramCorrelator
extends VTAbstractProgramCorrelator {
    private static final int MAX_DEPTH = 30;
    private static final int TOP_N = 5;
    private static final double DIFFERENTIAL = 0.2;
    private static final double EQUALS_EPSILON = 1.0E-5;
    private static final Comparator<VTMatchInfo> SCORE_COMPARATOR = (o1, o2) -> o2.getSimilarityScore().compareTo(o1.getSimilarityScore());
    private String correlatorName;
    private Map<Address, LSHCosineVectorAccum> srcVectorsByAddress;
    private Map<Address, LSHCosineVectorAccum> destVectorsByAddress;
    private Program sourceProgram;
    private Program destinationProgram;
    private Listing sourceListing;
    private Listing destinationListing;

    public VTAbstractReferenceProgramCorrelator(Program sourceProgram, AddressSetView sourceAddressSet, Program destinationProgram, AddressSetView destinationAddressSet, String correlatorName, ToolOptions options) {
        super(sourceProgram, sourceAddressSet, destinationProgram, destinationAddressSet, options);
        this.correlatorName = correlatorName;
        this.sourceProgram = sourceProgram;
        this.destinationProgram = destinationProgram;
        this.sourceListing = sourceProgram.getListing();
        this.destinationListing = destinationProgram.getListing();
    }

    @Override
    public String getName() {
        return this.correlatorName;
    }

    @Override
    protected void doCorrelate(VTMatchSet matchSet, TaskMonitor monitor) throws CancelledException {
        monitor.setMessage("Finding reference features");
        this.extractReferenceFeatures(matchSet, monitor);
        monitor.setMessage("Finding destination functions");
        this.findDestinations(matchSet, monitor);
    }

    private void findDestinations(VTMatchSet matchSet, TaskMonitor monitor) throws CancelledException {
        monitor.initialize((long)this.destVectorsByAddress.size());
        Set<Map.Entry<Address, LSHCosineVectorAccum>> destEntries = this.destVectorsByAddress.entrySet();
        for (Map.Entry<Address, LSHCosineVectorAccum> destEntry : destEntries) {
            monitor.checkCancelled();
            monitor.incrementProgress(1L);
            Function destFunc = this.destinationListing.getFunctionAt(destEntry.getKey());
            LSHCosineVectorAccum dstVector = destEntry.getValue();
            HashMap<Address, DominantPair<Double, VectorCompare>> srcNeighbors = new HashMap<Address, DominantPair<Double, VectorCompare>>();
            Set<Map.Entry<Address, LSHCosineVectorAccum>> srcEntries = this.srcVectorsByAddress.entrySet();
            for (Map.Entry<Address, LSHCosineVectorAccum> srcEntry : srcEntries) {
                Address srcAddr = srcEntry.getKey();
                LSHCosineVectorAccum srcVector = srcEntry.getValue();
                VectorCompare vectorCompare = new VectorCompare();
                double similarity = dstVector.compare((LSHVector)srcVector, vectorCompare);
                DominantPair compareOut = new DominantPair((Object)similarity, (Object)vectorCompare);
                if (!(dstVector.compare((LSHVector)srcVector, vectorCompare) > 0.0)) continue;
                srcNeighbors.put(srcAddr, (DominantPair<Double, VectorCompare>)compareOut);
            }
            List<VTMatchInfo> members = this.transform(matchSet, destFunc, dstVector, srcNeighbors, monitor);
            for (VTMatchInfo member : members) {
                if (member == null) continue;
                matchSet.addMatch(member);
            }
        }
    }

    private List<VTMatchInfo> transform(VTMatchSet matchSet, Function destinationFunction, LSHCosineVectorAccum destinationVector, Map<Address, DominantPair<Double, VectorCompare>> neighbors, TaskMonitor monitor) throws CancelledException {
        boolean refineResult = this.getOptions().getBoolean("Refine Results", true);
        double confidenceThreshold = this.getOptions().getDouble("Confidence threshold (info content)", 1.0);
        double similarityThreshold = this.getOptions().getDouble("Minimum similarity threshold (score)", 0.5);
        Address destinationAddress = destinationFunction.getEntryPoint();
        int destinationLength = (int)destinationFunction.getBody().getNumAddresses();
        List<VTMatchInfo> result = new ArrayList<VTMatchInfo>();
        for (Map.Entry<Address, DominantPair<Double, VectorCompare>> neighbor : neighbors.entrySet()) {
            monitor.checkCancelled();
            Address sourceAddr = neighbor.getKey();
            double similarity = (Double)neighbor.getValue().first;
            VectorCompare veccompare = (VectorCompare)neighbor.getValue().second;
            veccompare.fillOut();
            double confidence = veccompare.dotproduct;
            if (similarity < similarityThreshold || Double.isNaN(similarity) || confidence < confidenceThreshold) continue;
            confidence *= 10.0;
            VTMatchInfo match = new VTMatchInfo(matchSet);
            Function sourceFunction = this.sourceListing.getFunctionAt(sourceAddr);
            Address sourceAddress = sourceFunction.getEntryPoint();
            int sourceLength = (int)sourceFunction.getBody().getNumAddresses();
            match.setSimilarityScore(new VTScore(similarity));
            match.setConfidenceScore(new VTScore(confidence));
            match.setSourceLength(sourceLength);
            match.setDestinationLength(destinationLength);
            match.setSourceAddress(sourceAddress);
            match.setDestinationAddress(destinationAddress);
            match.setTag(null);
            match.setAssociationType(VTAssociationType.FUNCTION);
            result.add(match);
        }
        if (refineResult) {
            result = this.refine(result);
        }
        return result;
    }

    private List<VTMatchInfo> refine(List<VTMatchInfo> list) {
        int i;
        int cutoffIndex;
        Collections.sort(list, SCORE_COMPARATOR);
        int topN = Math.min(6, list.size());
        list = list.subList(0, topN);
        if (list.size() > 1) {
            double previousScore = list.get(0).getSimilarityScore().getScore();
            cutoffIndex = 1;
            for (i = 1; i < list.size(); ++i) {
                double currentScore = list.get(i).getSimilarityScore().getScore();
                if (currentScore > previousScore - 1.0E-5) {
                    --cutoffIndex;
                    break;
                }
                ++cutoffIndex;
                previousScore = currentScore;
            }
            list = list.subList(0, cutoffIndex);
        }
        if ((list = list.subList(0, topN = Math.min(5, list.size()))).size() > 1) {
            double bestScore = list.get(0).getSimilarityScore().getScore();
            cutoffIndex = list.size();
            for (i = 1; i < list.size(); ++i) {
                if (!(list.get(i).getSimilarityScore().getScore() < bestScore - 0.2)) continue;
                cutoffIndex = i;
                break;
            }
            list = list.subList(0, cutoffIndex);
        }
        return list;
    }

    private void accumulateFunctionReferences(int depth, Set<Function> list, Program program, Address address) {
        Address[] thunkAddresses;
        if (depth >= 30) {
            return;
        }
        FunctionManager functionManager = program.getFunctionManager();
        Function addressFunction = functionManager.getFunctionAt(address);
        if (addressFunction != null && (thunkAddresses = addressFunction.getFunctionThunkAddresses()) != null) {
            for (Address thunkAddress : thunkAddresses) {
                this.accumulateFunctionReferences(depth + 1, list, program, thunkAddress);
            }
        }
        if (address.isStackAddress() || address.isRegisterAddress()) {
            return;
        }
        ReferenceManager refManager = program.getReferenceManager();
        Listing listing = program.getListing();
        ReferenceIterator it = refManager.getReferencesTo(address);
        while (it.hasNext()) {
            Reference reference = it.next();
            Address fromAddress = reference.getFromAddress();
            CodeUnit codeUnit = listing.getCodeUnitAt(fromAddress);
            if (codeUnit instanceof Instruction) {
                Function function = functionManager.getFunctionContaining(fromAddress);
                if (function == null) continue;
                if (function.isThunk()) {
                    Address entryPoint = function.getEntryPoint();
                    this.accumulateFunctionReferences(depth + 1, list, program, entryPoint);
                    continue;
                }
                list.add(function);
                continue;
            }
            if (!(codeUnit instanceof Data)) continue;
            this.accumulateFunctionReferences(depth + 1, list, program, fromAddress);
        }
    }

    protected abstract boolean isExpectedRefType(VTAssociationType var1);

    protected abstract boolean isExpectedRefType(Reference var1);

    private void extractReferenceFeatures(VTMatchSet matchSet, TaskMonitor monitor) throws CancelledException {
        this.srcVectorsByAddress = LazyMap.lazyMap(new HashMap(), addr -> new LSHCosineVectorAccum());
        this.destVectorsByAddress = LazyMap.lazyMap(new HashMap(), addr -> new LSHCosineVectorAccum());
        FunctionManager srcFuncManager = this.sourceProgram.getFunctionManager();
        FunctionManager destFuncManager = this.destinationProgram.getFunctionManager();
        int srcFunctionCount = srcFuncManager.getFunctionCount();
        int destFunctionCount = destFuncManager.getFunctionCount();
        Counter totalMatches = new Counter();
        Collection<VTMatchSet> matchSets = this.getMatchSets(matchSet.getSession(), totalMatches);
        monitor.initialize((long)totalMatches.count());
        HashMap<VTMatch, Set<Function>> sourceRefMap = new HashMap<VTMatch, Set<Function>>();
        HashMap<VTMatch, Set<Function>> destinationRefMap = new HashMap<VTMatch, Set<Function>>();
        for (VTMatchSet ms : matchSets) {
            Collection<VTMatch> matches = ms.getMatches();
            for (VTMatch match : matches) {
                monitor.checkCancelled();
                monitor.incrementProgress(1L);
                this.accumulateMatchFunctionReferences(sourceRefMap, destinationRefMap, match);
            }
        }
        monitor.setMessage("Adding ACCEPTED matches to feature vectors.");
        int featureID = 1;
        for (VTMatch match : sourceRefMap.keySet()) {
            LSHCosineVectorAccum vector;
            monitor.checkCancelled();
            monitor.incrementProgress(1L);
            if (((Set)sourceRefMap.get(match)).isEmpty()) continue;
            HashSet srcRefFuncs = new HashSet((Collection)sourceRefMap.get(match));
            HashSet destRefFuncs = new HashSet((Collection)destinationRefMap.get(match));
            double altPraw = (double)(srcRefFuncs.size() + destRefFuncs.size()) / (double)(srcFunctionCount + destFunctionCount);
            double weight = Math.sqrt(-Math.log(altPraw));
            for (Function function : (Set)sourceRefMap.get(match)) {
                vector = this.srcVectorsByAddress.get(function.getEntryPoint());
                vector.addHash(featureID, weight);
            }
            for (Function function : (Set)destinationRefMap.get(match)) {
                vector = this.destVectorsByAddress.get(function.getEntryPoint());
                vector.addHash(featureID, weight);
            }
            ++featureID;
        }
        this.updateSourceAndDestinationVectors(featureID, srcFuncManager, destFuncManager, monitor);
    }

    private Collection<VTMatchSet> getMatchSets(VTSession session, Counter totalMatches) {
        HashMap<String, VTMatchSet> dedupedMatchSets = new HashMap<String, VTMatchSet>();
        for (VTMatchSet ms : session.getMatchSets()) {
            String name = ms.getProgramCorrelatorInfo().getName();
            if (name.equals(this.correlatorName) || dedupedMatchSets.containsKey(name) && ms.getID() < ((VTMatchSet)dedupedMatchSets.get(name)).getID()) continue;
            dedupedMatchSets.put(name, ms);
            totalMatches.add(ms.getMatchCount());
        }
        return dedupedMatchSets.values();
    }

    private void accumulateMatchFunctionReferences(Map<VTMatch, Set<Function>> sourceRefMap, Map<VTMatch, Set<Function>> destinationRefMap, VTMatch match) {
        VTAssociation association = match.getAssociation();
        Address sourceAddress = association.getSourceAddress();
        Address destinationAddress = association.getDestinationAddress();
        if (!this.isExpectedRefType(association.getType())) {
            return;
        }
        if (association.getStatus() != VTAssociationStatus.ACCEPTED) {
            return;
        }
        HashSet<Function> sourceReferences = new HashSet<Function>();
        this.accumulateFunctionReferences(0, sourceReferences, this.sourceProgram, sourceAddress);
        if (sourceReferences.isEmpty()) {
            return;
        }
        HashSet<Function> destinationReferences = new HashSet<Function>();
        this.accumulateFunctionReferences(0, destinationReferences, this.destinationProgram, destinationAddress);
        if (destinationReferences.isEmpty()) {
            return;
        }
        sourceRefMap.put(match, sourceReferences);
        destinationRefMap.put(match, destinationReferences);
    }

    private void updateSourceAndDestinationVectors(int featureID, FunctionManager srcFuncManager, FunctionManager destFuncManager, TaskMonitor monitor) {
        int i;
        int numEntries;
        int totalRefs;
        monitor.setMessage("Adding unmatched references to feature vectors.");
        double pSwitch = 0.5;
        double uniqueWeight = Math.sqrt(-Math.log(pSwitch));
        for (Address addr : this.srcVectorsByAddress.keySet()) {
            totalRefs = this.countFunctionRefs(this.sourceProgram, addr);
            LSHCosineVectorAccum srcVector = this.srcVectorsByAddress.get(addr);
            numEntries = srcVector.numEntries();
            for (i = 0; i < totalRefs - numEntries; ++i) {
                srcVector.addHash(featureID, uniqueWeight);
                ++featureID;
            }
        }
        for (Address addr : this.destVectorsByAddress.keySet()) {
            totalRefs = this.countFunctionRefs(this.destinationProgram, addr);
            LSHCosineVectorAccum dstVector = this.destVectorsByAddress.get(addr);
            numEntries = dstVector.numEntries();
            for (i = 0; i < totalRefs - numEntries; ++i) {
                dstVector.addHash(featureID, uniqueWeight);
                ++featureID;
            }
        }
    }

    private int countFunctionRefs(Program program, Address addr) {
        Function f = program.getFunctionManager().getFunctionAt(addr);
        CodeUnitIterator it = program.getListing().getCodeUnits(f.getBody(), true);
        int totalRefs = 0;
        while (it.hasNext()) {
            Reference[] memRefs;
            CodeUnit cu = it.next();
            for (Reference memRef : memRefs = cu.getReferencesFrom()) {
                if (!this.isExpectedRefType(memRef)) continue;
                ++totalRefs;
            }
        }
        return totalRefs;
    }
}

