Wednesday 25 November 2020

Calculate tensorflow Metric using more than one batch at a time

I'm using tf.keras and I have a metric that I'd like to calculate where I need multiple batches of validation data in order to calculate it reliably. Is there some way to accumulate batches before calculating the metric?

I'd like to do something like this:

class MultibatchMetric(tf.keras.metrics.Metric):
    def __init__(self, num_batches, name="sdr_metric", **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_batches = num_batches
        self.batch_accumulator = []
        self.my_metric = []

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.batch_accumulator.append((y_true, y_pred))
        if len(self.batch_accumulator) >= self.num_batches:
            metric = custom_multibatch_metric_func(self.batch_accumulator)
            self.my_metric.append(metric)
            self.batch_accumulator = []

    def result(self):
        return mean(self.my_metric)

    def reset_states(self):
        self.my_metric = []
        self.batch_accumulator = []

However, this all needs to occur on the tensorflow graph, severely complicating things.



from Calculate tensorflow Metric using more than one batch at a time

No comments:

Post a Comment