Friday, 24 September 2021

Plotting top n features using permutation importance

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.impute import SimpleImputer
from sklearn.inspection import permutation_importance
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder


result = permutation_importance(rf,
                                X_test,
                                y_test,
                                n_repeats=10,
                                random_state=42,
                                n_jobs=2)
sorted_idx = result.importances_mean.argsort()
        

fig, ax = plt.subplots()
ax.boxplot(result.importances[sorted_idx].T,
           vert=False,
           labels=X_test.columns[sorted_idx])

ax.set_title("Permutation Importances (test set)")
fig.tight_layout()
plt.show()

In the code above, taken from this example in the documentation, is there a way to plot the top 3 features only instead of all the features?



from Plotting top n features using permutation importance

No comments:

Post a Comment