Using Deep Q-Network to Learn How To Play Flappy Bird - 解釋
該程式是透過Reinforcement Learning的方式,經由Convolutional neural network從模擬器(螢幕上)讀取畫面,讀取該遊戲情態(s0),並由Deep neural network分析執行相對應的動作(a),觀察所得的獎勵(r),並取得下一個情態(s1)。 上述整個流程就是一個典型的Markov decision process(MDP),用一個Convolutional neural network(CNN)來實現,然則和典型CNN不同的是Cost function是由Bellman Equation對整個Neural network做更新。
7mins demo for DQN play flappy bird
接下來針對程式碼執行訓練Network的部分做較詳細的解釋: 詳細請參考
def trainNetwork(s, readout, h_fc1, sess):
# define the cost function
#定義cost function,採最小平方法訓練
a = tf.placeholder("float", [None, ACTIONS])
y = tf.placeholder("float", [None])
readout_action = tf.reduce_sum(tf.mul(readout, a), reduction_indices=1)
cost = tf.reduce_mean(tf.square(y - readout_action))
train_step = tf.train.AdamOptimizer(1e-6).minimize(cost)
# open up a game state to communicate with emulator
game_state = game.GameState()
# store the previous observations in replay memory
D = deque()
# printing
a_file = open("logs_" + GAME + "/readout.txt", 'w')
h_file = open("logs_" + GAME + "/hidden.txt", 'w')
# get the first state by doing nothing and preprocess the image to 80x80x4
do_nothing = np.zeros(ACTIONS)
do_nothing[0] = 1
x_t, r_0, terminal = game_state.frame_step(do_nothing)
x_t = cv2.cvtColor(cv2.resize(x_t, (80, 80)), cv2.COLOR_BGR2GRAY)
ret, x_t = cv2.threshold(x_t,1,255,cv2.THRESH_BINARY)
s_t = np.stack((x_t, x_t, x_t, x_t), axis=2)
# saving and loading networks
saver = tf.train.Saver()
checkpoint = tf.train.get_checkpoint_state("saved_networks")
if checkpoint and checkpoint.model_checkpoint_path:
saver.restore(sess, checkpoint.model_checkpoint_path)
print("Successfully loaded:", checkpoint.model_checkpoint_path)
print("Could not find old network weights")
# start training
t = 0
while "flappy bird" != "angry bird":
# choose an action epsilon greedily
readout_t = readout.eval(feed_dict={s : [s_t]})[0]
a_t = np.zeros([ACTIONS])
action_index = 0
if t % FRAME_PER_ACTION == 0:
if random.random() <= epsilon: print("----------Random Action----------") #隨機選擇一個動作 action_index = random.randrange(ACTIONS) a_t[random.randrange(ACTIONS)] = 1 else: #選取Q較大的值當作下步動作 action_index = np.argmax(readout_t) a_t[action_index] = 1 else: 非判斷的Frame則不動作 a_t[0] = 1 # do nothing # scale down epsilon 降低掉入隨機存取的機率 if epsilon > FINAL_EPSILON and t > OBSERVE:
# run the selected action and observe next state and reward
x_t1_colored, r_t, terminal = game_state.frame_step(a_t)
x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY)
ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY)
x_t1 = np.reshape(x_t1, (80, 80, 1))
#s_t1 = np.append(x_t1, s_t[:,:,1:], axis = 2)
s_t1 = np.append(x_t1, s_t[:, :, :3], axis=2)
#將執行過程中所產生的Situation,action,reward,下個situation存入Replay memory
# store the transition in D
D.append((s_t, a_t, r_t, s_t1, terminal))
if len(D) > REPLAY_MEMORY:
if t > OBSERVE:
# sample a minibatch to train on
minibatch = random.sample(D, BATCH)
# get the batch variables
s_j_batch = [d[0] for d in minibatch]
a_batch = [d[1] for d in minibatch]
r_batch = [d[2] for d in minibatch]
s_j1_batch = [d[3] for d in minibatch]
y_batch = []
readout_j1_batch = readout.eval(feed_dict = {s : s_j1_batch})
for i in range(0, len(minibatch)):
terminal = minibatch[i][4]
# if terminal, only equals reward
if terminal:
#如果未到Terminal則根據bellman equation更新
y_batch.append(r_batch[i] + GAMMA * np.max(readout_j1_batch[i]))
# perform gradient step = {
y : y_batch,
a : a_batch,
s : s_j_batch}
# update the old values
s_t = s_t1
t += 1
# save progress every 10000 iterations
if t % 10000 == 0:, 'saved_networks/' + GAME + '-dqn', global_step = t)
# print info
state = ""
if t <= OBSERVE: #觀察中,也就是前100000步 state = "observe" elif t > OBSERVE and t <= OBSERVE + EXPLORE:
state = "explore"
state = "train"
print("TIMESTEP", t, "/ STATE", state, \
"/ EPSILON", epsilon, "/ ACTION", action_index, "/ REWARD", r_t, \
"/ Q_MAX %e" % np.max(readout_t))
# write info to files
if t % 10000 <= 100:
a_file.write(",".join([str(x) for x in readout_t]) + '\n')
h_file.write(",".join([str(x) for x in h_fc1.eval(feed_dict={s:[s_t]})[0]]) + '\n')
cv2.imwrite("logs_tetris/frame" + str(t) + ".png", x_t1)
1.該學習方法稱為Off-Policy Learning,既然有Off,那On是否存在?其稱為On-Policy Learning!之後會有另外的文章做介紹。
2.早期在做Reinforcement Learning時,學習的方式是採用更新Q-Table的方式,什麼是Q-Table呢?想像情態和動作的多種組合,假設情態只發生兩種,動作也只有兩種,那可以有的組合為2X2=4種可能,在訓練過程當中,我們更新這4個組合的Q-Value值,讓整個模型學習在既定的狀態下做出最適合的動作,這其實非常容易!而另外的例子,我們依然假設動作只有兩種,但情態推廣到如Flappy Bird有那麼多種,電腦在初始記憶體大小時得要準備多少空間?於是我們將Q-Table抽換掉,將Deep Neural network(DNN)(或稱Q-Neural network)置入,讓DNN學習在某些情境,該採取什麼樣的動作,以取代如此高維度的空間。但缺點是這中間仰賴極大的運算,而近年因為GPU運算成本降低,該方法得到良好的推廣。 參考: 1.DeepLearningFlappyBird Q-learning