HATS:用于股票走势预测的分层图注意力网络

中文标题: HATS:用于股票走势预测的分层图注意力网络

英文标题:HATS: A Hierarchical Graph Attention Network for Stock Movement Prediction

发布平台:预印本

发布日期:2019-11-12

引用量(非实时):126

DOI:

作者:Raehyun Kim, Chan Ho So, Minbyul Jeong, Sanghoon Lee, Jinkyu Kim, Jaewoo Kang

关键字: #HATS #GAT

文章类型:preprint

品读时间:2024-01-13 17:58

1 文章萃取

1.1 核心观点

分层注意力网络(HATS)作为关系建模模块,有选择地从不同关系类型的相邻节点中聚合信息,并将这些信息(对预测最有用的是股票间的相关关系信息)添加到每个公司的表示中;最终的信息表示将用于特定任务(个股价格/指数趋势)的预测

实验分析表明,HATS 借助自主的信息选择机制实现了最优的性能表现,并且其性能表现取决于输入的关系数据。HATS 在标普 500 中股票进行了广分实验,其在夏普比率和 F1 分数方面的性能分别比现有基线高 19.8% 和 3%

1.2 综合评价

  • 本文通过两种注意力得分挖掘存在合理关系间的相关股票对,实现自适应的信息提取
  • 本文提出的注意力机制与后来的 Transform 相似,但针对图结构进行了更合理的调整
  • 本文实验丰富,针对不同波动区间分别评估模型的准确率和投资表现,结果更为可靠

1.3 主观评分:⭐⭐⭐⭐⭐

2 精读笔记

2.1 算法细节

HATS 整体框架:

  • 特征提取(Feature Extraction)模块:使用 LSTM (性能更好但训练困难,后续用于个股预测)和 GRU (性能一般但训练高效,后续用于指数趋势预测)进行特征提取,根据历史价格变动模式来表示个股的当前状态;
  • 关系建模(Relational Modeling)模块:基于分层注意力网络(HATS)捕获相邻节点和关系类型的重要性
  • 预测层(Prediction Layer):根据特定的任务激活不同的模块;个股价格使用单个股票的状态信息来预测(节点分类任务);指数趋势使用图池化(graph pooling)方法来聚合不同节点的信息以表示指数(图分类任务)

分层注意力网络(HATS)细节:

  • 假设时刻 $t$ 的股票 $c$ 对应的特征提取模块输出(嵌入表示)为 $e_c$;而关系类型 $m$ 对应的嵌入表示为 $e_{r_m}$
  • 对于关系类型 $m$,$e_c$ 对应的邻接节点集合可表示为 $N_c^{r_m}$;而股票 $f$ 与股票 $c$ 存在类型为 $m$ 的关系,即 $f\in N_c^{r_m}$
  • 对于关系类型 $m$,将关系类型嵌入 $e_{r_m}$ 和两个节点嵌入 $e_c,e_f$ 表示连接到一个向量中,即 $x^{r_m}_{cf}$
  • 对于关系类型 $m$,计算股票 $c$ 对不同相关股票/节点的状态注意力得分: $$

\begin{aligned}\upsilon_{cf}&=x_{cf}^{r_{m}}W_{s}+b_{s} \\ \\ \alpha_{ij}^{r_{m}}&=\frac{\exp(v_{cf})}{\sum_{k}\exp(v_{ck})},\quad k\in N_{i}^{r_{m}}\end{aligned} $$

  • 最终股票 $c$ 在关系类型为 $m$ 下汇总的关系信息向量可表示如下:$s_c^{r_m}=\sum_{f\in N^{r_m}}\alpha_{cf}^{r_m}e_f$
  • 继续拼接关系类型嵌入 $e_{r_m}$ 、节点嵌入嵌入 $e_c$ 和节点关系信息向量 $s_c^{r_m}$,得到 $\widetilde{x}^{r_m}_{c}$
  • 计算股票 $c$ 针对不同的关系类型的状态注意力得分:

$$ \begin{aligned}\widetilde{\upsilon}_{c}^{r_{m}}&=\widetilde{x}^{r_m}_{c}W_{r}+b_{r} \\ \\

\widetilde{\alpha}_{c}^{r_{m}}&=\frac{\exp(\widetilde{\upsilon}_{c}^{r_{m}})}{\sum_{k}\exp(\widetilde{\upsilon}_{c}^{r_{k}})},\quad|N_{c}^{r_{k}}|\ne 0 \end{aligned} $$

  • 根据不同关系类型的注意力得分,加权汇总所有关系信息向量,并更新节点 $c$ 的嵌入表示: $$

e_c^{new}=\Sigma_{k}\widetilde{\alpha}_{c}^{r_{k}}s_c^{r_k}+e_c $$

HATS 先根据特定关系类型下的邻接节点注意力得分有选择地从相邻节点收集有关特定关系的信息,再根据不同关系类型的注意力得分汇总特定关系搜集到的信息向量,最后加权所有有效关系类型下的有效邻接信息,更新股票/节点的嵌入表示

注意区分本文中提到的注意力和后来 Transform 机制中的注意力,二者理念相似但计算方式略有不同

实验分析

数据说明:

  • 剔除标普 500 股票中相对独立(不存在相关性)的公司,保留 431 家公司作为实验对象
  • 研究样本的价格时间跨度为2013年2月8日至2019年6月17日(总共1174个交易日)
  • 0_life/精品资源/数据资源/知识图数据资源#Wikidata 收集了企业关系数据,以公司为节点,提取与目标公司相关的 75 种关系
  • 根据数据的波动性,将整个数据集划分为 8 个较小的区间
  • 每个区间包含 250 天用于训练,50 天用于评估/验证,100 天用于测试

交易策略:

  • 先使用两个阈值对个股趋势进行分级:上涨、中性、下跌
  • 对 15 家晋升概率最高的公司,选择买入;对 15 家降级概率最高的公司,选择卖出

评价指标:投资回报率,夏普比率,准确度和 F1 分数(分类任务)

第 4 阶段预测效果最佳的关系类型 Top10:

不同阶段不同模型的分类精度(F1 值):

MLP CNN LSTM GCN GCN20 TGC HATS
Phase 1 0.2876 0.3111 0.3173 0.2874 0.3161 0.3110 0.3314
Phase 2 0.2862 0.3208 0.3228 0.3068 0.3339 0.3088 0.3347
Phase 3 0.2763 0.2938 0.3064 0.2692 0.3113 0.2237 0.3100
Phase 4 0.2810 0.3176 0.3030 0.2940 0.3240 0.2970 0.3267
Phase 5 0.2873 0.3354 0.3333 0.3116 0.3450 0.3329 0.3496
Phase 6 0.2855 0.3265 0.3229 0.2914 0.3140 0.2798 0.3394
Phase 7 0.2876 0.3111 0.3173 0.2874 0.3161 0.3110 0.3314
Phase 8 0.2862 0.3208 0.3228 0.3068 0.3339 0.3088 0.3347
Phase 9 0.2741 0.2390 0.2793 0.2980 0.3160 0.2851 0.3219
Phase 10 0.2529 0.2128 0.3134 0.3002 0.3272 0.2951 0.3243
Phase 11 0.2500 0.2270 0.2997 0.2714 0.3031 0.2577 0.3091
Phase 12 0.2678 0.2921 0.2968 0.2956 0.3299 0.3270 0.3396
Average 0.2769 0.2923 0.3113 0.2933 0.3225 0.2948 0.3294
  • GCN:基本图卷积神经网络模型,考虑所有类型的关系
  • GCN 20:使用相同的 GCN 模型,只考虑实验性能最佳的 20 种关系类型
  • TGC:基于时序图卷积的关系建模,是一种经典的 GNN 股票预测模型

不同阶段不同模型的分类精度(准确率):

MLP CNN LSTM GCN GCN200 TGC HATS
Phase 1 0.3455 0.3540 0.3597 0.3752 0.3700 0.3811 0.3725
Phase 2 0.3342 0.3626 0.3604 0.3735 0.3726 0.3701 0.3752
Phase 3 0.3547 0.3571 0.3771 0.3860 0.3834 0.3831 0.3859
Phase 4 0.3647 0.3855 0.3816 0.3992 0.3897 0.4059 0.3884
Phase 5 0.3208 0.3834 0.3684 0.4191 0.4164 0.4239 0.4176
Phase 6 0.3300 0.3803 0.3841 0.3627 0.3699 0.3716 0.3869
Phase 7 0.3553 0.4309 0.4252 0.4510 0.4488 0.4477 0.4502
Phase 8 0.3537 0.3901 0.3891 0.3993 0.4040 0.4022 0.4049
Phase 9 0.3711 0.3686 0.3511 0.3955 0.3836 0.3872 0.3923
Phase 10 0.3515 0.3484 0.3630 0.3552 0.3623 0.3667 0.3674
Phase 11 0.3813 0.3924 0.3608 0.3823 0.3794 0.3693 0.3953
Phase 12 0.3434 0.3615 0.3584 0.3530 0.3753 0.3827 0.3802
Average 0.3505 0.3762 0.3732 0.3877 0.3880 0.3910 0.3931

不同阶段不同模型的投资回报率:

MLP CNN LSTM GCN GCN20 TGC HATS
Phase 1 0.0672 -0.0506 0.0904 -0.0264 -0.0103 -0.0517 0.1231
Phase 2 -0.0195 0.0929 0.1005 0.1057 0.2435 0.1247 0.1759
Phase 3 0.0029 -0.0623 0.0454 -0.0189 0.0246 -0.0100 0.0703
Phase 4 0.0945 -0.0578 0.1429 0.0028 0.0385 0.0113 0.1779
Phase 5 -0.0623 -0.0673 -0.0159 -0.0427 0.0415 -0.0202 0.0183
Phase 6 -0.0002 0.0140 0.0400 0.0748 0.0828 0.0581 0.0726
Phase 7 -0.0081 0.0272 0.0246 0.0201 -0.0389 -0.0143 0.0860
Phase 8 0.0319 -0.0122 0.0742 0.0837 0.2356 0.2175 0.0662
Phase 9 0.0143 -0.0311 0.0234 0.0375 -0.0437 -0.0211 0.0394
Phase 10 0.0040 -0.0239 -0.0209 0.0126 0.0164 -0.0523 0.0612
Phase 11 0.0395 -0.0106 -0.0085 0.0597 0.0685 0.1321 0.1890
Phase 12 0.0068 0.0051 0.0222 0.0334 -0.0241 0.0775 0.0732
Average 0.0142 -0.0147 0.0432 0.0285 0.0529 0.0376 0.0961

不同阶段不同模型的夏普比率:

MLP CNN LSTM GCN GCN20 TGC HATS
Phase 1 2.4410 -1.4459 2.3553 -0.2802 -0.1013 -0.5029 2.4796
Phase 2 -1.0063 3.2835 4.0651 2.4700 4.9007 3.0525 4.3903
Phase 3 0.1070 -1.7872 1.0642 -0.2477 0.2994 -0.1796 1.2503
Phase 4 2.1602 -0.6064 2.2014 0.0289 0.3173 0.1085 2.3961
Phase 5 -1.6039 -1.7851 -0.4455 -0.7090 0.6222 -0.4131 0.4087
Phase 6 -0.0095 0.3565 1.1960 1.8324 2.0390 2.9435 1.6945
Phase 7 -0.4010 0.9306 0.8354 0.4078 -0.9107 -0.6618 2.0334
Phase 8 1.0398 -0.3917 2.1975 1.3746 3.1870 4.1305 1.4830
Phase 9 0.4915 -1.9624 0.4905 0.7758 -0.6896 -0.3619 0.8060
Phase 10 0.6667 -1.8774 -0.5671 0.3263 1.3576 -1.2023 1.4382
Phase 11 0.8059 -0.7052 -0.1983 2.5786 1.3053 3.3379 3.6146
Phase 12 0.1684 0.2055 0.6334 0.8066 -0.8201 1.7799 1.9014
Average 0.4050 -0.4821 1.1523 0.7803 0.9589 1.0026 1.9914

不同预测模型及其资产价值变化的比较:

不同模型针对不同指数的分类准确率:

S5CONS S5FINL S5INFT S5ENRS S5UTIL Average
MLP 0.2986 0.3002 0.2867 0.2785 0.2928 0.2913
CNN 0.3013 0.3157 0.3036 0.3011 0.3025 0.3049
LSTM 0.3405 0.2859 0.3454 0.3109 0.2942 0.3154
GCN 0.3410 0.3040 0.3423 0.2848 0.3111 0.3166
TGC 0.3322 0.3051 0.3391 0.2736 0.2911 0.3082
HATS 0.3758 0.3148 0.3518 0.3267 0.3256 0.3389

使用 T-SNE 对节点表示进行可视化

相关资源

往年同期文章