# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt

"""Classes representing different types of constraints on inference values."""

from __future__ import annotations

import sys
from abc import ABC, abstractmethod
from collections.abc import Iterator
from typing import TYPE_CHECKING

from astroid import helpers, nodes, util
from astroid.exceptions import AstroidTypeError, InferenceError, MroError
from astroid.typing import InferenceResult

if sys.version_info >= (3, 11):
    from typing import Self
else:
    from typing_extensions import Self

if TYPE_CHECKING:
    from astroid import bases

_NameNodes = nodes.AssignAttr | nodes.Attribute | nodes.AssignName | nodes.Name


class Constraint(ABC):
    """Represents a single constraint on a variable."""

    def __init__(self, node: nodes.NodeNG, negate: bool) -> None:
        self.node = node
        """The node that this constraint applies to."""
        self.negate = negate
        """True if this constraint is negated. E.g., "is not" instead of "is"."""

    @classmethod
    @abstractmethod
    def match(
        cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
    ) -> Self | None:
        """Return a new constraint for node matched from expr, if expr matches
        the constraint pattern.

        If negate is True, negate the constraint.
        """

    @abstractmethod
    def satisfied_by(self, inferred: InferenceResult) -> bool:
        """Return True if this constraint is satisfied by the given inferred value."""


class NoneConstraint(Constraint):
    """Represents an "is None" or "is not None" constraint."""

    CONST_NONE: nodes.Const = nodes.Const(None)

    @classmethod
    def match(
        cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
    ) -> Self | None:
        """Return a new constraint for node matched from expr, if expr matches
        the constraint pattern.

        Negate the constraint based on the value of negate.
        """
        if isinstance(expr, nodes.Compare) and len(expr.ops) == 1:
            left = expr.left
            op, right = expr.ops[0]
            if op in {"is", "is not"} and (
                _matches(left, node) and _matches(right, cls.CONST_NONE)
            ):
                negate = (op == "is" and negate) or (op == "is not" and not negate)
                return cls(node=node, negate=negate)

        return None

    def satisfied_by(self, inferred: InferenceResult) -> bool:
        """Return True if this constraint is satisfied by the given inferred value."""
        # Assume true if uninferable
        if inferred is util.Uninferable:
            return True

        # Return the XOR of self.negate and matches(inferred, self.CONST_NONE)
        return self.negate ^ _matches(inferred, self.CONST_NONE)


class BooleanConstraint(Constraint):
    """Represents an "x" or "not x" constraint."""

    @classmethod
    def match(
        cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
    ) -> Self | None:
        """Return a new constraint for node if expr matches one of these patterns:

        - direct match (expr == node): use given negate value
        - negated match (expr == `not node`): flip negate value

        Return None if no pattern matches.
        """
        if _matches(expr, node):
            return cls(node=node, negate=negate)

        if (
            isinstance(expr, nodes.UnaryOp)
            and expr.op == "not"
            and _matches(expr.operand, node)
        ):
            return cls(node=node, negate=not negate)

        return None

    def satisfied_by(self, inferred: InferenceResult) -> bool:
        """Return True for uninferable results, or depending on negate flag:

        - negate=False: satisfied if boolean value is True
        - negate=True: satisfied if boolean value is False
        """
        inferred_booleaness = inferred.bool_value()
        if inferred is util.Uninferable or inferred_booleaness is util.Uninferable:
            return True

        return self.negate ^ inferred_booleaness


class TypeConstraint(Constraint):
    """Represents an "isinstance(x, y)" constraint."""

    def __init__(
        self, node: nodes.NodeNG, classinfo: nodes.NodeNG, negate: bool
    ) -> None:
        super().__init__(node=node, negate=negate)
        self.classinfo = classinfo

    @classmethod
    def match(
        cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
    ) -> Self | None:
        """Return a new constraint for node if expr matches the
        "isinstance(x, y)" pattern. Else, return None.
        """
        is_instance_call = (
            isinstance(expr, nodes.Call)
            and isinstance(expr.func, nodes.Name)
            and expr.func.name == "isinstance"
            and not expr.keywords
            and len(expr.args) == 2
        )
        if is_instance_call and _matches(expr.args[0], node):
            return cls(node=node, classinfo=expr.args[1], negate=negate)

        return None

    def satisfied_by(self, inferred: InferenceResult) -> bool:
        """Return True for uninferable results, or depending on negate flag:

        - negate=False: satisfied when inferred is an instance of the checked types.
        - negate=True: satisfied when inferred is not an instance of the checked types.
        """
        if inferred is util.Uninferable:
            return True

        try:
            types = helpers.class_or_tuple_to_container(self.classinfo)
            matches_checked_types = helpers.object_isinstance(inferred, types)

            if matches_checked_types is util.Uninferable:
                return True

            return self.negate ^ matches_checked_types
        except (InferenceError, AstroidTypeError, MroError):
            return True


class EqualityConstraint(Constraint):
    """Represents a "==" or "!=" constraint."""

    def __init__(self, node: nodes.NodeNG, operand: nodes.NodeNG, negate: bool) -> None:
        super().__init__(node=node, negate=negate)
        self.operand = operand

    @classmethod
    def match(
        cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
    ) -> Self | None:
        """Return a new constraint for node if expr matches one of these patterns:

        - "node == operand" or "operand == node": use given negate value
        - "node != operand" or "operand != node": flip negate value

        Return None if no pattern matches.
        """
        if isinstance(expr, nodes.Compare) and len(expr.ops) == 1:
            left = expr.left
            op, right = expr.ops[0]
            matches_left = _matches(left, node)

            if op in {"==", "!="} and (matches_left or _matches(right, node)):
                operand = right if matches_left else left
                negate = (op == "==" and negate) or (op == "!=" and not negate)
                return cls(node=node, operand=operand, negate=negate)

        return None

    def satisfied_by(self, inferred: InferenceResult) -> bool:
        """Return True for uninferable/ambiguous results, or depending on negate flag:

        - negate=False: satisfied when both operands are equal.
        - negate=True: satisfied when both operands are not equal.

        Only comparisons between constants and callables are supported.
        """
        if inferred is util.Uninferable:
            return True

        operand_inferred = util.safe_infer(self.operand)
        if operand_inferred is util.Uninferable or operand_inferred is None:
            return True

        if isinstance(inferred, nodes.Const) and isinstance(
            operand_inferred, nodes.Const
        ):
            return self.negate ^ (inferred.value == operand_inferred.value)

        if inferred.callable() and operand_inferred.callable():
            return self.negate ^ (inferred is operand_inferred)

        return True


def get_constraints(
    expr: _NameNodes, frame: nodes.LocalsDictNodeNG
) -> dict[nodes.If | nodes.IfExp, set[Constraint]]:
    """Returns the constraints for the given expression.

    The returned dictionary maps the node where the constraint was generated to the
    corresponding constraint(s).

    Constraints are computed statically by analysing the code surrounding expr.
    Currently this only supports constraints generated from if conditions.
    """
    current_node: nodes.NodeNG | None = expr
    constraints_mapping: dict[nodes.If | nodes.IfExp, set[Constraint]] = {}
    while current_node is not None and current_node is not frame:
        parent = current_node.parent
        if isinstance(parent, (nodes.If, nodes.IfExp)):
            branch, _ = parent.locate_child(current_node)
            constraints: set[Constraint] | None = None
            if branch == "body":
                constraints = set(_match_constraint(expr, parent.test))
            elif branch == "orelse":
                constraints = set(_match_constraint(expr, parent.test, invert=True))

            if constraints:
                constraints_mapping[parent] = constraints
        current_node = parent

    return constraints_mapping


ALL_CONSTRAINT_CLASSES = frozenset(
    (
        NoneConstraint,
        BooleanConstraint,
        TypeConstraint,
        EqualityConstraint,
    )
)
"""All supported constraint types."""


def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool:
    """Returns True if the two nodes match."""
    if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name):
        return node1.name == node2.name
    if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute):
        return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr)
    if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const):
        return node1.value == node2.value

    return False


def _match_constraint(
    node: _NameNodes, expr: nodes.NodeNG, invert: bool = False
) -> Iterator[Constraint]:
    """Yields all constraint patterns for node that match."""
    for constraint_cls in ALL_CONSTRAINT_CLASSES:
        constraint = constraint_cls.match(node, expr, invert)
        if constraint:
            yield constraint
