ICML 2024 | 具有O(L)训练存储和O(1)推理功耗的时间可逆脉冲神经网络

©PaperWeekly 原创 · 作者 | 李国齐课题组

单位 | 中国科学院自动化研究所

研究方向 | 类脑计算


脉冲神经网络(Spike Neural Network,SNN)因其受大脑启发的神经元动态和基于脉冲的计算模式,被认为是一种低功耗的人工神经网络(Artifical Neural Network,ANN)替代方案。然而受限于 SNN 中的神经元的时空动态特性,SNN 的训练显存开销与运算时间均远远大于 ANN [1,2,3,4]


为解决此问题,本文提出一种时间可逆计算范式,并基于此开发了 T-RevSNN 模型。与现有的 Spike-driven Transformer [5] 相比,T-RevSNN 的内存效率、训练时间加速和推理能效分别具有 8.6 倍、2 倍和 1.6 倍的显著提高。



论文标题:

High-Performance Temporal Reversible Spiking Neural Networks with O(L) Training Memory and O(1) Inference Cost论文地址:https://openreview.net/forum?id=s4h6nyjM9H代码地址:https://github.com/BICLab/T-RevSNN



背景

当前 SNN 模型的任务性能已在 ImageNet 上达到 80% 准确率 [6],已能够满足绝大多数实际任务场景,但是其训练难度仍然远高于同架构下的 ANN。如何降低 SNN 的训练难度是目前 SNN 领域的重点难题。SNN 的训练困难来源于其使用的 BPTT 训练算法。在训练时需存储每一层、每一时间步的神经元的激活值,即在训练时显存复杂度为 O(LT),其中 L 是层数,T 是时间步。例如,训练 10 时间步的脉冲 ResNet-19 比 ANN-ResNet-19 多需约 20 倍显存 [1]。 


为解决这个问题,目前主流方法是解耦 SNN 训练过程与时间步。然而,它们中没有一个能够同时实现低廉的训练内存和低推理能耗,因为它们往往只在一个方向上进行优化。同时最近的研究显示,SNN 的时间反向传播对最终梯度影响小。既然如此,我们是否可以仅在关键位置保留时间前向,而关闭其他神经元的时间动态呢? 


基于此,我们考虑仅在关键位置保留时间前向,关闭其他神经元的时间动态。我们设计了时间可逆的 SNN (T-RevSNN)。首先,为减少训练内存,仅在每个阶段的输出脉冲层激活时间动态,并实现时间传递的可逆性,避免存储所有神经元的膜电位和激活。


其次,关闭的脉冲神经元不进行时间动态,简化为不重用时间维度的参数,同时通过一次编码输入,将特征和网络分为 T 组,避免增加参数和能耗。第三,为提升性能,采用多级信息传递,重新设计 SNN 块,并调整残差连接以确保有效性。



本文贡献

我们的贡献包括:


1. 我们重新设计了 SNN 的前向传播,简单直观地同时实现了低训练内存、低功耗和高性能。


2. 我们在三个方面进行了系统设计,以实现提出的想法,包括关键脉冲神经元的多级时间可逆前向信息传递、输入编码和网络架构的分组设计,以及SNN块和残差连接的改进。


3. 在 ImageNet-1k 上,我们的模型在基于 CNN 的 SNN 上实现了最先进的准确性,同时具有最小的内存和推理成本,并且训练速度最快。与当前基于 Transformer 的 SNN 相比,即基于脉冲驱动的 Transformer,我们的模型在准确性上接近,而内存效率、训练时间加速和推理能效可以显著提高分别为 8.6×、2.0× 和 1.6×。



动机:尽可能少的梯度反传

我们设置了以下实验来分析哪些脉冲神经元的关键哪些不关键。同时由于遍历工作量太大,一般认为个典型的神经网络被分为四个阶段,每个阶段的特征层次各不相同。所以我们假设 SNN 中两个阶段交界处的时间信息传递很重要。 为此我们设计了如下实验,首先在 CIFAR-10 上训练 Spiking Resnet,并将其设为基线,为了确定哪些神经元的时间梯度会对模型的训练过程产生显著影响,我们移除了可疑神经元的时间梯度,分为如下两种情况: 


1. 案例1:我们保留每个阶段最后一层的时间梯度,并去除其他神经元的时间梯度;


2. 案例2:我们采取相反的方法,去除每个阶段的最终脉冲神经元的时间梯度,但保留其他神经元。

▲ 图1. 基线与案例1的余弦相似度随训练过程变化图


▲ 图2. 基线与案例2的余弦相似度随训练过程变化图


之后我们计算随着 epoch 增加基线和案例 1,2 之间的余弦相似度变化。相似度高表明该条件下的神经元的时间动态重要,相似度低则反之。 最终结果如图1和图2所示。随着训练周期的增加,图 1(基线与案例 1 的比较)的相似度始终保持在高水平。相比之下,图 2(基线与案例 2 的比较)的相似度始终较小。这说明每个阶段最后一层的时间梯度比前面阶段的脉冲神经元更重要,我们称这些脉冲神经元为"关键神经元"。



方法

基于动机中的发现,我们设计了 Turn off / on 两种脉冲神经元,分别对应于在动机中找到的不关键或关键的神经元。

▲ 图3. 所提出的T-RevSNN的时间前向传播的示意图和网络结构细节

4.1 时间可逆脉冲神经元

T-RevSNN 中的脉冲神经元分为两种关键和不关键神经元。 对于不关键神经元,即上图中的绿色部分神经元,我们将其时间维度的连接进行关闭。我们称这种关闭后的神经元为 Turn off 神经元。Turn off 神经元与一般的脉冲神经元唯一的区别是丧失了时间维度上膜电势的信息传递。其中 Turn off 脉冲神经元的前向传播可描述如下:

可以看到 Turn off 脉冲神经元的权重更新依赖于空间和时间梯度。 对于 Turn on 脉冲神经元,其权重更新依赖于空间和时间梯度。受可逆性概念的启发,我们观察到其是自然可逆的。因此其前向传播可描述如下:


随后,可以在 之间建立可逆变换。这意味着在计算第一个时间步的梯度时,无需存储所有时间步的膜电位和激活值。我们只需要存储 。这减少了 SNNs 多时间步训练所需的内存。Turn on 神经元的时间复杂度与传统的 SNN 训练一致为 O(T)。

4.2 高性能的SNN训练框架

为了提升 SNN 的性能, 研究者们提出了许多方法,然而,上述方法不足以实现高精度的 SNN,为此我们首先引入了多层次连接训练框架 首先我们在相邻时间步的 SNNs 之间建立了更强的多层次连接(如图 3 所示)。我们将前一时间步的更深层次的高级特征纳入到当前时间步的信息融合中。通常,我们可以按照以下方式构建前向信息传递:


其次我们重写设计的基本的 SNN 模块。它由两个深度可分离卷积(DWConv/PWConv)和一个残差连接组成。我们去掉了所有批量归一化(BN)模块,转而去使用将网络中所有层的权重都进行了归一化的方法来稳定训练。 之后我们使用了 ReZero 技术来增强网络在初始化后满足动态等距的能力和促进高效的网络训练。为了保证在推理中只发生加法运算,我们使用重参数化,将 ReZero 的缩放比例(即图 4 中的 α)合并到上一层的权重中。


▲ 图4. 遵循ConvNext范式的基本的SNN模块



结果

5.1 不同训练方法复杂度分析

▲ 图5. T-RevSNN和其他SNN训练优化方法的前传和反传示意图
传统的 SNN 训练算法(STBP)在计算从最后一层的最后一个时间步的输出到第一层的第一个时间步的输入的梯度时所需的记忆和计算构成了训练 SNNs(脉冲神经网络)的记忆和时间复杂度。我们在表 1 和图 5 中分析了所提出的 T-RevSNN 和其他 SNN 训练优化方法 [2,3,4] 的训练内存和时间复杂度。▲ 表1. 不同算法训练和推理的计算复杂度

5.2 消融实验

我们对 T-RevSNN 的不同时间步长和是否使用缩放残差连接进行了消融实验。时间步长:在我们的设计中,我们将整个网络的参数分为T组子网络。在下表中,我们分析了不同的时间步长 T 对准确度、训练速度和内存的影响。由于我们固定了总参数数量,增加T意味着每个时间步的子网络变得更小。相应地,训练所需的内存会减少,但训练时间会相应增加。此外,可以看到准确性与时间步之间的关系并不是线性的。


▲ 表2. 关于时间步长的消融实验

缩放残差连接:可以看到使用该技术有助于提高模型的收敛速度和最终准确度,如下表所示。

▲ 表3. 关于残差连接的消融实验

5.3 主要实验结果

T-RevSNN 在 ImageNet 上的结果如下所示。本文取得了 SNN 域中最快的训练速度和最低的内存消耗。


▲ 表4. 在大型ImageNet数据集上的实验如上表所示,本文所提出的 T-RevSNN 以 85.7 MB/图片的内存消耗,和 9.1 分钟/周期的训练时间远低于脉冲 Transformer 和脉冲卷积模型。体现了 T-RevSNN 在训练速度、内存需求和推理功耗方面的显著优势,同时在性能上也具有竞争力。尽管准确率低于 Spike-driven Transformer,但我们认为这是由架构引起的差距,并且未来可以解决。 全文到此结束,更多细节建议查看原文。

参考文献

[1] Fang W, Chen Y, Ding J, et al. Spikingjelly: An open-source machine learning infrastructure platform for spike-based intelligence[J]. Science Advances, 2023, 9(40): eadi1480.

[2] Zhang H, Zhang Y. Memory-efficient reversible spiking neural networks[C]. Proceedings of the AAAI Conference on Artificial Intelligence. 2024, 38(15): 16759-16767.

[3] Meng Q, Xiao M, Yan S, et al. Towards memory-and time-efficient backpropagation for training spiking neural networks[C]. Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023: 6166-6176.

[4] Xiao M, Meng Q, Zhang Z, et al. Online training through time for spiking neural networks[J]. Advances in neural information processing systems, 2022, 35: 20717-20730.

[5] Yao M, Hu J, Zhou Z, et al. Spike-driven transformer[J]. Advances in neural information processing systems, 2024, 36:64043--64058.

[6] Yao M, Hu J K, Hu T, et al. Spike-driven Transformer V2: Meta Spiking Neural Network Architecture Inspiring the Design of Next-generation Neuromorphic Chips[C]. The Twelfth International Conference on Learning Representations.


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


···

相关推荐

  • 招聘|快手算法工程师
  • 【报名中】阿里云 x StarRocks:极速湖仓第二季—上海站
  • 分布式 Data Warebase - 让数据涌现智能
  • 微软野心再现:对Excel和谷歌Sheet下手了!
  • 百度内容生态视频AIGC新探索
  • 神速!枪击特朗普刺客手机已被破解!一文讲解FBI破解手机有多难:曾找苹果CEO库克建后门惹怒被拒,FBI:花百万美元我们自己搞!
  • 【云原生|K8S系列】K8s新手必看,不可不知的K8s技能,Service发现全解析!
  • 首个WebAgent在线评测框架和流程数据管理平台来了,GPT-4、Qwen登顶闭源和开源榜首!
  • Prompt工程师要下岗了!北大发布Prompt自动增强系统PAS,超越SOTA
  • AI+教育!前OpenAI联创Andrej Karpathy官宣创业!创办第一所AI原生学校
  • 无损加速最高5x,EAGLE-2让RTX 3060的生成速度超过A100
  • Mistral AI两连发:7B数学推理专用、Mamba2架构代码大模型
  • 快手开源LivePortrait,GitHub 6.6K Star,实现表情姿态极速迁移
  • 早半年发arXiv,却被质疑抄袭:活在微软AutoGen阴影里的CAMEL
  • AKOOL助力戛纳广告大奖,发布革命性实时数字人平台
  • 程序员都干过哪些很刺激的事情?
  • React 渲染流程可视化,有大佬实现了!
  • KG与大模型之三问三答及Agent遇见RAG:PersonalRAG及长文本压缩新策略CompAct
  • 今天,我38岁了!
  • 利用 RFM 和 CLTV 进行客户价值分析