Monday 27 June 2022

Understanding tf.keras.metrics.Precision and Recall for multiclass classification

I am building a model for a multiclass classification problem and I want to evaluate the model performance using the Recall and Precision. I have 4 classes in the dataset and it is provided in one hot representation.

I was reading the Precision and Recall tf.keras documentation, and have some questions:

  1. When calculating the Precision and Recall for the multi-class classification, how can we take the average of all of the labels, meaning the global precision & Recall? is it calculated with macro or micro since it is not specified in the documentation as in the Sikit learn.
  2. If I want to calculate the precision & Recall for each label separately, can I use the argument class_id for each label to do one_vs_rest or binary classification. Like what I have done in the code below?
  3. can I use the argument top_k with the value top_k=2 would be helpful here or it is not suitable for my classification of 4 classes only?
  4. While I am measuring the performance of each class, What could be the difference when I set the top_k=1 and not setting top_koverall?
model.compile(
      optimizer='sgd',
      loss=tf.keras.losses.CategoricalCrossentropy(),
      metrics=[tf.keras.metrics.CategoricalAccuracy(),
               ##class 0
               tf.keras.metrics.Precision(class_id=0,top_k=2), 
               tf.keras.metrics.Recall(class_id=0,top_k=2),
              ##class 1
               tf.keras.metrics.Precision(class_id=1,top_k=2), 
               tf.keras.metrics.Recall(class_id=1,top_k=2),
              ##class 2
               tf.keras.metrics.Precision(class_id=2,top_k=2), 
               tf.keras.metrics.Recall(class_id=2,top_k=2),
              ##class 3
               tf.keras.metrics.Precision(class_id=3,top_k=2), 
               tf.keras.metrics.Recall(class_id=3,top_k=2),
])

Any clarification of this function will be appreciated. Thanks in advance



from Understanding tf.keras.metrics.Precision and Recall for multiclass classification

No comments:

Post a Comment