from __future__ import unicode_literals

from six import string_types
import json
import logging


logger = logging.getLogger(__name__)


class InvalidDataFilter(object):
    def __init__(self, target_col, filepath_col):
        self.target_col = target_col
        self.filepath_col = filepath_col

    def filter(self, df):
        """
        filters rows of the dataframe `df`, ensuring that each row has been filtered by `filter_row()`
        throws an error if all rows of the dataframe are invalid
        :type df: pd.DataFrame
        :return: filtered dataframe
        :rtype: pd.DataFrame
        """
        def is_valid_row(row):
            filepath = row[self.filepath_col]
            try:
                if not isinstance(filepath, string_types) or len(filepath.strip()) == 0:
                    logger.warning("wrong filepath: %s, dropping row", filepath)
                    return False
                return self.check_target(row[self.target_col], filepath)
            except:
                logger.exception("an error occurred while checking target %s for %s, dropping row", row, filepath)
                return False
        valid = df.apply(is_valid_row, axis=1)
        filtered_df = df[valid].copy()
        if len(filtered_df) < len(df):
            logger.info("dropped %d rows, keeping %d", len(df) - len(filtered_df), len(filtered_df))
        return filtered_df

    def check_target(self, target_str, filepath):
        """
        check if a target matches the required format, first being a non-empty json
        :return: the filtered target or None if the target is invalid
        :rtype: dict | None
        """
        return NotImplementedError()


class ObjectDetectionInvalidDataFilter(InvalidDataFilter):
    def __init__(self, target_col, filepath_col, categories):
        super(ObjectDetectionInvalidDataFilter, self).__init__(target_col, filepath_col)
        self.categories = categories

    def check_target(self, target_str, filepath):
        target = json.loads(target_str)
        if len(target) == 0:
            logger.warning("target is empty for %s, dropping row", filepath)
            return False
        for t in target:
            valid = self.check_annotation_valid(t, filepath)
            if not valid:
                return False
        return True

    def check_annotation_valid(self, obj, filepath):
        """
        filter all objects in a single row, keeping only object whose attribute are useful for object detection:
        - bbox
        - category
        - iscrowd (optional)
        Other properties are discarded
        :rtype: dict | None
        """
        if "bbox" not in obj:
            logger.warning("target %s has no 'bbox' attribute for %s, dropping row", obj, filepath)
            return False

        bbox = obj["bbox"]
        if len(bbox) != 4 or bbox[0] < 0 or bbox[1] < 0 or bbox[2] <= 0 or bbox[3] <= 0:
            logger.warning("'bbox' in target %s has wrong properties (%s) for %s, dropping row", obj, bbox, filepath)
            return False

        category = obj.get("category")
        if category is None or not isinstance(category, string_types):
            logger.warning("target %s has no or wrong 'category' attribute for %s, dropping row", obj, filepath)
            return False

        if category not in self.categories:
            logger.warning("category '%s' not found in categories %s for %s, dropping row", category, self.categories, filepath)
            return False

        iscrowd = obj.get("iscrowd")
        # will also accept True and False, and that is alright in our context, thanks python
        if iscrowd is not None and iscrowd not in {0, 1}:
            logger.warning("target %s has wrong 'iscrowd' attribute for %s, dropping row", obj, filepath)
            return False
        return True


class ImageClassificationInvalidDataFilter(InvalidDataFilter):
    def __init__(self, target_col, filepath_col, categories):
        super(ImageClassificationInvalidDataFilter, self).__init__(target_col, filepath_col)
        self.categories = categories

    def check_target(self, target, filepath):
        if target is None or (isinstance(target, string_types) and len(target) == 0):
            logger.warning("target is empty or has wrong format for %s, dropping row", filepath)
            return False
        if target not in self.categories:
            logger.warning("category '%s' not found in categories %s for %s, dropping row", target, self.categories, filepath)
            return False
        return True