maintenant, l'arbre se construit et se parse correctement

This commit is contained in:
François Pelletier 2019-04-29 20:20:05 -04:00
parent 108839fcea
commit ee8c7cbb2a
2 changed files with 39 additions and 11 deletions

View file

@ -35,24 +35,29 @@ class DecisionTree: #nom de la class à changer
_,_,counts = np.unique(train_labels,return_index=True, return_counts=True)
total_count = sum(counts)
entropie_total = -sum(counts/total_count * np.log2(counts/total_count))
#print("Entropie Total:"+str(entropie_total))
#Trouver split
attribute_sort_order = np.argsort(train[:,attribute])
sorted_labels = train_labels[attribute_sort_order]
lags = np.hstack((np.array([False]),sorted_labels[:-1] != sorted_labels[1:]))
potential_splits = train[attribute_sort_order,attribute][lags]
lags2 = np.hstack((np.array([False]),lags))[:-1]
potential_splits = 0.5*train[attribute_sort_order,attribute][lags]+0.5*train[attribute_sort_order,attribute][lags2]
if (len(potential_splits)==0):
potential_splits = np.array([np.median(train[attribute_sort_order,attribute])])
#print("Potential Split:"+str(potential_splits))
split_gain = []
for v in potential_splits:
split_labels_1 = train_labels[train[:,attribute] < v]
split_labels_2 = train_labels[train[:,attribute] >= v]
split_labels_1 = train_labels[train[:,attribute] <= v]
split_labels_2 = train_labels[train[:,attribute] > v]
_,_,counts1 = np.unique(split_labels_1,return_index=True, return_counts=True)
total_count1 = sum(counts1)
entropie_total1 = -sum(counts1/total_count1 * np.log2(counts1/total_count1))
_,_,counts2 = np.unique(split_labels_2,return_index=True, return_counts=True)
total_count2 = sum(counts2)
entropie_total2 = -sum(counts2/total_count2 * np.log2(counts2/total_count2))
split_gain.append(entropie_total+(total_count1/total_count*entropie_total1+total_count2/total_count*entropie_total2))
split_gain.append(entropie_total-(total_count1/total_count*entropie_total1+total_count2/total_count*entropie_total2))
#Valeur unique attribut
#print("Split Gain:"+str(split_gain))
best_split = potential_splits[np.argmax(split_gain)]
best_gain = max(split_gain)
@ -65,13 +70,19 @@ class DecisionTree: #nom de la class à changer
classes_uniques = np.unique(train_labels)
# la feuille est vide
if (n_examples == 0):
return list(("Feuille",self.plurality_value(parent_examples)))
l1 = []
l1.append(("Feuille",self.plurality_value(parent_examples)))
return list(l1)
# tous les exemples ont la même classe
elif len(classes_uniques)==1:
return list(("Feuille",classes_uniques[0]))
l1 = []
l1.append(("Feuille",classes_uniques[0]))
return l1
# la liste d'attributs est vides
elif (sum(attributes)==0):
return list(("Feuille",self.plurality_value(train_labels)))
l1 = []
l1.append(("Feuille",self.plurality_value(train_labels)))
return l1
else:
# Calcul du gain
attr = np.where(attributes==1)[0]
@ -89,7 +100,8 @@ class DecisionTree: #nom de la class à changer
attributes[a_max]=0
# pour chaque valeur de l'attribut, faire un sous-arbre
for v in [True,False]:
train_pos = np.where((train[:,a_max] < a_max_split) == v)
print("Nouvelle branche: l'attribut "+str(a_max)+"<="+str(a_max_split)+" est: "+str(v))
train_pos = np.where((train[:,a_max] <= a_max_split) == v)
subtree = self.decision_tree_learning(train[train_pos],train_labels[train_pos],attributes,train_labels)
tree.append(("Branche",a_max,a_max_split,v,subtree))
return tree
@ -120,8 +132,17 @@ class DecisionTree: #nom de la class à changer
self.tree = self.decision_tree_learning(train, train_labels, attributes, None)
def extract_tree(self,myTree,exemple):
for b in myTree:
# On a atteint la feuille
if b[0] == 'Feuille':
return b[1]
# On est dans une branche, on teste le split
if ((exemple[b[1]] <= b[2]) == b[3]):
return self.extract_tree(b[4],exemple)
return None
def predict(self, exemple, label):
"""
Prédire la classe d'un exemple donné en entrée
@ -131,6 +152,11 @@ class DecisionTree: #nom de la class à changer
alors l'exemple est bien classifié, si non c'est une missclassification
"""
return self.extract_tree(self.tree,exemple)
def test(self, test, test_labels):
"""

View file

@ -12,4 +12,6 @@ dt = DecisionTree.DecisionTree()
dt.train(train,train_labels)
dt.tree
dt.tree
[(dt.predict(exemple,label),label) for exemple,label in zip(train,train_labels)]