为了账号安全,请及时绑定邮箱和手机立即绑定

🚀 TorchRL 强化学习教程:PPO 进阶指南 🚀

标签:
杂七杂八

在本教程中,您将跟随使用 PyTorch 和 torchrl 实现倒立摆任务强化学习的详细步骤。从环境的初始化与配置、策略与价值网络设计,直至损失计算与优化,本教程涵盖的六个关键组件将指导您完成整个过程。

导入所需库
!pip3 install torchrl
!pip3 install gym[mujoco]
!pip3 install tqdm
设置超参数
import torch
from torchrl.envs.transforms import Compose, ObservationNorm, StepCounter

device = 'cuda' if torch.cuda.is_available() else 'cpu'
frame_skip = 4
total_frames = 1_000_000
frames_per_batch = 1000
num_epochs = 10
sub_batch_size = 64
clip_epsilon = 0.2
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

env = gym.make('Ant-v4', frame_skip=frame_skip)
env = Compose([ObservationNorm(), StepCounter(), env])

env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)
构建环境与转换器

环境创建与转换器设置

import gym
from gym.wrappers import TimeLimit

env = TimeLimit(gym.make('Ant-v4', frame_skip=frame_skip), max_episode_steps=1000)
env = Compose([DoubleToFloat(), ObservationNorm(), StepCounter(), env])
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)

网络设计

class Actor(nn.Module):
    def __init__(self, num_cells=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(11 + 1, num_cells),
            nn.Tanh(),
            nn.Linear(num_cells, num_cells),
            nn.Tanh(),
            nn.Linear(num_cells, num_cells),
            nn.Tanh(),
            nn.Linear(num_cells, 2),
        )
    def forward(self, obs):
        loc, scale = self.net(obs).chunk(2, dim=1)
        scale = torch.clamp(scale, min=1e-4)
        return loc, scale

class Value(nn.Module):
    def __init__(self, num_cells=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(11 + 1, num_cells),
            nn.Tanh(),
            nn.Linear(num_cells, num_cells),
            nn.Tanh(),
            nn.Linear(num_cells, num_cells),
            nn.Tanh(),
            nn.Linear(num_cells, 1),
        )
    def forward(self, obs):
        return self.net(obs)
创建数据收集器与重放缓冲区
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

collector = SyncDataCollector(env, Actor(), frames_per_batch=frames_per_batch, total_frames=total_frames)
replay_buffer = ReplayBuffer(storage=LazyTensorStorage(max_size=frames_per_batch), sampler=SamplerWithoutReplacement())
定义损失与训练循环
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

advantage_module = GAE(gamma=gamma, lmbda=lmbda, value_network=Value(), average_gae=True)
loss_module = ClipPPOLoss(
    actor_network=Actor(),
    critic_network=Value(),
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    critic_coef=1.0,
    loss_critic_type="smooth_l1"
)

optimizer = torch.optim.Adam(loss_module.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_frames // frames_per_batch, eta_min=1e-5)

logs = defaultdict(list)
pbar = tqdm(range(total_frames))

for i, data in enumerate(collector):
    for _ in range(num_epochs):
        advantage_module(data)
        for _ in range(sub_batch_size):
            subdata = replay_buffer.sample(sub_batch_size)
            loss_vals = loss_module(subdata.to(device))
            loss = loss_vals["loss_objective"] + loss_vals["loss_critic"] + loss_vals["loss_entropy"]
            loss.backward()
            nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm=10)
            optimizer.step()
            optimizer.zero_grad()

    logs["reward"].append(data["next", "reward"].mean().item())
    pbar.update(data.numel())
    pbar.set_description(f"Reward: {logs['reward'][-1]}, Steps: {i}")
    scheduler.step()

通过以上步骤,您将运用 TorchRL 实现 PPO 算法,解决倒立摆任务中的强化学习问题。记得在训练过程中记录奖励和损失,评估算法的性能随时间的演化。

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号

举报

0/150
提交
取消