reinforcement/ubc1.jl

51 lines
1.1 KiB
Julia

using Random
using Plots
mutable struct Bandit
p::Number #the win rate
p_estimate::Number #How to estimate the win rate
N::Number #Number of samples
Bandit(p) = new(p,0,0)
end
function pull(ban::Bandit)
return convert(Int,rand() < ban.p)
end
function update(ban::Bandit, x::Number) #x is a sample number
ban.N += 1
ban.p_estimate = ((ban.N - 1) * ban.p_estimate + x) / ban.N
end
function ucb(mean::Number,n::Number,nⱼ::Number)
return mean + (2*log(n) / nⱼ)
end
num_trials = 100000;
ϵ = 0.1;
bandit_probs = [0.2,0.5,0.75];
bandits = [Bandit(p) for p in bandit_probs];
rewards = zeros(num_trials);
total_plays = 0;
optimal_j = argmax([b.p for b in bandits]);
for j in 1:size(bandits)[1]
x = pull(bandits[j])
total_plays += 1
update(bandits[j],x)
end
for i in 1:num_trials
j = argmax([ucb(b.p_estimate,total_plays,b.N) for b in bandits])
x = pull(bandits[j])
total_plays += 1
update(bandits[j],x)
rewards[i] = x
end
cumulative_average = cumsum(rewards) ./ Array(1:num_trials);
plot(cumulative_average,xaxis=:log)
plot!(ones(num_trials) .* max(bandit_probs...))