Jun / Sep 25 2019

Chapter05 Black Jack (Fig_5_3)

using ReinforcementLearning, ReinforcementLearningEnvironments
using RLIntro, RLIntro.BlackJack
env = BlackJackEnv(;init=(BlackJack.Hands(13, [1, 2], true), BlackJack.Hands(2, [2], false)))

ns, na = length(observation_space(env)), length(action_space(env))
(220, 2)
init_state = get_state(observe(env))
27
const GOLD_VAL = -0.27726
-0.27726
stick_action = findall(x -> x == :stick, BlackJack.ACTIONS)[]

table = fill(1, size(BlackJack.INDS)...)
table[:, 10:11, :] .= stick_action
table = reshape(table, :);

π_behavior = TabularRandomPolicy(fill(0.5, length(table), na))
TabularRandomPolicy([0.5 0.5; 0.5 0.5; … ; 0.5 0.5; 0.5 0.5])
struct StoreMSE <: AbstractHook
    mse::Vector{Float64}
end

StoreMSE() = StoreMSE([])

(f::StoreMSE)(::PostEpisodeStage, agent, env, obs) = push!(f.mse, (GOLD_VAL - agent.π.π_target.learner.approximator(init_state))^2)
function mse_of_ordinary_sampling()
    agent = Agent(
        π=OffPolicy(
            VBasedPolicy(
                learner=MonteCarloLearner(
                    approximator=TabularVApproximator(ns),
                    kind=FIRST_VISIT,
                    sampling=ORDINARY_IMPORTANCE_SAMPLING
                    ),
                f= TabularDeterministicPolicy(table=table,nactions=na)
                ),
            π_behavior
            ),
        buffer=episode_RTSA_buffer()
    )
    hook = StoreMSE([])
    run(agent, env, StopAfterEpisode(10000, is_show_progress=false);hook=hook)
    hook.mse
end
mse_of_ordinary_sampling (generic function with 1 method)
function mse_of_weighted_sampling()
    agent = Agent(
        π=OffPolicy(
            VBasedPolicy(
                learner=MonteCarloLearner(
                    approximator=TabularVApproximator(ns),
                    kind=FIRST_VISIT,
                    sampling=WEIGHTED_IMPORTANCE_SAMPLING,
                    returns=(CachedSum(), CachedSum())
                    ),
                f= TabularDeterministicPolicy(table=table,nactions=na)
                ),
            π_behavior
            ),
        buffer=episode_RTSA_buffer()
    )
    hook = StoreMSE([])
    run(agent, env, StopAfterEpisode(10000, is_show_progress=false);hook=hook)
    hook.mse
end
mse_of_weighted_sampling (generic function with 1 method)
using Plots, StatsBase

plot(mean((mse_of_ordinary_sampling() for _ in 1:100)); xscale=:log10, label="Ordinary Importance Sampling")
plot!(mean((mse_of_weighted_sampling() for _ in 1:100)); xscale=:log10, label="Weighted Importance Sampling")