58 lines
1.3 KiB
Julia
58 lines
1.3 KiB
Julia
using Plots
|
||
using Random
|
||
using Distributions
|
||
|
||
mutable struct Bandit
|
||
p::Number #the win rate
|
||
α::Number
|
||
β::Number
|
||
N::Number #Number of samples
|
||
Bandit(p) = new(p,1,1,0)
|
||
end
|
||
|
||
function pull(ban::Bandit)
|
||
return convert(Int,rand() < ban.p)
|
||
end
|
||
|
||
function sample(ban::Bandit)
|
||
return rand(Beta(ban.α,ban.β))
|
||
end
|
||
|
||
function update(ban::Bandit, x::Number) #x is a sample number
|
||
ban.α += x #Prob
|
||
ban.β += 1 - x #Complement of prob
|
||
ban.N += 1
|
||
end
|
||
|
||
function ban_plot(bandits::Array,trial::Number)
|
||
x = convert(Array,LinRange(0,1,200))
|
||
for b in bandits
|
||
y = pdf(Beta(b.α,b.β),x)
|
||
display(plot!(x,y, title="Bandit distributions after $trial trials"))
|
||
end
|
||
end
|
||
|
||
num_trials = 2000;
|
||
bandit_probs = [0.2,0.5,0.75];
|
||
|
||
bandits = [Bandit(p) for p in bandit_probs];
|
||
sample_points = [5,10,20,50,100,200,500,1000,1500,1999];
|
||
rewards = zeros(num_trials);
|
||
|
||
for i in 1:num_trials
|
||
# Thomson sampling
|
||
j = argmax([sample(b) for b in bandits])
|
||
|
||
if i in sample_points
|
||
ban_plot(bandits,i)
|
||
end
|
||
|
||
x = pull(bandits[j])
|
||
rewards[i] = x
|
||
update(bandits[j],x)
|
||
end
|
||
|
||
cumulative_average = cumsum(rewards) ./ Array(1:num_trials);
|
||
plot(cumulative_average,xaxis=:log)
|
||
plot!(ones(num_trials) .* max(bandit_probs...))
|