maintenant, l'arbre se construit et se parse correctement
This commit is contained in:
parent
108839fcea
commit
ee8c7cbb2a
2 changed files with 39 additions and 11 deletions
|
@ -35,24 +35,29 @@ class DecisionTree: #nom de la class à changer
|
|||
_,_,counts = np.unique(train_labels,return_index=True, return_counts=True)
|
||||
total_count = sum(counts)
|
||||
entropie_total = -sum(counts/total_count * np.log2(counts/total_count))
|
||||
|
||||
#print("Entropie Total:"+str(entropie_total))
|
||||
#Trouver split
|
||||
attribute_sort_order = np.argsort(train[:,attribute])
|
||||
sorted_labels = train_labels[attribute_sort_order]
|
||||
lags = np.hstack((np.array([False]),sorted_labels[:-1] != sorted_labels[1:]))
|
||||
potential_splits = train[attribute_sort_order,attribute][lags]
|
||||
lags2 = np.hstack((np.array([False]),lags))[:-1]
|
||||
potential_splits = 0.5*train[attribute_sort_order,attribute][lags]+0.5*train[attribute_sort_order,attribute][lags2]
|
||||
if (len(potential_splits)==0):
|
||||
potential_splits = np.array([np.median(train[attribute_sort_order,attribute])])
|
||||
#print("Potential Split:"+str(potential_splits))
|
||||
split_gain = []
|
||||
for v in potential_splits:
|
||||
split_labels_1 = train_labels[train[:,attribute] < v]
|
||||
split_labels_2 = train_labels[train[:,attribute] >= v]
|
||||
split_labels_1 = train_labels[train[:,attribute] <= v]
|
||||
split_labels_2 = train_labels[train[:,attribute] > v]
|
||||
_,_,counts1 = np.unique(split_labels_1,return_index=True, return_counts=True)
|
||||
total_count1 = sum(counts1)
|
||||
entropie_total1 = -sum(counts1/total_count1 * np.log2(counts1/total_count1))
|
||||
_,_,counts2 = np.unique(split_labels_2,return_index=True, return_counts=True)
|
||||
total_count2 = sum(counts2)
|
||||
entropie_total2 = -sum(counts2/total_count2 * np.log2(counts2/total_count2))
|
||||
split_gain.append(entropie_total+(total_count1/total_count*entropie_total1+total_count2/total_count*entropie_total2))
|
||||
split_gain.append(entropie_total-(total_count1/total_count*entropie_total1+total_count2/total_count*entropie_total2))
|
||||
#Valeur unique attribut
|
||||
#print("Split Gain:"+str(split_gain))
|
||||
best_split = potential_splits[np.argmax(split_gain)]
|
||||
best_gain = max(split_gain)
|
||||
|
||||
|
@ -65,13 +70,19 @@ class DecisionTree: #nom de la class à changer
|
|||
classes_uniques = np.unique(train_labels)
|
||||
# la feuille est vide
|
||||
if (n_examples == 0):
|
||||
return list(("Feuille",self.plurality_value(parent_examples)))
|
||||
l1 = []
|
||||
l1.append(("Feuille",self.plurality_value(parent_examples)))
|
||||
return list(l1)
|
||||
# tous les exemples ont la même classe
|
||||
elif len(classes_uniques)==1:
|
||||
return list(("Feuille",classes_uniques[0]))
|
||||
l1 = []
|
||||
l1.append(("Feuille",classes_uniques[0]))
|
||||
return l1
|
||||
# la liste d'attributs est vides
|
||||
elif (sum(attributes)==0):
|
||||
return list(("Feuille",self.plurality_value(train_labels)))
|
||||
l1 = []
|
||||
l1.append(("Feuille",self.plurality_value(train_labels)))
|
||||
return l1
|
||||
else:
|
||||
# Calcul du gain
|
||||
attr = np.where(attributes==1)[0]
|
||||
|
@ -89,7 +100,8 @@ class DecisionTree: #nom de la class à changer
|
|||
attributes[a_max]=0
|
||||
# pour chaque valeur de l'attribut, faire un sous-arbre
|
||||
for v in [True,False]:
|
||||
train_pos = np.where((train[:,a_max] < a_max_split) == v)
|
||||
print("Nouvelle branche: l'attribut "+str(a_max)+"<="+str(a_max_split)+" est: "+str(v))
|
||||
train_pos = np.where((train[:,a_max] <= a_max_split) == v)
|
||||
subtree = self.decision_tree_learning(train[train_pos],train_labels[train_pos],attributes,train_labels)
|
||||
tree.append(("Branche",a_max,a_max_split,v,subtree))
|
||||
return tree
|
||||
|
@ -120,8 +132,17 @@ class DecisionTree: #nom de la class à changer
|
|||
self.tree = self.decision_tree_learning(train, train_labels, attributes, None)
|
||||
|
||||
|
||||
def extract_tree(self,myTree,exemple):
|
||||
for b in myTree:
|
||||
# On a atteint la feuille
|
||||
if b[0] == 'Feuille':
|
||||
return b[1]
|
||||
# On est dans une branche, on teste le split
|
||||
if ((exemple[b[1]] <= b[2]) == b[3]):
|
||||
return self.extract_tree(b[4],exemple)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def predict(self, exemple, label):
|
||||
"""
|
||||
Prédire la classe d'un exemple donné en entrée
|
||||
|
@ -131,6 +152,11 @@ class DecisionTree: #nom de la class à changer
|
|||
alors l'exemple est bien classifié, si non c'est une missclassification
|
||||
|
||||
"""
|
||||
|
||||
return self.extract_tree(self.tree,exemple)
|
||||
|
||||
|
||||
|
||||
|
||||
def test(self, test, test_labels):
|
||||
"""
|
||||
|
|
|
@ -12,4 +12,6 @@ dt = DecisionTree.DecisionTree()
|
|||
|
||||
dt.train(train,train_labels)
|
||||
|
||||
dt.tree
|
||||
dt.tree
|
||||
|
||||
[(dt.predict(exemple,label),label) for exemple,label in zip(train,train_labels)]
|
Loading…
Add table
Reference in a new issue