Bridge619

Bridge619

Bridge619

命定的局限尽可永在,不屈的挑战却不可须臾或缺!

101 文章数
17 评论数
来首音乐
光阴似箭
今日已经过去小时
这周已经过去
本月已经过去
今年已经过去个月

《Liquid Time-constant Networks:液态时间常数的连续时间神经网络解析》

Bridge619
2025-01-26 / 0 评论 / 278 阅读 / 0 点赞
Liquid Time-constant Networks:麻省理工学院(MIT)的研究者开发出了一种新型的神经网络。

原论文➡️Liquid Time-constant Networks - AAAI

1. 简介

在时间序列建模领域,传统的循环神经网络(RNN)长期面临着两大挑战:一是固定时间步长难以适应多尺度时序特征(如高频波动与长期趋势的共存);二是梯度消失或爆炸问题导致长期依赖建模能力受限。近年来神经微分方程(Neural ODE)通过将深度学习与微分方程结合,为连续时间动态建模提供了新思路。然而,神经ODE的表达能力受限于其固定的时间常数,面对复杂时序模式(如混沌系统或非均匀采样数据)时仍显不足。

《Liquid Time-constant Networks》(LTCs)的提出,正是为了解决上述问题。文章的核心创新是将时间常数从固定参数转变为隐藏状态的函数,通过非线性门控机制动态调节系统的响应速度。这种设计灵感源自生物神经系统的适应性——突触传递速度可随输入强度和上下文动态调整。LTCs不仅继承了神经ODE的连续时间建模优势,还通过"液态"时间常数实现了更灵活的时序特征捕捉能力。论文通过理论证明和实验验证,展示了LTCs在稳定性、表达能力和实际任务性能上的显著提升。下面对其架构进行详细解读。

2. LTCs详解


2.1 LTCs的数学推导与液态时间常数机制

LTCs的核心创新在于通过动态调整系统的时间常数,使模型能够自适应输入信号的特性。这一机制通过重新设计连续时间循环神经网络(CT-RNN)的微分方程实现。

传统CT-RNN的局限性:

传统CT-RNN的动力学方程通常表示为:

$$\frac{d\mathbf{x}(t)}{dt} = -\frac{\mathbf{x}(t)}{\tau} + f(\mathbf{x}(t), \mathbf{I}(t), t, \theta),$$

其中$\tau$为固定时间常数,$-\frac{\mathbf{x}(t)}{\tau}$项确保系统向平衡状态演化,$f$为神经网络。然而,固定时间常数$\tau$限制了模型对多尺度时序特征的适应性。


LTCs的改进方程:

LTCs通过引入非线性调制项$\mathbf{S}(t)$重构微分方程:

$$\frac{d\mathbf{x}(t)}{dt} = -\frac{\mathbf{x}(t)}{\tau} + \mathbf{S}(t).$$

其中$\mathbf{S}(t)$定义为:

$$\mathbf{S}(t) = f(\mathbf{x}(t), \mathbf{I}(t), t, \theta) \cdot \left(A - \mathbf{x}(t)\right),$$

这里$f(\cdot)$是参数化的神经网络,$A$为超参数,表示系统的目标稳态值。该设计通过误差项$(A - \mathbf{x}(t))$动态加权,调整状态更新速率。


液态时间常数的推导:

将$\mathbf{S}(t)$代入原方程并展开:

$$\frac{d\mathbf{x}(t)}{dt} = -\frac{\mathbf{x}(t)}{\tau} + f(\cdot) \cdot \left(A - \mathbf{x}(t)\right).$$

合并同类项后得到:

$$\frac{d\mathbf{x}(t)}{dt} = -\left( \frac{1}{\tau} + f(\cdot) \right)\mathbf{x}(t) + f(\cdot) A.$$

令$$\tau_{\text{sys}} = \frac{\tau}{1 + \tau f(\cdot)},$$ 方程可重写为:

$$\frac{d\mathbf{x}(t)}{dt} = -\frac{\mathbf{x}(t)}{\tau_{\text{sys}}} + \frac{f(\cdot) A}{\tau_{\text{sys}}}.$$

此时,系统的有效时间常数$\tau_{\text{sys}}$由神经网络$f(\cdot)$动态生成。 这种动态调整的时间常数使得网络能够根据输入和当前状态灵活地调整动态特性,增强了模型的表达能力。


动态调整机制的解释:

  • 高频信号:当$f(\cdot)$输出增大时,$\tau_{\text{sys}}$减小(例如$\tau_{\text{sys}} \to 0.1$秒),系统响应加速以捕捉快速变化。
  • 低频信号:当$f(\cdot)$输出减小时,$\tau_{\text{sys}}$增大(例如$\tau_{\text{sys}} \to 1$秒),系统进入慢速模式以维持长期记忆。

例如,在处理心电信号时:

  • 突发心跳异常(高频)触发$\tau_{\text{sys}}$减小,模型快速响应异常波动。
  • 平稳阶段(低频)$\tau_{\text{sys}}$增大,稳定跟踪基线趋势。

输入与状态依赖:

$f(⋅)$的输入包括当前状态 $x(t)$和外部输入 $I(t)$,因此每个时间步的时间常数均与输入特征和网络状态相关,形成“液态”特性。


稳定性与优势:

通过约束$f(\cdot)$的输出范围(如使用Sigmoid函数将$f(\cdot)$限制在$[0,1]$),系统稳定性得以保证:

  • 分母$1 + \tau f(\cdot)$始终为正,避免发散。
  • 李雅普诺夫稳定性理论可证明状态轨迹收敛至平衡点。

液态时间常数的设计不仅提升了模型对多尺度时序特征的表达能力,还通过动态资源分配优化了计算效率。例如,在混沌系统建模中,LTCs能够同时捕捉快速瞬态变化(如洛伦兹吸引子的分岔)和慢速趋势演化,为复杂时序任务提供了理论支撑。

2.2 前向传播与Fused Solver设计

LTCs的动力学方程属于刚性微分方程(Stiff Equations)。这类方程的特点是状态演化中存在多个时间尺度(例如快变与慢变过程的耦合),传统显式ODE求解器(如Runge-Kutta方法)需要极细的离散化步长才能保证稳定性,导致计算效率低下。为解决这一问题,LTCs提出了一种结合显式与隐式欧拉方法的Fused Solver,在保证稳定性的同时兼顾计算效率。


刚性方程与求解器挑战:

LTCs的状态方程为:

$$\frac{d\mathbf{x}(t)}{dt} = -\left( \frac{1}{\tau} + f(\cdot) \right)\mathbf{x}(t) + f(\cdot) A,$$

其中非线性项$f(\cdot)$的动态变化可能使方程呈现刚性特征。例如,当$f(\cdot)$输出较大时,时间常数$\tau_{\text{sys}}$极小,系统响应速度极快,此时显式方法(如显式欧拉)需要极小的步长$\Delta t$以避免数值发散。

传统Runge-Kutta方法(如Dormand-Prince)虽然自适应调整步长,但对于刚性方程,其计算复杂度呈指数增长。因此,LTCs设计了混合显式-隐式策略的Fused Solver。


Fused Solver的数学推导:

Fused Solver的核心思想是将状态更新方程中的线性部分隐式处理,非线性部分显式处理。

具体策略为:

  • 显式部分:保留非线性$f(·)$中的当前状态$x(t)$;
  • 隐式部分:将线性衰减项中的$x(t)$替换为未来状态$x(t+\Delta t)$。

具体步骤如下:

步骤1:

LTC网络的动力学方程为:

$$\frac{d\mathbf{x}}{dt} = -\frac{1}{\tau}\mathbf{x}(t) + f\big(\mathbf{x}(t), \mathbf{I}(t), t, \theta\big)A,$$

其中:

  • $-\frac{1}{\tau}\mathbf{x}(t)$ 是线性衰减项(刚性部分),
  • $f(\cdot)A$ 是非线性驱动项(非刚性部分),
  • $\tau$ 为时间常数,$A$ 为权重矩阵,$\mathbf{I}(t)$ 为时变输入。

步骤2:确定融合策略

为兼顾稳定性和计算效率,采用显式-隐式混合策略,更新后的导数近似为:

$$\frac{d\mathbf{x}}{dt}\bigg|_{\text{fused}} \approx -\frac{1}{\tau}\mathbf{x}(t+\Delta t) + f\big(\mathbf{x}(t), \mathbf{I}(t), t, \theta\big)A.$$

步骤3:应用欧拉方法进行离散化

使用欧拉方法的更新公式:

$$\mathbf{x}(t+\Delta t) = \mathbf{x}(t) + \Delta t \cdot \frac{d\mathbf{x}}{dt}\bigg|_{\text{fused}},$$

将融合后的导数代入,得到:

$$\mathbf{x}(t+\Delta t) = \mathbf{x}(t) + \Delta t \cdot \left[ -\frac{1}{\tau}\mathbf{x}(t+\Delta t) + f\big(\mathbf{x}(t), \mathbf{I}(t), t, \theta\big)A \right].$$


步骤4:整理方程并解出 $\mathbf{x}(t+\Delta t)$

  1. 将含 $\mathbf{x}(t+\Delta t)$ 的项移至左侧:

    $$\mathbf{x}(t+\Delta t) + \frac{\Delta t}{\tau}\mathbf{x}(t+\Delta t) = \mathbf{x}(t) + \Delta t \cdot f(\cdot)A.$$

  2. 提取 $\mathbf{x}(t+\Delta t)$ 的公因子:

    $$\mathbf{x}(t+\Delta t) \left( 1 + \frac{\Delta t}{\tau} \right) = \mathbf{x}(t) + \Delta t \cdot f(\cdot)A.$$

  3. 最终解出更新公式:

    $$\mathbf{x}(t+\Delta t) = \frac{\mathbf{x}(t) + \Delta t \cdot f\big(\mathbf{x}(t), \mathbf{I}(t), t, \theta\big)A}{1 + \frac{\Delta t}{\tau}}.$$


步骤5:统一分母形式

分母可写为:

$$1 + \Delta t \left( \frac{1}{\tau} \right),$$

因此最终离散化公式为:

$$\mathbf{x}(t+\Delta t) = \frac{\mathbf{x}(t) + \Delta t \cdot f(\cdot)A}{1 + \Delta t \left( \frac{1}{\tau} \right)}.$$


算法实现步骤

image-20250127180743003

1. 参数定义

  • 参数集合 $\theta$:
    • $\tau^{(N \times 1)}$:时间常数向量,控制每个神经元的衰减速率,维度为 $N \times 1$。
    • $\gamma^{(M \times N)}$:输入权重矩阵,连接外部输入到神经元,维度为 $M \times N$。
    • $\gamma_\tau^{(N \times N)}$:循环权重矩阵,定义神经元间的自反馈连接强度,维度为 $N \times N$。
    • $\mu^{(N \times 1)}$:偏置向量,调整激活函数的偏移量,维度为 $N \times 1$。
    • $A^{(N \times 1)}$:非线性激活后的缩放向量,维度为 $N \times 1$。
  • 其他参数
    • $L$:时间展开步数(迭代次数),对应时间区间 $[t, t+L\Delta t]$。
    • $\Delta t$:时间步长,控制离散化精度。
    • $N$:神经元数量。

2. 输入与输出

  • 输入
    • $M$ 维时间序列 $\mathbf{I}(t)$,长度为 $T$,表示为 $\mathbf{I}(t) \in \mathbb{R}^{M \times T}$。
    • 初始状态 $\mathbf{x}(0)$,维度为 $N \times 1$,即 $\mathbf{x}(0) \in \mathbb{R}^{N \times 1}$。
  • 输出
    • 更新后的神经元状态 $\mathbf{x}{t+\Delta t}$,维度为 $N \times 1$,即 $\mathbf{x}{t+\Delta t} \in \mathbb{R}^{N \times 1}$。

3. 核心函数 FusedStep

功能:基于融合显式-隐式欧拉方法,单步更新神经元状态。
公式

$$\mathbf{x}(t+\Delta t) = \frac{\mathbf{x}(t) + \Delta t \cdot \left[ f(\mathbf{x}(t), \mathbf{I}(t), t, \theta) \odot A \right]}{1 + \Delta t \left( \frac{1}{\tau} + f(\mathbf{x}(t), \mathbf{I}(t), t, \theta) \right)},$$

操作说明

  • 分子:当前状态 $\mathbf{x}(t)$ 加上非线性驱动的增量:

    $$\mathbf{x}(t) + \Delta t \cdot \left[ f(\cdot) \odot A \right].$$

    其中,$\odot$ 表示哈达玛积(逐元素相乘),即:

    $$\left[ f(\cdot) \odot A \right]_i = f_i(\cdot) \cdot A_i \quad (i=1,2,\ldots,N).$$

  • 分母:稳定性项,确保刚性系统的数值稳定:

    $$1 + \Delta t \left( \frac{1}{\tau} + f(\cdot) \right).$$

    所有运算按元素进行,例如:

    $$\left[ \frac{1}{\tau} + f(\cdot) \right]_i = \frac{1}{\tau_i} + f_i(\cdot).$$


4. 主算法流程

步骤1:初始化

  • 初始状态 $\mathbf{x}_{t+\Delta t}$ 设为当前状态 $\mathbf{x}(t)$:

    $$\mathbf{x}_{t+\Delta t} \leftarrow \mathbf{x}(t).$$

步骤2:时间展开循环

  • 循环次数:$L$ 次(对应时间区间 $[t, t+L\Delta t]$)。

  • 每次迭代

    • 调用 FusedStep 函数,基于当前状态 $\mathbf{x}(t)$ 和输入 $\mathbf{I}(t)$,计算下一状态:
      $$\mathbf{x}_{t+\Delta t} \leftarrow \text{FusedStep}(\mathbf{x}(t), \mathbf{I}(t), \Delta t, \theta).$$

    • 更新后的状态作为下一时间步的输入。

步骤3:返回最终状态

  • 完成 $L$ 次迭代后,输出 $\mathbf{x}_{t+\Delta t}$。

通过Fused Solver的设计,LTCs在保持连续时间建模优势的同时,克服了刚性方程的数值稳定性难题,为复杂时序任务提供了高效的求解框架。

2.3 反向传播:基于BPTT的LTC训练方法

image-20250127183156941

BPTT训练流程与数学推导

LTC网络的训练通过时间反向传播(Backpropagation Through Time, BPTT)实现,其核心是将连续时间的ODE动态离散化,并按时间步展开计算梯度。以下是详细推导过程:


1. 前向传播动态

LTC的离散化状态更新公式为(见式(3)):

$$\mathbf{x}(t+\Delta t) = \frac{\mathbf{x}(t) + \Delta t \cdot \left[ f(\mathbf{x}(t), \mathbf{I}(t), t, \theta) \odot A \right]}{1 + \Delta t \left( \frac{1}{\tau} + f(\mathbf{x}(t), \mathbf{I}(t), t, \theta) \right)}.$$

其中:

  • $f(\cdot) = \tanh\left( \gamma_\tau \mathbf{x}(t) + \gamma \mathbf{I}(t) + \mu \right)$ 为非线性激活函数。
  • $\theta = {\tau, \gamma, \gamma_\tau, \mu, A}$ 为可学习参数。

输出预测值为:

$$\hat{y}(t) = W_{\text{out}} \cdot \mathbf{x}(t) + b_{\text{out}}.$$


2. 损失函数与梯度计算

总损失为各时间步损失之和:

$$L_{\text{total}} = \sum_{t=1}^T L(y(t), \hat{y}(t)).$$

以均方误差(MSE)为例:

$$L(t) = \frac{1}{2} \left( y(t) - \hat{y}(t) \right)^2.$$

参数更新需计算梯度

$$\nabla L(\theta) = \frac{\partial L_{\text{total}}}{\partial \theta} = \sum_{t=1}^T \frac{\partial L(t)}{\partial \theta}.$$

以参数 $\gamma$(输入权重)为例,其梯度为:

$$\frac{\partial L(t)}{\partial \gamma} = \frac{\partial L(t)}{\partial \hat{y}(t)} \cdot \frac{\partial \hat{y}(t)}{\partial \mathbf{x}(t)} \cdot \frac{\partial \mathbf{x}(t)}{\partial \gamma}.$$


3. 反向传播链式法则

步骤1:输出层梯度

$$\frac{\partial L(t)}{\partial \hat{y}(t)} = \hat{y}(t) - y(t).$$

步骤2:输出层到隐藏层的梯度

$$\frac{\partial \hat{y}(t)}{\partial \mathbf{x}(t)} = W_{\text{out}}.$$

步骤3:隐藏层状态对参数 $\gamma$ 的梯度
需递归计算 $\frac{\partial \mathbf{x}(t)}{\partial \gamma}$:

$$\frac{\partial \mathbf{x}(t)}{\partial \gamma} = \sum_{k=1}^t \frac{\partial \mathbf{x}(t)}{\partial \mathbf{x}(k)} \cdot \frac{\partial \mathbf{x}(k)}{\partial \gamma}.$$

其中,$\frac{\partial \mathbf{x}(t)}{\partial \mathbf{x}(k)}$ 表示状态 $\mathbf{x}(t)$ 对历史状态 $\mathbf{x}(k)$ 的依赖,具体为:

$$\frac{\partial \mathbf{x}(t)}{\partial \mathbf{x}(k)} = \prod_{m=k}^{t-1} \frac{\partial \mathbf{x}(m+1)}{\partial \mathbf{x}(m)}.$$

状态转移雅可比矩阵

$$\frac{\partial \mathbf{x}(m+1)}{\partial \mathbf{x}(m)} = \frac{1 + \Delta t \cdot \left( \frac{1}{\tau} \right)}{\left[1 + \Delta t \left( \frac{1}{\tau} + f(\cdot) \right)\right]^2} \cdot \left( 1 + \Delta t \cdot \frac{\partial f(\cdot)}{\partial \mathbf{x}(m)} \odot A \right).$$


4. Adjoint方法与BPTT的对比

Adjoint方法通过伴随方程计算梯度:

$$\frac{d \mathbf{a}(t)}{dt} = -\mathbf{a}(t)^T \frac{\partial f}{\partial \mathbf{x}}, \quad \mathbf{a}(t) = \frac{\partial L}{\partial \mathbf{x}(t)}.$$

其优势为内存复杂度 $O(1)$,但存在以下问题:

  • 数值误差:反向积分时因丢失前向轨迹导致精度下降(见图1中"FWD acc")。
  • 稳定性差:刚性ODE的伴随方程可能发散。

BPTT方法显式存储中间状态,梯度计算为:

$$\frac{\partial L_{\text{total}}}{\partial \theta} = \sum_{t=1}^T \frac{\partial L(t)}{\partial \theta} \bigg|_{\text{显式轨迹}}.$$

优势:

  • 高精度:前向与后向积分均基于精确存储的轨迹。
  • 灵活性:支持任意优化器(如Adam)。

复杂度分析

表1对比了两种方法的复杂度(单层网络 $f$):

image-20250127182727209

关键结论

  • BPTT通过显式存储中间状态(内存开销高),保证梯度计算的精确性。
  • Adjoint方法牺牲精度以节省内存,适用于大规模低精度需求场景。

优化策略

  1. 梯度裁剪:防止梯度爆炸,限制 $|\nabla L(\theta)| \leq c$。

  2. 自适应优化器:使用Adam更新规则:

    $$\theta_{k+1} = \theta_k - \alpha \cdot \frac{m_k}{\sqrt{v_k} + \epsilon},$$

    其中 $m_k$ 和 $v_k$ 为梯度的一阶和二阶矩估计。

  3. 分段训练:将长序列分割为子序列,分别进行前向和反向传播,平衡内存与计算效率。


总结:BPTT方法通过显式存储前向轨迹,在内存开销与计算精度间取得平衡,为LTC网络的刚性ODE求解提供了高精度的训练框架。

3. 实验评估

3.1 实验设置与基准模型

原论文在多个真实世界的监督学习任务中评估了Liquid Time-Constant Networks (LTCs)的性能,并与以下模型进行对比:

  • 传统RNN变体:LSTM、CTRNN(连续时间RNN)、GRU-D(带衰减门控的GRU)。
  • 现代方法:ODE-RNN、Latent ODE(基于神经ODE的编码器)。
  • 其他基线:RNN-Decay、RNN-MAE等。

实验覆盖以下任务:

  1. 人类活动识别(Person Activity Dataset)
  2. Half-Cheetah物理建模(运动轨迹预测)
  3. 时间序列预测(包括手势识别、交通流量预测等)

3.2 主要实验结果解析

3.2.1 人类活动识别(Person Activity Dataset)

  • 数据集:6554条人类活动序列(如躺、走、坐),时间间隔211ms。

  • 两种实验设置

    • 第一设置:直接对比LTCs与基准模型(CTRNN、ODE-RNN等)。

      • 结果:LTCs显著优于所有基线(表4),尤其在CTRNN和神经ODE上优势明显。

        image-20250127183759887
    • 第二设置:复现Rubanova et al. (2019)的改进实验设置。

      • 结果:LTCs以高准确率(88.2% ± 0.5)超越所有模型(表5),对比ODE-RNN(82.9%)和Latent ODE(84.6%)。

        image-20250127183850176

3.2.2 Half-Cheetah物理建模

  • 目标:评估连续时间模型对物理动力学的建模能力。

  • 数据:来自MuJoCo HalfCheetah-v2环境的25条运动轨迹。

  • 关键发现

    • LTCs在轨迹长度增长的下界分析中表现出更强的表达能力(扩展自Raghu et al. (2017))。

    • 实验验证了LTCs在复杂动力学建模中的优势。

      image-20250127183927428

3.2.3 时间序列预测任务

  • 任务范围:手势识别(Gesture)、占用检测(Occupancy)、顺序MNIST等。

  • 结果总结(表3)

    image-20250127183947028

    • LTCs在4/7任务中提升5%-70%:例如顺序MNIST准确率达98.41%,显著优于LSTM(98.0%)和CT-RNN(97.5%)。
    • 其他任务表现持平或略优:如交通流量预测(MSE 0.169)和臭氧预测(F1-score 0.284)。

3.3 实验结论

  1. 性能优势:LTCs在多数任务中(4/7)实现5%-70%的性能提升,尤其在物理仿真和长序列预测任务中表现突出。
  2. 表达能力:通过轨迹长度分析,验证了LTCs在连续时间建模中更强的函数逼近能力。
  3. 鲁棒性:在噪声和不规则采样的数据中(如人类活动识别),LTCs的稳定性优于传统ODE-RNN和Latent ODE。

总结:LTCs通过融合刚性ODE求解器和高表达力的非线性动态,在复杂时间序列任务中展现了显著优势,为连续时间建模提供了新的解决方案。

4. LTCs简单实现

最后基于PyTorch对液态时间常数网络(LTCs)进行简单复现,包含核心的LTCCell单元和完整的LTC模型:通过动态调整时间常数的微分方程(融合显式-隐式求解器)处理连续时间序列,自定义隐藏层维度与时间步长;训练模块采用Adam优化器和梯度裁剪策略,并绘制损失曲线。

import torch
import torch.nn as nn
import matplotlib.pyplot as plt


class LTCCell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(LTCCell, self).__init__()
        self.hidden_dim = hidden_dim

        # 定义可学习参数
        self.tau = nn.Parameter(torch.ones(hidden_dim))  # 时间常数向量
        self.gamma = nn.Parameter(torch.randn(input_dim, hidden_dim))  # 输入权重
        self.gamma_tau = nn.Parameter(torch.randn(hidden_dim, hidden_dim))  # 循环权重
        self.mu = nn.Parameter(torch.zeros(hidden_dim))  # 偏置
        self.A = nn.Parameter(torch.ones(hidden_dim))  # 缩放向量

        # 参数初始化
        nn.init.xavier_normal_(self.gamma)
        nn.init.xavier_normal_(self.gamma_tau)

    def fused_step(self, x, I_t, dt):
        """Fused Solver的单步更新"""
        # 计算f(⋅) = tanh(gamma_tau * x + gamma * I + mu)
        f = torch.tanh(
            torch.matmul(x, self.gamma_tau) +
            torch.matmul(I_t, self.gamma) +
            self.mu
        )

        # 分子:x + dt*(f⊙A)
        numerator = x + dt * (f * self.A)

        # 分母:1 + dt*(1/tau + f)
        denominator = 1 + dt * (1 / self.tau + f)

        # 更新状态
        x_next = numerator / denominator
        return x_next

    def forward(self, x, I_seq, dt=0.1):
        """处理整个时间序列"""
        states = []
        for t in range(I_seq.size(0)):
            x = self.fused_step(x, I_seq[t], dt)
            states.append(x)
        return torch.stack(states)


class LTC(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(LTC, self).__init__()
        self.ltc_cell = LTCCell(input_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)  # 输出层

    def forward(self, I_seq, init_state=None, dt=0.1):
        # 初始化隐藏状态
        if init_state is None:
            batch_size = I_seq.size(1)
            h0 = torch.zeros(batch_size, self.ltc_cell.hidden_dim).to(I_seq.device)
        else:
            h0 = init_state

        # 通过LTC单元处理序列
        states = self.ltc_cell(h0, I_seq, dt)

        # 通过全连接层得到输出
        outputs = self.fc(states)
        return outputs


# 训练和可视化模块
def train_and_visualize():
    # 超参数配置
    input_dim = 3
    hidden_dim = 32
    output_dim = 1
    seq_len = 50
    batch_size = 16
    num_epochs = 100

    # 初始化模型
    model = LTC(input_dim, hidden_dim, output_dim)

    # 生成示例数据(随机数据演示)
    inputs = torch.randn(seq_len, batch_size, input_dim)
    targets = torch.randn(seq_len, batch_size, output_dim)  # 随机目标

    # 训练配置
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_history = []

    # 训练循环
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        pred = model(inputs)
        loss = criterion(pred, targets)
        loss.backward()

        # 梯度裁剪
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        loss_history.append(loss.item())

        # 打印训练进度
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

    # 绘制损失曲线
    plt.figure(figsize=(10, 6))
    plt.plot(loss_history, 'b-o', linewidth=2, markersize=4)
    plt.title('Training Loss Curve')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.savefig('ltc_training_loss.png')
    plt.show()


if __name__ == "__main__":
    train_and_visualize()

损失函数迭代曲线:

image-20250127185039925

文章不错,扫码支持一下吧~
上一篇 下一篇
评论