Changes for saamplig wwith a bayesian prior
This commit is contained in:
parent
58088870b4
commit
52465ffe87
|
@ -0,0 +1,63 @@
|
||||||
|
using Plots
|
||||||
|
using Random
|
||||||
|
using Distributions
|
||||||
|
|
||||||
|
mutable struct Bandit
|
||||||
|
true_mean::Number
|
||||||
|
predicted_mean::Number
|
||||||
|
λ::Number
|
||||||
|
sum_x::Number
|
||||||
|
τ::Number
|
||||||
|
N::Number #Number of samples
|
||||||
|
Bandit(true_mean) = new(true_mean,0,1,0,1,0)
|
||||||
|
end
|
||||||
|
|
||||||
|
function pull(ban::Bandit)
|
||||||
|
return (rand(Normal(0,1)) / √(ban.τ)) + ban.true_mean
|
||||||
|
end
|
||||||
|
|
||||||
|
function sample(ban::Bandit)
|
||||||
|
return (rand(Normal(0,1)) / √(ban.λ)) + ban.predicted_mean
|
||||||
|
end
|
||||||
|
|
||||||
|
function update(ban::Bandit, x::Number) #x is a sample number
|
||||||
|
ban.λ += ban.τ
|
||||||
|
ban.sum_x += x
|
||||||
|
ban.predicted_mean = ban.τ * ban.sum_x / ban.λ
|
||||||
|
ban.N += 1
|
||||||
|
end
|
||||||
|
|
||||||
|
function ban_plot(bandits::Array,trial::Number)
|
||||||
|
plt = plot()
|
||||||
|
x = convert(Array,LinRange(-3,6,200))
|
||||||
|
for b in bandits
|
||||||
|
y = pdf(Normal(b.predicted_mean,√(1.0/b.λ)),x)
|
||||||
|
plot!(plt,x,y, title="Bandit distributions after $trial trials",label="Real mean: $(b.true_mean), Num of plays: $(b.N)")
|
||||||
|
end
|
||||||
|
display(plt)
|
||||||
|
end
|
||||||
|
|
||||||
|
num_trials = 2000;
|
||||||
|
bandit_means = [1,2,3];
|
||||||
|
|
||||||
|
bandits = [Bandit(p) for p in bandit_means];
|
||||||
|
sample_points = [5,10,20,50,100,200,500,1000,1500,1999];
|
||||||
|
rewards = zeros(num_trials);
|
||||||
|
|
||||||
|
for i in 1:num_trials
|
||||||
|
# Thomson sampling
|
||||||
|
j = argmax([sample(b) for b in bandits])
|
||||||
|
|
||||||
|
if i in sample_points
|
||||||
|
ban_plot(bandits,i)
|
||||||
|
end
|
||||||
|
|
||||||
|
x = pull(bandits[j])
|
||||||
|
rewards[i] = x
|
||||||
|
update(bandits[j],x)
|
||||||
|
end
|
||||||
|
|
||||||
|
cumulative_average = cumsum(rewards) ./ Array(1:num_trials);
|
||||||
|
plot(cumulative_average,xaxis=:log)
|
||||||
|
plot!(ones(num_trials) .* max(bandit_probs...))
|
||||||
|
|
Loading…
Reference in New Issue