arbre fonctionne pour 5 datasets
This commit is contained in:
parent
54badfdff4
commit
69dbe483f2
2 changed files with 23 additions and 3 deletions
|
@ -45,7 +45,10 @@ class DecisionTree: #nom de la class à changer
|
|||
sorted_labels = train_labels[attribute_sort_order]
|
||||
lags = np.hstack((np.array([False]),sorted_labels[:-1] != sorted_labels[1:]))
|
||||
lags2 = np.hstack((np.array([False]),lags))[:-1]
|
||||
potential_splits = 0.5*train[attribute_sort_order,attribute][lags]+0.5*train[attribute_sort_order,attribute][lags2]
|
||||
potential_splits = (
|
||||
0.5*train[attribute_sort_order,attribute][lags]+
|
||||
0.5*train[attribute_sort_order,attribute][lags2]
|
||||
)
|
||||
if (len(potential_splits)==0):
|
||||
potential_splits = np.array([np.median(train[attribute_sort_order,attribute])])
|
||||
#print("Potential Split:"+str(potential_splits))
|
||||
|
@ -115,13 +118,13 @@ class DecisionTree: #nom de la class à changer
|
|||
# pour chaque valeur de l'attribut, faire un sous-arbre
|
||||
if (self.attribute_type=="continuous"):
|
||||
for v in [True,False]:
|
||||
print("Nouvelle branche: l'attribut "+str(a_max)+"<="+str(a_max_split)+" est: "+str(v))
|
||||
#print("Nouvelle branche: l'attribut "+str(a_max)+"<="+str(a_max_split)+" est: "+str(v))
|
||||
train_pos = np.where((train[:,a_max] <= a_max_split) == v)
|
||||
subtree = self.decision_tree_learning(train[train_pos],train_labels[train_pos],attributes,train_labels)
|
||||
tree.append(("Branche",a_max,a_max_split,v,subtree))
|
||||
if (self.attribute_type=="discrete"):
|
||||
for v in np.unique(train[:,a_max]):
|
||||
print("Nouvelle branche: l'attribut "+str(a_max)+" est: "+str(v))
|
||||
#print("Nouvelle branche: l'attribut "+str(a_max)+" est: "+str(v))
|
||||
train_pos = np.where(train[:,a_max] == v)
|
||||
subtree = self.decision_tree_learning(train[train_pos],train_labels[train_pos],attributes,train_labels)
|
||||
tree.append(("Branche",a_max,v,subtree))
|
||||
|
|
17
Code/main.py
17
Code/main.py
|
@ -20,6 +20,23 @@ dt1.train(train1, train_labels1)
|
|||
dt1.predict(test1[0],test_labels1[0])
|
||||
dt1.test(test1, test_labels1)
|
||||
|
||||
dt2 = DecisionTree.DecisionTree(attribute_type="discrete")
|
||||
dt2.train(train2, train_labels2)
|
||||
dt2.tree
|
||||
dt2.predict(test2[0],test_labels2[0])
|
||||
dt2.test(test2, test_labels2)
|
||||
|
||||
dt3 = DecisionTree.DecisionTree(attribute_type="discrete")
|
||||
dt3.train(train3, train_labels3)
|
||||
dt3.tree
|
||||
dt3.predict(test3[0],test_labels3[0])
|
||||
dt3.test(test3, test_labels3)
|
||||
|
||||
dt4 = DecisionTree.DecisionTree(attribute_type="discrete")
|
||||
dt4.train(train4, train_labels4)
|
||||
dt4.tree
|
||||
dt4.predict(test4[0],test_labels4[0])
|
||||
dt4.test(test4, test_labels4)
|
||||
|
||||
dt5 = DecisionTree.DecisionTree(attribute_type="discrete")
|
||||
dt5.train(train5, train_labels5)
|
||||
|
|
Loading…
Reference in a new issue