Understanding AlphaZero by Reading Source Code in Julia
AlphaZero is a machine learning method based on deep neural networks, reinforcement learning, and randomized tree search (discrete actions). In a sense, it mimics combining a brain (a trainable network) and a dopamine system (a reward-seeking trainable network) with logical thinking (search guided by a reward-seeking trainable network). However, it requires high-performance computing to truly shine. On a high-level, this is how it works:
- Set up your problem as an environment that has states, actions that can change the state, and rewards that signal the goal of interacting with this environment. In other words, set up your problem as a game.
- Set up a deep neural network with two heads that inputs the state and outputs a value of the state (value) and a probability vector over possible actions (policy). Minimize the difference between predicted value and end-of-game value (squared differences) and predicted policy and search-estimated policy (vector dot product). Use any suitable optimization algorithm. While playing this network is called an oracle (for an intuitive reason).
- Simulate the game (self-play) to collect data. Before each decision, use search to improve the oracle-policy. Then use the improved policy to select an action, and repeat until the game is over. Play many games.
- During search, select an action that maximizes a measure (U) that combines an estimated value after action (Q) with the oracle-policy (P), the counts of actions (N) taken previously, and an exploration parameter (c). The basic idea is that the weight of the predicted policy (P) decreases as more experience of the action (N) accumulates or as the user changes the exploration parameter. Conversely, the weight of the value of actions (Q) increases as they are updated using predicted values and end-of-game rewards.
- During search, at each new state, the value of the state is predicted by the neural network or given by the game, and the value of previous actions (Q) are updated. Updating is done by taking an average over previous action values (N x Q) and the new state value. The values of actions from the new state (Q) start at 0 so only the predicted value and policy will affect the next decision as defined by the measure (U). Then again the value of each action is updated.
- During search, the counts of actions taken at the initial state of interest (root) are recorded. The proportion of each action will be the final search-improved policy since the most taken actions seem to have the highest value states down the line (out of the paths that were searched).
- After searching from a given state, the search-improved policy is used to select an action. For each new state, search repeats, until the game ends and sends reward. Many games like this are played until there’s enough data on states, search-improved policies, and end-of-game rewards to improve the predictions.
- Improve the value and policy predictions with data from the self-play simulations.
- Train the neural network used to guide the search. To reiterate, a two-headed network learns to predict the observed end-of-game rewards and the search-improved policies at the same time based on the game state.
- To test the trained network, pit the trained and untrained networks against each other for multiple games and count the fraction of wins (or whatever metric). If the trained network is significantly better, use it to guide the data-collecting games in the next iteration. Then again, with enriched data, train a new network.
This is a lot of text to describe but still vague. Let’s try to understand it better by reading some Julia code. Note the license.
Copyright (c) 2019 Jonathan Laurent
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
But before the source code, let’s see a usage example. Based on the documentation, everything starts at these train
methods.
train(e::Experiment; args...) = UserInterface.resume!(Session(e; args...))
train(s::String; args...) = train(Examples.experiments[s]; args...)
“Experiments” specify the whole thing to be run. Let’s look at the example experiment called “connect-four”.
module ConnectFour
export GameSpec, GameEnv, Board
# Define the game.
include("game.jl")
module Training
using AlphaZero
import ..GameSpec
# Define the experiment.
include("params.jl")
end
include("solver.jl")
end
The game specification is specific to the rules of Connect Four so we won’t look at it closer. This interface also cannot be used to specify arbitrary kind of environments. Anyway, know that in general you could implement an interface to your own game – or in general to almost any environment (inputs and outputs) represented as a game.
Once you have the game interface defined, params.jl
shows what kind of information you need to train the agent in the game in this package.
# Neural network constructor.
# Residual-block two-headed architecture.
Network = NetLib.ResNet
# Neural network parameters.
netparams = NetLib.ResNetHP(
num_filters=128,
num_blocks=5,
# Convolutional layers.
conv_kernel_size=(3, 3),
num_policy_head_filters=32,
num_value_head_filters=32,
batch_norm_momentum=0.1)
# Neural network optimization.
learning = LearningParams(
use_gpu=true,
use_position_averaging=true,
samples_weighing_policy=LOG_WEIGHT,
batch_size=1024,
loss_computation_batch_size=1024,
# Adam optimizer.
optimiser=Adam(lr=2e-3),
# Loss regularization.
l2_regularization=1e-4,
nonvalidity_penalty=1.,
min_checkpoints_per_epoch=1,
max_batches_per_checkpoint=2000,
num_checkpoints=1)
# Self-play just means simulating games to collect data for learning.
self_play = SelfPlayParams(
# Self-play simulation.
sim=SimParams(
num_games=5000,
# Distributed computing.
num_workers=128,
batch_size=64,
use_gpu=true,
reset_every=2,
flip_probability=0.,
alternate_colors=false),
# Self-play search.
mcts=MctsParams(
num_iters_per_turn=600,
# Exploration-enhancing parameters.
cpuct=2.0,
prior_temperature=1.0,
# Often you may want to decrease exploration over time.
temperature=PLSchedule([0, 20, 30], [1.0, 1.0, 0.3]),
dirichlet_noise_ϵ=0.25,
dirichlet_noise_α=1.0)
)
# Arena is for the comparisons of agents.
arena = ArenaParams(
sim=SimParams(
num_games=128,
num_workers=128,
batch_size=128,
use_gpu=true,
reset_every=2,
flip_probability=0.5,
alternate_colors=true),
mcts=MctsParams(
self_play.mcts,
temperature=ConstSchedule(0.2),
dirichlet_noise_ϵ=0.05),
update_threshold=0.05
)
# All parameters.
params = Params(
arena=arena,
self_play=self_play,
learning=learning,
num_iters=15,
ternary_outcome=true,
# Augment data with symmetrical rotations of states.
use_symmetries=true,
memory_analysis=nothing,
# Data-collecting games accumulate experiences to memory to learn from.
mem_buffer_size=PLSchedule([0, 15], [400_000, 1_000_000])
)
# Comparing AlphaZero against simpler agents. This is good in any data science.
mcts_baseline = Benchmark.MctsRollouts(MctsParams(arena.mcts, num_iters_per_turn=1000, cpuct=1.))
alphazero_player = Benchmark.Full(arena.mcts)
benchmark_sim = SimParams(
arena.sim;
num_games=256,
num_workers=256,
batch_size=256,
alternate_colors=false)
benchmark = [Benchmark.Duel(alphazero_player, mcts_baseline, benchmark_sim)]
experiment = Experiment("connect-four", GameSpec(), params, Network, netparams, benchmark)
Okay, so training an experiment resumes a session. Let’s see Session
. We can see that it adds some information to the experiment.
mutable struct Session{Env}
env :: Env
dir :: String
logger :: Logger
autosave :: Bool
save_intermediate :: Bool
benchmark :: Vector{Benchmark.Evaluation}
progress :: Union{Progress, Nothing}
report :: SessionReport
# Skip...
end
# Env-objects look like this. Fields are quite self-explanatory (cur, current; nn, neural network).
mutable struct Env{GameSpec, Network, State}
gspec :: GameSpec
params :: Params
curnn :: Network
bestnn :: Network
# Memory holds the experience from data-collecting games to learn from.
memory :: MemoryBuffer{GameSpec, State}
# Iteration counter.
itc :: Int
# Skip....
end
Resuming the session calls train!
on the environment (Env) within the session. This runs self-play and learning steps for the given number of iterations.
function train!(env::Env, handler=nothing)
while env.itc < env.params.num_iters
Handlers.iteration_started(handler)
resize_memory!(env, env.params.mem_buffer_size[env.itc])
# Self-play.
sprep, spperfs = Report.@timed self_play_step!(env, handler)
mrep, mperfs = Report.@timed memory_report(env, handler)
# Learn.
lrep, lperfs = Report.@timed learning_step!(env, handler)
rep = Report.Iteration(spperfs, mperfs, lperfs, sprep, mrep, lrep)
env.itc += 1
Handlers.iteration_finished(handler, rep)
end
Handlers.training_finished(handler)
end
So, training an AlphaZero agent consists of self-play and learning. Let’s look at what self-play looks like.
function self_play_step!(env::Env, handler)
params = env.params.self_play
Handlers.self_play_started(handler)
# Get the best neural network to use as an oracle for the search.
make_oracle() = Network.copy(env.bestnn, on_gpu=params.sim.use_gpu, test_mode=true)
# Make a player that uses this oracle within MCTS search.
simulator = Simulator(make_oracle, self_play_measurements) do oracle
return MctsPlayer(env.gspec, oracle, params.mcts)
end
# Player plays the game and new data (experience) are collected for learning.
results, elapsed = @timed simulate_distributed(
simulator, env.gspec, params.sim, game_simulated=()->Handlers.game_played(handler)
)
# Add experiences to memory.
new_batch!(env.memory)
for x in results
push_trace!(env.memory, x.trace, params.mcts.gamma)
end
# Skip...
Handlers.self_play_finished(handler, report)
return report
end
It’s just playing against yourself and collecting the data. After self_play_step!
the agent takes the learning_step!
.
function learning_step!(env::Env, handler)
ap = env.params.arena
lp = env.params.learning
checkpoints = Report.Checkpoint[]
losses = Float32[]
tloss, teval, ttrain = 0., 0., 0.
# Get the data collected during self-play.
experience = get_experience(env.memory)
if env.params.use_symmetries
experience = augment_with_symmetries(env.gspec, experience)
end
# Skip...
# Initialize training of current network.
trainer, tconvert = @timed Trainer(env.gspec, env.curnn, experience, lp)
init_status = learning_status(trainer)
status = init_status
Handlers.learning_started(handler)
# Skip....
# Looping over checkpoints.
for k in 1:lp.num_checkpoints
Handlers.updates_started(handler, status)
# Update current network's parameters.
dlosses, dttrain = @timed batch_updates!(trainer, nbatches)
status, dtloss = @timed learning_status(trainer)
Handlers.updates_finished(handler, status)
tloss += dtloss
ttrain += dttrain
append!(losses, dlosses)
if isnothing(ap)
env.curnn = get_trained_network(trainer)
env.bestnn = copy(env.curnn)
nn_replaced = true
else
Handlers.checkpoint_started(handler)
# Compare current and previous best network and update if better.
env.curnn = get_trained_network(trainer)
eval_report = compare_networks(env.gspec, env.curnn, env.bestnn, ap, handler)
teval += eval_report.time
success = (eval_report.avgr >= best_evalr)
if success
nn_replaced = true
env.bestnn = copy(env.curnn)
best_evalr = eval_report.avgr
end
# Skip...
end
end
# Skip...
Handlers.learning_finished(handler, report)
return report
end
compare_networks
calls pit_networks
for two-player games and evaluate_network
for single-player games. These simulate the game using the given networks and return results.
function pit_networks(gspec, contender, baseline, params, handler)
make_oracles() = (
Network.copy(contender, on_gpu=params.sim.use_gpu, test_mode=true),
Network.copy(baseline, on_gpu=params.sim.use_gpu, test_mode=true))
simulator = Simulator(make_oracles, record_trace) do oracles
white = MctsPlayer(gspec, oracles[1], params.mcts)
black = MctsPlayer(gspec, oracles[2], params.mcts)
return TwoPlayers(white, black)
end
samples = simulate(
simulator, gspec, params.sim,
game_simulated=(() -> Handlers.checkpoint_game_played(handler)))
return rewards_and_redundancy(samples, gamma=params.mcts.gamma)
end
function evaluate_network(gspec, net, params, handler)
make_oracles() = Network.copy(net, on_gpu=params.sim.use_gpu, test_mode=true)
simulator = Simulator(make_oracles, record_trace) do oracle
MctsPlayer(gspec, oracle, params.mcts)
end
samples = simulate(
simulator, gspec, params.sim,
game_simulated=(() -> Handlers.checkpoint_game_played(handler)))
return rewards_and_redundancy(samples, gamma=params.mcts.gamma)
end
So what is this MctsPlayer
and what does it mean to simulate the simulator? This object packages all the necessary information to run the actual neural net oracle enached MCTS search. Glimpse over the object contents below.
struct MctsPlayer{M} <: AbstractPlayer
mcts :: M
niters :: Int
timeout :: Union{Float64, Nothing}
τ :: AbstractSchedule{Float64}
function MctsPlayer(mcts::MCTS.Env; τ, niters, timeout=nothing)
@assert niters > 0
@assert isnothing(timeout) || timeout > 0
new{typeof(mcts)}(mcts, niters, timeout, τ)
end
end
function MctsPlayer(game_spec::AbstractGameSpec, oracle, params::MctsParams; timeout=nothing)
mcts = MCTS.Env(
game_spec,
oracle,
# Gamma is used to discount the future.
gamma=params.gamma,
# Cpuct is used to control amount of exploration.
cpuct=params.cpuct,
# Noise and temperature also increase amount of exploration.
noise_ϵ=params.dirichlet_noise_ϵ,
noise_α=params.dirichlet_noise_α,
prior_temperature=params.prior_temperature)
return MctsPlayer(mcts, niters=params.num_iters_per_turn, τ=params.temperature, timeout=timeout)
end
# This is a key structure.
mutable struct Env{State, Oracle}
tree :: Dict{State, StateInfo}
oracle :: Oracle
gamma :: Float64
cpuct :: Float64
noise_ϵ :: Float64
noise_α :: Float64
prior_temperature :: Float64
total_simulations :: Int64
total_nodes_traversed :: Int64
gspec :: GI.AbstractGameSpec
# Skip...
end
Going inside simulate_distributed
, we can find that a function play_game
is finally called in each distributed instance.
function play_game(gspec, player; flip_probability=0.)
game = GI.init(gspec)
trace = Trace(GI.current_state(game))
while true
if GI.game_terminated(game)
return trace
end
if !iszero(flip_probability) && rand() < flip_probability
GI.apply_random_symmetry!(game)
end
# "think" returns the policy (probability of each available action).
actions, π_target = think(player, game)
τ = player_temperature(player, game, length(trace))
π_sample = apply_temperature(π_target, τ)
# Decision is made randomly based on the policy distribution.
a = actions[Util.rand_categorical(π_sample)]
# Action is finally taken in the environment.
GI.play!(game, a)
# Information is recorded to the "trace" until game is terminated.
push!(trace, π_target, GI.white_reward(game), GI.current_state(game))
end
end
How does think
compute the policy then? This includes explore!
and policy
.
function think(p::MctsPlayer, game)
if isnothing(p.timeout)
MCTS.explore!(p.mcts, game, p.niters)
# With "timeout" you can keep searching for a fixed amount of time.
# All this depends on your resources that the method has no glue about.
# Perhaps a more intelligent method would be aware of the available resources too.
else
start = time()
while time() - start < p.timeout
MCTS.explore!(p.mcts, game, p.niters)
end
end
return MCTS.policy(p.mcts, game)
end
# Run simulations with initial noise.
function explore!(env::Env, game, nsims)
# Dirichlet distribution gives a random vector summing to 1.
η = dirichlet_noise(game, env.noise_α)
for i in 1:nsims
env.total_simulations += 1
# Games are run based on oracle evaluations (deep neural network), parameters, and noise.
# Parameters and noise control the amount of exploration done around the neural net predictions.
run_simulation!(env, GI.clone(game), η=η)
end
end
# Search-estimated policy is just the distribution of actions taken in the search simulations.
function policy(env::Env, game)
actions = GI.available_actions(game)
state = GI.current_state(game)
info =
try env.tree[state]
catch e
if isa(e, KeyError)
error("MCTS.explore! must be called before MCTS.policy")
else
rethrow(e)
end
end
Ntot = sum(a.N for a in info.stats)
π = [a.N / Ntot for a in info.stats]
π ./= sum(π)
return actions, π
end
run_simulation!
is the key function in explore!
. It runs one game and updates the values of actions (Q) made based on the UCT scores (U) and counts how many times actions were taken (N) at the initial state. These counts were then accessed by policy
above.
# This does recursion until the game is terminated.
function run_simulation!(env::Env, game; η, root=true)
if GI.game_terminated(game)
return 0.
else
state = GI.current_state(game)
actions = GI.available_actions(game)
info, new_node = state_info(env, state)
# For new states, returns the oracle evaluation.
if new_node
return info.Vest
# Otherwise, returns the Q-value.
else
# Additional noise added for the first state.
ϵ = root ? env.noise_ϵ : 0.
# Final evaluations for possible actions.
scores = uct_scores(info, env.cpuct, ϵ, η)
# Action with maximum score is selected.
action_id = argmax(scores)
action = actions[action_id]
wp = GI.white_playing(game)
# Apply action and update the state.
GI.play!(game, action)
# Reward.
wr = GI.white_reward(game)
# Opponent's (black) positive rewards should have opposite sign.
r = wp ? wr : -wr
pswitch = wp != GI.white_playing(game)
# Value of the next action.
qnext = run_simulation!(env, game, η=η, root=false)
qnext = pswitch ? -qnext : qnext
# The final Q-value.
q = r + env.gamma * qnext
update_state_info!(env, state, action_id, q)
env.total_nodes_traversed += 1
return q
end
end
end
To see what this UCT score metric actually is, we need to check uct_scores
.
function uct_scores(info::StateInfo, cpuct, ϵ, η)
@assert iszero(ϵ) || length(η) == length(info.stats)
sqrtNtot = sqrt(Ntot(info))
return map(enumerate(info.stats)) do (i, a)
# W is a cumulative value of action over games.
Q = a.W / max(a.N, 1)
# Add additional noise to the oracle policy (P) at the first state of simulation.
P = iszero(ϵ) ? a.P : (1-ϵ) * a.P + ϵ * η[i]
# UCT scoring adds the oracle policy P to the value of action Q.
# Oracle policy is basically weighted using an exploration parameter "cpuct"
# and the number of observations of the action.
Q + cpuct * P * sqrtNtot / (a.N + 1)
end
end
The UCT score basically combines the predicted state values and policies with observed rewards during the search.
We have covered the self-play part of the method quite well. In the learning part, we have the batch_updates!
left. It is regular deep learning code so we won’t look at it very closely.
function batch_updates!(tr::Trainer, n)
regws = Network.regularized_params(tr.network)
L(batch...) = losses(tr.network, regws, tr.params, tr.Wmean, tr.Hp, batch)[1]
data = Iterators.take(tr.batches_stream, n)
ls = Vector{Float32}()
Network.train!(tr.network, tr.params.optimiser, L, data, n) do i, l
push!(ls, l)
end
Network.gc(tr.network)
return ls
end
An example method of train!
from the module Network
is below.
function Network.train!(callback, nn::FluxNetwork, opt::Adam, loss, data, n)
optimiser = Flux.Adam(opt.lr)
params = Flux.params(nn)
for (i, d) in enumerate(data)
l, grads = lossgrads(params) do
loss(d...)
end
Flux.update!(optimiser, params, grads)
callback(i, l)
end
end
Finally, let’s look at the Residual Neural Network implemention based on Flux.
abstract type TwoHeadNetwork <: FluxNetwork end
function Network.forward(nn::TwoHeadNetwork, state)
c = nn.common(state)
v = nn.vhead(c)
p = nn.phead(c)
# Network predicts a tuple (policy, value).
return (p, v)
end
mutable struct ResNet <: TwoHeadNetwork
gspec
hyper
common
vhead
phead
end
function ResNetBlock(size, n, bnmom)
pad = size .÷ 2
layers = Chain(
Conv(size, n=>n, pad=pad),
BatchNorm(n, relu, momentum=bnmom),
Conv(size, n=>n, pad=pad),
BatchNorm(n, momentum=bnmom))
return Chain(
# Residual networks have multilayer-skipping connections.
SkipConnection(layers, +),
x -> relu.(x))
end
function ResNet(gspec::AbstractGameSpec, hyper::ResNetHP)
indim = GI.state_dim(gspec)
outdim = GI.num_actions(gspec)
ksize = hyper.conv_kernel_size
@assert all(ksize .% 2 .== 1)
pad = ksize .÷ 2
nf = hyper.num_filters
npf = hyper.num_policy_head_filters
nvf = hyper.num_value_head_filters
bnmom = hyper.batch_norm_momentum
# The common layers need to learn features of the state
# that can predict both policies and values.
common = Chain(
# Convolutional layers are the basis of this architecture.
Conv(ksize, indim[3]=>nf, pad=pad),
BatchNorm(nf, relu, momentum=bnmom),
[ResNetBlock(ksize, nf, bnmom) for i in 1:hyper.num_blocks]...)
# Policy predictions feeding to the UCT scores during search.
phead = Chain(
Conv((1, 1), nf=>npf),
BatchNorm(npf, relu, momentum=bnmom),
flatten,
Dense(indim[1] * indim[2] * npf, outdim),
softmax)
# State value predictions feeding to the Q-values during search.
vhead = Chain(
Conv((1, 1), nf=>nvf),
BatchNorm(nvf, relu, momentum=bnmom),
flatten,
Dense(indim[1] * indim[2] * nvf, nf, relu),
Dense(nf, 1, tanh))
# This struct has the forward-method above.
ResNet(gspec, hyper, common, vhead, phead)
end