Réseaux de neurones avec R (#2)
Dans le post précédent, la classification d’un dataset relativement simple a été réalisée de manière très efficace avec un réseau de neurones mais également avec une technique beaucoup plus conventionnelle: une régression logistique.
Ici, on va voir que les deux approches se détachent lorsque la dimensionnalité augmente. En effet, les ANN sont beaucoup moins sensibles au phénomène connu sous le nom de fléau de la dimension.
On va ici utiliser un autre dataset de l’étude « Breast Cancer Wisconsin (Diagnostic) » de l’UCI: wdbc.data dont la description est ici: wdbc.names.
L’étude contient 569 cas dont 37% correspondent à une tumeur maligne. On a cette fois-ci 30 prédicteurs (contre 9 dans le billet précédent).
On charge le dataset dans un dataframe puis on procède à quelques ajustements:
- Suppression de la premiere colonne (V1) qui correspond à un ID patient.
- Binarisation (0/1) de la colonne V2 qui correspond au diagnostic (codification initiale: B/M)
- Centrage/réduction à l’aide de la fonction « scale » de toutes les valeurs du dataframe (sauf celles de la colonne V2 qui est binarisée)
> breastcancer <- read.csv(url("https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data"), header=FALSE)
>
> breastcancer <- breastcancer[-1]
> breastcancer$V2 <- as.character(levels(breastcancer$V2))[breastcancer$V2]
> breastcancer$V2 <- as.numeric(ifelse(breastcancer$V2=="B",0,1))
> breastcancer[,-1] <- scale(breastcancer[,-1])[,]
> summary(breastcancer)
V2 V3 V4 V5 V6 V7 V8 V9
Min. :0.0000 Min. :-2.0279 Min. :-2.2273 Min. :-1.9828 Min. :-1.4532 Min. :-3.10935 Min. :-1.6087 Min. :-1.1139
1st Qu.:0.0000 1st Qu.:-0.6888 1st Qu.:-0.7253 1st Qu.:-0.6913 1st Qu.:-0.6666 1st Qu.:-0.71034 1st Qu.:-0.7464 1st Qu.:-0.7431
Median :0.0000 Median :-0.2149 Median :-0.1045 Median :-0.2358 Median :-0.2949 Median :-0.03486 Median :-0.2217 Median :-0.3419
Mean :0.3726 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.00000 Mean : 0.0000 Mean : 0.0000
3rd Qu.:1.0000 3rd Qu.: 0.4690 3rd Qu.: 0.5837 3rd Qu.: 0.4992 3rd Qu.: 0.3632 3rd Qu.: 0.63564 3rd Qu.: 0.4934 3rd Qu.: 0.5256
Max. :1.0000 Max. : 3.9678 Max. : 4.6478 Max. : 3.9726 Max. : 5.2459 Max. : 4.76672 Max. : 4.5644 Max. : 4.2399
V10 V11 V12 V13 V14 V15 V16 V17
Min. :-1.2607 Min. :-2.74171 Min. :-1.8183 Min. :-1.0590 Min. :-1.5529 Min. :-1.0431 Min. :-0.7372 Min. :-1.7745
1st Qu.:-0.7373 1st Qu.:-0.70262 1st Qu.:-0.7220 1st Qu.:-0.6230 1st Qu.:-0.6942 1st Qu.:-0.6232 1st Qu.:-0.4943 1st Qu.:-0.6235
Median :-0.3974 Median :-0.07156 Median :-0.1781 Median :-0.2920 Median :-0.1973 Median :-0.2864 Median :-0.3475 Median :-0.2201
Mean : 0.0000 Mean : 0.00000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000
3rd Qu.: 0.6464 3rd Qu.: 0.53031 3rd Qu.: 0.4706 3rd Qu.: 0.2659 3rd Qu.: 0.4661 3rd Qu.: 0.2428 3rd Qu.: 0.1067 3rd Qu.: 0.3680
Max. : 3.9245 Max. : 4.48081 Max. : 4.9066 Max. : 8.8991 Max. : 6.6494 Max. : 9.4537 Max. :11.0321 Max. : 8.0229
V18 V19 V20 V21 V22 V23 V24 V25
Min. :-1.2970 Min. :-1.0566 Min. :-1.9118 Min. :-1.5315 Min. :-1.0960 Min. :-1.7254 Min. :-2.22204 Min. :-1.6919
1st Qu.:-0.6923 1st Qu.:-0.5567 1st Qu.:-0.6739 1st Qu.:-0.6511 1st Qu.:-0.5846 1st Qu.:-0.6743 1st Qu.:-0.74797 1st Qu.:-0.6890
Median :-0.2808 Median :-0.1989 Median :-0.1404 Median :-0.2192 Median :-0.2297 Median :-0.2688 Median :-0.04348 Median :-0.2857
Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.00000 Mean : 0.0000
3rd Qu.: 0.3893 3rd Qu.: 0.3365 3rd Qu.: 0.4722 3rd Qu.: 0.3554 3rd Qu.: 0.2884 3rd Qu.: 0.5216 3rd Qu.: 0.65776 3rd Qu.: 0.5398
Max. : 6.1381 Max. :12.0621 Max. : 6.6438 Max. : 7.0657 Max. : 9.8429 Max. : 4.0906 Max. : 3.88249 Max. : 4.2836
V26 V27 V28 V29 V30 V31 V32
Min. :-1.2213 Min. :-2.6803 Min. :-1.4426 Min. :-1.3047 Min. :-1.7435 Min. :-2.1591 Min. :-1.6004
1st Qu.:-0.6416 1st Qu.:-0.6906 1st Qu.:-0.6805 1st Qu.:-0.7558 1st Qu.:-0.7557 1st Qu.:-0.6413 1st Qu.:-0.6913
Median :-0.3409 Median :-0.0468 Median :-0.2693 Median :-0.2180 Median :-0.2233 Median :-0.1273 Median :-0.2163
Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000
3rd Qu.: 0.3573 3rd Qu.: 0.5970 3rd Qu.: 0.5392 3rd Qu.: 0.5307 3rd Qu.: 0.7119 3rd Qu.: 0.4497 3rd Qu.: 0.4504
Max. : 5.9250 Max. : 3.9519 Max. : 5.1084 Max. : 4.6965 Max. : 2.6835 Max. : 6.0407 Max. : 6.8408
>
On divise ensuite le dataset breastcancer en un échantillon d’apprentissage (70%) et un échantillon de test (30%). L’appel a set.seed permet d’obtenir des résultats reproductibles.
> set.seed(1234) > sub <- sample(nrow(breastcancer), floor(nrow(breastcancer) * 0.70)) > breastcancer_train <- breastcancer[sub,] > breastcancer_test <- breastcancer[-sub,] >
On tente de construire un modèle de régression logistique:
> breastcancer_train$V2 <- as.factor(breastcancer_train$V2) > log_model <- glm(V2 ~ .,family=binomial,data=breastcancer_train) Warning messages: 1: glm.fit: algorithm did not converge 2: glm.fit: fitted probabilities numerically 0 or 1 occurred >
L’algorithme ne converge pas.
On peut aussi tenter une stepwise regression en autorisant un nombre d’itérations très important (10000) de manière à maximiser la probabilité de détermination d’un ensemble pertinent de prédicteurs:
> step(log_model, dir="backward", trace=0, steps=10000)
Call: glm(formula = V2 ~ V3 + V8 + V9 + V12 + V14 + V16 + V17 + V22 +
V23 + V24 + V30, family = binomial, data = breastcancer_train)
Coefficients:
(Intercept) V3 V8 V9 V12 V14 V16 V17 V22 V23 V24 V30
197.7 -841.6 -532.0 319.8 543.8 -125.8 1135.4 145.3 -481.3 1860.9 367.8 700.4
Degrees of Freedom: 397 Total (i.e. Null); 386 Residual
Null Deviance: 531.2
Residual Deviance: 5.285e-06 AIC: 24
There were 50 or more warnings (use warnings() to see the first 50)
> tail(warnings())
Warning messages:
1: glm.fit: algorithm did not converge
2: glm.fit: fitted probabilities numerically 0 or 1 occurred
3: glm.fit: algorithm did not converge
4: glm.fit: fitted probabilities numerically 0 or 1 occurred
5: glm.fit: algorithm did not converge
6: glm.fit: fitted probabilities numerically 0 or 1 occurred
>
Même chose ici. L’algorithme ne parvient toujours pas à converger…
Essayons maintenant de construire un ANN avec la fonction nnet. A l’instar du billet précédent, j’utilise arbitrairement 15 neurones au sein de la couche cachée:
> library(nnet)
> breastcancer_train$V2 <- as.numeric(levels(breastcancer_train$V2))[breastcancer_train$V2]
> nn <- nnet(V2 ~ ., data=breastcancer_train, size=15)
# weights: 481
initial value 50.360967
iter 10 value 23.745479
iter 20 value 18.998525
iter 30 value 16.996355
iter 40 value 15.005363
iter 50 value 15.000040
iter 60 value 14.998821
iter 70 value 14.003261
iter 80 value 14.000612
iter 90 value 14.000243
iter 100 value 14.000087
final value 14.000087
stopped after 100 iterations
> pred <- predict(nn, newdata=breastcancer_test, supplemental_cols=c("V2"))
> pred.bin <- ifelse(pred>0.5,1,0)
> table(pred.bin, breastcancer_test$V2)
pred.bin 0 1
0 89 1
1 12 69
>
Contrairement à la régression logistique, le modèle ANN permet d’obtenir une excellente performance de classification (92%) sur l’échantillon de test.
Dans le prochain post on verra un exemple de beaucoup plus grande envergure…