-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
replaced dice in metrics.py with
class dice(keras.metrics.Metric):
def __init__(self, name="dice", **kwargs):
super(dice, self).__init__(name=name, **kwargs)
self.dice = self.add_variable(shape=(), name="dice", initializer="zeros")
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
eps = tf.keras.backend.epsilon()
intersection = tf.reduce_sum(y_true * y_pred, axis=(1, 2, 3, 4))
union = tf.reduce_sum(y_true, axis=(1, 2, 3, 4)) + tf.reduce_sum(
y_pred, axis=self.axis
)
self.dice = (2 * intersection + eps) / (union + eps)
def result(self):
return self.dice.value
def reset_state(self):
self.dice.assign(0.0)
def get_config(self):
base_config = super(dice, self).get_config()
return base_config
gives the following error
File "/net/vast-storage/scratch/vast/gablab/hgazula/nobrainer_training_scripts/1.2.0/scripts/misc/warm_start_multi_gpu.py", line 60, in <module>
history = bem.fit(
File "/net/vast-storage/scratch/vast/gablab/hgazula/nobrainer/nobrainer/processing/segmentation.py", line 94, in fit
_compile()
File "/net/vast-storage/scratch/vast/gablab/hgazula/nobrainer/nobrainer/processing/segmentation.py", line 79, in _compile
self.model_.compile(
File "/om/user/hgazula/venvs/nobrainer_satra/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/om/user/hgazula/venvs/nobrainer_satra/lib/python3.10/site-packages/keras/src/engine/training.py", line 3893, in _validate_compile
for v in getattr(metric, "variables", []):
TypeError: 'property' object is not iterable
Metadata
Metadata
Assignees
Labels
No labels