reinforcement/cours_ex1.jl

49 lines
1.2 KiB
Julia
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using LinearAlgebra
using Distributions
using Random
using Zygote
using Plots
Random.seed!(0);
function normalgenerator(amount::Number,μ::Number,σ::Number,lowerbound::Number=0,upperbound::Number=1)
return rand(Truncated(Normal(μ,σ),lowerbound,upperbound),amount)
end
function noisyline(intercept::Number,slope::Number,samples::Number,μ::Number,σ::Number,lb::Number=0,ub::Number=1)
noise = normalgenerator(samples,μ,σ,lb,ub)
exes = Array{Float64,1}(undef,samples)
for i in 1:samples
exes[i] = i
end
line = slope .* exes .+ intercept
y = noise .+ line
w = ones(samples)
X = hcat(exes,w)
return X,y
end
function MSE(X::Array,y::Array,w::Array)
return mean((y - X * w).^2)
end
function gradient_descent(X::Array,y::Array,α::Number,w::Array,iter::Number)
costs = Array{Float64,1}(undef,iter)
for i in 1:iter
costs[i] = MSE(X,y,w)
∇X, ∇y, ∇w = gradient(MSE,X,y,w)
w = w - α * ∇w
end
return w,costs
end
function get_res_line(X::Array,result::Array)
return result[1] .* X[:,1] .+ result[2]
end
X,y = noisyline(2,4,100,0,1);
N,D = size(X);
w = ones(D);
pred,cost = gradient_descent(X,y,0.0001,w,6);