/*
 * Decompiled with CFR 0.152.
 */
package jadx.core.dex.visitors;

import jadx.core.clsp.ClspClass;
import jadx.core.clsp.ClspMethod;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.MethodBridgeAttr;
import jadx.core.dex.attributes.nodes.MethodOverrideAttr;
import jadx.core.dex.attributes.nodes.RenameReasonAttr;
import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.IMethodDetails;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.dex.visitors.JadxVisitor;
import jadx.core.dex.visitors.rename.RenameVisitor;
import jadx.core.dex.visitors.typeinference.TypeCompare;
import jadx.core.dex.visitors.typeinference.TypeCompareEnum;
import jadx.core.dex.visitors.typeinference.TypeInferenceVisitor;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxException;
import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

@JadxVisitor(name="OverrideMethodVisitor", desc="Mark override methods and revert type erasure", runBefore={TypeInferenceVisitor.class, RenameVisitor.class})
public class OverrideMethodVisitor
extends AbstractVisitor {
    @Override
    public boolean visit(ClassNode cls) throws JadxException {
        SuperTypesData superData = this.collectSuperTypes(cls);
        if (superData != null) {
            for (MethodNode mth : cls.getMethods()) {
                this.processMth(mth, superData);
            }
        }
        return true;
    }

    private void processMth(MethodNode mth, SuperTypesData superData) {
        if (mth.isConstructor() || mth.getAccessFlags().isStatic() || mth.getAccessFlags().isPrivate()) {
            return;
        }
        MethodOverrideAttr attr = this.processOverrideMethods(mth, superData);
        if (attr != null) {
            if (attr.getBaseMethods().isEmpty()) {
                throw new JadxRuntimeException("No base methods for override attribute: " + String.valueOf(attr.getOverrideList()));
            }
            mth.addAttr(attr);
            IMethodDetails baseMth = Utils.getOne(attr.getBaseMethods());
            if (baseMth != null) {
                boolean updated = this.fixMethodReturnType(mth, baseMth, superData);
                if (updated |= this.fixMethodArgTypes(mth, baseMth, superData)) {
                    this.checkMethodSignatureCollisions(mth, mth.root().getArgs().isRenameValid());
                }
            }
        }
    }

    private MethodOverrideAttr processOverrideMethods(MethodNode mth, SuperTypesData superData) {
        MethodOverrideAttr result = mth.get(AType.METHOD_OVERRIDE);
        if (result != null) {
            return result;
        }
        ClassNode cls = mth.getParentClass();
        String signature = mth.getMethodInfo().makeSignature(false);
        ArrayList<IMethodDetails> overrideList = new ArrayList<IMethodDetails>();
        HashSet<IMethodDetails> baseMethods = new HashSet<IMethodDetails>();
        for (ArgType superType : superData.getSuperTypes()) {
            ClassNode classNode = mth.root().resolveClass(superType);
            if (classNode != null) {
                MethodNode ovrdMth = this.searchOverriddenMethod(classNode, mth, signature);
                if (ovrdMth != null && this.isMethodVisibleInCls(ovrdMth, cls)) {
                    overrideList.add(ovrdMth);
                    MethodOverrideAttr attr = ovrdMth.get(AType.METHOD_OVERRIDE);
                    if (attr != null) {
                        this.addBaseMethod(superData, overrideList, baseMethods, superType);
                        return this.buildOverrideAttr(mth, overrideList, baseMethods, attr);
                    }
                }
            } else {
                ClspClass clsDetails = mth.root().getClsp().getClsDetails(superType);
                if (clsDetails != null) {
                    Map<String, ClspMethod> methodsMap = clsDetails.getMethodsMap();
                    for (Map.Entry<String, ClspMethod> entry : methodsMap.entrySet()) {
                        String mthShortId = entry.getKey();
                        if (!mthShortId.startsWith(signature)) continue;
                        overrideList.add(entry.getValue());
                        break;
                    }
                }
            }
            this.addBaseMethod(superData, overrideList, baseMethods, superType);
        }
        return this.buildOverrideAttr(mth, overrideList, baseMethods, null);
    }

    private void addBaseMethod(SuperTypesData superData, List<IMethodDetails> overrideList, Set<IMethodDetails> baseMethods, ArgType superType) {
        IMethodDetails last;
        if (superData.getEndTypes().contains(superType.getObject()) && (last = Utils.last(overrideList)) != null) {
            baseMethods.add(last);
        }
    }

    @Nullable
    private MethodNode searchOverriddenMethod(ClassNode cls, MethodNode mth, String signature) {
        String shortId = mth.getMethodInfo().getShortId();
        for (MethodNode supMth : cls.getMethods()) {
            if (!supMth.getMethodInfo().getShortId().equals(shortId) || supMth.getAccessFlags().isStatic()) continue;
            return supMth;
        }
        for (MethodNode supMth : cls.getMethods()) {
            ArgType mthRetType;
            ArgType supRetType;
            if (!supMth.getMethodInfo().getShortId().startsWith(signature) || supMth.getAccessFlags().isStatic()) continue;
            TypeCompare typeCompare = cls.root().getTypeCompare();
            TypeCompareEnum res = typeCompare.compareTypes(supRetType = supMth.getMethodInfo().getReturnType(), mthRetType = mth.getMethodInfo().getReturnType());
            if (res.isWider()) {
                return supMth;
            }
            if (res != TypeCompareEnum.UNKNOWN && res != TypeCompareEnum.CONFLICT) continue;
            mth.addDebugComment("Possible override for method " + supMth.getMethodInfo().getFullId());
        }
        return null;
    }

    @Nullable
    private MethodOverrideAttr buildOverrideAttr(MethodNode mth, List<IMethodDetails> overrideList, Set<IMethodDetails> baseMethods, @Nullable MethodOverrideAttr attr) {
        if (overrideList.isEmpty() && attr == null) {
            return null;
        }
        if (attr == null) {
            List<IMethodDetails> cleanOverrideList = overrideList.stream().distinct().collect(Collectors.toList());
            return this.applyOverrideAttr(mth, cleanOverrideList, baseMethods, false);
        }
        List<IMethodDetails> mergedOverrideList = Utils.mergeLists(overrideList, attr.getOverrideList());
        List<IMethodDetails> cleanOverrideList = mergedOverrideList.stream().distinct().collect(Collectors.toList());
        Set<IMethodDetails> mergedBaseMethods = Utils.mergeSets(baseMethods, attr.getBaseMethods());
        return this.applyOverrideAttr(mth, cleanOverrideList, mergedBaseMethods, true);
    }

    private MethodOverrideAttr applyOverrideAttr(MethodNode mth, List<IMethodDetails> overrideList, Set<IMethodDetails> baseMethods, boolean update) {
        boolean dontRename = overrideList.stream().anyMatch(m15 -> !(m15 instanceof MethodNode));
        SortedSet<MethodNode> relatedMethods = null;
        List<MethodNode> mthNodes = this.getMethodNodes(mth, overrideList);
        if (update) {
            MethodOverrideAttr ovrdAttr;
            for (MethodNode mthNode : mthNodes) {
                ovrdAttr = mthNode.get(AType.METHOD_OVERRIDE);
                if (ovrdAttr == null) continue;
                relatedMethods = ovrdAttr.getRelatedMthNodes();
                break;
            }
            if (relatedMethods != null) {
                relatedMethods.addAll(mthNodes);
            } else {
                relatedMethods = new TreeSet<MethodNode>(mthNodes);
            }
            for (MethodNode mthNode : mthNodes) {
                SortedSet<MethodNode> set;
                ovrdAttr = mthNode.get(AType.METHOD_OVERRIDE);
                if (ovrdAttr == null || relatedMethods == (set = ovrdAttr.getRelatedMthNodes())) continue;
                relatedMethods.addAll(set);
            }
        } else {
            relatedMethods = new TreeSet<MethodNode>(mthNodes);
        }
        int depth = 0;
        for (MethodNode mthNode : mthNodes) {
            MethodOverrideAttr ovrdAttr;
            if (dontRename) {
                mthNode.add(AFlag.DONT_RENAME);
            }
            if (depth == 0) {
                depth = 1;
                continue;
            }
            if (update && (ovrdAttr = mthNode.get(AType.METHOD_OVERRIDE)) != null) {
                ovrdAttr.setRelatedMthNodes(relatedMethods);
                continue;
            }
            mthNode.addAttr(new MethodOverrideAttr(Utils.listTail(overrideList, depth), relatedMethods, baseMethods));
            ++depth;
        }
        return new MethodOverrideAttr(overrideList, relatedMethods, baseMethods);
    }

    @NotNull
    private List<MethodNode> getMethodNodes(MethodNode mth, List<IMethodDetails> overrideList) {
        ArrayList<MethodNode> list2 = new ArrayList<MethodNode>(1 + overrideList.size());
        list2.add(mth);
        for (IMethodDetails md5 : overrideList) {
            if (!(md5 instanceof MethodNode)) continue;
            list2.add((MethodNode)md5);
        }
        return list2;
    }

    private boolean isMethodVisibleInCls(MethodNode superMth, ClassNode cls) {
        AccessInfo accessFlags = superMth.getAccessFlags();
        if (accessFlags.isPrivate()) {
            return false;
        }
        if (accessFlags.isPublic() || accessFlags.isProtected()) {
            return true;
        }
        return Objects.equals(superMth.getParentClass().getPackage(), cls.getPackage());
    }

    @Nullable
    private SuperTypesData collectSuperTypes(ClassNode cls) {
        LinkedHashSet<ArgType> superTypes = new LinkedHashSet<ArgType>();
        HashSet<String> endTypes = new HashSet<String>();
        this.collectSuperTypes(cls, superTypes, endTypes);
        if (superTypes.isEmpty()) {
            return null;
        }
        if (endTypes.isEmpty()) {
            throw new JadxRuntimeException("No end types in class hierarchy: " + String.valueOf(cls));
        }
        return new SuperTypesData(new ArrayList<ArgType>(superTypes), endTypes);
    }

    private void collectSuperTypes(ClassNode cls, Set<ArgType> superTypes, Set<String> endTypes) {
        RootNode root = cls.root();
        int k15 = 0;
        ArgType superClass = cls.getSuperClass();
        if (superClass != null) {
            k15 += this.addSuperType(root, superTypes, endTypes, superClass);
        }
        for (ArgType iface : cls.getInterfaces()) {
            k15 += this.addSuperType(root, superTypes, endTypes, iface);
        }
        if (k15 == 0) {
            endTypes.add(cls.getType().getObject());
        }
    }

    private int addSuperType(RootNode root, Set<ArgType> superTypes, Set<String> endTypes, ArgType superType) {
        if (Objects.equals(superType, ArgType.OBJECT)) {
            return 0;
        }
        if (!superTypes.add(superType)) {
            return 0;
        }
        ClassNode classNode = root.resolveClass(superType);
        if (classNode != null) {
            this.collectSuperTypes(classNode, superTypes, endTypes);
            return 1;
        }
        ClspClass clsDetails = root.getClsp().getClsDetails(superType);
        if (clsDetails != null) {
            int k15 = 0;
            for (ArgType parentType : clsDetails.getParents()) {
                k15 += this.addSuperType(root, superTypes, endTypes, parentType);
            }
            if (k15 == 0) {
                endTypes.add(superType.getObject());
            }
            return 1;
        }
        endTypes.add(superType.getObject());
        return 1;
    }

    private boolean fixMethodReturnType(MethodNode mth, IMethodDetails baseMth, SuperTypesData superData) {
        ArgType returnType = mth.getReturnType();
        if (returnType == ArgType.VOID) {
            return false;
        }
        boolean updated = this.updateReturnType(mth, baseMth, superData);
        if (updated) {
            mth.addDebugComment("Return type fixed from '" + String.valueOf(returnType) + "' to match base method");
        }
        return updated;
    }

    private boolean updateReturnType(MethodNode mth, IMethodDetails baseMth, SuperTypesData superData) {
        ArgType baseReturnType = baseMth.getReturnType();
        if (mth.getReturnType().equals(baseReturnType)) {
            return false;
        }
        if (!baseReturnType.containsTypeVariable()) {
            return false;
        }
        TypeCompare typeCompare = mth.root().getTypeUpdate().getTypeCompare();
        ArgType baseCls = baseMth.getMethodInfo().getDeclClass().getType();
        for (ArgType superType : superData.getSuperTypes()) {
            ArgType targetRetType;
            TypeCompareEnum compareResult = typeCompare.compareTypes(superType, baseCls);
            if (compareResult != TypeCompareEnum.NARROW_BY_GENERIC || (targetRetType = mth.root().getTypeUtils().replaceClassGenerics(superType, baseReturnType)) == null || targetRetType.containsTypeVariable() || targetRetType.equals(mth.getReturnType())) continue;
            mth.updateReturnType(targetRetType);
            return true;
        }
        return false;
    }

    private boolean fixMethodArgTypes(MethodNode mth, IMethodDetails baseMth, SuperTypesData superData) {
        List<ArgType> baseArgTypes;
        List<ArgType> mthArgTypes = mth.getArgTypes();
        if (mthArgTypes.equals(baseArgTypes = baseMth.getArgTypes())) {
            return false;
        }
        int argCount = mthArgTypes.size();
        if (argCount != baseArgTypes.size()) {
            return false;
        }
        boolean changed = false;
        ArrayList<ArgType> newArgTypes = new ArrayList<ArgType>(argCount);
        for (int argNum = 0; argNum < argCount; ++argNum) {
            ArgType newType = this.updateArgType(mth, baseMth, superData, argNum);
            if (newType != null) {
                changed = true;
                newArgTypes.add(newType);
                continue;
            }
            newArgTypes.add(mthArgTypes.get(argNum));
        }
        if (changed) {
            mth.updateArgTypes(newArgTypes, "Method arguments types fixed to match base method");
        }
        return changed;
    }

    private ArgType updateArgType(MethodNode mth, IMethodDetails baseMth, SuperTypesData superData, int argNum) {
        ArgType baseArg;
        ArgType arg = mth.getArgTypes().get(argNum);
        if (arg.equals(baseArg = baseMth.getArgTypes().get(argNum))) {
            return null;
        }
        if (!baseArg.containsTypeVariable()) {
            return null;
        }
        TypeCompare typeCompare = mth.root().getTypeUpdate().getTypeCompare();
        ArgType baseCls = baseMth.getMethodInfo().getDeclClass().getType();
        for (ArgType superType : superData.getSuperTypes()) {
            ArgType targetArgType;
            TypeCompareEnum compareResult = typeCompare.compareTypes(superType, baseCls);
            if (compareResult != TypeCompareEnum.NARROW_BY_GENERIC || (targetArgType = mth.root().getTypeUtils().replaceClassGenerics(superType, baseArg)) == null || targetArgType.containsTypeVariable() || targetArgType.equals(arg)) continue;
            return targetArgType;
        }
        return null;
    }

    private void checkMethodSignatureCollisions(MethodNode mth, boolean rename) {
        String mthName = mth.getMethodInfo().getAlias();
        String newSignature = MethodInfo.makeShortId(mthName, mth.getArgTypes(), null);
        for (MethodNode otherMth : mth.getParentClass().getMethods()) {
            String otherSignature;
            String otherMthName = otherMth.getAlias();
            if (!otherMthName.equals(mthName) || otherMth == mth || !(otherSignature = otherMth.getMethodInfo().makeSignature(true, false)).equals(newSignature)) continue;
            if (rename) {
                if (otherMth.contains(AFlag.DONT_RENAME) || otherMth.contains(AType.METHOD_OVERRIDE)) {
                    otherMth.addWarnComment("Can't rename method to resolve collision");
                } else {
                    otherMth.getMethodInfo().setAlias(OverrideMethodVisitor.makeNewAlias(otherMth));
                    otherMth.addAttr(new RenameReasonAttr("avoid collision after fix types in other method"));
                }
            }
            otherMth.addAttr(new MethodBridgeAttr(mth));
            return;
        }
    }

    private static String makeNewAlias(MethodNode mth) {
        ClassNode cls = mth.getParentClass();
        String baseName = mth.getAlias();
        int k15 = 2;
        String alias;
        MethodNode methodNode;
        while ((methodNode = cls.searchMethodByShortName(alias = baseName + k15)) != null) {
            ++k15;
        }
        return alias;
    }

    @Override
    public String getName() {
        return "OverrideMethodVisitor";
    }

    private static final class SuperTypesData {
        private final List<ArgType> superTypes;
        private final Set<String> endTypes;

        private SuperTypesData(List<ArgType> superTypes, Set<String> endTypes) {
            this.superTypes = superTypes;
            this.endTypes = endTypes;
        }

        public List<ArgType> getSuperTypes() {
            return this.superTypes;
        }

        public Set<String> getEndTypes() {
            return this.endTypes;
        }
    }
}

