In [1]:
import gym
import numpy as np
In [2]:
def select_a_with_epsilon_greedy(curr_s, q_value, epsilon=0.1):
    a = np.argmax(q_value[curr_s, :])
    
    if np.random.rand() < epsilon:
        all_actions = range(Q.shape[1])
        del all_actions[a]
        a = np.random.choice(all_actions)
    
    return a
In [4]:
env = gym.make('FrozenLake-v0')
[2017-04-08 16:51:58,087] Making new env: FrozenLake-v0
In [5]:
print env.observation_space, env.action_space
Discrete(16) Discrete(4)
In [10]:
Q = np.random.randn(env.observation_space.n, env.action_space.n)
#Q = np.zeros([env.observation_space.n, env.action_space.n])
In [11]:
num_episodes = 1000
rewards = []
alpha = 0.90
gamma = 0.95
epsilon = 0.3
epsilon_decay = 0.995
In [12]:
for i_episode in range(num_episodes):
    state = env.reset()
    reward_episode = 0
    done = False
    for t in range(600):
        action = select_a_with_epsilon_greedy(state, Q)
        new_state, reward, done, _  = env.step(action)
        Q[state, action] = (1 - alpha)*Q[state, action] + alpha*(reward + gamma*np.max(Q[new_state,:]))
        reward_episode += reward
        state = new_state
        if done == True:
            epsilon *= epsilon_decay
            break
    rewards.append(reward_episode)
In [13]:
print "Average reward over time: " +  str(sum(rewards)/num_episodes)
Average reward over time: 0.022
In [14]:
print "Q-values", Q
Q-values [[ 0.72004006  0.60984445  0.63523615  0.76008057]
 [ 0.63221879  0.58038665  0.60240954  0.57760969]
 [ 0.75129567  1.18166167  0.76814843  0.73413324]
 [ 0.78507737  0.80026918  1.24570388  0.7945747 ]
 [ 0.59323982  0.66832636  0.82429042  0.57772198]
 [ 0.14779713 -0.43395151  0.66563234 -0.27506472]
 [ 0.69323003  0.69564561  1.1045144   0.69568877]
 [ 0.07449064 -0.46787468  1.3325042  -0.21237733]
 [ 0.78729121  0.76354526  0.73244601  0.69499116]
 [ 0.62201938  1.87268469  0.70941668  0.65555688]
 [ 1.20459361  2.63625257  1.53260809  2.26849456]
 [-0.40443151  2.81133734 -2.27908152  0.54970654]
 [ 0.63929857 -2.89408856  0.6868482   0.80993811]
 [ 0.83136056  1.68283836  2.11032436  1.18025433]
 [ 1.84899468  1.37367502  1.67876782  1.18209879]
 [-0.67431893  0.66180157 -1.3182722  -1.11126144]]
In [15]:
env.render()
  (Left)
SFFF
FHFH
FFFH
HFFG
In [ ]: