基于因果框架的医疗领域数据分布偏移评估

中文标题:基于因果框架的医疗领域数据分布偏移评估

英文标题:Diagnosing failures of fairness transfer across distribution shift in real-world medical settings

发布平台:NeurIPS

NeurIPS

发布日期:2023-02-10

引用量(非实时):

DOI:10.48550/arXiv.2202.01034

作者:Jessica Schrouff, Natalie Harris, Oluwasanmi Koyejo, Ibrahim Alabdulmohsin, Eva Schnider, Krista Opsahl-Ong, Alex Brown, Subhrajit Roy, Diana Mincu, Christina Chen, Awa Dieng, Yuan Liu, Vivek Natarajan, Alan Karthikesalingam, Katherine Heller, Silvia Chiappa, Alexander D'Amour

关键字: #因果框架 #分布偏移 #AI公平

文章类型:preprint

品读时间:2023-05-09 14:36

1 文章萃取

1.1 核心观点

  • 在医疗领域进行机器学习安全部署时,合理诊断和减轻数据分布偏移是保障模型公平性的关键。本文基于因果框架提出了一种通用的分布偏移检验算法,能独立于场景适用于多种类型数据的偏移检验,最终结果包含针对每一项特征的检验p值,并以联合因果图的形式说明分布的结构偏移,可以用于辅助医疗模型的落地。

1.2 综合评价

  • 本文提出的算法通用性强,可以通过建模算法的替换升级适用更全面的场景
  • 本文在两类数据(皮肤图像和病历文本)进行多中心的算法验证,结论具备较好地说服力
  • 整体算法的有效性依赖于联合因果图的构建,可靠性对数据规模有一定要求
  • 缺少对算法结果的应用相关试验,仅从理论上提出了一些可行的方案

1.3 主观评分:⭐⭐⭐⭐⭐

2 精读笔记

2.1 背景知识

符号定义:

  • $X$是一组特征(协变量),$Y$是要预测的结果或标签
  • $A$是一个或多个敏感属性(比如人口统计学属性)
  • 一个分类或回归模型$f(X')$
  • $X'$根据上下文有两种定义,$X':=(X,A)$或$X':=X$

因果贝叶斯网络(Causal Bayesian network,CBN)是一种有向无环图

  • 节点$U$表示各个属性/风险事件,边表示不同属性/风险事件间的因果联系
  • 当节点$U^i$指向节点$U^j$时,表示$U^i$是$U^j$的一个直接原因(a direct cause)
  • 节点$U^i$的所有直接原因可表示为父因集合(causal parents),用符号$pa(U^i)$表示

联合因果图(joint causal graph)是对CBN的一种拓展

  • 为更好地表示数据分布的偏移,引入环境变量$S$来扩展CBN
  • $S = 0$表示来自源环境的数据,$S = 1$表示来自目标环境的数据
  • 源数据可表示为$P ( A , X , Y | S = 0)$,目标数据可表示为$P ( A , X , Y | S = 1)$
  • 当变量$X$存在环境导致的分布偏移时,可用$S\rightarrow X$表示
  • 环境导致的分布偏移可能是复合的,即作用在多个变量中:$S\rightarrow {A,X,Y}$

联合因果图示例:

  • 上图中假设目标Y是协变量X的直接原因,而敏感特征A是X和Y的直接原因
  • 这种情况也被称为反因果推断(Anti-causal prediction)
  • 当协变量X是目标Y的直接原因时,被称为因果推断(Causal prediction)

联合因果推理(Joint Causal Inference,JCI)框架:

  • 一种在不同情况下对系统因果进行建模的框架
  • 该框架兼容多种因果发现算法,并允许潜在的数据杂质
  • JCI不需要了解干预目标或类型,可以统一地处理不同干预措施
  • 关于JCI框架的更多细节可参阅原论文

2.2 算法细节

数据分布变化会严重影响ML模型的行为,这种影响也会表现在模型的公平性上。比如在医疗领域,模型在医院A的训练满足公平性标准,但是在使用医院B的数据进行测试时,会发现可能不满足公平性标准。

有很多研究数据分布偏移、模型鲁棒性、公平性,但很少有方法能够诊断转变的本质并指导制定适当的缓解策略。而本文设计了一种满足JCI框架的统计测试方法,可以基于应用和目标数据集的简化联合因果图,评估系统遇到的分布偏移的结构,识别阻碍数据公平性的属性转移。具体算法对应伪代码如下:

  • 该算法针对图中每个节点$U$(对应每一种$A,X,Y$)进行单独的环境偏移评估
  • $S = 0$表示源环境,对应数据为$D$;$S = 1$表示目标环境,对应数据为$D'$
  • 每个节点的环境偏移都是通过假设检验对应$p$值来描述的。原假设为$H_0:P(U|pa(U),S=0)=P(U|pa(U),S=1)$,拒绝此原检验则意味着联合因果图中存在$S\rightarrow U$,即节点$U$存在分布偏移的情况

原假设$H_0$的转换:

  • 考虑用变量$M$来表示$pa(U)$,$M$的概率密度函数可定义为$\pi(M)$: $$H_0:P(U|M,S=0)=P(U|M,S=1)$$
  • 原假设$H_0$对于任意的$pa(U)$都应该是成立的: $$H_0:\int E[U|M|S=0]\pi(M)dM=\int E[U|M|S=1]\pi(M)dM$$
  • 上式经过化简后,可转化为以下形式(过程略): $$H_0: E[\frac{\pi(M)}{P(M|S=0)}U|S=0]= E[\frac{\pi(M)}{P(M|S=1)}U|S=1]$$
  • 假设$M$满足均匀分布(不同集合的出现概率是相等的),可定义$w_0(M)$和$w_1(M)$如下: $$\begin{equation} \left\{ \begin{gathered} w_0(M)=\frac{\pi(M)}{P(M|S=0)}\propto P(M|S=0)^{-1} \ \\ w_1(M)=\frac{\pi(M)}{P(M|S=1)}\propto P(M|S=1)^{-1} \end{gathered} \right.

\end{equation}$$

  • 将$w_0(M)$和$w_1(M)$带入$H_0$后便得到了原假设的最终形式: $$H_0: E[w_0U|S=0]- E[w_1U|S=1]=0$$

上式的原假设,可解释为$U$的重加权(权重为$w$)分布在不同环境$S$下是否保持一致

用于度量差异性的$p$值的计算过程理解:

  • 两种环境下的数据都会划分为训练集$D_w$和测试集$D_t$
  • 对训练集进行抽样后得到不同环境下的训练子集$Q$和$Q'$,构建分类器
  • 在测试子集$V$上计算分类器的预测输出$P(S|M)$,进而得到$w_0$和$w_1$
  • 最后根据$H_0:E[w_0U|S=0]- E[w_1U|S=1]=0$,得到$t$检验值
  • 重复以上$bootstrap$过程$n$次,得到更稳健的假设检验$p$值

当特征的维度较高时,可用考虑仅针对特征的低维表示(摘要,summary)进行测试。使用这种方法时,需要保证摘要数据的条件分布与原始数据的条件分布的一致性

算法的有效性验证(三种实验):

  1. 使用相同的数据进行随机拆分并进行算法测试,最终预测准确率在50%上下浮动,假阳性率约为5%,接近假设检验阈值;但需要注意结果的方差随数据量的减少而增加(说明此算法在少样本情况下不稳定)
  2. 针对皮肤数据中特定皮肤状况的年轻患者进行下采样,人为构建特征的分布偏移,算法识别成功
  3. 对Y进行统一的平移,人为构建标签的分布偏移,算法识别成功

2.3 实验分析

公平性的统计定义:子群准确性差距、人口奇偶性和均等优势

实验阶段主要展示了两个案例研究:皮肤科和电子健康记录( Electronic Health Records,EHR )

皮肤科数据包含患者人口统计学信息A、病种Y以及1~6张皮肤照片X,源数据使用来自美国2个州的12024+1925+1924(训练、验证、测试)例患者,目标数据使用来自哥伦比亚和澳大利亚的病例数据。

皮肤病数据的对比实验结果如下:

  • 图(a)显示算法针对皮肤科数据构建的简化联合因果图
  • 图(b)显示,源数据对应人群的年龄分布存在显著差异,表明S与A存在直接关系
  • 图(c)显示,源数据中大于65岁的老年女性患者包括更多的“湿疹”和“牛皮癣”病例,即排除年龄性别的因素影响下,病种分布依然存在显著差异,表明S与Y存在直接关系
  • 图(d)显示,类别为‘sk/isk’的老年女性在图像数据存在显著差异,表明S与X存在直接关系

EHR则主要使用开源的MIMIC-Ⅲ数据集,本文将医学ICU(MICU),手术ICU(SICU)和创伤手术ICU(TSICU)视为源数据(通用ICU);心脏手术恢复单元(CSRU)和冠状动脉护理单位(CCU)视为目标数据(专科ICU)。预测目标是“入住ICU时长是否大于3天”

EHR数据的对比实验结果如下:

  • 图(a)显示算法针对EHR数据构建的简化联合因果图,其中A表示年龄/性别、M表示病史、X表示检验/体征信息、T表示治疗手段,Y表示入住ICU时长
  • 图(b)显示,源数据对应人群的年龄、性别分布均存在显著差异,表明S与A存在直接关系
  • 图(c)显示,目标数据中大于65岁的老年男性患者的外周血管系统相关的合并症患病率更高,即排除年龄性别的因素影响下,合并症依然存在显著差异,表明S与M存在直接关系
  • 图(d)显示,接收加压剂/肌肉(特定治疗手段)下患有实体肿瘤合并症(特定合并症)的老年男性在ICU入住时长存在显著差异,表明S与Y存在直接关系

两个案例中不同亚组(年龄和性别)的模型迁移前后性能表现:

其他建模细节:

  • 实际训练时主要使用逻辑回归或GPT类算法,对于每种测试重复训练100次
  • 对于高维数据需要进行低维表示,包括皮肤病相关图像数据(28TPU,24h~30h)和EHR文本数据(2CPU,1h),以上过程也会重复10次以增加结果的公平性
  • 皮肤病相关图像数据(448x448)借助宽ResNet101x3特征抽取器进行编码,对应模型架构如下:
    (摘自论文 A deep learning system for differential diagnosis of skin diseases
  • GBT类算法可能对某一类子集过度重视而导致算法出错(由于忽视了某一子类,而导致分布偏移没有正确识别)。因此需要对类权重进行裁剪(比如限制最大值为10)来缓解此问题
  • 相关建模代码主要参考自谷歌之前开源的一套EHR预测建模框架 #待补充

改进数据分布偏移的5种可行方案:

  1. 问题选择:专注于临床/政策保障型或预期危害较低的任务,减少ML不公平导致的影响
  2. 数据搜集:针对可能存在分布偏移的场景,尽量在早期完成数据搜集,方便算法测试与调整
  3. 输出定义:考虑对数据分布偏移不敏感的中间结果,比如图像分割任务VS诊断任务
  4. 算法改进:手工设计的特征工程可能有助于对特定偏见的归纳,降低部署的风险
  5. 后处理:针对特定场景进行迁移学习或重新训练,部署后保持对模型公平性的监控

其他分析总结:

  • 使用Alabdulmohsin等人提出的方法进行预测输出后处理后,模型迁移前后的性能表现差异显著缩小,但经过本文算法进行检验测试后,发现数据的公平性并没有得到有效缓解
  • 本文提出的算法考虑了因果结构和全部的直接影响因素,能有效指导缓解策略的选择
  • 本文工作并没有评估不同模型架构和训练策略的影响,也没有考虑社会因素等无法观察到的变量影响
  • 未来的工作可以拓展到更多类型的变量,要重视对偏见来源的探究
  • 医疗领域的公平性是值得探索的方向,但在应用前一定要注意严格的验证

通过在皮肤科和电子健康记录( Electronic Health Records,EHR )中的试验分析可知,算法结果对于选择适当的缓解策略至关重要。此外,我们的工作还表明,临床上可能发生的变化往往比目前的缓解技术所能处理的更复杂,这突出表明需要进一步研究整个ML处理流程中更广泛的补救措施

相关资源

往年同期文章