# MIT License
#
# Copyright (c) 2019-2024 Iván de Paz Centeno
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# pylint: disable=duplicate-code

from mtcnn.network.rnet import RNet

from mtcnn.utils.tensorflow import load_weights
from mtcnn.utils.images import extract_patches
from mtcnn.utils.bboxes import replace_confidence, adjust_bboxes, pick_matches, smart_nms_from_bboxes, resize_to_square

from mtcnn.stages.base import StageBase


class StageRNet(StageBase):
    """
    Stage for running the Refinement Network (RNet) of the MTCNN model. This stage refines the bounding box 
    proposals generated by the PNet stage, adjusts the bounding boxes, and filters them using RNet's output.

    Args:
        stage_name (str): Name of the stage. Defaults to "Stage RNET".
        stage_id (int): Unique identifier for the stage. Defaults to 2.
        weights (str): Path to the weights file to load the model. Defaults to "rnet.lz4".
    """

    def __init__(self, stage_name="Stage RNET", stage_id=2, weights="rnet.lz4"):
        """
        Initializes the StageRNet by loading the RNet model and setting the specified weights.

        Args:
            stage_name (str, optional): The name of the stage. Default is "Stage RNET".
            stage_id (int, optional): The ID for the stage. Default is 2.
            weights (str, optional): The file path to the weights for the RNet model. Default is "rnet.lz4".
        """
        model = RNet()
        model.build()  # Building the RNet model
        model.set_weights(load_weights(weights))  # Load pre-trained weights

        super().__init__(stage_name=stage_name, stage_id=stage_id, model=model)

    def __call__(self, images_normalized, bboxes_batch, threshold_rnet=0.7, nms_rnet=0.7, **kwargs):
        """
        Runs the RNet stage on the input images and bounding boxes, refining the proposals generated by the PNet stage.

        Args:
            images_normalized (tf.Tensor): A tensor of normalized images with shape (batch_size, width, height, 3).
            bboxes_batch (np.ndarray): An array of bounding boxes produced by the PNet stage, each row representing 
                                       [image_id, x1, y1, x2, y2, confidence, landmark_x1, landmark_y1, ...].
            threshold_rnet (float, optional): The confidence threshold for keeping bounding boxes after RNet refinement. Default is 0.7.
            nms_rnet (float, optional): The IoU threshold for Non-Maximum Suppression after RNet refinement. Default is 0.7.
            **kwargs: Additional arguments passed to the function.

        Returns:
            np.ndarray: A numpy array of refined bounding boxes after RNet processing, ready for the next stage.
        """
        # 1. Extract patches for each bounding box from the normalized images.
        # These patches are resized to the expected input size for RNet (24x24).
        patches = extract_patches(images_normalized, bboxes_batch, expected_size=(24, 24))

        # 2. Pass the extracted patches through RNet to get bounding box offsets and confidence scores.
        bboxes_offsets, scores = self._model(patches)

        # 3. Replace the confidence of the bounding boxes with the ones provided by RNet.
        bboxes_batch = replace_confidence(bboxes_batch, scores)

        # 4. Adjust the bounding boxes using the offsets predicted by RNet (refinement of the proposals).
        bboxes_batch = adjust_bboxes(bboxes_batch, bboxes_offsets)

        # 5. Filter out bounding boxes based on the new confidence scores and the threshold set for RNet.
        bboxes_batch = pick_matches(bboxes_batch, score_threshold=threshold_rnet)

        # 6. Apply Non-Maximum Suppression (NMS) to remove overlapping boxes based on the refined boxes and scores.
        bboxes_batch = smart_nms_from_bboxes(bboxes_batch, threshold=nms_rnet, method="union", initial_sort=True)

        # 7. Resize bounding boxes to a square format to prepare them for the next stage.
        bboxes_batch = resize_to_square(bboxes_batch)

        return bboxes_batch
