Jun / Sep 25 2019
Chapter11 Counter Example
using ReinforcementLearning, ReinforcementLearningEnvironments, RLIntro.BairdCounter
using Plots
env = BairdCounterEnv()
BairdCounterEnv(1, DiscreteSpace{Int64}(1, 7, 7), DiscreteSpace{Int64}(1, 2, 2))
Base. struct RecordWeights <: AbstractHook weights::Vector{Vector{Float64}}=[] end (h::RecordWeights)(::PostActStage, agent, env, action_obs) = push!(h.weights, agent.π.π_target.learner.approximator.weights |> deepcopy)
Base. struct StateMapping <: AbstractPreprocessor mapping::Array{Int,2}=features end (p::StateMapping)(s::Int) = (p.mapping[s, :])
46.1s
env = BairdCounterEnv() ns = length(observation_space(env)) na = length(action_space(env)) init_weights = ones(Float64, 8) init_weights[7] = 10 features = zeros(ns, length(init_weights)) for i in 1:6 features[i, i] = 2 features[i, 8] = 1 end features[7, 7] = 1 features[7, 8] = 2 π_b = obs -> rand() < 6/7 ? 1 : 2 π_t = obs -> 2 prob_b = [6/7, 1/7] prob_t = [0., 1.] RL.get_prob(f::typeof(π_b), s, a) = prob_b[a] RL.get_prob(f::typeof(π_t), s, a) = prob_t[a] agent = Agent( π=OffPolicy( π_target=VBasedPolicy( learner=TDLearner( approximator=LinearVApproximator(init_weights), γ=0.99, optimizer=Descent(0.01), n=0, method=:SRS ), f=π_t ), π_behavior=π_b ), buffer=episode_RTSA_buffer(state_eltype=Any) ) env = WrappedEnv( env=BairdCounterEnv(), preprocessor=StateMapping() ) hook = RecordWeights() run(agent, env, StopAfterStep(1000);hook=hook) p = plot(legend=:topleft) for i in 1:length(init_weights) plot!(p, [w[i] for w in hook.weights]) end p