用 Lag-Llama 进行时间序列预测实战


Lag-Llama——开源时间序列基础模型

Lag-Llama 是由 LLaMA 团队开发的时间序列基础模型,于2023年发布后迅速受到人工智能界的关注。这些预训练的模型经过大量时间序列数据的预训练,具备了存储不同频率和长度的时间序列数据的一般数据模式的能力,因此能够识别未见过的数据模式,且无需进行大量的微调。对于大型时间序列基础模型进行进一步微调,可以使它们实现与非基础模型相当的预测能力。

Lag-Llama 模型是基于LLaMA 模型的解码器部分进行训练的,它是一种单变量概率预测的通用基础模型。通过添加“Lag”作为前缀,该模型使用时间序列的滞后项作为协变量,以捕获时间依赖性,而不假设线性或平稳性。

时间序列数据和语言数据之间显然存在差异。时间序列具有当前值和滞后值之间的时间模式,并且包含与日历相关的信息,如一周中的某一天、一个月中的一周等。因此,Lag-Llama 将时间序列数据的输入用作滞后协变量(t=1, 7, 14, 21, …, 𝛕 ),以及日历相关的功能。由于时间序列数据在编码输入时需要不同的方式,这就是 Lag-Llama 使用LLaMA 模型的解码器部分的原因。

在使用 Lag-Llama 之前,请注意以下几点:该工具专为单变量时间序列设计,用于概率预测。它利用滞后协变量和日历特征作为输入,基于 LLaMA 的仅解码器部分。它进行zero-shot(ZSL)和 fwe-shot (FSL)。

本文云朵君将和大家一起探讨学习使用该方法来预测沃尔玛每周的商店销售数据,介绍该方法的架构,解释零点学习的概念,并学习概率预测的评估指标,即连续排序概率得分(CRPS)。

  • 输入 - 滞后协变量和日期特征
  • Lag-Llama 的架构
  • 概率预测
  • 零点学习和少点学习
  • 使用 Lag-Llama 预测沃尔玛每周商店销售额
  • 评估 - 连续排序概率得分 (CRPS)

Lag-Llama 的输入

Lag-Llama 输入--滞后协变量和日期特征。

尽管大型语言模型(LLM)源自时间序列 RNN/LSTM,但我们不直接将时间序列数据输入LLM,因为这两种数据是不同的。时间序列基础模型旨在将时间序列数据作为输入,然后进行相应编码,捕捉时间依赖性。Lag-Llama 利用时间序列过去值的滞后特征来捕捉时间依赖性。这是该模型前缀为“Lag”的原因。

时间序列数据还可以提取与日期相关的信息,例如一周中的哪一天、一个月中的哪一周等。Lag-Llama 将日期相关特征添加到滞后协变量(t=1, 7, 14, 21, ..., 𝛕)中,如图(1)所示。

图(1):Lag-Llama的特征

了解了输入,现在来了解一下它的架构。

Lag-Llama 的架构

Lag-Llama 是基于 LLaMA,而 LLaMA 又是基于 Transformer 模型的。LLaMA(大型语言模型 Meta AI)是 Meta AI 于 2023 年发布的开源大型语言模型,LLaMA 沿用了 Transformer 架构,但对其进行了三处修改。相比于 Transformer 模型,LLaMA 的三大修改是:

  1. RMSNorm 归一化函数 [GPT3]:GPT-3 中使用了 RMSNorm ,以提高训练的稳定性。LLaMA 采用 RMSNorm 对每个变压器子层的输入进行归一化,而不是对输出进行归一化。
  2. 使用 SwiGLU 激活函数 [PaLM]:谷歌人工智能在 2022 年 4 月提出了 PaLM(Pathways Language Model)。PaLM 采用 SwiGLU 激活函数代替 ReLU 非线性激活函数来提高性能。因此,LLaMA 采用了 SwiGLU 来提高性能。
  3. Rotary Embeddings [GPTNeo]。GPTNeo 模型与 GPT-2 类似,发布在 Github:EleutherAI/gpt-neo(https://github.com/EleutherAI/gpt-neo)中。GPTNeo 使用旋转位置嵌入(RoPE) 来取代绝对位置嵌入,以获得更好的性能。因此,LLaMA 采用了 RoPE。

概率预测

Lag-Llama 方法将概率预测视为从学生 t 分布中抽取的样本,并需要对学生 t 分布的自由度、均值和尺度三个关键参数进行建模。除学生 t 分布外,Lag-Llama 还可以灵活应用其他分布。

ZSL和FSL

Lag-Llama的作者介绍称,它在未见过的数据集上表现出强大的零次学习能力,并在根据特定数据对模型进行微调后,又展现出强大的少量学习能力。接下来了解一下零样本学习和少样本学习的含义。

Zero-shot learning (ZSL) and few-shot learning (FSL) 是机器学习的子领域,侧重于训练模型以泛化到新的、未见过的数据。两者的主要区别在于训练数据数量,通常称为“shots”。ZSL假设模型无法访问目标领域或任务中的标注数据,因此无需任何标注数据就能识别新的、未见过的类别。与此相反,FSL假设模型可以从目标领域或任务中获取少量标注数据。

  • ZSL → 无标注数据
  • FSL → 少量标注数据

零样本学习(zero-shot learning)是一个相对较新的概念,其基本思想是在多个领域或任务中学习共享表征。这样一来,模型就能够在没有明确训练数据的情况下识别并泛化到新的类别或任务。具体来说,这通常是通过使用共享嵌入层来实现的,该嵌入层可以将来自不同领域或任务的输入数据映射到一个共同的向量空间,其中保留了输入之间的相似性。

  1. 预训练:在相关领域或任务的大型数据集上对模型进行预训练,让它学会识别和分类不同的类别或任务。
  2. 共享嵌入层:添加共享嵌入层,将来自不同领域或任务的输入数据映射到一个共同的向量空间,对预训练模型进行微调。
  3. 迁移学习:冻结共享嵌入层的同时,在目标领域或任务的少量标注示例上对模型进行微调,适应新的领域或任务,同时利用从预训练任务中学到的知识。
  4. 推理:对模型进行微调后,可以对新的、未见过的类别或任务进行预测,通过共享嵌入层传递输入数据,然后通过微调模型来实现。

 Lag-Llama 的训练语料库由 27 个时间序列数据集组成,涵盖能源、交通、经济、自然、空气质量和云计算等多个领域。训练数据的多样性包括频率、每个序列的长度、预测长度和多序列数量的差异。数据源的广泛性赋予了 Lag-Llama 零点学习的能力。

环境要求

Lag-Llama 库使用 Python gluonTS 库进行数据格式化、预测和评估。

安装 gluonTS 时,需要把 numpy 降级到 1.23。所以建议你再创建一个 conda 虚拟环境,避免影响其他资源。

!pip install --upgrade mxnet==1.6.0
!pip install gluonts==0.14.2
!pip uninstall numpy # Downgrade numpy to 1.23
!pip3 install mxnet-mkl==1.6.0 numpy==1.23.1

!git clone https://github.com/time-series-foundation-models/lag-llama/

cd lag-llama

# !pip install -r requirements.txt --quiet # this could take some time
# Or pip install one by one of the libraries in requirements.txt so you can observe the progress
!pip install gluonts[torch]
!pip install torch>=2.0.0
!pip install wandb
!pip install scipy
!pip install pandas==2.1.4
!pip install huggingface_hub[cli]

从 Huggingface 下载 Lag_Llama:

!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir /content/lag-llama

导入需要用到的python库。

from itertools import islice
from matplotlib import pyplot as plt
import matplotlib.dates as mdates
import torch
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from lag_llama.gluon.estimator import LagLlamaEstimator

数据处理

Kaggle公开了沃尔玛商店的历史销售数据。该数据集包含多个商店的每周销售系列。

%matplotlib inline
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
# 可以在#公众号:数据STUDIO 后台回复 云朵君 免费领取
data = pd.read_csv('/walmart2.csv'
# convert string to datetime64
data["ds"] = pd.to_datetime(data["Date"],format='%d-%m-%Y')
#data = data.sort_values(by=['Store','ds'])
data.tail()

数据集包括以下字段:

  • Store -- 商店:每个沃尔玛商店的唯一标识符
  • Date -- 日期:2010 年 2 月 5 日至 2012 年 11 月 1 日的销售周
  • Weekly_Sales -- 每周销售额:指定商店在给定一周内的销售额

其他字段包括:本周是否为特殊假日周、销售当天的气温、商店所在地区的燃料成本、消费价格指数和失业率。

沃尔玛每周商店销售额

绘制时间序列

将数据透视为所需的数据形状,并查看前 5 家商店的每周销售额。

# pivot the data into the correct shape
storewide = data.pivot(index='ds', columns='Store', values='Weekly_Sales')
some_stores = storewide.loc[:,1:5# Plot only Store 1 - 5
storewide = some_stores. # Model only Store 1-10 in this demo
# plot the pivoted dataframe
some_stores.plot(figsize=(124))
plt.legend(loc='upper left')
plt.title("Walmart Weekly Sales of Store 1 - 10")

将看到商店每周销售额的共同变化。

前 5 家商店的每周销售额

我们需要为模型训练预留一些实时数据,为模型验证预留一些非实时数据。

"实时" 和 "非实时" 数据分割

我们将 85% 的数据作为 "实时" 训练数据,其余 15% 作为 "非实时" 测试数据。

print("The time series has", storewide.shape[0], "weeks")
len_train = int(storewide.shape[0] * 0.85)
train_data = storewide[0:len_train] # 121 weeks
test_data = storewide[len_train:] # 22 weeks
[train_data.shape, test_data.shape]

时间序列有 143 周。我们将 85% 作为训练数据,其余作为测试数据。训练数据有 121 周,测试数据有 22 周。

转换为 GluonTS 格式

任何时间序列数据都应包含三个基本要素:开始日期、目标数据和数据频率。GluonTS 要求数据格式包含这三个要素。下面的代码将数据集转换为与 gluonTS 兼容的格式,通过计算最小日期获得起始日期,并将列作为目标。

# Prepare the data for deepAR format
from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName

def to_deepar_format(dataframe, freq):
    start_index = dataframe.index.min()
    data = [{
                FieldName.START:  start_index,
                FieldName.TARGET:  dataframe[c].values,
            }
            for c in dataframe.columns]
    print(data[0])
    return ListDataset(data, freq=freq)
train_data_lds = to_deepar_format(train_data, 'W')
test_data_lds = to_deepar_format(test_data, 'W')

同样的处理方法也适用于其他时间序列数据。加载完成后,我们可以开始建模过程。GluonTS要求在训练过程中使用上下文数据的长度以及在预测时使用的长度。在这里,我们将指定训练数据的长度作为上下文数据,并将指定时间外数据的长度作为预测数据。

context_length = len(train_data) # 121
prediction_length = len(test_data) # 22

Lag-Lama建模

使用以下代码加载并打开 Lag-Llama 模型

ckpt = torch.load("lag-llama.ckpt", map_location=torch.device('cuda:0'))
estimator_args = ckpt["hyper_parameters"]["model_kwargs"]

在这里,我们将介绍.ckpttorch.load()的新概念。CKPT 文件是由 PyTorch Lightning 创建的检查点文件,其中包含了一个 PyTorch Lightning 模型的转储。这个文件包含了加载模型所需的所有内容。在模型训练过程中,开发人员可以保存模型的当前状态为检查点文件,以便在未来继续开发。另一个重要概念是 PyTorch Lightning 库,它实际上是 PyTorch 的一个接口。PyTorch 作为一个强大的深度学习框架而闻名,后来也被用于生产可扩展的深度学习模型。

我们将使用 LagLlamaEstimator() 声明模型。它需要.ckpt文件、上下文和预测长度,以及下面的一些其他参数。

estimator = LagLlamaEstimator(
    ckpt_path="lag-llama.ckpt",
    prediction_length=prediction_length,
    context_length=context_length,

    # estimator args
    input_size=estimator_args["input_size"],
    n_layer=estimator_args["n_layer"],
    n_embd_per_head=estimator_args["n_embd_per_head"],
    n_head=estimator_args["n_head"],
    scaling=estimator_args["scaling"],
    time_feat=estimator_args["time_feat"],
)

lightning_module = estimator.create_lightning_module()
transformation = estimator.create_transformation()
predictor = estimator.create_predictor(transformation, lightning_module)

通过predictor()函数,我们可以生成预测结果。需要注意的是,Lag-Llama会进行零点学习(ZSL),以对沃尔玛数据进行预测,这可能包括之前未曾见过的情况。

Lag-Lama 预测

forecast_it, ts_it = make_evaluation_predictions(
    dataset=train_data_lds,
    predictor=predictor,
)

forecasts = list(forecast_it)
tss = list(ts_it)

可视化输出值:

plt.figure(figsize=(2015))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size'15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
    ax = plt.subplot(33, idx+1)

    plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target", )
    forecast.plot( color='g')
    plt.xticks(rotation=60)
    ax.xaxis.set_major_formatter(date_formater)
    ax.set_title(forecast.item_id)

plt.gcf().tight_layout()
plt.legend()
plt.show()
Lag-Llama 为每个时间序列提供概率预测

模型评估

评价指标--连续排列概率得分(CRPS)

在文本末尾,将介绍连续排序概率得分(CRPS),它是一种常用的评估指标,特别适用于概率预测。当预测涉及一系列概率值时,我们应如何评估性能?对于点估计,可以使用MSE、MAE或MAPE。但对于概率预测,我们关注预测分布的扩散和中心倾向。如果预测分布的扩散极大,导致任何预测都有可能,则该模型不可被视为优秀模型。

CRPS范围从0到正无穷大,当预测的累积分布函数(CDF)与观测结果完全吻合时,CRPS为0,我们希望CRPS越低越好。连续排序概率得分(CRPS)的计算公式:

  • 给定随机变量 的累积分布函数 (CDF),即
  • 是海维塞德阶跃函数。如果 ,它的值为 1.0,否则为 0。它定义了每个预测概率是否超过观察结果。海维塞德阶跃函数简单来说就是

公式中的整合意味着评分考虑了整个潜在结果范围及其相关概率。

evaluator = Evaluator()
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))
print("CRPS:", agg_metrics['mean_wQuantileLoss'])
CRPS: 0.07203626685077341

结果虽然不是零,但可以接受。如果你还有其他改进的方法,欢迎评论区讨论!

写在最后

在本文中,云朵君和大家一起学习了使用 Lag-Llama 进行数据零样本预测的方法,包括 Lag-Llama 的架构和零样本学习的概念。我们还探讨了概率预测的评估指标,即连续排序概率得分(CRPS)。


🏴‍☠️宝藏级🏴‍☠️ 原创公众号『数据STUDIO』内容超级硬核。公众号以Python为核心语言,垂直于数据科学领域,包括可戳👉 PythonMySQL数据分析数据可视化机器学习与数据挖掘爬虫 等,从入门到进阶!

长按👇关注- 数据STUDIO -设为星标,干货速递

相关推荐

  • 这算是裁到大动脉了吧
  • [开源]轻松构建车联网平台,可应用于各种车辆监管场景和应用平台
  • Kubernetes新手必看:快速生成YAML清单的终极指南!
  • 记一次疑似JVM内存泄漏的排查过程
  • 高中信息技术考试竟然有Flash、IIS、Frontpage、Access、VB……
  • 带您认识物联网首选协议MQTT
  • 29.3K Star强!集成微信登录,核心代码就10行
  • 下半年!真心建议大家冲一冲新兴领域,工资高前景好
  • 探索TypeScript的映射类型,从简单到高级的7个实例
  • 【第17讲】6月17日,AI代写(期刊、演讲稿、小说)
  • 茅台降价,“黄牛”公司纷纷跑路
  • 黄仁勋 · 加州理工2024届毕业典礼演讲 | 2024年6月14日(全文+视频)
  • 面试为什么老爱问 Redis?
  • 成都周报丨策源投了清华系大模型,高新区天使母基金遴选GP
  • 天使++轮拿了近亿融资丨投融周报
  • 代码学上头了,感觉自己又行了!
  • 最新编程语言排行榜,C++ 和 Go 成为新王?!
  • 大模型理解复杂表格,字节&中科大出手了
  • 微软也“扶不起”的 Win11!明年退役的 Win10 市占率再涨,网友:不如专心搞 Win12 吧
  • 京东云来了。。。