DiscoRL
DiscoRL
DeepMind 最新力作,10月15日 Nature 预印版。
论文地址: https://www.nature.com/articles/s41586-025-09761-x
介绍
DeepMind 证明,通过在大量复杂环境中,对智能体的累积经验直接进行元学习,机器能够自己发现一种超越人工设计规则的先进的强化学习算法,这种算法是一个更新智能体策略与预测的强化学习规则,DeepMind 又在 Atari 测试中刷榜了,所以他们认为实现先进人工智能需要的强化学习算法可能很快不依赖于人工设计而是智能体自己发现。总的来说这个看法和 Sutton 的思想一致,可能元强化学习是一个实现先进人工智能的途径。
这种自主发现强化学习规则的方法,依赖于两个重要改进:搜索范围扩大到规则潜空间;在复杂多样的环境进行大规模的学习。
规则潜空间,作者认为强化学习算法的核心是规则,规则向着目标更新一个或者多个预测值或者策略本身,而目标可能是未来奖励,未来预测值等等。在不同目标下的规则有 TD( \( r + \gamma V(s') \) ),Q-learning( \( r + \gamma \max_a Q(s', a) \) ),PPO(r + \gamma V(s’) - V(s)),辅助任务(监督学习辅助强化学习训练,但是不直接在更新中,譬如学一个世界模型来预测下个状态等),后继特征( \( \phi(s_{t+1}) + \gamma \Psi(s_{t+1}, \pi(s_{t+1})) \) 其中 \( \Psi(s_{t+1}, \pi(s_{t+1})) \) 称为后继特征,就是状态特征的累计折扣),分布强化学习( \(r + \gamma Z(s', a')\) 其中 \( Z(s', a') \) 是回报分布),目标的选择决定了这些目标是价值函数、模型还是后继特征。
在作者的框架中,强化学习规则用一个元网络表示,这个网络直接决定了预测和策略的更新方向,这样系统能够在无需定义目标是什么的情况下自动发现有效的预测机制(选择什么目标)和使用方式(怎么用这个目标更新策略)。理论上这个框架可以发现现有的强化学习规则,但是这样灵活的形式也能让智能体创造出一些全新的强化学习规则,在一些特定环境里这些规则可能会更优秀。
在发现阶段,作者创建了一群智能体,这些智能体在各自独立的环境中进行并行训练,探索哪些强化学习规则更好,环境从多样化且有挑战性的任务集合(环境的差别在于任务不同,初始条件不同,任务参数不同,奖励函数不同)中随机抽取,每个智能体都按照自己的强化学习规则(不同的算法,不同的目标,不同的超参)更新,这些超参数和目标被参数化为一个可微分的强化学习规则,用一个向量表示,使用元梯度逐步改进这个强化学习规则本身。这样就能进化出一个更好的通用强化学习算法。
DeepMind的大规模测试表明,他们用这个元强化学习算法,学出来的称为DiscoRL的强化学习规则,再次在Atari刷榜,而且在没训练过的任务上也有不错的效率和泛化性,如果在发现阶段,再用一些更复杂多样的环境,DiscoRL的性能和泛化水平会进一步提高。
发现方法
发现方法,分为两类优化:智能体优化与元优化。智能体参数通过将其策略和预测更新至强化学习规则生成的目标来优化,强化学习规则的元参数通过更新其目标以最大化智能体的累计奖励来优化。
智能体网络
作者没有人工设计这些更新方式(例如策略梯度)、损失(例如TD误差)、预测(例如价值估计)等等,定义了一个有丰富表达能力的预测空间,这个空间的各个参数的表征义也是由元网络指定的,以元学习的方式让智能体自主优化所需目标,这方法能保持对现有强化学习核心思想的表征能力,也能支持广阔的强化学习规则探索。
由 \(\theta\) 参数化的网络除了输出策略 \( \pi \) 外还输出两类预测:
-
观测条件向量预测 \(y(s) \in \mathbb{R}^n\),一个依赖于 \(s\) 的输出用于表征类似于,预测类:例如值函数估计的TD预测和蒙特卡洛预测的输出。(在 Atari 这种图像任务中对原始状态-图像进行特征提取很重要,所以默认提取,提取后的特征向量,认为是一般意义上的状态用 \(s\) 表示,而原始状态命名为观测用 \(o\) 表示)
-
动作条件向量预测 \( z(s, a) \in \mathbb{R}^m \),一个依赖于 \(s, a\) 的输出用于表征类似于,控制类:例如动作值函数估计的TD控制和蒙特卡洛控制的输出。
后继特征等等也都是依赖于 \(s\) 或者 \(s, a\) 所以这两个向量有极强的普适性。
除了这两种预测,作者还是给智能体加上了一些预定义好的预测减小搜索空间,一个是动作值函数的预测 \(Q(s, a)\) 一个是动作条件辅助策略预测 \(P(s, a)\) (这个预测输出一个概率分布,智能体在 \(s\) 执行了动作 \(a\) 后的未来策略的动作分布 \(P((s, a) \approx \pi(a' | s')\))。
元网络
多数现代强化学习都是用前向视角,算法接受从时间步 \(t\) 到 \(t+n\) 未来 \(n\) 步的轨迹信息,利用这些信息更新智能体的预测或者策略,一般是使用自举的目标更新,是对未来预测的目标(例如TD)
作者的强化学习规则,使用一个元网络去决定智能体应该把预测和策略如何移动,在时刻 \(t\) 生成这一次更新的目标,元网络的输入包括:
-
从 \(t\) 到 \(t+n\) 智能体预测和策略的轨迹
-
从 \(t\) 到 \(t+n\) 的奖励和 episode 结束标志
使用一个标准的 LSTM 网络去处理输入,也可以使用别的架构。
所以元网络和智能体网络不是同步学习的,在智能体网络给出预测后的 \(n\) 步,元网络才收集轨迹进行训练。
这种输入输出的选择保留了强化学习规则的优良特性:
-
元网络能够处理任意观测值和任意规模的离散动作空间,因为元网络直接以预测信息作为输入从预测中间接获取信息。通过跨动作维度共享的权重处理动作特定的输入与输出(使用特定的头去处理不同动作的输入与输出,隐层不变),能够泛化到完全不同的环境(从极少的动作到极多的动作都可以)。
-
元网络对智能体网络的设计是无感的,因为他只接受智能体的输出,只要智能体网络能生成所需形式的输出 \(\pi, y ,z\),元网络所发现的强化学习规则就可以泛化到任意的智能体架构或者规模。(元网络不挑智能体网络,是什么都行)
-
元网络定义的搜索空间包含了自举法(能学到TD等方法)这一重要的算法思想,因为它的输入是一个未来的轨迹。
-
由于元网络同时处理策略与预测值,他不仅能元学习辅助任务,还能直接利用预测值更新策略(譬如使用优势)
-
输出目标值在表达能力上严格优于输出标量损失函数,因为它将 Q-learning 等半梯度方法也纳入了搜索空间。(由于神经网络的表达能力优势,搜索空间不止包含标量的损失函数,损失函数一定是可导的–标准梯度,也包括半梯度的方法,例如 Q-leaning \(Q_{target} = r + \gamma \max_a Q(s', a')\) 这样损失不可导的半梯度方法)
在继承了标准强化学习算法的这些特性的基础上,参数丰富的神经网络使所发现的规则能够实现效率更高,情景适应性更强的算法。
智能体优化
智能体参数 \(\theta\) 通过最小化预测与策略和元网络生成的目标之间的距离来更新,损失函数可以表示为:
$$ L(\theta) = \mathbb{E}_{s, a, s' \sim \pi_\theta}[D(\hat{\pi}, \pi_\theta(s)) + D(\hat{y}, y_\theta(s)) + D(\hat{z}, z_\theta(s, a)) + L_{aux}] $$其中,\(D(p, q)\) 表示 \(p\) 与 \(q\) 之间的距离函数,作者选择了 KL 散度作为距离函数,具有足够的普适性,且DeepMind团队此前的工作已经证明了 KL 散度是最合适的,相比交叉熵,L2,JS散度而言。算距离前使用 softmax 函数对每个向量进行了归一化处理。
辅助损失 \(L_{aux}\) 是预定义语义的预测的损失,
$$ L_{aux} = D(\hat{q}, q_\theta(s, a)) + D(\hat{p}, p_\theta(s, a)) $$\(\hat{q}\) 是 Retrace 算法的动作值目标,再使用 two-hot 变成类似于 C51 的离散分布,方便计算 KL 散度,\(\hat{p}\) 是单步未来策略,这里的距离也是用 KL 散度。
这个 Retrace 是 Q-target 估计器里的 GAE,在 off-policy 的情况下算一个稳健,低方差,几乎无偏的 Q-target
$$ \begin{align} \delta_k &= r_k + \gamma (1 - d_{k+1}) V(s_{k+1}) - Q(s_k, a_k) \\ c_k &= \lambda \cdot \min\left(1, \frac{\pi(a_k \mid s_k)}{\mu(a_k \mid s_k)}\right) \cdot c_{k+1} \\ q^{ret}_k &= V(s_k) + c_k [ Q(s_k,a_k) − V(s_k) + \delta_k ] \end{align} $$\(k\) 是时间步, \(\delta_k\) 是 V 的单步 TD 误差,\( r_k + \gamma (1 - d_{k+1}) V(s_{k+1}) \) 是 Q 的自举估计值,令 \(Q(s_k, a_k)\) 逼近于目标回报的自举估计。\( d_{k+1} \) 是终止指示器,终止状态没有下一个状态了所以直接用奖励。
\(c_k\) 是截断重要性权重,一个后向递归,将 \(\lambda\) 和截断 IS 权重乘积累计起来,相当于 \(\lambda\) 折扣和裁剪的重要性采样权重,对该步的信任程度考虑后续轨迹。
\(q_{k}^{ret}\) 是 Retrace 动作值目标,其中 V 是平均动作 Q 值,当学习目标不可靠时,V 值提供一个稳定的基准。\(Q(s_k,a_k) − V(s_k)\) 是优势项,当前动作比平均动作差时减小目标 Q 值,防止高估;好时防止低估。\(\delta_k\) 就是标准的更新信号,当 \(c_k = 1\) 退化成标准的目标 Q 值。
元优化
目标是发现一个有元网络和元参数 \(\eta\) 表示的强化学习规则,让智能体最大化在不同训练环境中的奖励,元参数的性能指标和它的梯度可以表示如下:
$$ \begin{align} J(\eta) &= \mathbb{E}_\varepsilon \mathbb{E}_\theta [J(\theta)] \nabla_\eta J(\eta) &= \mathbb{E}_\varepsilon \mathbb{E}_\theta [\nabla_\eta \theta \nabla_\theta J(\theta)] \end{align} $$其中 \(\varepsilon\) 表示从环境分布中采样的环境,\(\theta\) 代表由初始参数分布及其在强化学习规则作用下随学习过程演化所生成的智能体参数。\(J(\theta) = \mathbb{E}[\sum_t \gamma^t r_t]\) 是期望累计折扣奖励。元参数使用上述方差通过梯度上升法进行优化。
为了估计元梯度,作者实例化了一组智能体群体,这些智能体在一系列采样环境中依据元网络进行学习。为确保此采样估计的元梯度近似值接近目标元梯度真实分布,采用了来自高难度基准测试的大量复杂环境。探索过程呈现出多样化的强化学习挑战,例如奖励稀疏性、任务时间跨度、环境的部分可观测性或随机性。
为促使更新规则在有限智能体生命周期内快速学习,各智能体的参数会被周期性重置。元梯度项 \(\nabla_\eta J(\eta)\) 可通过链式法则分解为两个梯度项:\(\nabla_\eta \theta\) 与 \(\nabla_\theta J(\theta)\)。第一项可理解为智能体更新过程的梯度,而第二项是标准强化学习目标的梯度。为估计第一项,对智能体进行多次迭代更新,沿整个更新过程反向传播。由于智能体数量过多,作者使用滑动窗口只对最近的20次智能体更新进行反向传播。为估计第二项,使用 A2C 方法,再训练一个元价值函数网络,估计优势从而计算策略梯度,这个元价值函数只用于发现阶段,不是元学习规则的一部分。
实验
测试了 Atari 57,再次刷榜,从未进行训练的环境也进行了测试,依然有不错的泛化性,通篇都在说很厉害。
分析
作者对在雅达利57个游戏上训练的强化学习规则进行了仔细分析。
定性分析
从定性角度看,所发现的预测信号在关键事件(如获得奖励或策略熵值变化)发生前会呈现显著峰值。进一步通过测量观测数据各部分的梯度范数,探究了观测值中哪些特征会引发元学习预测的强烈响应。结果表明,元学习预测倾向于关注未来可能相关的对象(观测中未来可能会变得极为关键的东西,譬如游戏中一个可能会刷怪的位置,几秒后就刷怪了),这与策略函数和价值函数的关注点(短视,只能看到现在存在的东西)存在明显差异。这些发现表明,DiscoRL已学会识别并预测适度时间范围内的关键事件,从而对策略函数和价值函数等现有概念形成了有效补充。
信息分析
为验证定性分析结论,作者进一步探究了预测结果所包含的信息。首先收集 DiscoRL 智能体在 10 款雅达利游戏中的运行数据,并训练神经网络分别从发现的预测结果、策略函数或价值函数中预测目标变量。结果显示,相较于策略函数与价值函数,发现的预测结果对即将到来的高额奖励及未来策略熵值具有更强的信息表征能力。这表明发现的预测可能捕获了任务相关的独特信息,而这些信息未被策略函数和价值函数充分提取。
自举机制的涌现
作者还发现了证据表明 DiscoRL 采用了自举机制。当给元网络对未来时间步的预测输入 \(z_{t+k}\) 扰动时,会显著影响当前目标预测值 \(\hat{z}_t\)。这意味着未来预测结果被用于构建当前预测的目标值。这种自举机制与发现的预测被证实对算法性能至关重要。若在计算目标值 \(\hat{y}\) 和 \(\hat{z}\) 时将元网络的 \(y\) 和 \(z\) 输入置零(从而禁用自举机制),算法性能将大幅下降。若在计算包括策略目标在内的所有目标值时将 \(y\) 和 \(z\) 输入置零,性能衰退更为显著。这表明发现的预测深度参与策略更新过程,而非仅作为辅助任务存在。(这说明存在一个尚未被发现的目标有远超现有任何目标的合理性)
方法
实现细节就不仔细看了,这里贴上翻译
元网络
元网络将智能体输出的轨迹以及环境中的相关量映射为目标值:
$$\text{m}_\eta: f_\theta(s_t), f_{\theta^-}(s_t), a_t, r_t, b_t, \ldots, f_\theta(s_{t+n}), f_{\theta^-}(s_{t+n}), a_{t+n}, r_{t+n}, b_{t+n} \mapsto \hat{\pi}, \hat{y}, \hat{z}$$其中,\(\eta\) 代表元参数,而 \(f_\theta=[{\pi_\theta(s), y_\theta(s), z_\theta(s), q_\theta(s)}]\) 是参数为 \(\theta\) 的智能体输出。\(a, r, b\) 分别是智能体采取的动作、获得的奖励和回合终止指示器。\(\theta^-\) 是参数 \(\theta\) 的指数移动平均。这种函数形式允许元网络搜索的规则空间严格大于元学习一个标量损失函数所能搜索的空间。这将在补充信息中进一步讨论。元网络通过在时间上反向展开)一个 LSTM 来处理输入。这使得它能够考虑 \(n\) 步的未来信息来生成目标,类似于 \(\text{TD}(\lambda)\) 等多步 RL 方法。我们发现,这种架构在计算上比 Transformer 等替代方案更有效率,同时取得了相似的性能。元网络使用共享权重跨动作维度处理动作特定的输入和输出,并通过对其进行平均来计算一个中间嵌入向量。这使得元网络能够处理任意数量的动作。更多细节可在补充信息中找到。为了使元网络能够发现更广泛的算法类别(例如奖励归一化等需要维护智能体生命周期统计信息的算法),我们添加了一个额外的循环神经网络。这个 “Meta-RNN” 是沿着智能体更新(从 \(\theta_i\) 到 \(\theta_{i+1}\))正向展开,而不是沿着回合中的时间步展开。Meta-RNN 的核心是另一个 LSTM 模块。对于每一次智能体更新,一整个批次的轨迹被嵌入为一个单一向量,并传递给这个 LSTM。Meta-RNN 能够潜在地捕获智能体整个生命周期中的学习动态,生成适应特定智能体和环境的目标。Meta-RNN 略微提高了总体性能。更多细节在补充信息中描述。
元优化稳定性
大规模进行发现时,会出现一些挑战,主要是因为来自不同环境的智能体带来的不平衡梯度信号(unbalanced gradient signals)以及由智能体较长生命周期引起的短视梯度(myopic gradients)。我们引入了几种方法来缓解这些问题。稳定方法 1:优势项归一化首先,在 \(\text{A2C}\) 中估计优势项(advantage term)以估计元梯度中的 \(\nabla_{\theta}J(\theta)\) 时,我们对优势项进行如下归一化:
$$\bar{A} = \frac{A - \mu}{\sigma}$$其中 \(\bar{A}\) 是归一化的优势,\(\mu\) 和 \(\sigma\) 分别是在智能体生命周期内累积的优势的指数移动平均和标准差。我们发现这使得优势项的规模在不同环境之间保持平衡。稳定方法 2:聚合元梯度时的个体 Adam 优化器此外,在聚合来自智能体群体的元梯度时,我们在对每个智能体计算出的元梯度应用一个单独的 \(\text{Adam}\) 优化器之后,取所有智能体元梯度的平均值:
$$\eta \leftarrow \eta + \frac{1}{n} \sum_{i=1}^{n} \text{ADAM}(g_i)$$其中 \(g_i\) 是来自群体中第 \(i\) 个智能体的元梯度估计。我们发现这有助于归一化每个智能体元梯度的量级。稳定方法 3:元正则化损失(Meta-Regularisation Losses)我们在元目标 \(J(\eta)\) 中添加了两个元正则化损失(\(L_{\text{ent}}\) 和 \(L_{\text{kl}}\)),如下所示:
$$\text{E}_{\theta}[J(\theta) - L_{\text{ent}}(\theta) - L_{\text{kl}}(\theta)]$$\(L_{\text{ent}}(\theta)\):
$$L_{\text{ent}}(\theta) = -E_{s, a}[H(y_\theta(s)) + H(z_\theta(s, a))]$$这是对预测 \(y\) 和 \(z\) 的熵正则化,其中 \(H(\cdot)\) 是给定分类分布的熵。我们发现这有助于防止预测过早收敛。\(L_{\text{kl}}(\theta)\):
$$L_{\text{kl}}(\theta) = D_{\text{KL}}(\pi_{\theta^-} || \hat{\pi})$$这是使用智能体参数的指数移动平均 \(\theta^-\) 的目标网络策略与元网络策略目标 \(\hat{\pi}\) 之间的 KL 散度。这防止了元网络提出过于激进的更新,这些更新可能导致性能崩溃。
补充细节
我们开发了一个基于 JAX 的框架,该框架借鉴了 Podracer 架构,将计算分布在 \(\text{TPU}\) 上。在这个框架中,每个智能体都是独立模拟的,所有智能体的元梯度都是并行计算的。通过聚合所有智能体的元梯度,元参数被同步更新。我们使用了 MixFlow-MG 来最小化运行的计算成本。对于 Disco57,我们通过以词典顺序循环遍历 57 个 \(\text{Atari}\) 环境,实例化了 128 个智能体。对于 Disco103,我们实例化了 206 个智能体,包含了来自 \(\text{Atari}\)、\(\text{ProcGen}\) 和 \(\text{DMLab-30}\) 的每个环境的两份拷贝。Disco57 使用 1024 个 \(\text{TPUv3}\) 核心发现了 64 小时,Disco103 使用 2048 个 \(\text{TPUv3}\) 核心发现了 60 小时。用于计算元梯度的元价值函数使用 V-Trace 进行更新,折扣因子为 0.997,\(\text{TD}(\lambda)\) 系数为 0.95。元价值函数和智能体网络使用 Adam 优化器进行优化,学习率为 0.0003。对于元参数更新,我们使用学习率为 0.001 且梯度裁剪为 1.0 的 \(\text{Adam}\) 优化器。每个智能体基于一个包含 96 条轨迹的批次进行更新,每条轨迹包含 29 个时间步。在每个批次中,在策略(on-policy)轨迹和从重放缓冲区(replay buffer)采样的轨迹被混合,其中重放轨迹占每个批次的 90%。在每个元步骤中,生成 48 条轨迹用于计算元梯度和更新元价值函数。智能体参数重置机制每个智能体的参数在其消耗完分配的经验预算后会被重置。重置时,新的经验预算从类别 (200M, 100M, 50M, 20M) 中采样,其权重与预算成反比,这样每个类别采样的总经验量是相同的。这是基于我们观察到大部分学习发生在生命周期的早期,并且在我们初步的小规模调查中显示出边际改进。
超参和评估
对于在保留基准(held-out benchmarks)上的评估,我们仅调优了学习率,其范围为 {0.0001, 0.0003, 0.0005}。其余的超参数是根据文献中基线算法选取的。对 \(\text{Atari}\) 游戏的评估(如图 2a和扩展数据表 1所示)使用了 \(\text{IMPALA}^{34}\) 网络的修改版本,其参数数量有所增加,以匹配 \(\text{MuZero}^{2}\) 使用的智能体网络大小。具体来说,我们使用了具有四个卷积残差块(过滤器数量分别为 256, 384, 384, 256)、一个 768 维度共享全连接最终层,以及一个基于 \(\text{LSTM}\) 的动作条件预测组件的网络。该组件由一个隐藏状态维度为 1024 的 \(\text{LSTM}\) 和一个1024 维度的全连接层组成。\(\text{DMLab-30}\) 评估(图 2c和扩展数据表 3)使用了与 \(\text{IMPALA}\) 中使用的相同的动作空间离散化和智能体网络架构。有关超参数列表,请参见扩展数据表 6。为了验证我们评估的统计显著性,我们对 \(\text{Atari}\)、\(\text{ProcGen}\) 和 \(\text{DMLab}\) 中的每个环境使用了两个随机种子进行初始化;对 \(\text{Crafter}\) 和 \(\text{Nethack}\) 使用了三个种子;对 \(\text{Sokoban}\) 使用了五个种子。