74 lines
1.6 KiB
Julia
74 lines
1.6 KiB
Julia
using Distributions
|
|
using Random
|
|
using Plots
|
|
|
|
Random.seed!(0);
|
|
|
|
mutable struct BanditArm
|
|
m::Number #the win rate
|
|
m_estimate::Number #How to estimate the win rate
|
|
N::Number #Number of samples
|
|
|
|
BanditArm(m) = new(m,0,0)
|
|
end
|
|
|
|
function pull(ban::BanditArm)
|
|
return rand(Normal(0,1)) + ban.m
|
|
end
|
|
|
|
function update(ban::BanditArm, x::Number) #x is a sample number
|
|
ban.N += 1
|
|
ban.m_estimate = ((1 - 1/ban.N) * ban.m_estimate) + (1 / (ban.N * x))
|
|
end
|
|
|
|
function run_experiment(m1::Number,m2::Number,m3::Number,ϵ::Number,N::Number)
|
|
bandits = [BanditArm(m1),BanditArm(m2),BanditArm(m3)]
|
|
means = [m1,m2,m3]
|
|
true_best = argmax(means)
|
|
count_suboptimal = 0
|
|
|
|
data = Array{Number}(undef,N)
|
|
|
|
for i in 1:N
|
|
p = rand()
|
|
if p < ϵ
|
|
j = rand(1:size(bandits)[1])
|
|
else
|
|
j = argmax([b.m_estimate for b in bandits])
|
|
end
|
|
x = pull(bandits[j])
|
|
update(bandits[j],x)
|
|
|
|
if j != true_best
|
|
count_suboptimal += 1
|
|
end
|
|
data[i] = x
|
|
end
|
|
|
|
gr();
|
|
cumulative_average = cumsum(data) ./ Array(1:N)
|
|
plot(cumulative_average,xaxis=:log)
|
|
plot!(ones(N) .* m1,xaxis=:log)
|
|
plot!(ones(N) .* m2,xaxis=:log)
|
|
display(plot!(ones(N) .* m3,xaxis=:log))
|
|
|
|
for b in bandits
|
|
println(b.m_estimate)
|
|
end
|
|
println("Perccent suboptimal for ϵ = $ϵ: ", count_suboptimal / N)
|
|
|
|
return cumulative_average
|
|
end
|
|
|
|
m1 = 1.5
|
|
m2 = 2.5
|
|
m3 = 3.5
|
|
|
|
c_1 = run_experiment(m1,m2,m3,0.1,100000);
|
|
c_05 = run_experiment(m1,m2,m3,0.05,100000);
|
|
c_01 = run_experiment(m1,m2,m3,0.01,100000);
|
|
|
|
plot(c_1)
|
|
plot!(c_05)
|
|
plot!(c_01)
|