67 lines
No EOL
2 KiB
Python
67 lines
No EOL
2 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on Fri Mar 22 23:55:53 2019
|
|
Matrice de confusion et calcul des métriques
|
|
@author: François Pelletier
|
|
"""
|
|
|
|
import numpy as np
|
|
|
|
def confusion_matrix(obs_labels,pred_labels):
|
|
"""
|
|
Retourne la matrice de confusion
|
|
Prend en entrée deux vecteurs d'étiquettes: observations et prédictions
|
|
Retourne une matrice NumPy
|
|
"""
|
|
unique_obs_labels=np.unique(obs_labels)
|
|
nb_unique_obs_labels=(unique_obs_labels.shape)[0]
|
|
|
|
confusion_matrix = np.zeros((nb_unique_obs_labels,nb_unique_obs_labels))
|
|
|
|
for observed,predicted in zip(obs_labels,pred_labels):
|
|
confusion_matrix[observed][predicted] += 1
|
|
return confusion_matrix
|
|
|
|
def prediction_metrics(cm,obs_labels,pred_labels):
|
|
"""
|
|
Cette fonction retourne les métriques accuracy, precision et recall
|
|
Elle prend en entrée la matrice de confusion et les vecteurs d'étiquettes: observations et prédictions
|
|
accuracy=(tp+tn)/all
|
|
precision=tp/(tp+fp)
|
|
recall=tp/(tp+fn)
|
|
"""
|
|
|
|
accuracy = (obs_labels == pred_labels).sum() / float(len(obs_labels))
|
|
precision=[]
|
|
recall=[]
|
|
for label_num in np.unique(obs_labels):
|
|
try:
|
|
myPrecision = cm[label_num,label_num] / sum(cm[:,label_num])
|
|
if (not np.any(np.isnan(myPrecision))):
|
|
precision.append(myPrecision)
|
|
except:
|
|
pass
|
|
try:
|
|
myRecall = cm[label_num,label_num] / sum(cm[label_num,:])
|
|
if (not np.any(np.isnan(myRecall))):
|
|
recall.append(myRecall)
|
|
except:
|
|
pass
|
|
|
|
return accuracy, precision, recall
|
|
|
|
def print_prediction_metrics(cm,accuracy,precision,recall,compute_time):
|
|
"""
|
|
Cette fonction imprime la matrice de confusion et les métriques
|
|
"""
|
|
print("Matrice de confusion:")
|
|
print(cm)
|
|
print("\nExactitude:")
|
|
print(accuracy)
|
|
print("\nPrécision:")
|
|
print(precision)
|
|
print("\nRappel:")
|
|
print(recall)
|
|
print("\nCalculé en:")
|
|
print(str(compute_time)+"s") |