ISSN 1000-1239 CN 11-1777/TP

计算机研究与发展 ›› 2022, Vol. 59 ›› Issue (3): 674-682.doi: 10.7544/issn1000-1239.20200693

• 人工智能 • 上一篇    下一篇

基于度量学习的无监督域适应方法及其在死亡风险预测上的应用

蔡德润,李红燕   

  1. (北京大学信息科学技术学院 北京 100871) (机器感知与智能教育部重点实验室(北京大学) 北京 100871) (cdr@stu.pku.edu.cn)
  • 出版日期: 2022-03-07
  • 基金资助: 
    国家重点研发计划项目(2021YFE0205300);国家自然科学基金项目(62172018,62102008)

A Metric Learning Based Unsupervised Domain Adaptation Method with Its Application on Mortality Prediction

Cai Derun, Li Hongyan   

  1. (School of Electronics Engineering and Computer Science, Peking University, Beijing 100871) (Key Laboratory of Machine Perception (Peking University), Ministry of Education, Beijing 100871)
  • Online: 2022-03-07
  • Supported by: 
    This work was supported by the National Key Research and Development Program of China (2021YFE0205300) and the National Natural Science Foundation of China (62172018, 62102008).

摘要: 近年来,深度学习模型已在医疗领域的预测任务上得到广泛应用,并取得了不错的效果.然而,深度学习模型常会面临带标签训练数据不足、整体数据分布偏移和类别之间数据分布偏移的问题,导致模型预测的准确度下降.为解决上述问题,提出一种基于域对抗和加性余弦间隔损失的无监督域适应方法(additive margin softmax based adversarial domain adaptation, AMS-ADA).首先,该方法使用带有注意力机制的双向长短程记忆网络来提取特征.其次,该方法引入了生成对抗网络的思想,以域对抗的形式减少了整体数据之间数据分布偏移.然后,该方法引入了度量学习的思想,以最大化角度空间内决策边界的方式进一步减少了类别之间的数据分布偏移.该方法能够提升域适应的效果与模型预测的准确度.在真实世界的医疗数据集上进行了重症监护病人死亡风险预测任务,实验结果表明:由于该方法相较于其他5种基线模型能够更好地解决数据分布偏移的问题,取得比其他基线模型更好的分类效果.

关键词: 无监督域适应;深度学习;死亡风险预测, 域对抗网络, 度量学习, 注意力机制

Abstract: Deep learning models have been widely used in the field of healthcare prediction tasks and have achieved good results in recent ears. However, deep learning models often face the problems of insufficient labeled training data, the overall data distribution shift, and the category level data distribution shift, which leads to a decrease in the accuracy of the models. To solve the above problems, we propose an unsupervised domain adaptation method based on metric learning (additive margin softmax based adversarial domain adaptation, AMS-ADA). Firstly, this method uses the long short-term memory network with the attention mechanism to extract features. Secondly, this method introduces the idea of the generative adversarial network and reduces the overall data distribution shift via adversarial domain adaptation. Thirdly, this method introduces the idea of metric learning, which further reduces the category level data distribution shift by maximizing the decision boundary in the angular space. This method improves the effect of domain adaptation and the accuracy of the model. We perform the mortality prediction task of ICU patients in real-world healthcare datasets. The experimental results show that compared with other baseline models, our method can better solve the problem of data distribution shift and achieve better classification accuracy.

Key words: unsupervised domain adaptation, deep learning, mortality prediction, domain adversarial network, metric learning, attention mechanism

中图分类号: