• home
  • about
  • 全ての投稿
  • ソフトウェア・ハードウェアの設定のまとめ
  • 分析関連のまとめ
  • ヘルスケア関連のまとめ
  • 生涯学習関連のまとめ

stable baseline3

date: 2021-02-20 excerpt: stable baseline3について

tag: reinforcement learinglibstable baseline3


stable baseline3について

  • 強化学習の便利ライブラリ
  • 特徴
    • pytorchのラッパー
    • 自分でモデル、外部環境を与えることができる

各機能の説明

BaseFeaturesExtractor 継承してネットワークとモデルを定義する

from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class MyNN(BaseFeaturesExtractor):
  """
  :param observation: (gym.Space)
  :parma features_dim: (int) Number of features extracted.
  """
  def __init__(self, observation, features_dim):
	super(MyNN, self).__init__(observation_space, features_dim)

	... # パラメータを書く

  def forward(self, x):
	... # ネットワークの接続を書く

Monitor
環境を引数にモニター可能な環境を返す

DQN

モデル

model = DQN('NAMING_PARAM', 
			m_env: Monitor, 
			policy_kwargs=policy_kwargs,
            gamma, 
			learning_rate,
            learning_starts, 
			target_update_interval, 
			exploration_fraction, 
			tau
			)

# 学習
model.learn(total_timesteps)

公式サイトによる例

  • Examples
    • 可視化機能が微妙に足りないが、どういったワークフローで学習と推論が行われるか理解できる


reinforcement learinglibstable baseline3 Share Tweet