reinforcement/epsilon_greedy_starter.jl

65 lines
1.5 KiB
Julia
Raw Normal View History

2021-01-02 12:10:47 +00:00
using Random
using Plots
mutable struct Bandit
p::Number #the win rate
p_estimate::Number #How to estimate the win rate
N::Number #Number of samples
Bandit(p) = new(p,0,0)
end
function pull(ban::Bandit)
return convert(Int,rand() < ban.p)
end
function update(ban::Bandit, x::Number) #x is a sample number
ban.N += 1
ban.p_estimate = ((ban.N - 1) * ban.p_estimate + x) / ban.N
end
num_trials = 10000
ϵ = 0.1
bandit_probs = [0.2,0.5,0.75]
bandits = [Bandit(p) for p in bandit_probs]
rewards = zeros(num_trials)
num_times_explored = 0
num_times_exploited = 0
num_optimal = 0
optimal_j = argmax([b.p for b in bandits])
println("Optimal j: ", optimal_j)
for i in 1:num_trials
if rand() < ϵ
num_times_explored += 1
j = rand(1:size(bandits)[1])
else
num_times_exploited += 1
j = argmax([b.p_estimate for b in bandits])
end
if j == optimal_j
num_optimal += 1
end
x = pull(bandits[j])
rewards[i] = x
update(bandits[j],x)
end
for b in bandits
println("Mean estimate: ", b.p_estimate)
end
2021-01-03 15:16:21 +00:00
println("Total reward earned: ", sum(rewards))
2021-01-02 12:10:47 +00:00
println("Overall win rate: ", sum(rewards)/ num_trials)
println("Number times explored: ", num_times_explored)
println("Number times exploited: ", num_times_exploited)
println("Number of times the optimal bandit was selected: ", num_optimal)
cumulative_rewards = cumsum(rewards)
win_rates = cumulative_rewards ./ Array(1:num_trials)
plot(win_rates)
plot!(ones(num_trials) .* max(bandit_probs...))