Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 88 additions & 9 deletions centernet/dataloaders/centernet_input.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import tensorflow as tf
from official.vision.beta.dataloaders import parser

from centernet.ops import preprocessing_ops
from official.vision.beta.ops import box_ops, preprocess_ops
from centernet.ops import preprocessing_ops as ops
from yolo.ops import preprocessing_ops

class CenterNetParser(parser.Parser):
def __init__(
Expand All @@ -10,12 +11,22 @@ def __init__(
max_num_instances: int,
gaussian_iou: float,

aug_rand_saturation=True,
aug_rand_brightness=True,
aug_rand_zoom=True,
aug_rand_hue=True,
seed=1,
):
self._num_classes = num_classes
self._max_num_instances = max_num_instances
self._gaussian_iou = gaussian_iou
self._gaussian_bump = True
self._gaussian_rad = -1
self._aug_rand_zoom = aug_rand_zoom
# self._mosaic_frequency
# self._jitter_im
self._seed = seed
# self._random_flip

def _generate_heatmap(self, boxes, output_size, input_size):
boxes = tf.cast(boxes, dtype=tf.float32)
Expand Down Expand Up @@ -68,17 +79,17 @@ def _generate_heatmap(self, boxes, output_size, input_size):
height = tf.math.ceil(height * height_ratio)

if self._gaussian_rad == -1:
radius = preprocessing_ops.gaussian_radius((height, width), self._gaussian_iou)
radius = ops.gaussian_radius((height, width), self._gaussian_iou)
radius = tf.math.maximum(0, tf.math.floor(radius))
else:
radius = self._gaussian_rad

# test
# tl_heatmaps = preprocessing_ops.draw_gaussian(tl_heatmaps[category], category, [xtl, ytl], radius)
# inputs heatmap, center, radius, k=1
tl_heatmaps = preprocessing_ops.draw_gaussian(tl_heatmaps, [[category, xtl, ytl, radius]])
br_heatmaps = preprocessing_ops.draw_gaussian(br_heatmaps, [[category, xbr, ybr, radius]])
ct_heatmaps = preprocessing_ops.draw_gaussian(ct_heatmaps, [[category, xct, yct, radius]], scaling_factor=5)
tl_heatmaps = ops.draw_gaussian(tl_heatmaps, [[category, xtl, ytl, radius]])
br_heatmaps = ops.draw_gaussian(br_heatmaps, [[category, xbr, ybr, radius]])
ct_heatmaps = ops.draw_gaussian(ct_heatmaps, [[category, xct, yct, radius]], scaling_factor=5)

else:
# TODO: See if this is a typo
Expand Down Expand Up @@ -126,12 +137,78 @@ def _parse_train_data(self, decoded_tensors):
images: the image tensor.
labels: a dict of Tensors that contains labels.
"""
print('running _parse_train_data')
# TODO: input size, output size
image = decoded_tensors["image"]
image = decoded_tensors["image"] / 255

labels = self._generate_heatmap(
decoded_tensors["groundtruth_boxes"],
output_size, input_size
)
boxes = decoded_tensors['groundtruth_boxes']
classes = decoded_tensors['groundtruth_classes']

image_shape = tf.shape(image)[:2]

print('image_shape: ', image_shape)
#CROP
# if self._aug_rand_zoom > 0.0 and self._mosaic_frequency > 0.0:
# zfactor = preprocessing_ops.rand_uniform_strong(self._aug_rand_zoom, 1.0)
if self._aug_rand_zoom > 0.0:
zfactor = preprocessing_ops.rand_scale(self._aug_rand_zoom)
else:
zfactor = tf.convert_to_tensor(1.0)

# # TODO: random_op_image not defined
# image, crop_info = preprocessing_ops.random_op_image(
# image, self._jitter_im, zfactor, zfactor, self._aug_rand_translate)

image = tf.image.stateless_random_crop(image, size=[image_shape[0]*zfactor, image_shape[1]*zfactor, 3], seed = seed)

#RESIZE
shape = tf.shape(image)
width = shape[1]
height = shape[0]
image, boxes, classes = preprocessing_ops.resize_crop_filter(
image,
boxes,
classes,
default_width=width, # randscale * self._net_down_scale,
default_height=height, # randscale * self._net_down_scale,
target_width=self._image_w,
target_height=self._image_h,
randomize=False)

#CLIP DETECTION TO BOUNDARIES
boxes = box_ops.clip_boxes(boxes, shape)

#RANDOM HORIZONTAL FLIP
if self._random_flip:
image, boxes, _ = preprocess_ops.random_horizontal_flip(
image, boxes, seed=self._seed)

# Color and lighting jittering
image = tf.image.rgb_to_hsv(image)
i_h, i_s, i_v = tf.split(image, 3, axis=-1)
if self._aug_rand_hue:
delta = preprocessing_ops.rand_uniform_strong(
-0.1, 0.1
) # tf.random.uniform([], minval= -0.1,maxval=0.1, seed=self._seed, dtype=tf.float32)
i_h = i_h + delta # Hue
i_h = tf.clip_by_value(i_h, 0.0, 1.0)
if self._aug_rand_saturation:
delta = preprocessing_ops.rand_scale(
0.75
) # tf.random.uniform([], minval= 0.5,maxval=1.1, seed=self._seed, dtype=tf.float32)
i_s = i_s * delta
if self._aug_rand_brightness:
delta = preprocessing_ops.rand_scale(
0.75
) # tf.random.uniform([], minval= -0.15,maxval=0.15, seed=self._seed, dtype=tf.float32)
i_v = i_v * delta
image = tf.concat([i_h, i_s, i_v], axis=-1)
image = tf.image.hsv_to_rgb(image)

return image, labels

def _parse_eval_data(self, data):
Expand Down Expand Up @@ -170,7 +247,9 @@ def generate_heatmaps(self, dectections):

# tl_heatmaps, br_heatmaps, ct_heatmaps = generate_heatmaps(1, 2, (416, 416), detections)
# ct_heatmaps[batch_id, class_id, ...]
plt.imshow(ct_heatmaps[0, ...])
plt.show()

# plt.imshow(ct_heatmaps[0, ...])
# plt.show()

# This is to run the test
# tf.test.main()
8 changes: 4 additions & 4 deletions centernet/ops/preprocessing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def draw_gaussian(heatmap, blobs, scaling_factor=1, dtype=tf.float32):
left, right = tf.math.minimum(x, radius), tf.math.minimum(width - x, radius + 1)
top, bottom = tf.math.minimum(y, radius), tf.math.minimum(height - y, radius + 1)

print('heatmap ',heatmap)
print(len(heatmap))
print('category ',category)
# print('heatmap ',heatmap)
# print(len(heatmap))
# print('category ',category)

# TODO: make sure this replicates original functionality
# masked_heatmap = heatmap[0, category, y - top:y + bottom, x - left:x + right]
Expand Down Expand Up @@ -180,7 +180,7 @@ def draw_gaussian(heatmap, blobs, scaling_factor=1, dtype=tf.float32):
heatmap_mask = heatmap_mask_ta.stack()
heatmap_mask = tf.reshape(heatmap_mask, (-1, 3))
heatmap = tf.tensor_scatter_nd_max(heatmap, heatmap_mask, masked_gaussian * scaling_factor)
print('after ',heatmap)
# print('after ',heatmap)
return heatmap

# def draw_gaussian(heatmap, category, center, radius, scaling_factor=1):
Expand Down