stable baseline3
date: 2021-02-20 excerpt: stable baseline3について
stable baseline3について
- 強化学習の便利ライブラリ
- 特徴
- pytorchのラッパー
- 自分でモデル、外部環境を与えることができる
各機能の説明
BaseFeaturesExtractor 継承してネットワークとモデルを定義する
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class MyNN(BaseFeaturesExtractor):
"""
:param observation: (gym.Space)
:parma features_dim: (int) Number of features extracted.
"""
def __init__(self, observation, features_dim):
super(MyNN, self).__init__(observation_space, features_dim)
... # パラメータを書く
def forward(self, x):
... # ネットワークの接続を書く
Monitor
環境を引数にモニター可能な環境を返す
DQN
モデル
model = DQN('NAMING_PARAM',
m_env: Monitor,
policy_kwargs=policy_kwargs,
gamma,
learning_rate,
learning_starts,
target_update_interval,
exploration_fraction,
tau
)
# 学習
model.learn(total_timesteps)
公式サイトによる例
- Examples
- 可視化機能が微妙に足りないが、どういったワークフローで学習と推論が行われるか理解できる