reinforcement/bayesian_bandit.jl

58 lines
1.3 KiB
Julia
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using Plots
using Random
using Distributions
mutable struct Bandit
p::Number #the win rate
α::Number
β::Number
N::Number #Number of samples
Bandit(p) = new(p,1,1,0)
end
function pull(ban::Bandit)
return convert(Int,rand() < ban.p)
end
function sample(ban::Bandit)
return rand(Beta(ban.α,ban.β))
end
function update(ban::Bandit, x::Number) #x is a sample number
ban.α += x #Prob
ban.β += 1 - x #Complement of prob
ban.N += 1
end
function ban_plot(bandits::Array,trial::Number)
x = convert(Array,LinRange(0,1,200))
for b in bandits
y = pdf(Beta(b.α,b.β),x)
display(plot!(x,y, title="Bandit distributions after $trial trials"))
end
end
num_trials = 2000;
bandit_probs = [0.2,0.5,0.75];
bandits = [Bandit(p) for p in bandit_probs];
sample_points = [5,10,20,50,100,200,500,1000,1500,1999];
rewards = zeros(num_trials);
for i in 1:num_trials
# Thomson sampling
j = argmax([sample(b) for b in bandits])
if i in sample_points
ban_plot(bandits,i)
end
x = pull(bandits[j])
rewards[i] = x
update(bandits[j],x)
end
cumulative_average = cumsum(rewards) ./ Array(1:num_trials);
plot(cumulative_average,xaxis=:log)
plot!(ones(num_trials) .* max(bandit_probs...))