From 69dbe483f26c0a76feb553f131551f6633d6f654 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Pelletier?= Date: Mon, 29 Apr 2019 22:28:32 -0400 Subject: [PATCH] arbre fonctionne pour 5 datasets --- Code/DecisionTree.py | 9 ++++++--- Code/main.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/Code/DecisionTree.py b/Code/DecisionTree.py index b21f7a4..8309ea5 100644 --- a/Code/DecisionTree.py +++ b/Code/DecisionTree.py @@ -45,7 +45,10 @@ class DecisionTree: #nom de la class à changer sorted_labels = train_labels[attribute_sort_order] lags = np.hstack((np.array([False]),sorted_labels[:-1] != sorted_labels[1:])) lags2 = np.hstack((np.array([False]),lags))[:-1] - potential_splits = 0.5*train[attribute_sort_order,attribute][lags]+0.5*train[attribute_sort_order,attribute][lags2] + potential_splits = ( + 0.5*train[attribute_sort_order,attribute][lags]+ + 0.5*train[attribute_sort_order,attribute][lags2] + ) if (len(potential_splits)==0): potential_splits = np.array([np.median(train[attribute_sort_order,attribute])]) #print("Potential Split:"+str(potential_splits)) @@ -115,13 +118,13 @@ class DecisionTree: #nom de la class à changer # pour chaque valeur de l'attribut, faire un sous-arbre if (self.attribute_type=="continuous"): for v in [True,False]: - print("Nouvelle branche: l'attribut "+str(a_max)+"<="+str(a_max_split)+" est: "+str(v)) + #print("Nouvelle branche: l'attribut "+str(a_max)+"<="+str(a_max_split)+" est: "+str(v)) train_pos = np.where((train[:,a_max] <= a_max_split) == v) subtree = self.decision_tree_learning(train[train_pos],train_labels[train_pos],attributes,train_labels) tree.append(("Branche",a_max,a_max_split,v,subtree)) if (self.attribute_type=="discrete"): for v in np.unique(train[:,a_max]): - print("Nouvelle branche: l'attribut "+str(a_max)+" est: "+str(v)) + #print("Nouvelle branche: l'attribut "+str(a_max)+" est: "+str(v)) train_pos = np.where(train[:,a_max] == v) subtree = self.decision_tree_learning(train[train_pos],train_labels[train_pos],attributes,train_labels) tree.append(("Branche",a_max,v,subtree)) diff --git a/Code/main.py b/Code/main.py index aa2e523..e3fc2f5 100644 --- a/Code/main.py +++ b/Code/main.py @@ -20,6 +20,23 @@ dt1.train(train1, train_labels1) dt1.predict(test1[0],test_labels1[0]) dt1.test(test1, test_labels1) +dt2 = DecisionTree.DecisionTree(attribute_type="discrete") +dt2.train(train2, train_labels2) +dt2.tree +dt2.predict(test2[0],test_labels2[0]) +dt2.test(test2, test_labels2) + +dt3 = DecisionTree.DecisionTree(attribute_type="discrete") +dt3.train(train3, train_labels3) +dt3.tree +dt3.predict(test3[0],test_labels3[0]) +dt3.test(test3, test_labels3) + +dt4 = DecisionTree.DecisionTree(attribute_type="discrete") +dt4.train(train4, train_labels4) +dt4.tree +dt4.predict(test4[0],test_labels4[0]) +dt4.test(test4, test_labels4) dt5 = DecisionTree.DecisionTree(attribute_type="discrete") dt5.train(train5, train_labels5)