q learning pas mal avancé
This commit is contained in:
parent
aca1c3c599
commit
77caa10f89
3 changed files with 74 additions and 30 deletions
|
@ -1,3 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# analysis.py
|
||||
# -----------
|
||||
# Licensing Information: You are free to use or extend these projects for
|
||||
|
@ -25,45 +26,45 @@ def question2():
|
|||
return answerDiscount, answerNoise
|
||||
|
||||
def question3a():
|
||||
answerDiscount = None
|
||||
answerNoise = None
|
||||
answerLivingReward = None
|
||||
answerDiscount = 0.9
|
||||
answerNoise = 0.1
|
||||
answerLivingReward = -3
|
||||
return answerDiscount, answerNoise, answerLivingReward
|
||||
# If not possible, return 'NOT POSSIBLE'
|
||||
|
||||
def question3b():
|
||||
answerDiscount = None
|
||||
answerNoise = None
|
||||
answerLivingReward = None
|
||||
answerDiscount = 0.1
|
||||
answerNoise = 0.1
|
||||
answerLivingReward = -2
|
||||
return answerDiscount, answerNoise, answerLivingReward
|
||||
# If not possible, return 'NOT POSSIBLE'
|
||||
|
||||
def question3c():
|
||||
answerDiscount = None
|
||||
answerNoise = None
|
||||
answerLivingReward = None
|
||||
answerDiscount = 0.9
|
||||
answerNoise = 0
|
||||
answerLivingReward = 0
|
||||
return answerDiscount, answerNoise, answerLivingReward
|
||||
# If not possible, return 'NOT POSSIBLE'
|
||||
|
||||
def question3d():
|
||||
answerDiscount = None
|
||||
answerNoise = None
|
||||
answerLivingReward = None
|
||||
answerDiscount = 0.1
|
||||
answerNoise = 0.1
|
||||
answerLivingReward = 1
|
||||
return answerDiscount, answerNoise, answerLivingReward
|
||||
# If not possible, return 'NOT POSSIBLE'
|
||||
|
||||
def question3e():
|
||||
answerDiscount = None
|
||||
answerNoise = None
|
||||
answerLivingReward = None
|
||||
answerDiscount = 0
|
||||
answerNoise = 0
|
||||
answerLivingReward = 1
|
||||
return answerDiscount, answerNoise, answerLivingReward
|
||||
# If not possible, return 'NOT POSSIBLE'
|
||||
|
||||
def question6():
|
||||
answerEpsilon = None
|
||||
answerLearningRate = None
|
||||
return answerEpsilon, answerLearningRate
|
||||
# If not possible, return 'NOT POSSIBLE'
|
||||
# return answerEpsilon, answerLearningRate
|
||||
return 'NOT POSSIBLE'
|
||||
|
||||
if __name__ == '__main__':
|
||||
print 'Answers to analysis questions:'
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# qlearningAgents.py
|
||||
# ------------------
|
||||
# Licensing Information: You are free to use or extend these projects for
|
||||
|
@ -43,6 +44,7 @@ class QLearningAgent(ReinforcementAgent):
|
|||
ReinforcementAgent.__init__(self, **args)
|
||||
|
||||
"*** YOUR CODE HERE ***"
|
||||
self.q_values = {}
|
||||
|
||||
def getQValue(self, state, action):
|
||||
"""
|
||||
|
@ -51,8 +53,11 @@ class QLearningAgent(ReinforcementAgent):
|
|||
or the Q node value otherwise
|
||||
"""
|
||||
"*** YOUR CODE HERE ***"
|
||||
util.raiseNotDefined()
|
||||
|
||||
if (state, action) in self.q_values:
|
||||
return self.q_values.get((state, action))
|
||||
else:
|
||||
return 0
|
||||
# util.raiseNotDefined()
|
||||
|
||||
def computeValueFromQValues(self, state):
|
||||
"""
|
||||
|
@ -62,7 +67,14 @@ class QLearningAgent(ReinforcementAgent):
|
|||
terminal state, you should return a value of 0.0.
|
||||
"""
|
||||
"*** YOUR CODE HERE ***"
|
||||
util.raiseNotDefined()
|
||||
legal_actions = self.getLegalActions(state)
|
||||
|
||||
if len(legal_actions) == 0:
|
||||
return 0.0
|
||||
|
||||
return max([self.getQValue(state, action) for action in legal_actions])
|
||||
|
||||
# util.raiseNotDefined()
|
||||
|
||||
def computeActionFromQValues(self, state):
|
||||
"""
|
||||
|
@ -71,7 +83,23 @@ class QLearningAgent(ReinforcementAgent):
|
|||
you should return None.
|
||||
"""
|
||||
"*** YOUR CODE HERE ***"
|
||||
util.raiseNotDefined()
|
||||
|
||||
legalActions = self.getLegalActions(state)
|
||||
|
||||
if len(legalActions) == 0:
|
||||
return None
|
||||
|
||||
q_values = [self.getQValue(state, action) for action in legalActions]
|
||||
q_max = max(q_values)
|
||||
|
||||
q_max_indices = []
|
||||
for index, value in enumerate(q_values):
|
||||
if value == q_max:
|
||||
q_max_indices.append(index)
|
||||
|
||||
return legalActions[random.choice(q_max_indices)]
|
||||
|
||||
# util.raiseNotDefined()
|
||||
|
||||
def getAction(self, state):
|
||||
"""
|
||||
|
@ -86,11 +114,21 @@ class QLearningAgent(ReinforcementAgent):
|
|||
"""
|
||||
# Pick Action
|
||||
legalActions = self.getLegalActions(state)
|
||||
action = None
|
||||
"*** YOUR CODE HERE ***"
|
||||
util.raiseNotDefined()
|
||||
|
||||
return action
|
||||
"*** YOUR CODE HERE ***"
|
||||
if len(legalActions) == 0:
|
||||
return None
|
||||
|
||||
best_action = self.computeActionFromQValues(state)
|
||||
|
||||
if util.flipCoin(self.epsilon):
|
||||
# Action aléatoire
|
||||
return random.choice(legalActions)
|
||||
else:
|
||||
# Meilleure action
|
||||
return best_action
|
||||
|
||||
# util.raiseNotDefined()
|
||||
|
||||
def update(self, state, action, nextState, reward):
|
||||
"""
|
||||
|
@ -102,7 +140,12 @@ class QLearningAgent(ReinforcementAgent):
|
|||
it will be called on your behalf
|
||||
"""
|
||||
"*** YOUR CODE HERE ***"
|
||||
util.raiseNotDefined()
|
||||
q_value = self.getQValue(state, action)
|
||||
best_value = self.getValue(nextState)
|
||||
new_q_value = (1-self.alpha)*q_value+self.alpha*(reward+self.discount*best_value)
|
||||
self.q_values[(state, action)] = new_q_value
|
||||
self.q_values.update({(state, action): new_q_value})
|
||||
# util.raiseNotDefined()
|
||||
|
||||
def getPolicy(self, state):
|
||||
return self.computeActionFromQValues(state)
|
||||
|
|
|
@ -50,7 +50,7 @@ class ValueIterationAgent(ValueEstimationAgent):
|
|||
"*** YOUR CODE HERE ***"
|
||||
states = self.mdp.getStates()
|
||||
|
||||
print "__init__ ... states: " + str(states)
|
||||
# print "__init__ ... states: " + str(states)
|
||||
|
||||
for i in range(iterations):
|
||||
# On reprend les valeurs de l'itération précédente comme référence
|
||||
|
@ -107,9 +107,9 @@ class ValueIterationAgent(ValueEstimationAgent):
|
|||
return None
|
||||
|
||||
q_values = [self.computeQValueFromValues(state, action) for action in possibleActions]
|
||||
print "computeActionFromValues ... q_values: "+str(q_values)
|
||||
print "index:"+str(q_values.index(max(q_values)))
|
||||
print "action:"+str(possibleActions[q_values.index(max(q_values))])
|
||||
# print "computeActionFromValues ... q_values: "+str(q_values)
|
||||
# print "index:"+str(q_values.index(max(q_values)))
|
||||
# print "action:"+str(possibleActions[q_values.index(max(q_values))])
|
||||
return possibleActions[q_values.index(max(q_values))]
|
||||
|
||||
def getPolicy(self, state):
|
||||
|
|
Loading…
Add table
Reference in a new issue