#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import functools
import os
from typing import (
    Any,
    Callable,
    Optional,
    List,
    Sequence,
    TYPE_CHECKING,
    cast,
    TypeVar,
    Union,
    Type,
)
# For backward compatibility.
from pyspark.errors import (  # noqa: F401
    AnalysisException,
    ParseException,
    IllegalArgumentException,
    StreamingQueryException,
    QueryExecutionException,
    PythonException,
    UnknownException,
    SparkUpgradeException,
    PySparkNotImplementedError,
    PySparkRuntimeError,
)
from pyspark.util import is_remote_only, JVM_INT_MAX
from pyspark.errors.exceptions.captured import CapturedException  # noqa: F401
from pyspark.find_spark_home import _find_spark_home
if TYPE_CHECKING:
    from py4j.java_collections import JavaArray
    from py4j.java_gateway import (
        JavaClass,
        JavaGateway,
        JavaObject,
    )
    from pyspark import SparkContext
    from pyspark.sql.session import SparkSession
    from pyspark.sql.dataframe import DataFrame
    from pyspark.sql.column import Column
    from pyspark.sql.window import Window
    from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex
has_numpy: bool = False
try:
    import numpy as np  # noqa: F401
    has_numpy = True
except ImportError:
    pass
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
def toJArray(gateway: "JavaGateway", jtype: "JavaClass", arr: Sequence[Any]) -> "JavaArray":
    """
    Convert python list to java type array
    Parameters
    ----------
    gateway :
        Py4j Gateway
    jtype :
        java type of element in array
    arr :
        python type list
    """
    jarray: "JavaArray" = gateway.new_array(jtype, len(arr))
    for i in range(0, len(arr)):
        jarray[i] = arr[i]
    return jarray
def require_test_compiled() -> None:
    """Raise Exception if test classes are not compiled"""
    import os
    import glob
    test_class_path = os.path.join(_find_spark_home(), "sql", "core", "target", "*", "test-classes")
    paths = glob.glob(test_class_path)
    if len(paths) == 0:
        raise PySparkRuntimeError(
            error_class="TEST_CLASS_NOT_COMPILED",
            message_parameters={"test_class_path": test_class_path},
        )
class ForeachBatchFunction:
    """
    This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps
    the user-defined 'foreachBatch' function such that it can be called from the JVM when
    the query is active.
    """
    def __init__(self, session: "SparkSession", func: Callable[["DataFrame", int], None]):
        self.func = func
        self.session = session
    def call(self, jdf: "JavaObject", batch_id: int) -> None:
        from pyspark.sql.dataframe import DataFrame
        from pyspark.sql.session import SparkSession
        try:
            session_jdf = jdf.sparkSession()
            # assuming that spark context is still the same between JVM and PySpark
            wrapped_session_jdf = SparkSession(self.session.sparkContext, session_jdf)
            self.func(DataFrame(jdf, wrapped_session_jdf), batch_id)
        except Exception as e:
            self.error = e
            raise e
    class Java:
        implements = ["org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction"]
# Python implementation of 'org.apache.spark.sql.catalyst.util.StringConcat'
class StringConcat:
    def __init__(self, maxLength: int = JVM_INT_MAX - 15):
        self.maxLength: int = maxLength
        self.strings: List[str] = []
        self.length: int = 0
    def atLimit(self) -> bool:
        return self.length >= self.maxLength
    def append(self, s: str) -> None:
        if s is not None:
            sLen = len(s)
            if not self.atLimit():
                available = self.maxLength - self.length
                stringToAppend = s if available >= sLen else s[0:available]
                self.strings.append(stringToAppend)
            self.length = min(self.length + sLen, JVM_INT_MAX - 15)
    def toString(self) -> str:
        # finalLength = self.maxLength if self.atLimit()  else self.length
        return "".join(self.strings)
# Python implementation of 'org.apache.spark.util.SparkSchemaUtils.escapeMetaCharacters'
def escape_meta_characters(s: str) -> str:
    return (
        s.replace("\n", "\\n")
        .replace("\r", "\\r")
        .replace("\t", "\\t")
        .replace("\f", "\\f")
        .replace("\b", "\\b")
        .replace("\u000B", "\\v")
        .replace("\u0007", "\\a")
    )
def to_str(value: Any) -> Optional[str]:
    """
    A wrapper over str(), but converts bool values to lower case strings.
    If None is given, just returns None, instead of converting it to string "None".
    """
    if isinstance(value, bool):
        return str(value).lower()
    elif value is None:
        return value
    else:
        return str(value)
def is_timestamp_ntz_preferred() -> bool:
    """
    Return a bool if TimestampNTZType is preferred according to the SQL configuration set.
    """
    if is_remote():
        from pyspark.sql.connect.session import SparkSession as ConnectSparkSession
        session = ConnectSparkSession.getActiveSession()
        if session is None:
            return False
        else:
            return session.conf.get("spark.sql.timestampType", None) == "TIMESTAMP_NTZ"
    else:
        from pyspark import SparkContext
        jvm = SparkContext._jvm
        return jvm is not None and jvm.PythonSQLUtils.isTimestampNTZPreferred()
[docs]def is_remote() -> bool:
    """
    Returns if the current running environment is for Spark Connect.
    .. versionadded:: 4.0.0
    Notes
    -----
    This will only return ``True`` if there is a remote session running.
    Otherwise, it returns ``False``.
    This API is unstable, and for developers.
    Returns
    -------
    bool
    Examples
    --------
    >>> from pyspark.sql import is_remote
    >>> is_remote()
    False
    """
    return ("SPARK_CONNECT_MODE_ENABLED" in os.environ) or is_remote_only() 
def try_remote_functions(f: FuncT) -> FuncT:
    """Mark API supported from Spark Connect."""
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect import functions
            return getattr(functions, f.__name__)(*args, **kwargs)
        else:
            return f(*args, **kwargs)
    return cast(FuncT, wrapped)
def try_partitioning_remote_functions(f: FuncT) -> FuncT:
    """Mark API supported from Spark Connect."""
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.functions import partitioning
            return getattr(partitioning, f.__name__)(*args, **kwargs)
        else:
            return f(*args, **kwargs)
    return cast(FuncT, wrapped)
def try_remote_avro_functions(f: FuncT) -> FuncT:
    """Mark API supported from Spark Connect."""
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.avro import functions
            return getattr(functions, f.__name__)(*args, **kwargs)
        else:
            return f(*args, **kwargs)
    return cast(FuncT, wrapped)
def try_remote_protobuf_functions(f: FuncT) -> FuncT:
    """Mark API supported from Spark Connect."""
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.protobuf import functions
            return getattr(functions, f.__name__)(*args, **kwargs)
        else:
            return f(*args, **kwargs)
    return cast(FuncT, wrapped)
def try_remote_window(f: FuncT) -> FuncT:
    """Mark API supported from Spark Connect."""
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.window import Window
            return getattr(Window, f.__name__)(*args, **kwargs)
        else:
            return f(*args, **kwargs)
    return cast(FuncT, wrapped)
def try_remote_windowspec(f: FuncT) -> FuncT:
    """Mark API supported from Spark Connect."""
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.window import WindowSpec
            return getattr(WindowSpec, f.__name__)(*args, **kwargs)
        else:
            return f(*args, **kwargs)
    return cast(FuncT, wrapped)
def get_active_spark_context() -> "SparkContext":
    """Raise RuntimeError if SparkContext is not initialized,
    otherwise, returns the active SparkContext."""
    from pyspark import SparkContext
    sc = SparkContext._active_spark_context
    if sc is None or sc._jvm is None:
        raise PySparkRuntimeError(
            error_class="SESSION_OR_CONTEXT_NOT_EXISTS",
            message_parameters={},
        )
    return sc
def try_remote_session_classmethod(f: FuncT) -> FuncT:
    """Mark API supported from Spark Connect."""
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.session import SparkSession
            assert inspect.isclass(args[0])
            return getattr(SparkSession, f.__name__)(*args[1:], **kwargs)
        else:
            return f(*args, **kwargs)
    return cast(FuncT, wrapped)
def dispatch_df_method(f: FuncT) -> FuncT:
    """
    For the usecases of direct DataFrame.union(df, ...), it checks if self
    is a Connect DataFrame or Classic DataFrame, and dispatches.
    """
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
            if isinstance(args[0], ConnectDataFrame):
                return getattr(ConnectDataFrame, f.__name__)(*args, **kwargs)
        else:
            from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame
            if isinstance(args[0], ClassicDataFrame):
                return getattr(ClassicDataFrame, f.__name__)(*args, **kwargs)
        raise PySparkNotImplementedError(
            error_class="NOT_IMPLEMENTED",
            message_parameters={"feature": f"DataFrame.{f.__name__}"},
        )
    return cast(FuncT, wrapped)
def dispatch_col_method(f: FuncT) -> FuncT:
    """
    For the usecases of direct Column.method(col, ...), it checks if self
    is a Connect DataFrame or Classic DataFrame, and dispatches.
    """
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.column import Column as ConnectColumn
            if isinstance(args[0], ConnectColumn):
                return getattr(ConnectColumn, f.__name__)(*args, **kwargs)
        else:
            from pyspark.sql.classic.column import Column as ClassicColumn
            if isinstance(args[0], ClassicColumn):
                return getattr(ClassicColumn, f.__name__)(*args, **kwargs)
        raise PySparkNotImplementedError(
            error_class="NOT_IMPLEMENTED",
            message_parameters={"feature": f"DataFrame.{f.__name__}"},
        )
    return cast(FuncT, wrapped)
def pyspark_column_op(
    func_name: str, left: "IndexOpsLike", right: Any, fillna: Any = None
) -> Union["SeriesOrIndex", None]:
    """
    Wrapper function for column_op to get proper Column class.
    """
    from pyspark.pandas.base import column_op
    from pyspark.sql.column import Column as PySparkColumn
    from pyspark.pandas.data_type_ops.base import _is_extension_dtypes
    if is_remote():
        from pyspark.sql.connect.column import Column as ConnectColumn
        Column = ConnectColumn
    else:
        Column = PySparkColumn  # type: ignore[assignment]
    result = column_op(getattr(Column, func_name))(left, right)
    # It works as expected on extension dtype, so we don't need to call `fillna` for this case.
    if (fillna is not None) and (_is_extension_dtypes(left) or _is_extension_dtypes(right)):
        fillna = None
    # TODO(SPARK-43877): Fix behavior difference for compare binary functions.
    return result.fillna(fillna) if fillna is not None else result
def get_column_class() -> Type["Column"]:
    from pyspark.sql.column import Column as PySparkColumn
    if is_remote():
        from pyspark.sql.connect.column import Column as ConnectColumn
        return ConnectColumn
    else:
        return PySparkColumn
def get_dataframe_class() -> Type["DataFrame"]:
    from pyspark.sql.dataframe import DataFrame as PySparkDataFrame
    if is_remote():
        from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
        return ConnectDataFrame
    else:
        return PySparkDataFrame
def get_window_class() -> Type["Window"]:
    from pyspark.sql.window import Window as PySparkWindow
    if is_remote():
        from pyspark.sql.connect.window import Window as ConnectWindow
        return ConnectWindow  # type: ignore[return-value]
    else:
        return PySparkWindow
def get_lit_sql_str(val: str) -> str:
    # Equivalent to `lit(val)._jc.expr().sql()` for string typed val
    # See `sql` definition in `sql/catalyst/src/main/scala/org/apache/spark/
    # sql/catalyst/expressions/literals.scala`
    return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'"