From ee8c7cbb2afe979305810d14093db0fe3dfa5124 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Pelletier?= Date: Mon, 29 Apr 2019 20:20:05 -0400 Subject: [PATCH] maintenant, l'arbre se construit et se parse correctement --- Code/DecisionTree.py | 46 ++++++++++++++++++++++++++++++++++---------- Code/main.py | 4 +++- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/Code/DecisionTree.py b/Code/DecisionTree.py index b3ddd11..d2c66c8 100644 --- a/Code/DecisionTree.py +++ b/Code/DecisionTree.py @@ -35,24 +35,29 @@ class DecisionTree: #nom de la class à changer _,_,counts = np.unique(train_labels,return_index=True, return_counts=True) total_count = sum(counts) entropie_total = -sum(counts/total_count * np.log2(counts/total_count)) - + #print("Entropie Total:"+str(entropie_total)) #Trouver split attribute_sort_order = np.argsort(train[:,attribute]) sorted_labels = train_labels[attribute_sort_order] lags = np.hstack((np.array([False]),sorted_labels[:-1] != sorted_labels[1:])) - potential_splits = train[attribute_sort_order,attribute][lags] + lags2 = np.hstack((np.array([False]),lags))[:-1] + potential_splits = 0.5*train[attribute_sort_order,attribute][lags]+0.5*train[attribute_sort_order,attribute][lags2] + if (len(potential_splits)==0): + potential_splits = np.array([np.median(train[attribute_sort_order,attribute])]) + #print("Potential Split:"+str(potential_splits)) split_gain = [] for v in potential_splits: - split_labels_1 = train_labels[train[:,attribute] < v] - split_labels_2 = train_labels[train[:,attribute] >= v] + split_labels_1 = train_labels[train[:,attribute] <= v] + split_labels_2 = train_labels[train[:,attribute] > v] _,_,counts1 = np.unique(split_labels_1,return_index=True, return_counts=True) total_count1 = sum(counts1) entropie_total1 = -sum(counts1/total_count1 * np.log2(counts1/total_count1)) _,_,counts2 = np.unique(split_labels_2,return_index=True, return_counts=True) total_count2 = sum(counts2) entropie_total2 = -sum(counts2/total_count2 * np.log2(counts2/total_count2)) - split_gain.append(entropie_total+(total_count1/total_count*entropie_total1+total_count2/total_count*entropie_total2)) + split_gain.append(entropie_total-(total_count1/total_count*entropie_total1+total_count2/total_count*entropie_total2)) #Valeur unique attribut + #print("Split Gain:"+str(split_gain)) best_split = potential_splits[np.argmax(split_gain)] best_gain = max(split_gain) @@ -65,13 +70,19 @@ class DecisionTree: #nom de la class à changer classes_uniques = np.unique(train_labels) # la feuille est vide if (n_examples == 0): - return list(("Feuille",self.plurality_value(parent_examples))) + l1 = [] + l1.append(("Feuille",self.plurality_value(parent_examples))) + return list(l1) # tous les exemples ont la même classe elif len(classes_uniques)==1: - return list(("Feuille",classes_uniques[0])) + l1 = [] + l1.append(("Feuille",classes_uniques[0])) + return l1 # la liste d'attributs est vides elif (sum(attributes)==0): - return list(("Feuille",self.plurality_value(train_labels))) + l1 = [] + l1.append(("Feuille",self.plurality_value(train_labels))) + return l1 else: # Calcul du gain attr = np.where(attributes==1)[0] @@ -89,7 +100,8 @@ class DecisionTree: #nom de la class à changer attributes[a_max]=0 # pour chaque valeur de l'attribut, faire un sous-arbre for v in [True,False]: - train_pos = np.where((train[:,a_max] < a_max_split) == v) + print("Nouvelle branche: l'attribut "+str(a_max)+"<="+str(a_max_split)+" est: "+str(v)) + train_pos = np.where((train[:,a_max] <= a_max_split) == v) subtree = self.decision_tree_learning(train[train_pos],train_labels[train_pos],attributes,train_labels) tree.append(("Branche",a_max,a_max_split,v,subtree)) return tree @@ -120,8 +132,17 @@ class DecisionTree: #nom de la class à changer self.tree = self.decision_tree_learning(train, train_labels, attributes, None) + def extract_tree(self,myTree,exemple): + for b in myTree: + # On a atteint la feuille + if b[0] == 'Feuille': + return b[1] + # On est dans une branche, on teste le split + if ((exemple[b[1]] <= b[2]) == b[3]): + return self.extract_tree(b[4],exemple) + return None + - def predict(self, exemple, label): """ Prédire la classe d'un exemple donné en entrée @@ -131,6 +152,11 @@ class DecisionTree: #nom de la class à changer alors l'exemple est bien classifié, si non c'est une missclassification """ + + return self.extract_tree(self.tree,exemple) + + + def test(self, test, test_labels): """ diff --git a/Code/main.py b/Code/main.py index f9edf32..e1e875e 100644 --- a/Code/main.py +++ b/Code/main.py @@ -12,4 +12,6 @@ dt = DecisionTree.DecisionTree() dt.train(train,train_labels) -dt.tree \ No newline at end of file +dt.tree + +[(dt.predict(exemple,label),label) for exemple,label in zip(train,train_labels)] \ No newline at end of file