表格数据建模性能对比:树模型 VS 深度学习

中文标题:为什么基于树的模型在表格数据上仍然优于深度学习?

英文标题:Why do tree-based models still outperform deep learning on tabular data?

发布平台:NIPS

Advances in neural information processing systems

发布日期:2022-01-01

引用量(非实时):1483

DOI:10.48550/ARXIV.2207.08815

作者:Léo Grinsztajn, Edouard Oyallon, Gaël Varoquaux

关键字: #树模型 #表格数据 #深度学习

文章类型:conferencePaper

品读时间:2025-01-31 19:32

1 文章萃取

1.1 核心观点

本文搜集了很多领域的优质数据集,并通过重复的筛选和清洗,构建了一个包含 45 个数据集的测试集合,并设置了严格的流程与评价方法用于对比分析树模型与深度学习。最终发现,树模型在表格数据上的表现优于深度学习

在后续的分析中,本文先从目标的不规则性,分析了两类模型在目标函数拟合上的特质;然后分别从冗余特征和特征旋转这两种干扰上入手,发现了一个关键结论:神经网络类模型由于存在对特征不变性的学习,导致其更容易受到冗余特征的干扰

1.2 综合评价

  • 实验设计严谨,代码和数据开源,可复现性强
  • 论文简练,表达清晰,直命主题,冗余信息少
  • 对两类模型进行了深入浅出的剖析,受益匪浅

1.3 主观评分:⭐⭐⭐⭐⭐

2 精读笔记

2.1 结论:树模型在表格数据集上表现优于深度学习

实验设计:

  • 严格筛选数据,排除标签偏斜、数据缺失、信息过少、样本量过少的数据集
  • 最终定义了一个包含 45 个数据集的集合,数据集涵盖了不同的领域
  • 树模型主要考虑三个经典模型:RandomForest,GradientBoostingTrees, XGBoost
  • 深度学习主要考虑四个经典模型:MLP、Resnet、FT_Transformer、SAINT
  • 对于分类/回归问题,分别用准确率/ $R^2$ 衡量模型在每类数据集上的性能
  • 对于每个模型,用类似平均最小距离(ADTM)的综合指标来衡量模型表现

主要结论:对于不同类型的任务和特征,树模型的收敛速度和最终性能均高于深度学习模型

1.不同模型随着迭代次数的收敛情况(仅考虑了数值类特征) 2.不同模型随着迭代次数的收敛情况(同时考虑了数值类特征和类别型特征)

神经网络对于表格数据具备充足的表达力,但可能缺乏合理的正则化?

2.2 分析:为什么树模型的表现优于深度学习?

2.2.1 发现 1:神经网络倾向于过度平滑的解,难以拟合不规则的函数

通过高斯核平滑器对训练集的标签进行平滑后,树模型的性能下降明显而神经网络无影响: 400

2.2.2 发现 2:神经网络类模型更容易受到无信息的冗余特征影响

  • 移除一半的特征对于 GBT 类树模型的精度影响不大,但对神经网络类模型影响更明显
  • 添加冗余的无效特征对神经网络类模型的负面影响,也要远大于对 GBT 类树模型的影响
2.2.3 发现 3:数据旋转干扰对神经网络类模型的性能影响更小,尤其是ResNet

研究表明,对于需要学习旋转不变性的模型来说,在数据中加入大量冗余的特征,会导致算法需要更多的训练数据来学习和适应这些冗余特征(在预期性能不变的情况下,冗余特征的增加会导致训练所需样本量线性增加)

这可能在一定程度上解释,为什么神经网络类模型更容易受到冗余特征的影响(因为它们会学习旋转不变性)

相关资源

往年同期文章