I'm using tf.keras
and I have a metric that I'd like to calculate where I need multiple batches of validation data in order to calculate it reliably. Is there some way to accumulate batches before calculating the metric?
I'd like to do something like this:
class MultibatchMetric(tf.keras.metrics.Metric):
def __init__(self, num_batches, name="sdr_metric", **kwargs):
super().__init__(name=name, **kwargs)
self.num_batches = num_batches
self.batch_accumulator = []
self.my_metric = []
def update_state(self, y_true, y_pred, sample_weight=None):
self.batch_accumulator.append((y_true, y_pred))
if len(self.batch_accumulator) >= self.num_batches:
metric = custom_multibatch_metric_func(self.batch_accumulator)
self.my_metric.append(metric)
self.batch_accumulator = []
def result(self):
return mean(self.my_metric)
def reset_states(self):
self.my_metric = []
self.batch_accumulator = []
However, this all needs to occur on the tensorflow graph, severely complicating things.
from Calculate tensorflow Metric using more than one batch at a time
No comments:
Post a Comment