New files for other algorithms
This commit is contained in:
parent
6d6b0a5186
commit
40bf4570af
|
@ -52,7 +52,7 @@ for b in bandits
|
||||||
println("Mean estimate: ", b.p_estimate)
|
println("Mean estimate: ", b.p_estimate)
|
||||||
end
|
end
|
||||||
|
|
||||||
println("Total reward eaarned: ", sum(rewards))
|
println("Total reward earned: ", sum(rewards))
|
||||||
println("Overall win rate: ", sum(rewards)/ num_trials)
|
println("Overall win rate: ", sum(rewards)/ num_trials)
|
||||||
println("Number times explored: ", num_times_explored)
|
println("Number times explored: ", num_times_explored)
|
||||||
println("Number times exploited: ", num_times_exploited)
|
println("Number times exploited: ", num_times_exploited)
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
using Random
|
||||||
|
using Plots
|
||||||
|
|
||||||
|
mutable struct Bandit
|
||||||
|
p::Number #the win rate
|
||||||
|
p_estimate::Number #How to estimate the win rate
|
||||||
|
N::Number #Number of samples
|
||||||
|
|
||||||
|
Bandit(p) = new(p,5,1)
|
||||||
|
end
|
||||||
|
|
||||||
|
function pull(ban::Bandit)
|
||||||
|
return convert(Int,rand() < ban.p)
|
||||||
|
end
|
||||||
|
|
||||||
|
function update(ban::Bandit, x::Number) #x is a sample number
|
||||||
|
ban.N += 1
|
||||||
|
ban.p_estimate = ((ban.N - 1) * ban.p_estimate + x) / ban.N
|
||||||
|
end
|
||||||
|
|
||||||
|
num_trials = 10000;
|
||||||
|
ϵ = 0.1;
|
||||||
|
bandit_probs = [0.2,0.5,0.75];
|
||||||
|
|
||||||
|
bandits = [Bandit(p) for p in bandit_probs];
|
||||||
|
rewards = zeros(num_trials);
|
||||||
|
num_times_explored = 0;
|
||||||
|
num_times_exploited = 0;
|
||||||
|
num_optimal = 0;
|
||||||
|
optimal_j = argmax([b.p for b in bandits]);
|
||||||
|
println("Optimal j: ", optimal_j);
|
||||||
|
|
||||||
|
for i in 1:num_trials
|
||||||
|
j = argmax([b.p_estimate for b in bandits])
|
||||||
|
x = pull(bandits[j])
|
||||||
|
rewards[i] = x
|
||||||
|
update(bandits[j],x)
|
||||||
|
end
|
||||||
|
|
||||||
|
for b in bandits
|
||||||
|
println("Mean estimate: ", b.p_estimate)
|
||||||
|
end
|
||||||
|
|
||||||
|
println("Total reward earned: ", sum(rewards))
|
||||||
|
println("Overall win rate: ", sum(rewards)/ num_trials)
|
||||||
|
println("Number of times selected each bandit ", [b.N for b in bandits])
|
||||||
|
|
||||||
|
cumulative_rewards = cumsum(rewards)
|
||||||
|
win_rates = cumulative_rewards ./ Array(1:num_trials)
|
||||||
|
plot(win_rates)
|
||||||
|
plot!(ones(num_trials) .* max(bandit_probs...))
|
|
@ -0,0 +1,50 @@
|
||||||
|
using Random
|
||||||
|
using Plots
|
||||||
|
|
||||||
|
mutable struct Bandit
|
||||||
|
p::Number #the win rate
|
||||||
|
p_estimate::Number #How to estimate the win rate
|
||||||
|
N::Number #Number of samples
|
||||||
|
Bandit(p) = new(p,0,0)
|
||||||
|
end
|
||||||
|
|
||||||
|
function pull(ban::Bandit)
|
||||||
|
return convert(Int,rand() < ban.p)
|
||||||
|
end
|
||||||
|
|
||||||
|
function update(ban::Bandit, x::Number) #x is a sample number
|
||||||
|
ban.N += 1
|
||||||
|
ban.p_estimate = ((ban.N - 1) * ban.p_estimate + x) / ban.N
|
||||||
|
end
|
||||||
|
|
||||||
|
function ucb(mean::Number,n::Number,nⱼ::Number)
|
||||||
|
return mean + √(2*log(n) / nⱼ)
|
||||||
|
end
|
||||||
|
|
||||||
|
num_trials = 100000;
|
||||||
|
ϵ = 0.1;
|
||||||
|
bandit_probs = [0.2,0.5,0.75];
|
||||||
|
|
||||||
|
bandits = [Bandit(p) for p in bandit_probs];
|
||||||
|
rewards = zeros(num_trials);
|
||||||
|
total_plays = 0;
|
||||||
|
optimal_j = argmax([b.p for b in bandits]);
|
||||||
|
|
||||||
|
for j in 1:size(bandits)[1]
|
||||||
|
x = pull(bandits[j])
|
||||||
|
total_plays += 1
|
||||||
|
update(bandits[j],x)
|
||||||
|
end
|
||||||
|
|
||||||
|
for i in 1:num_trials
|
||||||
|
j = argmax([ucb(b.p_estimate,total_plays,b.N) for b in bandits])
|
||||||
|
x = pull(bandits[j])
|
||||||
|
total_plays += 1
|
||||||
|
update(bandits[j],x)
|
||||||
|
|
||||||
|
rewards[i] = x
|
||||||
|
end
|
||||||
|
|
||||||
|
cumulative_average = cumsum(rewards) ./ Array(1:num_trials);
|
||||||
|
plot(cumulative_average,xaxis=:log)
|
||||||
|
plot!(ones(num_trials) .* max(bandit_probs...))
|
Loading…
Reference in New Issue