a
    +ca7                    @   sf  d Z ddlZddlZddlZddlZddlZddlZddlmZ ddl	Z	ddl
ZddlZddlZddlmZ ddlmZ ddlmZ ddlmZ ddlZddlmZ eejedksJ eejedksJ d_ddZG d	d
 d
ej Z!dd Z"d`ddZ#daddZ$dbddZ%dd Z&dd Z'G dd dej(Z)dd Z*G dd dej(Z+d d! Z,d"d# Z-G d$d% d%ej(Z.d&d' Z/G d(d) d)ej(Z0d*d+ Z1d,d- Z2dcd/d0Z3ddd1d2Z4d3d4 Z5d5d6 Z6d7d8 Z7d9d: Z8d;d< Z9d=d> Z:ded?d@Z;dAdB Z<dCdD Z=dEdF Z>dfdHdIZ?G dJdK dKZ@dLdM ZAdNdO ZBdPdQ ZCdRdS ZDdTdU ZEdgdWdXZFdYdZ ZGd[d\ ZHd]d^ ZIdS )haH  
Mask R-CNN
The main Mask R-CNN model implementation.

Copyright (c) 2017 Matterport, Inc.
Licensed under the MIT License (see LICENSE for details)
Written by Waleed Abdulla
Modified by Ondrej Pesek
    Written error handling in load_image_gt and data_generator
    Changed the verbosity to correspond GRASS verbosity in detect
    N)OrderedDict)LooseVersionz1.3z2.0.8c                 C   sr   |durf|  d} | dt|j7 } |jrF| d| | 7 } n| ddd7 } | d|j7 } t|  dS )zxPrints a text message. And, optionally, if a Numpy array is provided it
    prints it's shape, min, and max values.
    N   zshape: {:20}  zmin: {:10.5f}  max: {:10.5f}zmin: {:10}  max: {:10} z  {})	ljustformatstrshapesizeminmaxdtypeprint)textarray r   UC:/Users/landamar/grass_packager/grass786/addons/i.ann.maskrcnn/etc/maskrcnn/model.pylog+   s    
r   c                       s"   e Zd ZdZd fdd	Z  ZS )	BatchNorma  Extends the Keras BatchNormalization class to allow a central place
    to make changes if needed.

    Batch normalization has a negative effect on training if batches are small
    so this layer is often frozen (via setting in Config class) and functions
    as linear layer.
    Nc                    s   t | j| j||dS )a  
        Note about training values:
            None: Train BN layers. This is the normal mode
            False: Freeze BN layers. Good when batch size is small
            True: (don't use). Set layer in training mode even when making inferences
        training)super	__class__call)selfinputsr   r   r   r   r   C   s    zBatchNorm.call)N)__name__
__module____qualname____doc__r   __classcell__r   r   r   r   r   :   s   r   c                    s<   t | jr|  S | jdv s"J t fdd| jD S )zComputes the width and height of each stage of the backbone network.

    Returns:
        [N, (height, width)]. Where N is the number of stages
    Zresnet50Z	resnet101c                    s8   g | ]0}t t d  | t t d | gqS r      )intmathZceil).0Zstrideimage_shaper   r   
<listcomp>Y   s   z+compute_backbone_shapes.<locals>.<listcomp>)callableBACKBONEZCOMPUTE_BACKBONE_SHAPEnpr   BACKBONE_STRIDES)configr)   r   r(   r   compute_backbone_shapesM   s    


r0   Tc                 C   s  |\}}}	dt | | d }
dt | | d }tj|d|
d |d| }t|d d||d}td	|}tj|||fd
|
d |d|}t|d d||d}td	|}tj|	d|
d |d|}t|d d||d}t || g}tjd	dt | | d d|}|S )a9  The identity_block is the block that has no conv layer at shortcut
    # Arguments
        input_tensor: input tensor
        kernel_size: default 3, the kernel size of middle conv layer at main path
        filters: list of integers, the nb_filters of 3 conv layer at main path
        stage: integer, current stage label, used for generating layer names
        block: 'a','b'..., current block label, used for generating layer names
        use_bias: Boolean. To use or not use a bias in conv layers.
        train_bn: Boolean. Train or freeze Batch Norm layers
    res_branchbnr$   r$   2anameuse_biasr7   r   relusame2bpaddingr7   r8   2c_outr   KLConv2Dr   
ActivationAdd)input_tensorkernel_sizefiltersstageblockr8   train_bn
nb_filter1
nb_filter2
nb_filter3conv_name_basebn_name_basexr   r   r   identity_blockk   s0    
"rR      rT   c                 C   sN  |\}}	}
dt | | d }dt | | d }tj|d||d |d| }t|d d||d}td	|}tj|	||fd
|d |d|}t|d d||d}td	|}tj|
d|d |d|}t|d d||d}tj|
d||d |d| }t|d d||d}t ||g}tjd	dt | | d d|}|S )a  conv_block is the block that has a conv layer at shortcut
    # Arguments
        input_tensor: input tensor
        kernel_size: default 3, the kernel size of middle conv layer at main path
        filters: list of integers, the nb_filters of 3 conv layer at main path
        stage: integer, current stage label, used for generating layer names
        block: 'a','b'..., current block label, used for generating layer names
        use_bias: Boolean. To use or not use a bias in conv layers.
        train_bn: Boolean. Train or freeze Batch Norm layers
    Note that from stage 3, the first conv layer at main path is with subsample=(2,2)
    And the shortcut should have subsample=(2,2) as well
    r1   r2   r3   r4   r5   stridesr7   r8   r9   r   r:   r;   r<   r=   r?   r6   1r@   rA   )rF   rG   rH   rI   rJ   rV   r8   rK   rL   rM   rN   rO   rP   rQ   Zshortcutr   r   r   
conv_block   sP    
"rX   Fc              	   C   s  |dv sJ t d| }t jdddddd|}td	d
||d}t d|}t jdddd| }}t|dg dddd|d}t|dg ddd|d}t|dg ddd|d }}t|dg ddd|d}t|dg ddd|d}t|dg ddd|d}t|dg ddd|d }}t|dg ddd|d}ddd| }t|D ]&}	t|dg ddt	d|	 |d}qJ|}
|rt|dg ddd|d}t|dg ddd|d}t|dg ddd|d }}nd }||||
|gS )!zBuild a ResNet graph.
    architecture: Can be resnet50 or resnet101
    stage5: Boolean. If False, stage5 of the network is not created
    train_bn: Boolean. Train or freeze Batch Norm layers
    r"      rZ   @   )   r\   rS   Zconv1TrU   Zbn_conv1r9   r   r:   r;   )rV   r>   rZ   )r[   r[      rT   ar4   )rI   rJ   rV   rK   b)rI   rJ   rK   c)   ra      d)r]   r]               b   )rb   rb   i   N)
rB   ZZeroPadding2DrC   r   rD   MaxPooling2DrX   rR   rangechr)input_imageZarchitecturestage5rK   rQ   ZC1C2C3Zblock_countiC4C5r   r   r   resnet_graph   sD    

rs   c                 C   s  | dddf | dddf  }| dddf | dddf  }| dddf d|  }| dddf d|  }||dddf | 7 }||dddf | 7 }|t |dddf 9 }|t |dddf 9 }|d|  }|d|  }|| }|| }	t j||||	gddd}
|
S )	zApplies the given deltas to the given boxes.
    boxes: [N, (y1, x1, y2, x2)] boxes to update
    deltas: [N, (dy, dx, log(dh), log(dw))] refinements to apply
    NrT   r   rZ   r$         ?Zapply_box_deltas_outaxisr7   )tfZexpstack)boxesdeltasZheightwidthZcenter_yZcenter_xy1x1y2x2resultr   r   r   apply_box_deltas_graph  s      r   c                 C   s   t |d\}}}}t j| ddd\}}}}	t t |||}t t |||}t t |||}t t |	||}	t j||||	gddd}
|
|
jd df |
S )zQ
    boxes: [N, (y1, x1, y2, x2)]
    window: [4] in the form y1, x1, y2, x2
    re   r$   rv   Zclipped_boxesru   r   )rw   splitmaximumminimumconcat	set_shaper	   )ry   windowwy1wx1wy2wx2r|   r}   r~   r   Zclippedr   r   r   clip_boxes_graph   s    r   c                       s2   e Zd ZdZd	 fdd	Zdd Zdd Z  ZS )
ProposalLayera6  Receives anchor scores and selects a subset to pass as proposals
    to the second stage. Filtering is done based on anchor scores and
    non-max suppression to remove overlaps. It also applies bounding
    box refinement deltas to anchors.

    Inputs:
        rpn_probs: [batch, num_anchors, (bg prob, fg prob)]
        rpn_bbox: [batch, num_anchors, (dy, dx, log(dh), log(dw))]
        anchors: [batch, num_anchors, (y1, x1, y2, x2)] anchors in normalized coordinates

    Returns:
        Proposals in normalized coordinates [batch, rois, (y1, x1, y2, x2)]
    Nc                    s,   t t| jf i | || _|| _|| _d S N)r   r   __init__r/   proposal_countnms_threshold)r   r   r   r/   kwargsr   r   r   r   A  s    zProposalLayer.__init__c                    s@  |d d d d d df }|d }|t  jjg d }|d }t jjt|d }tjj	||dddj
}t||gdd	  jj}t||gd
d	  jj}tj||gdd	  jjdgd}tj||gdd	  jjdgd}t jg dt jdtj|fdd	 jjdgd} fdd}	t||g|	 jj}
|
S )Nr   r$   )r$   r$   re   rT   TZtop_anchors)sortedr7   c                 S   s   t | |S r   rw   gatherrQ   yr   r   r   <lambda>U      z$ProposalLayer.call.<locals>.<lambda>c                 S   s   t | |S r   r   r   r   r   r   r   X  r   c                 S   s   t | |S r   r   )r^   rQ   r   r   r   r   \  r   pre_nms_anchorsnamesc                 S   s
   t | |S r   )r   r   r   r   r   r   e  r   Zrefined_anchors)r   r   r$   r$   r   c                    s
   t |  S r   )r   rQ   )r   r   r   r   o  r   Zrefined_anchors_clippedc                    sZ   t jj| | j jdd}t | |}t  jt |d  d}t |d|fdg}|S )NZrpn_non_max_suppressionr9   r   r   r   )	rw   imagenon_max_suppressionr   r   r   r   r	   pad)ry   scoresindices	proposalsr>   r   r   r   nmsy  s    zProposalLayer.call.<locals>.nms)r-   reshaper/   RPN_BBOX_STD_DEVrw   r   ZPRE_NMS_LIMITr	   nntop_kr   utilsbatch_sliceIMAGES_PER_GPUr   float32)r   r   r   rz   anchorsZpre_nms_limitixr   ry   r   r   r   )r   r   r   r   G  sD    		
zProposalLayer.callc                 C   s   d | j dfS )Nre   )r   r   Zinput_shaper   r   r   compute_output_shape  s    z"ProposalLayer.compute_output_shape)Nr   r   r   r    r   r   r   r!   r   r   r   r   r   2  s   Cr   c                 C   s   t | t d S )z@Implementation of Log2. TF doesn't have a native implementation.g       @)rw   r   r   r   r   r   
log2_graph  s    r   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )PyramidROIAligna  Implements ROI Pooling on multiple levels of the feature pyramid.

    Params:
    - pool_shape: [pool_height, pool_width] of the output pooled regions. Usually [7, 7]

    Inputs:
    - boxes: [batch, num_boxes, (y1, x1, y2, x2)] in normalized
             coordinates. Possibly padded with zeros if not enough
             boxes to fill the array.
    - image_meta: [batch, (meta data)] Image details. See compose_image_meta()
    - feature_maps: List of feature maps from different levels of the pyramid.
                    Each is [batch, height, width, channels]

    Output:
    Pooled regions in the shape: [batch, num_boxes, pool_height, pool_width, channels].
    The width and height are those specific in the pool_shape in the layer
    constructor.
    c                    s$   t t| jf i | t|| _d S r   )r   r   r   tuple
pool_shape)r   r   r   r   r   r   r     s    zPyramidROIAlign.__init__c                 C   sZ  |d }|d }|dd  }t j|ddd\}}}}|| }	|| }
t|d d }t |d |d  t j}tt |	|
 dt |  }t dt ddt t 	|t j
 }t |d}g }g }ttdd	D ]\}}t t ||}t ||}t |d d df t j
}|| t |}t |}|t jj|| ||| jd
d qt j|dd}t j|dd}t t t |d d}t jt |t j
|gdd}|d d df d |d d df  }t jj|t |d djd d d }t |d d df |}t ||}t jt |d d t |dd  gdd}t ||}|S )Nr   r$   rT   re   r   r)   g      l@rf      Zbilinear)methodi )k)rw   r   parse_image_meta_graphcastr   r   Zsqrtr   r   roundint32squeeze	enumeraterj   whereequal	gather_ndappendZstop_gradientr   crop_and_resizer   r   expand_dimsr	   r   r   r   r   r   )r   r   ry   
image_metafeature_mapsr|   r}   r~   r   hwr)   Z
image_areaZ	roi_levelZpooledZbox_to_levelrp   levelr   Zlevel_boxesZbox_indicesZ	box_rangeZsorting_tensorr	   r   r   r   r     sT      


$&.zPyramidROIAlign.callc                 C   s$   |d d d | j  |d d f S )Nr   rT   r   )r   r   r   r   r   r     s    z$PyramidROIAlign.compute_output_shaper   r   r   r   r   r     s   Vr   c              
   C   s&  t t t | dddt |d gddg}t |t | d dg}t j|ddd\}}}}t j|ddd\}}	}
}t ||}t ||	}t ||
}t ||}t || dt || d }|| ||  }|
| ||	  }|| | }|| }t |t | d t |d g}|S )z`Computes IoU overlaps between two sets of boxes.
    boxes1, boxes2: [N, (y1, x1, y2, x2)].
    r$   r   r   re   r   )rw   r   Ztiler   r	   r   r   r   )Zboxes1Zboxes2Zb1Zb2Zb1_y1Zb1_x1Zb1_y2Zb1_x2Zb2_y1Zb2_x1Zb2_y2Zb2_x2r|   r}   r~   r   intersectionZb1_areaZb2_areaunionZiouoverlapsr   r   r   overlaps_graph  s"    ( $r   c           .         s6  t jt t | d d| gddg}t | t | } W d   n1 sP0    Y  t| dd\} }t|dd\}}t j||dd}t j|t 	|dddf dd	d
}t 	|dk dddf }t 	|dkdddf }	t ||}
t ||	}t ||	}t j||	dd}t
| |}t
| |
}t j|dd}|dk }t j|dd}|dk}t 	|dddf }t 	t |dk |dddf }t|j|j }t |d| }t |d }d|j }t |t |t j t j| }t |d| }t | |}t | |}t || t jt t  d d fdddd d}t ||}t ||}t||}||j }t t |g dd}t ||}|}|jr.t j|ddd\}} }!}"t j|ddd\}#}$}%}&|%|# }'|&|$ }(||# |' }| |$ |( } |!|# |' }!|"|$ |( }"t || |!|"gd}t dt |d })t jt |t j||)|j }*t j!|*dd}*t "|*}*t j||gdd}+t |d },t #|jt |+d  d}-t $|+d|-fdg}+t $|d|,|- fdg}t $|d|,|- fg}t $|d|,|- fdg}t $|*d|,|- gddg}*|+|||*fS )a  Generates detection targets for one image. Subsamples proposals and
    generates target class IDs, bounding box deltas, and masks for each.

    Inputs:
    proposals: [POST_NMS_ROIS_TRAINING, (y1, x1, y2, x2)] in normalized coordinates. Might
               be zero padded if there are not enough proposals.
    gt_class_ids: [MAX_GT_INSTANCES] int class IDs
    gt_boxes: [MAX_GT_INSTANCES, (y1, x1, y2, x2)] in normalized coordinates.
    gt_masks: [height, width, MAX_GT_INSTANCES] of boolean type.

    Returns: Target ROIs and corresponding class IDs, bounding box shifts,
    and masks.
    rois: [TRAIN_ROIS_PER_IMAGE, (y1, x1, y2, x2)] in normalized coordinates
    class_ids: [TRAIN_ROIS_PER_IMAGE]. Integer class IDs. Zero padded.
    deltas: [TRAIN_ROIS_PER_IMAGE, (dy, dx, log(dh), log(dw))]
    masks: [TRAIN_ROIS_PER_IMAGE, height, width]. Masks cropped to bbox
           boundaries and resized to neural network output size.

    Note: Returned arrays might be zero padded if not enough target ROIs.
    r   Zroi_assertionr9   NZtrim_proposalsZtrim_gt_boxesZtrim_gt_class_idsrT   Ztrim_gt_masksru   r   r$   MbP?rt         ?c                      s   t j ddS )Nr$   r   )rw   argmaxr   Zpositive_overlapsr   r   r     r   z)detection_targets_graph.<locals>.<lambda>c                   S   s   t t g t jS r   )rw   r   constantint64r   r   r   r   r     r   )Ztrue_fnZfalse_fn)rT   r   r$   r   re   rZ   r   )%rw   ZAssertZgreaterr	   Zcontrol_dependenciesZidentitytrim_zeros_graphboolean_maskr   r   r   Z
reduce_maxZlogical_andr%   TRAIN_ROIS_PER_IMAGEROI_POSITIVE_RATIOZrandom_shuffler   r   r   Zcondr   Zbox_refinement_graphBBOX_STD_DEVr   	transposeUSE_MINI_MASKr   r   rj   r   r   
MASK_SHAPEr   r   r   r   ).r   gt_class_idsgt_boxesgt_masksr/   Zasserts_	non_zeroscrowd_ixnon_crowd_ixcrowd_boxesr   crowd_overlapscrowd_iou_maxno_crowd_boolZroi_iou_maxZpositive_roi_boolZpositive_indicesZnegative_indicesZpositive_countrZnegative_countZpositive_roisZnegative_roisZroi_gt_box_assignmentroi_gt_boxesroi_gt_class_idsrz   Ztransposed_masksZ	roi_masksry   r|   r}   r~   r   gt_y1gt_x1gt_y2gt_x2gt_hgt_wZbox_idsmasksroisNPr   r   r   detection_targets_graph-  s    (

"



r   c                       s:   e Zd ZdZ fddZdd Zdd Zdd	d
Z  ZS )DetectionTargetLayera)  Subsamples proposals and generates target box refinement, class_ids,
    and masks for each.

    Inputs:
    proposals: [batch, N, (y1, x1, y2, x2)] in normalized coordinates. Might
               be zero padded if there are not enough proposals.
    gt_class_ids: [batch, MAX_GT_INSTANCES] Integer class IDs.
    gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)] in normalized
              coordinates.
    gt_masks: [batch, height, width, MAX_GT_INSTANCES] of boolean type

    Returns: Target ROIs and corresponding class IDs, bounding box shifts,
    and masks.
    rois: [batch, TRAIN_ROIS_PER_IMAGE, (y1, x1, y2, x2)] in normalized
          coordinates
    target_class_ids: [batch, TRAIN_ROIS_PER_IMAGE]. Integer class IDs.
    target_deltas: [batch, TRAIN_ROIS_PER_IMAGE, (dy, dx, log(dh), log(dw)]
    target_mask: [batch, TRAIN_ROIS_PER_IMAGE, height, width]
                 Masks cropped to bbox boundaries and resized to neural
                 network output size.

    Note: Returned arrays might be zero padded if not enough target ROIs.
    c                    s    t t| jf i | || _d S r   )r   r   r   r/   r   r/   r   r   r   r   r     s    zDetectionTargetLayer.__init__c                    sR   |d }|d }|d }|d }g d}t j||||g fdd jj|d}|S )	Nr   r$   rT   rZ   )r   target_class_idstarget_bboxtarget_maskc                    s   t | ||| jS r   )r   r/   )r   rQ   r   zr   r   r   r     r   z+DetectionTargetLayer.call.<locals>.<lambda>r   )r   r   r/   r   )r   r   r   r   r   r   r   outputsr   r   r   r     s    

zDetectionTargetLayer.callc                 C   sD   d | j jdfd | j jfd | j jdfd | j j| j jd | j jd fgS )Nre   r   r$   )r/   r   r   r   r   r   r   r     s    


z)DetectionTargetLayer.compute_output_shapeNc                 C   s   g dS )N)NNNNr   )r   r   maskr   r   r   compute_mask  s    z!DetectionTargetLayer.compute_mask)N)	r   r   r   r    r   r   r   r  r!   r   r   r   r   r     s
   r   c                    s<  t j|dt jd}t jt |jd |gdd}t ||}t ||}t| | j }	t	|	|}	t 
|dkdddf  jrt 
| jkdddf }
t jt dt |
dt d t |t |t |	t d } fdd}t j||t jd}t |d	g}t |t 
|d	kdddf }t jt dt |dt d  j}t |}t t |d |}t jj||d
dd }t |t jt |	t t |dt jf t |dt jf gdd} jt |d  }t |d|fdgd}|S )aS  Refine classified proposals and filter overlaps and return final
    detections.

    Inputs:
        rois: [N, (y1, x1, y2, x2)] in normalized coordinates
        probs: [N, num_classes]. Class probabilities.
        deltas: [N, num_classes, (dy, dx, log(dh), log(dw))]. Class-specific
                bounding box deltas.
        window: (y1, x1, y2, x2) in normalized coordinates. The part of the image
            that contains the image excluding the padding.

    Returns detections shaped: [num_detections, (y1, x1, y2, x2, class_id, score)] where
        coordinates are normalized.
    r$   )rv   Zoutput_typer   r   Nc                    s   t t | dddf }t jjt |t | j jd}t t ||} jt |d  }t j	|d|fgddd}|
 jg |S )z9Apply Non-Maximum Suppression on ROIs of the given class.Nr   )Zmax_output_sizeZiou_thresholdCONSTANTr   )modeZconstant_values)rw   r   r   r   r   r   DETECTION_MAX_INSTANCESZDETECTION_NMS_THRESHOLDr	   r   r   )class_idZixsZ
class_keepgapr/   keepZpre_nms_class_idsZpre_nms_roisZpre_nms_scoresr   r   nms_keep_map*  s    

z-refine_detections_graph.<locals>.nms_keep_mapr   r   T)r   r   .r   r  )rw   r   r   rx   rj   r	   r   r   r   r   r   ZDETECTION_MIN_CONFIDENCEZsetsZset_intersectionr   Zsparse_tensor_to_denser   uniqueZmap_fnr   r   r  r   r   r   r   Zto_floatnewaxisr   )r   Zprobsrz   r   r/   	class_idsr   Zclass_scoresZdeltas_specificZrefined_roisZ	conf_keepZunique_pre_nms_class_idsr
  Znms_keepZ	roi_countZclass_scores_keepZnum_keepZtop_ids
detectionsr  r   r  r   refine_detections_graph  sP    
"

r  c                       s2   e Zd ZdZd	 fdd	Zdd Zdd Z  ZS )
DetectionLayerzTakes classified proposal boxes and their bounding box deltas and
    returns the final detection boxes.

    Returns:
    [batch, num_detections, (y1, x1, y2, x2, class_id, class_score)] where
    coordinates are normalized.
    Nc                    s    t t| jf i | || _d S r   )r   r  r   r/   r   r   r   r   r   i  s    zDetectionLayer.__init__c           
         s   |d }|d }|d }|d }t |}|d d }t|d |d d }t||||g fdd jj}	t|	 jj jj	d	gS )
Nr   r$   rT   rZ   r)   r   c                    s   t | ||| jS r   )r  r/   )rQ   r   r   r   r   r   r   r   ~  r   z%DetectionLayer.call.<locals>.<lambda>r   )
r   norm_boxes_graphr   r   r/   r   rw   r   
BATCH_SIZEr  )
r   r   r   mrcnn_class
mrcnn_bboxr   mr)   r   Zdetections_batchr   r   r   r   m  s     

	zDetectionLayer.callc                 C   s   d | j jdfS )Nr   )r/   r  r   r   r   r   r     s    z#DetectionLayer.compute_output_shape)Nr   r   r   r   r   r  `  s   r  c                 C   s   t jdddd|dd| }t jd| dd	d
dd|}t dd |}t jddd|}t j|d dd	d
dd|}t dd |}|||gS )a  Builds the computation graph of Region Proposal Network.

    feature_map: backbone features [batch, height, width, depth]
    anchors_per_location: number of anchors per pixel in the feature map
    anchor_stride: Controls the density of anchors. Typically 1 (anchors for
                   every pixel in the feature map), or 2 (every other pixel).

    Returns:
        rpn_class_logits: [batch, H * W * anchors_per_location, 2] Anchor classifier logits (before softmax)
        rpn_probs: [batch, H * W * anchors_per_location, 2] Anchor classifier probabilities.
        rpn_bbox: [batch, H * W * anchors_per_location, (dy, dx, log(dh), log(dw))] Deltas to be
                  applied to anchors.
    rb   rY   r;   r:   Zrpn_conv_shared)r>   
activationrV   r7   rT   r4   validlinearZrpn_class_raw)r>   r  r7   c                 S   s   t | t | d ddgS )Nr   r   rT   rw   r   r	   tr   r   r   r     r   zrpn_graph.<locals>.<lambda>softmaxZrpn_class_xxxr9   re   Zrpn_bbox_predc                 S   s   t | t | d ddgS )Nr   r   re   r  r  r   r   r   r     r   )rB   rC   LambdarD   )Zfeature_mapanchors_per_locationanchor_stridesharedrQ   rpn_class_logitsZ	rpn_probsrpn_bboxr   r   r   	rpn_graph  s@    
		r#  c                 C   s2   t jdd|gdd}t||| }tj|g|ddS )a,  Builds a Keras model of the Region Proposal Network.
    It wraps the RPN graph so it can be used multiple times with shared
    weights.

    anchors_per_location: number of anchors per pixel in the feature map
    anchor_stride: Controls the density of anchors. Typically 1 (anchors for
                   every pixel in the feature map), or 2 (every other pixel).
    depth: Depth of the backbone feature map.

    Returns a Keras Model object. The model outputs, when called, are:
    rpn_class_logits: [batch, H * W * anchors_per_location, 2] Anchor classifier logits (before softmax)
    rpn_probs: [batch, H * W * anchors_per_location, 2] Anchor classifier probabilities.
    rpn_bbox: [batch, H * W * anchors_per_location, (dy, dx, log(dh), log(dw))] Deltas to be
                applied to anchors.
    NZinput_rpn_feature_mapr	   r7   Z	rpn_modelr9   )rB   Inputr#  KMModel)r  r  ZdepthZinput_feature_mapr   r   r   r   build_rpn_model  s
    
r(  rd   c                 C   s:  t ||gdd| |g| }tjtj|||fdddd|}tjt dd||d}td|}tjt|d	d
d|}tjt dd||d}td|}tjdd dd|}tjt|dd|}	tjtddd|	}
tjtj|d dddd|}t	|}tj
|d |dfdd|}|	|
|fS )a  Builds the computation graph of the feature pyramid network classifier
    and regressor heads.

    rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
          coordinates.
    feature_maps: List of feature maps from different layers of the pyramid,
                  [P2, P3, P4, P5]. Each has a different resolution.
    image_meta: [batch, (meta data)] Image details. See compose_image_meta()
    pool_size: The width of the square feature map generated from ROI Pooling.
    num_classes: number of classes, which determines the depth of the results
    train_bn: Boolean. Train or freeze Batch Norm layers
    fc_layers_size: Size of the 2 FC layers

    Returns:
        logits: [batch, num_rois, NUM_CLASSES] classifier logits (before softmax)
        probs: [batch, num_rois, NUM_CLASSES] classifier probabilities
        bbox_deltas: [batch, num_rois, NUM_CLASSES, (dy, dx, log(dh), log(dw))] Deltas to apply to
                     proposal boxes
    Zroi_align_classifierr9   r  r>   Zmrcnn_class_conv1Zmrcnn_class_bn1r   r:   r4   Zmrcnn_class_conv2Zmrcnn_class_bn2c                 S   s   t t | ddS )NrZ   rT   )Kr   r   r   r   r   r     r   z&fpn_classifier_graph.<locals>.<lambda>Zpool_squeezemrcnn_class_logitsr  r  re   r  )r  Zmrcnn_bbox_fcr$   r  )r   rB   TimeDistributedrC   r   rD   r  ZDenser*  	int_shapeZReshape)r   r   r   	pool_sizenum_classesrK   fc_layers_sizerQ   r   r+  Zmrcnn_probssr  r   r   r   fpn_classifier_graph  sB    


r2  c                 C   sp  t ||gdd| |g| }tjtjdddddd|}tjt dd||d	}td
|}tjtjdddddd|}tjt dd||d	}td
|}tjtjdddddd|}tjt dd||d	}td
|}tjtjdddddd|}tjt dd||d	}td
|}tjtjdddd
ddd|}tjtj|dddddd|}|S )a  Builds the computation graph of the mask head of Feature Pyramid Network.

    rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
          coordinates.
    feature_maps: List of feature maps from different layers of the pyramid,
                  [P2, P3, P4, P5]. Each has a different resolution.
    image_meta: [batch, (meta data)] Image details. See compose_image_meta()
    pool_size: The width of the square feature map generated from ROI Pooling.
    num_classes: number of classes, which determines the depth of the results
    train_bn: Boolean. Train or freeze Batch Norm layers

    Returns: Masks [batch, num_rois, MASK_POOL_SIZE, MASK_POOL_SIZE, NUM_CLASSES]
    Zroi_align_maskr9   r]   rY   r;   r)  Zmrcnn_mask_conv1Zmrcnn_mask_bn1r   r:   Zmrcnn_mask_conv2Zmrcnn_mask_bn2Zmrcnn_mask_conv3Zmrcnn_mask_bn3Zmrcnn_mask_conv4Zmrcnn_mask_bn4rS   rT   )rV   r  Zmrcnn_mask_deconvr4   r$   Zsigmoid
mrcnn_mask)r   rB   r,  rC   r   rD   ZConv2DTranspose)r   r   r   r.  r/  rK   rQ   r   r   r   build_fpn_mask_graph,  sX    
r4  c                 C   sF   t | | }t t |dd}|d |d  d| |d   }|S )zdImplements Smooth-L1 loss.
    y_true and y_pred are typically: [N, 4], but could be any shape.
    r   r   rt   rT   r$   )r*  absr   Zless)y_truey_predZdiffZless_than_onelossr   r   r   smooth_l1_lossk  s     r9  c                 C   s   t | d} tt| dt j}t t| d}t ||}t ||}tj	||dd}t
t |dkt|t d}|S )zRPN anchor classifier loss.

    rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive,
               -1=negative, 0=neutral anchor.
    rpn_class_logits: [batch, anchors, 2]. RPN classifier logits for BG/FG.
    r   r$   r   T)targetoutputZfrom_logits        )rw   r   r*  r   r   r   r   Z	not_equalr   Zsparse_categorical_crossentropyswitchr
   meanr   )	rpn_matchr!  Zanchor_classr   r8  r   r   r   rpn_class_loss_graphu  s    $r@  c                 C   s   t |d}tt |d}t||}t jt t |dtjdd}t	||| j
}t||}t t|dkt |td}|S )a  Return the RPN bounding box loss graph.

    config: the model config object.
    target_bbox: [batch, max positive anchors, (dy, dx, log(dh), log(dw))].
        Uses 0 padding to fill in unsed bbox deltas.
    rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive,
               -1=negative, 0=neutral anchor.
    rpn_bbox: [batch, anchors, (dy, dx, log(dh), log(dw))]
    r   r$   r   r   r<  )r*  r   rw   r   r   r   sumr   r   batch_pack_graphr   r9  r=  r
   r>  r   )r/   r   r?  r"  r   Zbatch_countsr8  r   r   r   rpn_bbox_loss_graph  s     
$rC  c                 C   sZ   t | d} t j|dd}t |d |}t jj| |d}|| }t |t | }|S )a  Loss for the classifier head of Mask RCNN.

    target_class_ids: [batch, num_rois]. Integer class IDs. Uses zero
        padding to fill in the array.
    pred_class_logits: [batch, num_rois, num_classes]
    active_class_ids: [batch, num_classes]. Has a value of 1 for
        classes that are in the dataset of the image, and 0
        for classes that are not in the dataset.
    r   rT   r   r   )labelsZlogits)rw   r   r   r   r   Z(sparse_softmax_cross_entropy_with_logits
reduce_sum)r   Zpred_class_logitsactive_class_idsZpred_class_idsZpred_activer8  r   r   r   mrcnn_class_loss_graph  s    rG  c                 C   s   t |d}t | d} t |dt |d df}t|dkdddf }tt||tj}tj||gdd	}t| |} t	||}t 
t| dkt| |d
td}t |}|S )zLoss for Mask R-CNN bounding box refinement.

    target_bbox: [batch, num_rois, (dy, dx, log(dh), log(dw))]
    target_class_ids: [batch, num_rois]. Integer class IDs.
    pred_bbox: [batch, num_rois, num_classes, (dy, dx, log(dh), log(dw))]
    r   )r   re   r   rT   re   r   Nr$   r   )r6  r7  r<  )r*  r   r-  rw   r   r   r   r   rx   r   r=  r
   r9  r   r>  )r   r   Z	pred_bboxZpositive_roi_ixZpositive_roi_class_idsr   r8  r   r   r   mrcnn_bbox_loss_graph  s"    

rI  c                 C   s   t |d}t| }t | d|d |d f} t|}t |d|d |d |d f}t|g d}t|dkdddf }tt||tj}tj	||gd	d
}t| |}t
||}	t t|dkt j||	dtd}
t |
}
|
S )ax  Mask binary cross-entropy loss for the masks head.

    target_masks: [batch, num_rois, height, width].
        A float32 tensor of values 0 or 1. Uses zero padding to fill array.
    target_class_ids: [batch, num_rois]. Integer class IDs. Zero padded.
    pred_masks: [batch, proposals, height, width, num_classes] float32 tensor
                with values from 0 to 1.
    rH  r   rT   rZ   re   )r   rZ   r$   rT   r   Nr$   r   )r:  r;  r<  )r*  r   rw   r	   r   r   r   r   r   rx   r   r=  r
   Zbinary_crossentropyr   r>  )Ztarget_masksr   Z
pred_masks
mask_shapeZ
pred_shapeZpositive_ixZpositive_class_idsr   r6  r7  r8  r   r   r   mrcnn_mask_loss_graph  s(    



rK  c                    s  |  |}| |\}}}	|	r"dS |j}
tj||j|j|j|jd\}}}}}t	||||}|rt
d tddrt|}t|}|rddl}g d  fdd	}|j}|j}| }||}|j|tj|j|d
d}|j|ksJ d|j|ksJ d|tj}tj|dddk}|dddd|f }|| }t|}tj| jgtjd}| j| j| d  }d||< |rt|||j}t ||
|j|||}|||||dfS )a  Load and return ground truth data for an image (image, mask, bounding boxes).

    augment: (deprecated. Use augmentation instead). If true, apply random
        image augmentation. Currently, only horizontal flipping is offered.
    augmentation: Optional. An imgaug (https://github.com/aleju/imgaug) augmentation.
        For example, passing imgaug.augmenters.Fliplr(0.5) flips images
        right/left 50% of the time.
    use_mini_mask: If False, returns full-size masks that are the same height
        and width as the original image. These can be big, for example
        1024x1024x100 (for 100 instances). Mini masks are smaller, typically,
        224x224 and are generated by extracting the bounding box of the
        object and resizing it to MINI_MASK_SHAPE.

    Returns:
    image: [height, width, 3]
    shape: the original shape of the image before resizing and cropping.
    class_ids: [instance_count] Integer class IDs
    bbox: [instance_count, (y1, x1, y2, x2)]
    mask: [height, width, instance_count]. The height and width are those
        of the image unless use_mini_mask is True, in which case they are
        defined in MINI_MASK_SHAPE.
    )NNNNNr$   Zmin_dimZ	min_scaleZmax_dimr  z4'augment' is deprecated. Use 'augmentation' instead.r   r$   N)	Z
SequentialZSomeOfZOneOfZ	SometimesZFliplrZFlipudZ
CropAndPadZAffineZPiecewiseAffinec                    s   |j j v S )z.Determines which augmenters to apply to masks.)r   r   )imagesZ	augmenterparentsdefaultZMASK_AUGMENTERSr   r   hookd  s    zload_image_gt.<locals>.hook)Z	activator)Zhooksz(Augmentation shouldn't change image sizez'Augmentation shouldn't change mask sizer#   r   r   source)!Z
load_imageZget_maskr	   r   resize_imageIMAGE_MIN_DIMIMAGE_MIN_SCALEIMAGE_MAX_DIMIMAGE_RESIZE_MODEZresize_maskloggingZwarningrandomrandintr-   ZfliplrimgaugZto_deterministicZaugment_imageastypeuint8ZHooksImagesboolrA  Zextract_bboxeszerosr/  r   source_class_ids
image_infoZminimize_maskMINI_MASK_SHAPEcompose_image_meta)datasetr/   image_idaugmentaugmentationuse_mini_maskr   r  r  errorZoriginal_shaper   scaler>   cropr[  rQ  r)   rJ  ZdetZ_idxZbboxrF  r`  r   r   rP  r   load_image_gt  sZ    





rl  c           .      C   s  | j d dksJ |jtjks.J d|j|jtjksJJ d|j|jtjksfJ d|jt|dkd }|j d dksJ d|| }|| }|dddd|f }| dddf | dddf  | dddf | dddf   }|dddf |dddf  |dddf |dddf   }t| j d |j d f}t|j d D ].}	||	 }
t	
|
| ||	 ||dd|	f< q\tj|dd	}|t|j d |f }|| }|| }t|d
kd }t|d
k d }t|j|j }|j d |krtjj||dd}n|}|j|j d  }|j d |krRtjj||dd}n|}t||g}|j|j d  }|dkr
|j d dkrt|d
k d }|j d |ksJ tjj||dd}|j d |ksJ t||g}n tjj||dd}t||g}|j d |jks4J d|j d |jd||ddf< d||< | | }|| }|| }|| }tj|j|jdftjd}t|dkd }t	|| ||ddf |||| f< ||j }tj|j|jd |jd |jftjd}|D ] }	||	 }|dksJ d||	 }|dddd|f } |jrtj|jdd td}!|| \}"}#}$}%|%|# }&|$|" }'tt	| |'|&ft|!|"|$|#|%f< |!} ||	 tj\}(})}*}+| |(|*|)|+f },t	|,|j}-|-||	dddd|f< q||||fS )au  Generate targets for training Stage 2 classifier and mask heads.
    This is not used in normal training. It's useful for debugging or to train
    the Mask RCNN heads without using the RPN head.

    Inputs:
    rpn_rois: [N, (y1, x1, y2, x2)] proposal boxes.
    gt_class_ids: [instance count] Integer class IDs
    gt_boxes: [instance count, (y1, x1, y2, x2)]
    gt_masks: [height, width, instance count] Ground truth masks. Can be full
              size or mini-masks.

    Returns:
    rois: [TRAIN_ROIS_PER_IMAGE, (y1, x1, y2, x2)]
    class_ids: [TRAIN_ROIS_PER_IMAGE]. Integer class IDs.
    bboxes: [TRAIN_ROIS_PER_IMAGE, NUM_CLASSES, (y, x, log(h), log(w))]. Class-specific
            bbox refinements.
    masks: [TRAIN_ROIS_PER_IMAGE, height, width, NUM_CLASSES). Class specific masks cropped
           to bbox boundaries and resized to neural network output size.
    r   zExpected int but got {}zExpected bool but got {}zImage must contain instances.NrT   rZ   r$   r   rt   FreplaceTz(keep doesn't match ROI batch size {}, {}re   r   zclass id must be greater than 0)r	   r   r-   r   r   Zbool_r   r_  rj   r   Zcompute_iour   aranger%   r   r   rY  choiceZconcatenateNUM_CLASSESr   Zbox_refinementr   r   r   IMAGE_SHAPEr^  r   Zresizer\  ).rpn_roisr   r   r   r/   Zinstance_idsZrpn_roi_areaZgt_box_arear   rp   gtZrpn_roi_iou_argmaxZrpn_roi_iou_maxZrpn_roi_gt_boxesZrpn_roi_gt_class_idsZfg_idsZbg_idsZfg_roi_countZkeep_fg_idsZ	remainingZkeep_bg_idsr	  Zkeep_extra_idsr   r   r   Zroi_gt_assignmentZbboxesZpos_idsr   r  Zgt_idZ
class_maskZplaceholderr   r   r   r   r   r   r|   r}   r~   r   r  r  r   r   r   build_detection_targets  s    @@$

	
ru  c                 C   s  t j|jd gt jd}t |jdf}t |dk d }|jd dkrt |dkd }|| }	|| }|| }t||	}
t j|
dd}|dk }nt j	|jd gt
d}t||}t j|dd}|t |jd |f }d||dk |@ < t |t j|ddkd	d	df }d||< d||d
k< t |dkd }t||jd  }|dkrnt jj||dd}d||< t |dkd }t||jt |dk  }|dkrt jj||dd}d||< t |dkd }d}t||| D ]\}}|||  }|d |d  }|d |d  }|d d|  }|d d|  }|d |d  }|d |d  }|d d|  }|d d|  }|| | || | t || t || g||< ||  |j  < |d7 }q||fS )a   Given the anchors and GT boxes, compute overlaps and identify positive
    anchors and deltas to refine them to match their corresponding GT boxes.

    anchors: [num_anchors, (y1, x1, y2, x2)]
    gt_class_ids: [num_gt_boxes] Integer class IDs.
    gt_boxes: [num_gt_boxes, (y1, x1, y2, x2)]

    Returns:
    rpn_match: [N] (int32) matches between anchors and GT boxes.
               1 = positive anchor, -1 = negative anchor, 0 = neutral
    rpn_bbox: [N, (dy, dx, log(dh), log(dw))] Anchor bbox deltas.
    r   r   re   r$   r   r   r   g333333?Ngffffff?rT   Frm  rZ   rt   )r-   r_  r	   r   RPN_TRAIN_ANCHORS_PER_IMAGEr   r   Zcompute_overlapsZamaxZonesr^  r   ro  Zargwherer   lenrY  rp  rA  zipr   r   )r)   r   r   r   r/   r?  r"  r   r   r   r   r   r   r   Zanchor_iou_argmaxZanchor_iou_maxZgt_iou_argmaxidsZextrar   rp   r^   rt  r   r   Zgt_center_yZgt_center_xZa_hZa_wZ
a_center_yZ
a_center_xr   r   r   build_rpn_targets7  sb    
$



rz  c                 C   s
  t j|dft jd}td| |jd  }t|jd D ]}|| \}}}	}
|	| }|
| }t|| d}t|	| | d }t|| d}t|
| | d }t j	|||d df}t j	|||d df}d}|t 
|dddf |dddf  |k d| }|t 
|dddf |dddf  |k d| }|jd |kr|jd |krq`qt jt j|ddddd\}}t jt j|ddddd\}}t ||||g}|||| ||d  < q8|||jd   }t j	d| d |d df}t j	d| d |d df}d}|t 
|dddf |dddf  |k d| }|t 
|dddf |dddf  |k d| }|jd |kr|jd |krڐqqt jt j|ddddd\}}t jt j|ddddd\}}t ||||g}||| d< |S )	a]  Generates ROI proposals similar to what a region proposal network
    would generate.

    image_shape: [Height, Width, Depth]
    count: Number of ROIs to generate
    gt_class_ids: [N] Integer ground truth class IDs
    gt_boxes: [N, (y1, x1, y2, x2)] Ground truth boxes in pixels.

    Returns: [count, (y1, x1, y2, x2)] ROI boxes in pixels.
    re   r   g?r   r$   rT   Nr   )r-   r_  r   r%   r	   rj   r   r   rY  rZ  r5  r   sortZhstack)r)   countr   r   r   Zrois_per_boxrp   r   r   r   r   r   r   Zr_y1Zr_y2Zr_x1Zr_x2Zy1y2Zx1x2Z	thresholdr}   r   r|   r~   Zbox_roisZremaining_countZglobal_roisr   r   r   generate_random_rois  sD    6666 r}  r$   c	           -   	   c   s  d}	d}
t | j}d}|pg }t||j}t|j|j||j	|j
}z|
d t| }
|rr|
dkrrt j| ||
 }| j| d |v rt| |||d|jd\}}}}}}n"t| |||||jd\}}}}}}|rddl}|jd|dd	 W qFt |dks
W qFt|j||||\}}|rTt|j|||}|rTt|||||\}}}}|	dkrt j|f|j |jd
}t j||jd dg|jd
}t j||jdg|jd
} t j|f|j t jd
}!t j||jft jd
}"t j||jdft jd
}#t j||jd |jd |jf|jd
}$|rt j||jd df|jd
}%|rt j|f|j |jd
}&t j|f|j |jd
}'t j|f|j |jd
}(t j|f|j |jd
})|jd |jkrt jjt  |jd |jdd}*||* }||* }|dddd|*f }|||	< |ddt j!f ||	< || |	< t"|#t j||!|	< ||"|	d|jd f< ||#|	d|jd f< ||$|	ddddd|jd f< |r||%|	< |r||&|	< ||'|	< ||(|	< ||)|	< |	d7 }	|	|kr*|!||| |"|#|$g}+g },|r|+$|%g |r|+$|&g t %|'d}'|,$|'|(|)g |+|,fV  d}	W qF t&t'fyF    Y qF   t()d| j|  |d7 }|dkrv Y qF0 qFdS )a  A generator that returns images and corresponding target class ids,
    bounding box deltas, and masks.

    dataset: The Dataset object to pick data from
    config: The model config object
    shuffle: If True, shuffles the samples before every epoch
    augment: (deprecated. Use augmentation instead). If true, apply random
        image augmentation. Currently, only horizontal flipping is offered.
    augmentation: Optional. An imgaug (https://github.com/aleju/imgaug) augmentation.
        For example, passing imgaug.augmenters.Fliplr(0.5) flips images
        right/left 50% of the time.
    random_rois: If > 0 then generate proposals to be used to train the
                 network classifier and mask heads. Useful if training
                 the Mask RCNN part without the RPN.
    batch_size: How many images to return in each call
    detection_targets: If True, generate detection targets (class IDs, bbox
        deltas, and masks). Typically for debugging or visualizations because
        in trainig detection targets are generated by DetectionTargetLayer.
    no_augmentation_sources: Optional. List of sources to exclude for
        augmentation. A source is string that identifies a dataset and is
        defined in the Dataset class.

    Returns a Python generator. Upon calling next() on it, the
    generator returns two lists, inputs and outputs. The contents
    of the lists differs depending on the received arguments:
    inputs list:
    - images: [batch, H, W, C]
    - image_meta: [batch, (meta data)] Image details. See compose_image_meta()
    - rpn_match: [batch, N] Integer (1=positive anchor, -1=negative, 0=neutral)
    - rpn_bbox: [batch, N, (dy, dx, log(dh), log(dw))] Anchor bbox deltas.
    - gt_class_ids: [batch, MAX_GT_INSTANCES] Integer class IDs
    - gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)]
    - gt_masks: [batch, height, width, MAX_GT_INSTANCES]. The height and width
                are those of the image unless use_mini_mask is True, in which
                case they are defined in MINI_MASK_SHAPE.

    outputs list: Usually empty in regular training. But if detection_targets
        is True then the outputs list contains target class_ids, bbox deltas,
        and masks.
    r   r   r$   rR  N)rf  rg  rh  ziProblems while processing masks of image {}. This image is skipped in the training and should be checked.rT   )
stacklevelr   re   Frm  zError processing image {}rf   )*r-   copy	image_idsr0   rr  r   generate_pyramid_anchorsRPN_ANCHOR_SCALESRPN_ANCHOR_RATIOSr.   RPN_ANCHOR_STRIDErw  rY  shufflera  rl  r   warningswarnr   anyrz  r	   r}  ru  r_  r   rv  r   ZMAX_GT_INSTANCESr   rp  ro  r  
mold_imager\  extendr   GeneratorExitKeyboardInterruptrX  Z	exception)-rd  r/   r  rf  rg  Zrandom_rois
batch_sizeZdetection_targetsno_augmentation_sourcesr_   Zimage_indexr  Zerror_countbackbone_shapesr   re  r   r   r   r   r   ri  r  r?  r"  rs  r   Zmrcnn_class_idsr  r3  Zbatch_image_metaZbatch_rpn_matchZbatch_rpn_bboxZbatch_imagesZbatch_gt_class_idsZbatch_gt_boxesZbatch_gt_masksZbatch_rpn_roisZ
batch_roisZbatch_mrcnn_class_idsZbatch_mrcnn_bboxZbatch_mrcnn_maskry  r   r   r   r   r   data_generator  s\   3





	
"
	

r  c                   @   s   e Zd ZdZdd Zdd Zdd Zd*d
dZdd Zdd Z	d+ddZ
d,ddZd-ddZdd Zdd Zd.ddZd/ddZd d! Zd0d"d#Zd$d% Zd&d' Zd1d(d)Zd	S )2MaskRCNNzpEncapsulates the Mask RCNN model functionality.

    The actual Keras model is in the keras_model property.
    c                 C   s:   |dv sJ || _ || _|| _|   | j||d| _dS )z
        mode: Either "training" or "inference"
        config: A Sub-class of the Config class
        model_dir: Directory to save training logs and trained weights
        r   	inference)r  r/   N)r  r/   	model_dirset_log_dirbuildkeras_model)r   r  r/   r  r   r   r   r     s    zMaskRCNN.__init__c           9         s.  |dv sJ j dd \}}|d t|d ksF|d t|d krNtdtjddj d gddtjjgdd}|d	kr@tjdd
gdtjd}tjddgdtjd}tjdgdtjd}tjddgdtjd}	t	fdd|	}
j
rtjjd jd
 dgdtd}n"tjj d j d
 dgdtd}n|dkr\tjddgdd}tjrjdjd\}}}}}ntjdjd\}}}}}tjjddd|}tjddtjddd|tjjdd d|g}tjd!dtjdd"d|tjjdd#d|g}tjd$dtjdd%d|tjjdd&d|g}tjjd'd(d)d*|}tjjd'd(d+d*|}tjjd'd(d,d*|}tjjd'd(d-d*|}tjddd.d/|}|||||g}||||g}|d	kr,| j  t jf j  tj	 fd0dd1d n| tjtjj}g }|D ]}|||g qNg d2}t t!| }d3d4 t!||D }|\}}} |d	krj"nj#}!t$|!j%d5d6||  g}"|d	krzt	d7d |}#j&stjj"dgd8tjd}$t	fd9d|$}%n|"}%t'd:d|%||
|g\}&}'}(})t(|&||j)j*jj+d;\}*}+},t,|&||j-j*jd<}-tj	d=d d>d|&}.tj	d?d d@d||g}/tj	fdAddBd||| g}0tj	dCd dDd|'|*|#g}1tj	dEd dFd|(|'|,g}2tj	dGd dHd|)|'|-g}3|||||	|g}4j&sH|4|$ ||| |*|+|,|-|"|.|/|0|1|2|3g}t.j/|4|dId}5nt(|"||j)j*jj+d;\}*}+},t0dJd|"|+|,|g}6t	dKd |6}7t,|7||j-j*jd<}-t.j/||g|6|+|,|-|"|| gdId}5j1d
kr*ddLl2m3}8 |8|5j1}5|5S )MzBuild Mask R-CNN architecture.
        input_shape: The shape of the input image.
        mode: Either "training" or "inference". The inputs and
            outputs of the model differ accordingly.
        r  NrT   r[   zImage size must be dividable by 2 at least 6 times to avoid fractions when downscaling and upscaling.For example, use 256, 320, 384, 448, 512, ... etc. rl   r$  input_image_metar   r$   input_rpn_match)r	   r7   r   re   input_rpn_bboxinput_gt_class_idsinput_gt_boxesc                    s   t | t dd S Nr$   rZ   r  r*  r	   r   rl   r   r   r   9  r   z MaskRCNN.build.<locals>.<lambda>r   input_gt_masksr  input_anchorsT)rm   rK   r4   Zfpn_c5p5r9   Z	fpn_p4addrS   Zfpn_p5upsampled)r
   r7   Zfpn_c4p4Z	fpn_p3addZfpn_p4upsampledZfpn_c3p3Z	fpn_p2addZfpn_p3upsampledZfpn_c2p2rY   ZSAMEZfpn_p2)r>   r7   Zfpn_p3Zfpn_p4Zfpn_p5Zfpn_p6)r.  rV   r7   c                    s
   t  S r   )rw   ZVariabler   )r   r   r   r     r   r   )r!  	rpn_classr"  c                 S   s&   g | ]\}}t jd |dt|qS )r$   ru   )rB   ZConcatenatelist)r'   onr   r   r   r*     s   z"MaskRCNN.build.<locals>.<listcomp>ZROI)r   r   r7   r/   c                 S   s   t | d S )NrF  )r   r   r   r   r   r     r   Z	input_roic                    s   t | t dd S r  r  r   r  r   r   r     r   Zproposal_targets)rK   r0  )rK   c                 S   s   | d S )Nr$   r   r   r   r   r   r     r   output_roisc                 S   s   t |  S r   )r@  r   r   r   r   r     r   rpn_class_lossc                    s   t  g| R  S r   )rC  r   )r/   r   r   r     r   rpn_bbox_lossc                 S   s   t |  S r   )rG  r   r   r   r   r     r   mrcnn_class_lossc                 S   s   t |  S r   )rI  r   r   r   r   r     r   mrcnn_bbox_lossc                 S   s   t |  S r   )rK  r   r   r   r   r     r   mrcnn_mask_loss	mask_rcnnZmrcnn_detectionc                 S   s   | dd df S )N.re   r   r   r   r   r   r   .	  r   )ParallelModel)4rr  r%   	ExceptionrB   r%  ZIMAGE_META_SIZErw   r   r   r  r   rb  r^  r+   r,   ZTRAIN_BNrs   rC   ZTOP_DOWN_PYRAMID_SIZErE   ZUpSampling2Dri   get_anchorsr-   broadcast_tor  r	   r(  r  rw  r  r   r  rx  ZPOST_NMS_ROIS_TRAININGZPOST_NMS_ROIS_INFERENCEr   ZRPN_NMS_THRESHOLDZUSE_RPN_ROISr   r2  Z	POOL_SIZErq  ZFPN_CLASSIF_FC_LAYERS_SIZEr4  ZMASK_POOL_SIZEr&  r'  r  Z	GPU_COUNTZparallel_modelr  )9r   r  r/   r   r   r  r  r  r  r  r   r  r  r   rn   ro   rq   rr   ZP5ZP4ZP3ZP2ZP6Zrpn_feature_mapsZmrcnn_feature_mapsZrpnZlayer_outputspZoutput_namesr   r!  r  r"  r   rs  rF  Z
input_roisZtarget_roisr   r   r   r   r+  r  r  r3  r  r  r  Z
class_lossZ	bbox_lossZ	mask_lossr   modelr  Zdetection_boxesr  r   )r   r/   rl   r   r    s   (


















	


		zMaskRCNN.buildc                    s   t t| jd }| jj  t fdd|}t|}|sZddl	}t
|jd| jtj| j|d }t t|d }td	d |}t|}|sddl	}t
|jd
|tj||d }|S )zFinds the last checkpoint file of the last trained model in the
        model directory.
        Returns:
            The path of the last checkpoint file
        r$   c                    s
   |   S r   
startswithfkeyr   r   r   W	  r   z$MaskRCNN.find_last.<locals>.<lambda>r   Nz'Could not find model directory under {}r   rT   c                 S   s
   |  dS )Nr  r  r  r   r   r   r   d	  r   z!Could not find weight files in {})nextoswalkr  r/   NAMElowerfilterr   errnoFileNotFoundErrorZENOENTr   pathjoin)r   Z	dir_namesr  Zdir_nameZcheckpointsZ
checkpointr   r  r   	find_lastN	  s*    
zMaskRCNN.find_lastFNc           	         s   ddl }zddlm} W n ty6   ddlm} Y n0  r@d}|du rPtd|j|dd}d	|jvrxd
|v rx|d
 }| j}t|dr|j	j
n|j
} rt fdd|}|r||| n||| t|dr|  | | dS )zModified version of the corresponding Keras function with
        the addition of multi-GPU support and the ability to exclude
        some layers from loading.
        exclude: list of layer names to exclude
        r   N)saving)topologyTz`load_weights` requires h5py.r   )r  Zlayer_namesZmodel_weightsinner_modelc                    s
   | j  vS r   r9   )lexcluder   r   r   	  r   z'MaskRCNN.load_weights.<locals>.<lambda>close)h5pykeras.enginer  ImportErrorr  ZFileattrsr  hasattrr  layersr  Z$load_weights_from_hdf5_group_by_nameZload_weights_from_hdf5_groupr  r  )	r   filepathZby_namer  r  r  r  r  r  r   r  r   load_weightso	  s2    

zMaskRCNN.load_weightsc                 C   s$   ddl m} d}|d|ddd}|S )z]Downloads ImageNet trained weights from Keras.
        Returns path to weights file.
        r   )get_filez|https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5z4resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5modelsZ a268eb855778b3df3c7506639542a6af)Zcache_subdirZmd5_hash)Zkeras.utils.data_utilsr  )r   r  ZTF_WEIGHTS_PATH_NO_TOPZweights_pathr   r   r   get_imagenet_weights	  s    zMaskRCNN.get_imagenet_weightsc           	         s*  t jj|| jjd}g  j_i  j_g d}|D ]L} j|}|j	 jj
v rRq2tj|j	dd jj|d } j| q2 fdd jjD } jt|  jj|dgt jj d	 |D ]Z}| jjv rqʈ j|} jj| tj|j	dd jj|d } jj| qdS )
zGets the model ready for training. Adds losses, regularization, and
        metrics. Then calls the Keras compile() function.
        )ZlrmomentumZclipnorm)r  r  r  r  r  T)Zkeepdimsr   c                    sH   g | ]@}d |j vrd|j vrtj jj|tt|tj	 qS )ZgammaZbeta)
r7   kerasZregularizersl2r/   ZWEIGHT_DECAYrw   r   r
   r   )r'   r   r   r   r   r*   	  s
   z$MaskRCNN.compile.<locals>.<listcomp>N)	optimizerr8  )r  Z
optimizersZSGDr/   ZGRADIENT_CLIP_NORMr  Z_lossesZ_per_input_lossesZ	get_layerr;  Zlossesrw   Zreduce_meanZLOSS_WEIGHTSgetZadd_lossZtrainable_weightsZadd_ncompilerw  r   Zmetrics_namesr   Zmetrics_tensors)	r   learning_rater  r  Z
loss_namesr7   layerr8  Z
reg_lossesr   r   r   r  	  sD    

zMaskRCNN.compiler   r$   c                 C   s   |dkr|du rt d |p | j}t|dr4|jjn|j}|D ]}|jjdkrptd|j | j	|||d d q>|j
sxq>tt||j}|jjd	kr||j_n||_|r>|dkr>t d
d| |j|jj q>dS )zbSets model layers as trainable if their names match
        the given regular expression.
        r   NzSelecting layers to trainr  r'  z
In model: re   )r  indentr,  z{}{:20}   ({}) )r   r  r  r  r  r   r   r   r7   set_trainableZweightsr^  re	fullmatchr  	trainabler   )r   layer_regexr  r  verboser  r  r  r   r   r   r  	  s0    


zMaskRCNN.set_trainablec              
   C   s   d| _ tj }|rd}t||}|rtt|dt|dt|dt|dt|d}t|dd d | _ td	| j   tj	
| jd
| jj || _tj	
| jd| jj | _| jdd| _dS )a.  Sets the model log directory and epoch counter.

        model_path: If None, or a format different from what this code uses
            then set a new log directory and start epochs from 0. Otherwise,
            extract the log directory and the epoch counter from the file
            name.
        r   zS.*[/\\][\w-]+(\d{4})(\d{2})(\d{2})T(\d{2})(\d{2})[/\\]mask\_rcnn\_[\w-]+(\d{4})\.h5r$   rT   rZ   re   rf   r   zRe-starting from epoch %dz{}{:%Y%m%dT%H%M}zmask_rcnn_{}_*epoch*.h5z*epoch*z{epoch:04d}N)epochdatetimenowr  matchr%   groupr   r  r  r  r  r   r/   r  r  log_dircheckpoint_pathrn  )r   Z
model_pathr  Zregexr  r   r   r   r  
  s,    	
	zMaskRCNN.set_log_dirc	                 C   sH  | j dksJ ddddddd}	||	 v r6|	| }t|| jd	|| jj|d
}
t|| jd	| jjd}tj| jst	| j t
jj| jdd	ddt
jj| jdd	dg}|r||7 }td| j| td| j | | | || jj tjdkrd}nt }| jj|
| j|| jj||| jjd|d	d
 t| j|| _dS )a6  Train the model.
        train_dataset, val_dataset: Training and validation Dataset objects.
        learning_rate: The learning rate to train with
        epochs: Number of training epochs. Note that previous training epochs
                are considered to be done alreay, so this actually determines
                the epochs to train in total rather than in this particaular
                call.
        layers: Allows selecting wich layers to train. It can be:
            - A regular expression to match layer names to train
            - One of these predefined values:
              heads: The RPN, classifier and mask heads of the network
              all: All the layers
              3+: Train Resnet stage 3 and up
              4+: Train Resnet stage 4 and up
              5+: Train Resnet stage 5 and up
        augmentation: Optional. An imgaug (https://github.com/aleju/imgaug)
            augmentation. For example, passing imgaug.augmenters.Fliplr(0.5)
            flips images right/left 50% of the time. You can pass complex
            augmentations as well. This augmentation applies 50% of the
            time, and when it does it flips images right/left half the time
            and adds a Gaussian blur with a random sigma in range 0 to 5.

                augmentation = imgaug.augmenters.Sometimes(0.5, [
                    imgaug.augmenters.Fliplr(0.5),
                    imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0))
                ])
            custom_callbacks: Optional. Add custom callbacks to be called
                with the keras fit_generator method. Must be list of type keras.callbacks.
        no_augmentation_sources: Optional. List of sources to exclude for
            augmentation. A source is string that identifies a dataset and is
            defined in the Dataset class.
        r   zCreate model in training mode.z(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)zR(res3.*)|(bn3.*)|(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)zA(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)z0(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)z.*)Zheadsz3+z4+z5+allT)r  rg  r  r  )r  r  r   F)r  Zhistogram_freqZwrite_graphZwrite_images)r  Zsave_weights_onlyz
Starting at epoch {}. LR={}
zCheckpoint Path: {}ntd   )	Zinitial_epochepochsZsteps_per_epoch	callbacksZvalidation_dataZvalidation_stepsZmax_queue_sizeworkersZuse_multiprocessingN)r  keysr  r/   r  r  r  existsr  makedirsr  r  ZTensorBoardZModelCheckpointr  r   r   r  r  r  ZLEARNING_MOMENTUMr7   multiprocessing	cpu_countr  Zfit_generatorZSTEPS_PER_EPOCHZVALIDATION_STEPSr   )r   Ztrain_datasetZval_datasetr  r  r  rg  Zcustom_callbacksr  r  Ztrain_generatorZval_generatorr  r  r   r   r   train>
  sl    +

zMaskRCNN.trainc                 C   s   g }g }g }|D ]}t j|| jj| jj| jj| jjd\}}}}	}
t|| j}td|j	|j	||t
j| jjgt
jd}|| || || qt
|}t
|}t
|}|||fS )a
  Takes a list of images and modifies them to the format expected
        as an input to the neural network.
        images: List of image matrices [height,width,depth]. Images can have
            different sizes.

        Returns 3 Numpy matrices:
        molded_images: [N, h, w, 3]. Images resized and normalized.
        image_metas: [N, length of meta data]. Details about each image.
        windows: [N, (y1, x1, y2, x2)]. The portion of the image that has the
            original image (padding excluded).
        rL  r   r   )r   rS  r/   rT  rU  rV  rW  r  rc  r	   r-   r_  rq  r   r   rx   )r   rM  molded_imagesimage_metaswindowsr   Zmolded_imager   rj  r>   rk  r   r   r   r   mold_inputs
  s6    	




zMaskRCNN.mold_inputsc                 C   s"  t |dddf dkd }|jd dkr4|d n|jd }|d|ddf }|d|df t j}	|d|df }
|t |dddd|	f }t||dd }|\}}}}t ||||g}|| }|| }t ||||g}t 	|| |}t
||dd }t |dddf |dddf  |dddf |dddf   dkd }|jd dkrt j||dd}t j|	|dd}	t j|
|dd}
t j||dd}|	jd }g }t|D ]&}t|| || |}|| q|r t j|d	dnt |dd d
 }||	|
|fS )a[  Reformats the detections of one image from the format of the neural
        network output to a format suitable for use in the rest of the
        application.

        detections: [N, (y1, x1, y2, x2, class_id, score)] in normalized coordinates
        mrcnn_mask: [N, height, width, num_classes]
        original_image_shape: [H, W, C] Original image shape before resizing
        image_shape: [H, W, C] Shape of the image after resizing and padding
        window: [y1, x1, y2, x2] Pixel coordinates of box in the image where the real
                image is excluding the padding.

        Returns:
        boxes: [N, (y1, x1, y2, x2)] Bounding boxes in pixels
        class_ids: [N] Integer class IDs for each bounding box
        scores: [N] Float probability scores of the class_id
        masks: [height, width, num_instances] Instance masks
        Nre   r   rf   rT   rZ   r$   r   r   )r   )r-   r   r	   r\  r   ro  r   
norm_boxesr   divideZdenorm_boxesdeleterj   Zunmold_maskr   rx   empty)r   r  r3  original_image_shaper)   r   Zzero_ixr   ry   r  r   r   r   r   r   r   shiftZwhZwwrj  Z
exclude_ixZ
full_masksrp   Z	full_maskr   r   r   unmold_detections
  sD     B
zMaskRCNN.unmold_detectionsc              	   C   s`  | j dksJ dt|| jjks*J d|dkrXtdt| |D ]}td| qH| |\}}}|d j}|dd	 D ]}|j|ks~J d
q~| |}	t	
|	| jjf|	j }	|dkrtd| td| td|	 | jj|||	gdd\}
}}}}}}g }t|D ]L\}}| |
| || |j|| j|| \}}}}|||||d q|S )a  Runs the detection pipeline.

        images: List of images, potentially of different sizes.

        Returns a list of dicts, one dict per image. The dict contains:
        rois: [N, (y1, x1, y2, x2)] detection bounding boxes
        class_ids: [N] int class IDs
        scores: [N] float probability scores for the class IDs
        masks: [H, W, N] instance binary masks

        Modified by Ondrej Pesek:
            verbose -> verbosity
            verbose form in case of verbosity == 3, not 1
        r  Create model in inference mode.z'len(images) must be equal to BATCH_SIZErZ   Processing {} imagesr   r   r$   Nz\After resizing, all images must have the same size. Check IMAGE_RESIZE_MODE and image sizes.r  r  r   r  r   r  r   r   )r  rw  r/   r  r   r   r  r	   r  r-   r  r  predictr   r  r   )r   rM  	verbosityr   r  r  r  r)   gr   r  r   r3  resultsrp   
final_roisfinal_class_idsfinal_scoresfinal_masksr   r   r   detect*  s^    





zMaskRCNN.detectc                 C   sZ  | j dksJ dt|| jjks*J d|rTtdt| |D ]}td| qD|d j}|dd D ]}|j|ksjJ d	qj| |}t	|| jjf|j }|rtd
| td| td| | j
j|||gdd\}}	}	}
}	}	}	g }t|D ]^\}}dd|jd |jd g}| || |
| |j|| j|\}}}}|||||d q|S )a8  Runs the detection pipeline, but expect inputs that are
        molded already. Used mostly for debugging and inspecting
        the model.

        molded_images: List of images loaded using load_image_gt()
        image_metas: image meta data, also returned by load_image_gt()

        Returns a list of dicts, one dict per image. The dict contains:
        rois: [N, (y1, x1, y2, x2)] detection bounding boxes
        class_ids: [N] int class IDs
        scores: [N] float probability scores for the class IDs
        masks: [H, W, N] instance binary masks
        r  r  z,Number of images must be equal to BATCH_SIZEr  r   r   r$   NzImages must have the same sizer  r  r   r  r  )r  rw  r/   r  r   r   r	   r  r-   r  r  r  r   r  r   )r   r  r  r  r   r)   r
  r   r  r   r3  r  rp   r   r  r  r  r  r   r   r   detect_moldedu  sX    





zMaskRCNN.detect_moldedc                 C   s~   t | j|}t| dsi | _t|| jvrpt| jj| jj|| jj	| jj
}|| _t||dd | jt|< | jt| S )z0Returns anchor pyramid for the given image size._anchor_cacheNrT   )r0   r/   r  r  r   r   r  r  r  r.   r  r   r  )r   r)   r  r^   r   r   r   r    s     

zMaskRCNN.get_anchorsc                 C   s   |dur|ng }t |dkr dS t|tr<t|dd}|jj}|D ]P}||v rVqHtt	||j
rp|  S || | |||}|durH|  S qHdS )a'  Finds the ancestor of a TF tensor in the computation graph.
        tensor: TensorFlow symbolic tensor.
        name: Name of ancestor tensor to find
        checked: For internal use. A list of tensors that were already
                 searched to avoid loops in traversing the graph.
        Ni  /z	(\_\d+)*/)rw  
isinstancer   r  r  rn  opr   r^  r  r7   r   ancestor)r   Ztensorr7   checkedrN  r  r^   r   r   r   r    s     


zMaskRCNN.ancestorc                 C   s   |j jdkr| |jS |S )zIf a layer is encapsulated by another layer, this function
        digs through the encapsulation and returns the layer that holds
        the weights.
        r,  )r   r   find_trainable_layerr  )r   r  r   r   r   r    s    zMaskRCNN.find_trainable_layerc                 C   s2   g }| j jD ] }| |}| r|| q|S )z+Returns a list of layers that have weights.)r  r  r  Zget_weightsr   )r   r  r  r   r   r   get_trainable_layers  s    
zMaskRCNN.get_trainable_layersc                 C   s  | j }t|}| D ]}|dusJ q|j}|jrPtt tsP|t g7 }t	|jt
| }|du r| |\}}}	n|}|d j}
| |
}t|| jjf|j }|||g}|jrtt ts|d ||}tdd t| |D }| D ]\}}t|| q|S )a  Runs a sub-set of the computation graph that computes the given
        outputs.

        image_metas: If provided, the images are assumed to be already
            molded (i.e. resized, padded, and normalized)

        outputs: List of tuples (name, tensor) to compute. The tensors are
            symbolic TensorFlow tensors and the names are for easy tracking.

        Returns an ordered dict of results. Keys are the names received in the
        input and values are Numpy arrays.
        Nr   r<  c                 S   s   g | ]\}}||fqS r   r   )r'   r   vr   r   r   r*   2  r   z&MaskRCNN.run_graph.<locals>.<listcomp>)r  r   valuesr   Zuses_learning_phaser  r*  Zlearning_phaser%   Zfunctionr  r  r	   r  r-   r  r/   r  r   rx  r  itemsr   )r   rM  r   r  r  r  r   Zkfr  r   r)   r   Zmodel_inZ
outputs_npr   r  r   r   r   	run_graph  s,    



zMaskRCNN.run_graph)FN)Nr   r$   )N)NNN)r   )r   )N)N)r   r   r   r    r   r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r   r   r   r   r    s0     D!
06
*
3   
y-F
K
F
	r  c                 C   s6   t | gt| t| t| |g t| }|S )a  Takes attributes of an image and puts them in one 1D array.

    image_id: An int ID of the image. Useful for debugging.
    original_image_shape: [H, W, C] before resizing or padding.
    image_shape: [H, W, C] after resizing and padding
    window: (y1, x1, y2, x2) in pixels. The area of the image where the real
            image is (excluding the padding)
    scale: The scaling factor applied to the original image (float32)
    active_class_ids: List of class_ids available in the dataset from which
        the image came. Useful if training on images from multiple datasets
        where not all classes are present in all datasets.
    )r-   r   r  )re  r  r)   r   rj  rF  metar   r   r   rc  =  s    rc  c                 C   s   | dddf }| ddddf }| ddddf }| ddddf }| dddf }| ddddf }| tj| tj| tj| tj| tj| tjdS )	zParses an array that contains image attributes to its components.
    See compose_image_meta() for more details.

    meta: [batch, meta length] where meta length depends on NUM_CLASSES

    Returns a dict of the parsed values.
    Nr   r$   re   r\         re  r  r)   r   rj  rF  )r\  r-   r   r   r  re  r  r)   r   rj  rF  r   r   r   parse_image_metaW  s    





r#  c                 C   s   | dddf }| ddddf }| ddddf }| ddddf }| dddf }| ddddf }||||||dS )	zParses a tensor that contains image attributes to its components.
    See compose_image_meta() for more details.

    meta: [batch, meta length] where meta length depends on NUM_CLASSES

    Returns a dict of the parsed tensors.
    Nr   r$   re   r\   r  r   r!  r   r"  r   r   r   r   o  s    r   c                 C   s   |  tj|j S )zExpects an RGB image (or array of images) and subtracts
    the mean pixel and converts it to float. Expects image
    colors in RGB order.
    )r\  r-   r   
MEAN_PIXEL)rM  r/   r   r   r   r    s    r  c                 C   s   | |j  tjS )z>Takes a image normalized with mold() and returns the original.)r$  r\  r-   r]  )Znormalized_imagesr/   r   r   r   unmold_image  s    r%  
trim_zerosc                 C   s6   t t jt | ddt j}t j| ||d} | |fS )zOften boxes are represented with matrices of shape [N, 4] and
    are padded with zeros. This removes zero boxes.

    boxes: [N, 4] matrix of boxes.
    non_zeros: [N] a 1D boolean mask identifying the rows to keep
    r$   r   r9   )rw   r   rE  r5  r^  r   )ry   r7   r   r   r   r   r     s    r   c                 C   s:   g }t |D ]}|| |d|| f  qtj|ddS )z_Picks different number of values from each row
    in x depending on the values in counts.
    Nr   r   )rj   r   rw   r   )rQ   ZcountsZnum_rowsr   rp   r   r   r   rB    s    rB  c                 C   sX   t t |t jd\}}t j||||gddt d }t g d}t | | |S )ae  Converts boxes from pixel coordinates to normalized coordinates.
    boxes: [..., (y1, x1, y2, x2)] in pixel coordinates
    shape: [..., (height, width)] in pixels

    Note: In pixel coordinates (y2, x2) is outside the box. But in normalized
    coordinates it's inside the box.

    Returns:
        [..., (y1, x1, y2, x2)] in normalized coordinates
    rT   r   r   r   r<  r<  r   r   )rw   r   r   r   r   r   r  ry   r	   r   r   rj  r  r   r   r   r    s     r  c                 C   sh   t t |t jd\}}t j||||gddt d }t g d}t t t | || t jS )ae  Converts boxes from normalized coordinates to pixel coordinates.
    boxes: [..., (y1, x1, y2, x2)] in normalized coordinates
    shape: [..., (height, width)] in pixels

    Note: In pixel coordinates (y2, x2) is outside the box. But in normalized
    coordinates it's inside the box.

    Returns:
        [..., (y1, x1, y2, x2)] in pixel coordinates
    rT   r   r   r   r'  )	rw   r   r   r   r   r   r   Zmultiplyr   r(  r   r   r   denorm_boxes_graph  s     r)  )N)TT)rS   TT)FT)Trd   )T)FNF)TFNr   r$   FN)r&  )Jr    r  rY  r  r  r&   rX  collectionsr   r  Znumpyr-   Z
tensorflowrw   r  Zkeras.backendZbackendr*  Zkeras.layersr  rB   r  ZengineZKEZkeras.modelsr  r&  r   Zdistutils.versionr   __version__r   ZBatchNormalizationr   r0   rR   rX   rs   r   r   ZLayerr   r   r   r   r   r   r  r  r#  r(  r2  r4  r9  r@  rC  rG  rI  rK  rl  ru  rz  r}  r  r  rc  r#  r   r  r%  r   rB  r  r)  r   r   r   r   <module>   s   
 
/   
?
5aw De39"  
E 
?
$"/ 
w #nI       
          L


