import numpy as np
from scipy import stats


def _moment(a, mean, order, weights=None):
    s = np.power(a - mean, order)
    return np.average(s, weights=weights)

# Reimplementation from scipy.stats._stats_py to add support for weights
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/stats/_stats_py.py#L1223
def skew(a, weights=None):
    r"""Compute the sample skewness of a data set.

    For normally distributed data, the skewness should be about zero. For
    unimodal continuous distributions, a skewness value greater than zero means
    that there is more weight in the right tail of the distribution.

    Parameters
    ----------
    a : ndarray
        Input array.
    weights : array_like, optional
        The importance that each element has in the computation of the skewness.
        If ``weights=None``, then all data in `a` are assumed to have a
        weight equal to one.

    Returns
    -------
    skewness : float
        The skewness of values along an axis, returning NaN where all values
        are equal.

    """
    mean = np.average(a, axis=0, weights=weights)

    m2 = _moment(a, mean, 2, weights=weights)
    m3 = _moment(a, mean, 3, weights=weights)

    if m2 == 0:
        return 0
    else:
        return m3 / m2 ** 1.5


# Reimplementation from scipy.stats._stats_py to add support for weights
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/stats/_stats_py.py#L1325
def kurtosis(a, weights=None):
    """Compute the kurtosis (Fisher or Pearson) of a dataset.

    Kurtosis is the fourth central moment divided by the square of the
    variance. Fisher's definition is used and 3.0 is subtracted from
    the result to give 0.0 for a normal distribution.

    Parameters
    ----------
    a : array
        Data for which the kurtosis is calculated.
    weights : array_like, optional
        The importance that each element has in the computation of the skewness.
        If ``weights=None``, then all data in `a` are assumed to have a
        weight equal to one.

    Returns
    -------
    kurtosis : float
        The kurtosis of values along an axis, returning NaN where all values
        are equal.

    """
    mean = np.average(a, axis=0, weights=weights)

    m2 = _moment(a, mean, 2, weights=weights)
    m4 = _moment(a, mean, 4, weights=weights)

    if m2 == 0:
        return 0
    else:
        return (m4 / m2 ** 2.0) - 3


# Reimplementation from statsmodels.stats.stattools to add support for weights
# https://github.com/statsmodels/statsmodels/blob/v0.14.6/statsmodels/stats/stattools.py#L81
def jarque_bera(resids, weights=None):
    r"""
    The Jarque-Bera test of normality.

    Parameters
    ----------
    resids : array_like
        Data to test for normality. Usually regression model residuals that
        are mean 0.
    weights: int array, optional
             Weight for each individual data point, when not provided use same weight 1 for all data points.
             Default is None.

    Returns
    -------
    JB : {float, ndarray}
        The Jarque-Bera test statistic.
    JBpv : {float, ndarray}
        The pvalue of the test statistic.
    skew : {float, ndarray}
        Estimated skewness of the data.
    kurtosis : {float, ndarray}
        Estimated kurtosis of the data.

    """
    resids = np.atleast_1d(np.asarray(resids, dtype=float))
    if resids.size < 2:
        raise ValueError("resids must contain at least 2 elements")

    # Calculate residual skewness and kurtosis
    skew_value = skew(resids, weights=weights)
    kurtosis_value = kurtosis(resids, weights=weights)

    # Calculate the Jarque-Bera test for normality
    sample_size = resids.shape[0]
    jb = (sample_size / 6.) * (skew_value ** 2 + (1 / 4.) * kurtosis_value ** 2)
    jb_pv = stats.chi2.sf(jb, 2)

    return jb, jb_pv, skew_value, (kurtosis_value + 3)
