高级检索

    基于余弦相似度的多模态模仿学习方法

    Multi-Modal Imitation Learning Method with Cosine Similarity

    • 摘要: 生成对抗模仿学习(generative adversarial imitation learning, GAIL)是一种基于生成对抗框架的逆向强化学习(inverse reinforcement learning, IRL)方法,旨在从专家样本中模仿专家策略. 在实际任务中,专家样本往往由多模态策略产生. 然而,现有的GAIL方法大部分假设专家样本产自于单一模态策略,导致生成对抗模仿学习只能学习到部分模态策略,即出现模式塌缩问题,这极大地限制了模仿学习方法在多模态任务中的应用. 针对模式塌缩问题,提出了基于余弦相似度的多模态模仿学习方法(multi-modal imitation learning method with cosine similarity,MCS-GAIL). 该方法引入编码器和策略组,通过编码器提取专家样本的模态特征,计算采样样本与专家样本之间特征的余弦相似度,并将其加入策略组的损失函数中,引导策略组学习对应模态的专家策略. 此外,MCS-GAIL使用新的极小极大博弈公式指导策略组以互补的方式学习不同模态策略. 在假设条件成立的情况下,通过理论分析证明了MCS-GAIL的收敛性. 为了验证方法的有效性,将MCS-GAIL用于格子世界和MuJoCo平台上,并与现有模式塌缩方法进行比较. 实验结果表明,MCS-GAIL在所有环境中均能有效学习到多个模态策略,且具有较高的准确性和稳定性.

       

      Abstract: Generative adversarial imitation learning is an inverse reinforcement learning (IRL) method based on generative adversarial framework to imitate expert policies from expert demonstrations. In practical tasks, expert demonstrations are often generated from multi-modal policies. However, most of the existing generative adversarial imitation learning (GAIL) methods assume that the expert demonstrations are generated from a single modal policy, which leads to the mode collapse problem where the generative adversarial imitation learning can only partially learn the modal policies. Therefore, the application of the method is greatly limited for multi-modal tasks. To address the mode collapse problem, we propose the multi-modal imitation learning method with cosine similarity (MCS-GAIL). The method introduces an encoder and a policy’s group, extracts the modal features of the expert demonstrations by the encoder, calculates the cosine similarity of the features between the sample of policy sampling and the expert demonstrations, and adds them to the loss function of the policy’s group to help the policy’s group learn the expert policies of the corresponding modalities. In addition, MCS-GAIL uses a new min-max game formulation for the policy’s group to learn different modal policies in a complementary way. Under the assumptions, we prove the convergence of MCS-GAIL by theoretical analysis. To verify the effectiveness of the method, MCS-GAIL is implemented on the Grid World and MuJoCo platforms and compared with the existing mode collapse methods. The experimental results show that MCS-GAIL can effectively learn multiple modal policies in all environments with high accuracy and stability.

       

    /

    返回文章
    返回