from operator import attrgetter, itemgetter
from typing import (
    Any,
    Collection,
    Dict,
    List,
    Optional,
    Set,
    Tuple,
    Union,
    cast,
)

from ..error import GraphQLError
from ..pyutils import inspect
from ..language import (
    DirectiveNode,
    InputValueDefinitionNode,
    NamedTypeNode,
    Node,
    OperationType,
    SchemaDefinitionNode,
    SchemaExtensionNode,
)
from .definition import (
    GraphQLEnumType,
    GraphQLInputField,
    GraphQLInputObjectType,
    GraphQLInterfaceType,
    GraphQLObjectType,
    GraphQLUnionType,
    is_enum_type,
    is_input_object_type,
    is_input_type,
    is_interface_type,
    is_named_type,
    is_non_null_type,
    is_object_type,
    is_output_type,
    is_union_type,
    is_required_argument,
    is_required_input_field,
)
from ..utilities.type_comparators import is_equal_type, is_type_sub_type_of
from .directives import is_directive, GraphQLDeprecatedDirective
from .introspection import is_introspection_type
from .schema import GraphQLSchema, assert_schema

__all__ = ["validate_schema", "assert_valid_schema"]


def validate_schema(schema: GraphQLSchema) -> List[GraphQLError]:
    """Validate a GraphQL schema.

    Implements the "Type Validation" sub-sections of the specification's "Type System"
    section.

    Validation runs synchronously, returning a list of encountered errors, or an empty
    list if no errors were encountered and the Schema is valid.
    """
    # First check to ensure the provided value is in fact a GraphQLSchema.
    assert_schema(schema)

    # If this Schema has already been validated, return the previous results.
    # noinspection PyProtectedMember
    errors = schema._validation_errors
    if errors is None:

        # Validate the schema, producing a list of errors.
        context = SchemaValidationContext(schema)
        context.validate_root_types()
        context.validate_directives()
        context.validate_types()

        # Persist the results of validation before returning to ensure validation does
        # not run multiple times for this schema.
        errors = context.errors
        schema._validation_errors = errors

    return errors


def assert_valid_schema(schema: GraphQLSchema) -> None:
    """Utility function which asserts a schema is valid.

    Throws a TypeError if the schema is invalid.
    """
    errors = validate_schema(schema)
    if errors:
        raise TypeError("\n\n".join(error.message for error in errors))


class SchemaValidationContext:
    """Utility class providing a context for schema validation."""

    errors: List[GraphQLError]
    schema: GraphQLSchema

    def __init__(self, schema: GraphQLSchema):
        self.errors = []
        self.schema = schema

    def report_error(
        self,
        message: str,
        nodes: Union[Optional[Node], Collection[Optional[Node]]] = None,
    ) -> None:
        if nodes and not isinstance(nodes, Node):
            nodes = [node for node in nodes if node]
        nodes = cast(Optional[Collection[Node]], nodes)
        self.errors.append(GraphQLError(message, nodes))

    def validate_root_types(self) -> None:
        schema = self.schema
        query_type = schema.query_type
        if not query_type:
            self.report_error("Query root type must be provided.", schema.ast_node)
        elif not is_object_type(query_type):
            self.report_error(
                f"Query root type must be Object type, it cannot be {query_type}.",
                get_operation_type_node(schema, OperationType.QUERY)
                or query_type.ast_node,
            )

        mutation_type = schema.mutation_type
        if mutation_type and not is_object_type(mutation_type):
            self.report_error(
                "Mutation root type must be Object type if provided,"
                f" it cannot be {mutation_type}.",
                get_operation_type_node(schema, OperationType.MUTATION)
                or mutation_type.ast_node,
            )

        subscription_type = schema.subscription_type
        if subscription_type and not is_object_type(subscription_type):
            self.report_error(
                "Subscription root type must be Object type if provided,"
                f" it cannot be {subscription_type}.",
                get_operation_type_node(schema, OperationType.SUBSCRIPTION)
                or subscription_type.ast_node,
            )

    def validate_directives(self) -> None:
        directives = self.schema.directives
        for directive in directives:
            # Ensure all directives are in fact GraphQL directives.
            if not is_directive(directive):
                self.report_error(
                    f"Expected directive but got: {inspect(directive)}.",
                    getattr(directive, "ast_node", None),
                )
                continue

            # Ensure they are named correctly.
            self.validate_name(directive)

            # Ensure the arguments are valid.
            for arg_name, arg in directive.args.items():
                # Ensure they are named correctly.
                self.validate_name(arg, arg_name)

                # Ensure the type is an input type.
                if not is_input_type(arg.type):
                    self.report_error(
                        f"The type of @{directive.name}({arg_name}:)"
                        f" must be Input Type but got: {inspect(arg.type)}.",
                        arg.ast_node,
                    )

                if is_required_argument(arg) and arg.deprecation_reason is not None:
                    self.report_error(
                        f"Required argument @{directive.name}({arg_name}:)"
                        " cannot be deprecated.",
                        [
                            get_deprecated_directive_node(arg.ast_node),
                            arg.ast_node and arg.ast_node.type,
                        ],
                    )

    def validate_name(self, node: Any, name: Optional[str] = None) -> None:
        # Ensure names are valid, however introspection types opt out.
        try:
            if not name:
                name = node.name
            name = cast(str, name)
            ast_node = node.ast_node
        except AttributeError:  # pragma: no cover
            pass
        else:
            if name.startswith("__"):
                self.report_error(
                    f"Name {name!r} must not begin with '__',"
                    " which is reserved by GraphQL introspection.",
                    ast_node,
                )

    def validate_types(self) -> None:
        validate_input_object_circular_refs = InputObjectCircularRefsValidator(self)
        for type_ in self.schema.type_map.values():

            # Ensure all provided types are in fact GraphQL type.
            if not is_named_type(type_):
                self.report_error(
                    f"Expected GraphQL named type but got: {inspect(type_)}.",
                    type_.ast_node if is_named_type(type_) else None,
                )
                continue

            # Ensure it is named correctly (excluding introspection types).
            if not is_introspection_type(type_):
                self.validate_name(type_)

            if is_object_type(type_):
                type_ = cast(GraphQLObjectType, type_)
                # Ensure fields are valid
                self.validate_fields(type_)

                # Ensure objects implement the interfaces they claim to.
                self.validate_interfaces(type_)
            elif is_interface_type(type_):
                type_ = cast(GraphQLInterfaceType, type_)
                # Ensure fields are valid.
                self.validate_fields(type_)

                # Ensure interfaces implement the interfaces they claim to.
                self.validate_interfaces(type_)
            elif is_union_type(type_):
                type_ = cast(GraphQLUnionType, type_)
                # Ensure Unions include valid member types.
                self.validate_union_members(type_)
            elif is_enum_type(type_):
                type_ = cast(GraphQLEnumType, type_)
                # Ensure Enums have valid values.
                self.validate_enum_values(type_)
            elif is_input_object_type(type_):
                type_ = cast(GraphQLInputObjectType, type_)
                # Ensure Input Object fields are valid.
                self.validate_input_fields(type_)

                # Ensure Input Objects do not contain non-nullable circular references
                validate_input_object_circular_refs(type_)

    def validate_fields(
        self, type_: Union[GraphQLObjectType, GraphQLInterfaceType]
    ) -> None:
        fields = type_.fields

        # Objects and Interfaces both must define one or more fields.
        if not fields:
            self.report_error(
                f"Type {type_.name} must define one or more fields.",
                [type_.ast_node, *type_.extension_ast_nodes],
            )

        for field_name, field in fields.items():

            # Ensure they are named correctly.
            self.validate_name(field, field_name)

            # Ensure the type is an output type
            if not is_output_type(field.type):
                self.report_error(
                    f"The type of {type_.name}.{field_name}"
                    f" must be Output Type but got: {inspect(field.type)}.",
                    field.ast_node and field.ast_node.type,
                )

            # Ensure the arguments are valid.
            for arg_name, arg in field.args.items():
                # Ensure they are named correctly.
                self.validate_name(arg, arg_name)

                # Ensure the type is an input type.
                if not is_input_type(arg.type):
                    self.report_error(
                        f"The type of {type_.name}.{field_name}({arg_name}:)"
                        f" must be Input Type but got: {inspect(arg.type)}.",
                        arg.ast_node and arg.ast_node.type,
                    )

                if is_required_argument(arg) and arg.deprecation_reason is not None:
                    self.report_error(
                        f"Required argument {type_.name}.{field_name}({arg_name}:)"
                        " cannot be deprecated.",
                        [
                            get_deprecated_directive_node(arg.ast_node),
                            arg.ast_node and arg.ast_node.type,
                        ],
                    )

    def validate_interfaces(
        self, type_: Union[GraphQLObjectType, GraphQLInterfaceType]
    ) -> None:
        iface_type_names: Set[str] = set()
        for iface in type_.interfaces:
            if not is_interface_type(iface):
                self.report_error(
                    f"Type {type_.name} must only implement Interface"
                    f" types, it cannot implement {inspect(iface)}.",
                    get_all_implements_interface_nodes(type_, iface),
                )
                continue

            if type_ is iface:
                self.report_error(
                    f"Type {type_.name} cannot implement itself"
                    " because it would create a circular reference.",
                    get_all_implements_interface_nodes(type_, iface),
                )

            if iface.name in iface_type_names:
                self.report_error(
                    f"Type {type_.name} can only implement {iface.name} once.",
                    get_all_implements_interface_nodes(type_, iface),
                )
                continue

            iface_type_names.add(iface.name)

            self.validate_type_implements_ancestors(type_, iface)
            self.validate_type_implements_interface(type_, iface)

    def validate_type_implements_interface(
        self,
        type_: Union[GraphQLObjectType, GraphQLInterfaceType],
        iface: GraphQLInterfaceType,
    ) -> None:
        type_fields, iface_fields = type_.fields, iface.fields

        # Assert each interface field is implemented.
        for field_name, iface_field in iface_fields.items():
            type_field = type_fields.get(field_name)

            # Assert interface field exists on object.
            if not type_field:
                self.report_error(
                    f"Interface field {iface.name}.{field_name}"
                    f" expected but {type_.name} does not provide it.",
                    [
                        iface_field.ast_node,
                        type_.ast_node,
                        *type_.extension_ast_nodes,
                    ],
                )
                continue

            # Assert interface field type is satisfied by type field type, by being
            # a valid subtype (covariant).
            if not is_type_sub_type_of(self.schema, type_field.type, iface_field.type):
                self.report_error(
                    f"Interface field {iface.name}.{field_name}"
                    f" expects type {iface_field.type}"
                    f" but {type_.name}.{field_name}"
                    f" is type {type_field.type}.",
                    [
                        iface_field.ast_node and iface_field.ast_node.type,
                        type_field.ast_node and type_field.ast_node.type,
                    ],
                )

            # Assert each interface field arg is implemented.
            for arg_name, iface_arg in iface_field.args.items():
                type_arg = type_field.args.get(arg_name)

                # Assert interface field arg exists on object field.
                if not type_arg:
                    self.report_error(
                        "Interface field argument"
                        f" {iface.name}.{field_name}({arg_name}:)"
                        f" expected but {type_.name}.{field_name}"
                        " does not provide it.",
                        [iface_arg.ast_node, type_field.ast_node],
                    )
                    continue

                # Assert interface field arg type matches object field arg type
                # (invariant).
                if not is_equal_type(iface_arg.type, type_arg.type):
                    self.report_error(
                        "Interface field argument"
                        f" {iface.name}.{field_name}({arg_name}:)"
                        f" expects type {iface_arg.type}"
                        f" but {type_.name}.{field_name}({arg_name}:)"
                        f" is type {type_arg.type}.",
                        [
                            iface_arg.ast_node and iface_arg.ast_node.type,
                            type_arg.ast_node and type_arg.ast_node.type,
                        ],
                    )

            # Assert additional arguments must not be required.
            for arg_name, type_arg in type_field.args.items():
                iface_arg = iface_field.args.get(arg_name)
                if not iface_arg and is_required_argument(type_arg):
                    self.report_error(
                        f"Object field {type_.name}.{field_name} includes"
                        f" required argument {arg_name} that is missing from"
                        f" the Interface field {iface.name}.{field_name}.",
                        [type_arg.ast_node, iface_field.ast_node],
                    )

    def validate_type_implements_ancestors(
        self,
        type_: Union[GraphQLObjectType, GraphQLInterfaceType],
        iface: GraphQLInterfaceType,
    ) -> None:
        type_interfaces, iface_interfaces = type_.interfaces, iface.interfaces
        for transitive in iface_interfaces:
            if transitive not in type_interfaces:
                self.report_error(
                    (
                        f"Type {type_.name} cannot implement {iface.name}"
                        " because it would create a circular reference."
                        if transitive is type_
                        else f"Type {type_.name} must implement {transitive.name}"
                        f" because it is implemented by {iface.name}."
                    ),
                    get_all_implements_interface_nodes(iface, transitive)
                    + get_all_implements_interface_nodes(type_, iface),
                )

    def validate_union_members(self, union: GraphQLUnionType) -> None:
        member_types = union.types

        if not member_types:
            self.report_error(
                f"Union type {union.name} must define one or more member types.",
                [union.ast_node, *union.extension_ast_nodes],
            )

        included_type_names: Set[str] = set()
        for member_type in member_types:
            if is_object_type(member_type):
                if member_type.name in included_type_names:
                    self.report_error(
                        f"Union type {union.name} can only include type"
                        f" {member_type.name} once.",
                        get_union_member_type_nodes(union, member_type.name),
                    )
                else:
                    included_type_names.add(member_type.name)
            else:
                self.report_error(
                    f"Union type {union.name} can only include Object types,"
                    f" it cannot include {inspect(member_type)}.",
                    get_union_member_type_nodes(union, str(member_type)),
                )

    def validate_enum_values(self, enum_type: GraphQLEnumType) -> None:
        enum_values = enum_type.values

        if not enum_values:
            self.report_error(
                f"Enum type {enum_type.name} must define one or more values.",
                [enum_type.ast_node, *enum_type.extension_ast_nodes],
            )

        for value_name, enum_value in enum_values.items():
            # Ensure valid name.
            self.validate_name(enum_value, value_name)

    def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
        fields = input_obj.fields

        if not fields:
            self.report_error(
                f"Input Object type {input_obj.name}"
                " must define one or more fields.",
                [input_obj.ast_node, *input_obj.extension_ast_nodes],
            )

        # Ensure the arguments are valid
        for field_name, field in fields.items():

            # Ensure they are named correctly.
            self.validate_name(field, field_name)

            # Ensure the type is an input type.
            if not is_input_type(field.type):
                self.report_error(
                    f"The type of {input_obj.name}.{field_name}"
                    f" must be Input Type but got: {inspect(field.type)}.",
                    field.ast_node.type if field.ast_node else None,
                )

            if is_required_input_field(field) and field.deprecation_reason is not None:
                self.report_error(
                    f"Required input field {input_obj.name}.{field_name}"
                    " cannot be deprecated.",
                    [
                        get_deprecated_directive_node(field.ast_node),
                        field.ast_node and field.ast_node.type,
                    ],
                )


def get_operation_type_node(
    schema: GraphQLSchema, operation: OperationType
) -> Optional[Node]:
    ast_node: Optional[Union[SchemaDefinitionNode, SchemaExtensionNode]]
    for ast_node in [schema.ast_node, *(schema.extension_ast_nodes or ())]:
        if ast_node:
            operation_types = ast_node.operation_types
            if operation_types:  # pragma: no cover else
                for operation_type in operation_types:
                    if operation_type.operation == operation:
                        return operation_type.type
    return None


class InputObjectCircularRefsValidator:
    """Modified copy of algorithm from validation.rules.NoFragmentCycles"""

    def __init__(self, context: SchemaValidationContext):
        self.context = context
        # Tracks already visited types to maintain O(N) and to ensure that cycles
        # are not redundantly reported.
        self.visited_types: Set[str] = set()
        # Array of input fields used to produce meaningful errors
        self.field_path: List[Tuple[str, GraphQLInputField]] = []
        # Position in the type path
        self.field_path_index_by_type_name: Dict[str, int] = {}

    def __call__(self, input_obj: GraphQLInputObjectType) -> None:
        """Detect cycles recursively."""
        # This does a straight-forward DFS to find cycles.
        # It does not terminate when a cycle was found but continues to explore
        # the graph to find all possible cycles.
        name = input_obj.name
        if name in self.visited_types:
            return

        self.visited_types.add(name)
        self.field_path_index_by_type_name[name] = len(self.field_path)

        for field_name, field in input_obj.fields.items():
            if is_non_null_type(field.type) and is_input_object_type(
                field.type.of_type
            ):
                field_type = cast(GraphQLInputObjectType, field.type.of_type)
                cycle_index = self.field_path_index_by_type_name.get(field_type.name)

                self.field_path.append((field_name, field))
                if cycle_index is None:
                    self(field_type)
                else:
                    cycle_path = self.field_path[cycle_index:]
                    field_names = map(itemgetter(0), cycle_path)
                    self.context.report_error(
                        f"Cannot reference Input Object '{field_type.name}'"
                        " within itself through a series of non-null fields:"
                        f" '{'.'.join(field_names)}'.",
                        cast(
                            Collection[Node],
                            map(attrgetter("ast_node"), map(itemgetter(1), cycle_path)),
                        ),
                    )
                self.field_path.pop()

        del self.field_path_index_by_type_name[name]


def get_all_implements_interface_nodes(
    type_: Union[GraphQLObjectType, GraphQLInterfaceType], iface: GraphQLInterfaceType
) -> List[NamedTypeNode]:
    ast_node = type_.ast_node
    nodes = type_.extension_ast_nodes
    if ast_node is not None:
        nodes = [ast_node, *nodes]  # type: ignore
    implements_nodes: List[NamedTypeNode] = []
    for node in nodes:
        iface_nodes = node.interfaces
        if iface_nodes:  # pragma: no cover else
            implements_nodes.extend(
                iface_node
                for iface_node in iface_nodes
                if iface_node.name.value == iface.name
            )
    return implements_nodes


def get_union_member_type_nodes(
    union: GraphQLUnionType, type_name: str
) -> List[NamedTypeNode]:
    ast_node = union.ast_node
    nodes = union.extension_ast_nodes
    if ast_node is not None:
        nodes = [ast_node, *nodes]  # type: ignore
    member_type_nodes: List[NamedTypeNode] = []
    for node in nodes:
        type_nodes = node.types
        if type_nodes:  # pragma: no cover else
            member_type_nodes.extend(
                type_node
                for type_node in type_nodes
                if type_node.name.value == type_name
            )
    return member_type_nodes


def get_deprecated_directive_node(
    definition_node: Optional[Union[InputValueDefinitionNode]],
) -> Optional[DirectiveNode]:
    directives = definition_node and definition_node.directives
    if directives:
        for directive in directives:
            if (
                directive.name.value == GraphQLDeprecatedDirective.name
            ):  # pragma: no cover else
                return directive
    return None  # pragma: no cover
