Réseaux de neurones avec R (#3)
Il peut être intéressant de visualiser la topologie d’un ANN. Le blog suivant indique une méthode simple pour y parvenir: https://beckmw.wordpress.com/2013/11/14/visualizing-neural-networks-in-r-update/
Le code de la fonction plot.nnet est disponible sur https://gist.github.com/fawda123/7471137
Essayons à l’aide de cette fonction de représenter le réseau de neurones construit dans ce billet :
> library(devtools)
Warning message:
package ‘devtools’ was built under R version 3.2.5
> source("C:/RTI/Stats/nnet_plot_update.r")
>
> plot.nnet(nn_model)
>
La représentation s’affiche alors:
On retrouve bien:
- les 9 prédicteurs en entrées (associés à 9 neurones d’entrée I1 -> I9)
- les 5 neurones de la couche cachée (H1 -> H5)
- le neurone de sortie O1 (entrainé pour estimer V11)
- les neurones de biais pour chaque couche (B1 et B2)
Les poids synaptiques ne sont pas figurés car cela alourdirait considérablement le graphique. En revanche, l’épaisseur des liens est proportionnelle à l’importance du poids. La couleur informe du signe associé: gris pour un poids négatif, noir pour un poids positif.
Par exemple, dans le graphique ci-dessus, on voit que le neurone H3 a le poids positif le plus important.
On peut le confirmer aisément avec les données numériques:
> summary(nn_model) a 9-5-1 network with 56 weights options were - b->h1 i1->h1 i2->h1 i3->h1 i4->h1 i5->h1 i6->h1 i7->h1 i8->h1 i9->h1 1.04 -11.14 -14.36 -14.44 -12.92 -12.05 -14.21 -12.57 -14.14 -7.14 b->h2 i1->h2 i2->h2 i3->h2 i4->h2 i5->h2 i6->h2 i7->h2 i8->h2 i9->h2 2.07 -11.58 -12.34 -12.19 -9.69 -10.64 -12.48 -11.75 -10.62 -5.45 b->h3 i1->h3 i2->h3 i3->h3 i4->h3 i5->h3 i6->h3 i7->h3 i8->h3 i9->h3 -1.50 3.37 3.68 3.90 3.02 2.87 4.01 3.29 3.28 1.58 b->h4 i1->h4 i2->h4 i3->h4 i4->h4 i5->h4 i6->h4 i7->h4 i8->h4 i9->h4 -0.42 -1.53 -2.51 -1.69 -1.37 -1.76 -3.17 -1.76 -1.90 -1.75 b->h5 i1->h5 i2->h5 i3->h5 i4->h5 i5->h5 i6->h5 i7->h5 i8->h5 i9->h5 -1.60 5.33 7.01 6.26 5.65 5.71 6.81 6.44 5.68 2.15 b->o h1->o h2->o h3->o h4->o h5->o -5.80 -35.31 -5.04 41.26 24.80 2.44 >
