Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715200759
  • Loading branch information
tensorflower-gardener committed Jan 14, 2025
1 parent 772da52 commit 4dc9ca2
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
1 change: 1 addition & 0 deletions official/vision/configs/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class RetinaNetHead(hyperparams.Config):
@dataclasses.dataclass
class DetectionGenerator(hyperparams.Config):
apply_nms: bool = True
decode_boxes: bool = True
pre_nms_top_k: int = 5000
pre_nms_score_threshold: float = 0.05
nms_iou_threshold: float = 0.5
Expand Down
20 changes: 20 additions & 0 deletions official/vision/serving/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def _normalize_coordinates(self, detections_dict, dict_keys, image_info):

return detections_dict

def _flatten_output(self, feature_map, feature_size=4):
flatten_outputs = []
for level_output in feature_map.values():
flatten_outputs.append(
tf.reshape(level_output, (self._batch_size, -1, feature_size))
)
return tf.concat(flatten_outputs, axis=1)

def preprocess(
self, images: tf.Tensor
) -> Tuple[tf.Tensor, Mapping[str, tf.Tensor], tf.Tensor]:
Expand Down Expand Up @@ -271,6 +279,18 @@ def serve(self, images: tf.Tensor):
final_outputs['detection_outer_boxes'] = detections[
'detection_outer_boxes'
]
elif (
isinstance(self.params.task.model, configs.retinanet.RetinaNet)
and not self.params.task.model.detection_generator.decode_boxes
):
final_outputs = {
'raw_boxes': self._flatten_output(detections['box_outputs'], 4),
'raw_scores': tf.sigmoid(
self._flatten_output(
detections['cls_outputs'], self.params.task.model.num_classes
)
),
}
else:
# For RetinaNet model, apply export_config.
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
Expand Down
38 changes: 38 additions & 0 deletions official/vision/serving/detection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _get_detection_module(
normalized_coordinates=False,
nms_version='batched',
output_intermediate_features=False,
decode_boxes=True,
):
params = exp_factory.get_exp_config(experiment_name)
params.task.model.outer_boxes_scale = outer_boxes_scale
Expand All @@ -48,6 +49,8 @@ def _get_detection_module(
params.task.model.detection_generator.nms_version = nms_version
if output_intermediate_features:
params.task.export_config.output_intermediate_features = True
if not decode_boxes:
params.task.model.detection_generator.decode_boxes = False
detection_module = detection.DetectionModule(
params,
batch_size=1,
Expand Down Expand Up @@ -232,6 +235,41 @@ def test_export_normalized_coordinates_no_nms(
max_values.numpy(), tf.ones_like(max_values).numpy()
)

@parameterized.parameters(
'retinanet_mobile_coco',
'retinanet_spinenet_coco',
)
def test_export_without_decoding_boxes(
self,
experiment_name,
):
input_type = 'tflite'
tmp_dir = self.get_temp_dir()
module = self._get_detection_module(
experiment_name,
input_type=input_type,
apply_nms=False,
decode_boxes=False,
)

self._export_from_module(module, input_type, tmp_dir)

imported = tf.saved_model.load(tmp_dir)
detection_fn = imported.signatures['serving_default']

images = self._get_dummy_input(
input_type, batch_size=1, image_size=(640, 640)
)
outputs = detection_fn(tf.constant(images))

self.assertContainsSubset(
{
'raw_boxes',
'raw_scores',
},
outputs.keys(),
)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 4dc9ca2

Please sign in to comment.