How to work with grouped data¶
One of the often appearing properties of the Data Science problems is the natural grouping of the data. You could for instance have multiple samples for the same customer. In such case, you need to make sure that all samples from a given group are in the same fold e.g. in Cross-Validation.
Let's prepare a dataset with groups.
%%capture
!pip install probatus
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=100, n_features=10, random_state=42)
groups = [i % 5 for i in range(100)]
groups[:10]
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
The integers in groups
variable indicate the group id, to which a given sample belongs.
One of the easiest ways to ensure that the data is split using the information about groups is using from sklearn.model_selection import GroupKFold
. You can also read more about other ways of splitting data with groups in sklearn here.
from sklearn.model_selection import GroupKFold
cv = list(GroupKFold(n_splits=5).split(X, y, groups=groups))
Such variable can be passed to the cv
parameter in probatus
as well as to hyperparameter optimization e.g. RandomizedSearchCV
classes.
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from probatus.feature_elimination import ShapRFECV
model = RandomForestClassifier(random_state=42)
param_grid = {
"n_estimators": [5, 7, 10],
"max_leaf_nodes": [3, 5, 7, 10],
}
search = RandomizedSearchCV(model, param_grid, cv=cv, n_iter=1, random_state=42)
shap_elimination = ShapRFECV(model=search, step=0.2, cv=cv, scoring="roc_auc", n_jobs=3, random_state=42)
report = shap_elimination.fit_compute(X, y)
report
num_features | features_set | eliminated_features | train_metric_mean | train_metric_std | val_metric_mean | val_metric_std | |
---|---|---|---|---|---|---|---|
1 | 10 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | [6, 7] | 0.999562 | 0.000876 | 0.954945 | 0.090110 |
2 | 8 | [0, 1, 2, 3, 4, 5, 8, 9] | [5] | 0.999118 | 0.001081 | 0.945513 | 0.089606 |
3 | 7 | [0, 1, 2, 3, 4, 8, 9] | [4] | 0.999559 | 0.000548 | 0.928749 | 0.137507 |
4 | 6 | [0, 1, 2, 3, 8, 9] | [8] | 0.999179 | 0.001051 | 0.969288 | 0.058854 |
5 | 5 | [0, 1, 2, 3, 9] | [9] | 0.999748 | 0.000237 | 0.961767 | 0.066540 |
6 | 4 | [0, 1, 2, 3] | [1] | 0.999433 | 0.000700 | 0.950816 | 0.090982 |
7 | 3 | [0, 2, 3] | [0] | 0.999120 | 0.000729 | 0.970596 | 0.051567 |
8 | 2 | [2, 3] | [3] | 0.999496 | 0.000617 | 0.938639 | 0.117736 |
9 | 1 | [2] | [] | 0.998424 | 0.001819 | 0.938339 | 0.097936 |