Lag-Llama 是由 LLaMA 团队开发的时间序列基础模型,于2023年发布后迅速受到人工智能界的关注。这些预训练的模型经过大量时间序列数据的预训练,具备了存储不同频率和长度的时间序列数据的一般数据模式的能力,因此能够识别未见过的数据模式,且无需进行大量的微调。对于大型时间序列基础模型进行进一步微调,可以使它们实现与非基础模型相当的预测能力。
Lag-Llama 模型是基于LLaMA 模型的解码器部分进行训练的,它是一种单变量概率预测的通用基础模型。通过添加“Lag”作为前缀,该模型使用时间序列的滞后项作为协变量,以捕获时间依赖性,而不假设线性或平稳性。
时间序列数据和语言数据之间显然存在差异。时间序列具有当前值和滞后值之间的时间模式,并且包含与日历相关的信息,如一周中的某一天、一个月中的一周等。因此,Lag-Llama 将时间序列数据的输入用作滞后协变量(t=1, 7, 14, 21, …, 𝛕 ),以及日历相关的功能。由于时间序列数据在编码输入时需要不同的方式,这就是 Lag-Llama 使用LLaMA 模型的解码器部分的原因。
在使用 Lag-Llama 之前,请注意以下几点:该工具专为单变量时间序列设计,用于概率预测。它利用滞后协变量和日历特征作为输入,基于 LLaMA 的仅解码器部分。它进行zero-shot(ZSL)和 fwe-shot (FSL)。
本文云朵君将和大家一起探讨学习使用该方法来预测沃尔玛每周的商店销售数据,介绍该方法的架构,解释零点学习的概念,并学习概率预测的评估指标,即连续排序概率得分(CRPS)。
Lag-Llama 输入--滞后协变量和日期特征。
尽管大型语言模型(LLM)源自时间序列 RNN/LSTM,但我们不直接将时间序列数据输入LLM,因为这两种数据是不同的。时间序列基础模型旨在将时间序列数据作为输入,然后进行相应编码,捕捉时间依赖性。Lag-Llama 利用时间序列过去值的滞后特征来捕捉时间依赖性。这是该模型前缀为“Lag”的原因。
时间序列数据还可以提取与日期相关的信息,例如一周中的哪一天、一个月中的哪一周等。Lag-Llama 将日期相关特征添加到滞后协变量(t=1, 7, 14, 21, ..., 𝛕)中,如图(1)所示。
图(1):Lag-Llama的特征了解了输入,现在来了解一下它的架构。
Lag-Llama 是基于 LLaMA,而 LLaMA 又是基于 Transformer 模型的。LLaMA(大型语言模型 Meta AI)是 Meta AI 于 2023 年发布的开源大型语言模型,LLaMA 沿用了 Transformer 架构,但对其进行了三处修改。相比于 Transformer 模型,LLaMA 的三大修改是:
Lag-Llama 方法将概率预测视为从学生 t 分布中抽取的样本,并需要对学生 t 分布的自由度、均值和尺度三个关键参数进行建模。除学生 t 分布外,Lag-Llama 还可以灵活应用其他分布。
Lag-Llama的作者介绍称,它在未见过的数据集上表现出强大的零次学习能力,并在根据特定数据对模型进行微调后,又展现出强大的少量学习能力。接下来了解一下零样本学习和少样本学习的含义。
Zero-shot learning (ZSL) and few-shot learning (FSL) 是机器学习的子领域,侧重于训练模型以泛化到新的、未见过的数据。两者的主要区别在于训练数据数量,通常称为“shots”。ZSL假设模型无法访问目标领域或任务中的标注数据,因此无需任何标注数据就能识别新的、未见过的类别。与此相反,FSL假设模型可以从目标领域或任务中获取少量标注数据。
零样本学习(zero-shot learning)是一个相对较新的概念,其基本思想是在多个领域或任务中学习共享表征。这样一来,模型就能够在没有明确训练数据的情况下识别并泛化到新的类别或任务。具体来说,这通常是通过使用共享嵌入层来实现的,该嵌入层可以将来自不同领域或任务的输入数据映射到一个共同的向量空间,其中保留了输入之间的相似性。
Lag-Llama 的训练语料库由 27 个时间序列数据集组成,涵盖能源、交通、经济、自然、空气质量和云计算等多个领域。训练数据的多样性包括频率、每个序列的长度、预测长度和多序列数量的差异。数据源的广泛性赋予了 Lag-Llama 零点学习的能力。
Lag-Llama 库使用 Python gluonTS 库进行数据格式化、预测和评估。
安装 gluonTS 时,需要把 numpy 降级到 1.23。所以建议你再创建一个 conda 虚拟环境,避免影响其他资源。
!pip install --upgrade mxnet==1.6.0
!pip install gluonts==0.14.2
!pip uninstall numpy # Downgrade numpy to 1.23
!pip3 install mxnet-mkl==1.6.0 numpy==1.23.1
!git clone https://github.com/time-series-foundation-models/lag-llama/
cd lag-llama
# !pip install -r requirements.txt --quiet # this could take some time
# Or pip install one by one of the libraries in requirements.txt so you can observe the progress
!pip install gluonts[torch]
!pip install torch>=2.0.0
!pip install wandb
!pip install scipy
!pip install pandas==2.1.4
!pip install huggingface_hub[cli]
从 Huggingface 下载 Lag_Llama:
!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir /content/lag-llama
导入需要用到的python库。
from itertools import islice
from matplotlib import pyplot as plt
import matplotlib.dates as mdates
import torch
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from lag_llama.gluon.estimator import LagLlamaEstimator
Kaggle公开了沃尔玛商店的历史销售数据。该数据集包含多个商店的每周销售系列。
%matplotlib inline
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
# 可以在#公众号:数据STUDIO 后台回复 云朵君 免费领取
data = pd.read_csv('/walmart2.csv')
# convert string to datetime64
data["ds"] = pd.to_datetime(data["Date"],format='%d-%m-%Y')
#data = data.sort_values(by=['Store','ds'])
data.tail()
数据集包括以下字段:
其他字段包括:本周是否为特殊假日周、销售当天的气温、商店所在地区的燃料成本、消费价格指数和失业率。
沃尔玛每周商店销售额将数据透视为所需的数据形状,并查看前 5 家商店的每周销售额。
# pivot the data into the correct shape
storewide = data.pivot(index='ds', columns='Store', values='Weekly_Sales')
some_stores = storewide.loc[:,1:5] # Plot only Store 1 - 5
storewide = some_stores. # Model only Store 1-10 in this demo
# plot the pivoted dataframe
some_stores.plot(figsize=(12, 4))
plt.legend(loc='upper left')
plt.title("Walmart Weekly Sales of Store 1 - 10")
将看到商店每周销售额的共同变化。
前 5 家商店的每周销售额我们需要为模型训练预留一些实时数据,为模型验证预留一些非实时数据。
我们将 85% 的数据作为 "实时" 训练数据,其余 15% 作为 "非实时" 测试数据。
print("The time series has", storewide.shape[0], "weeks")
len_train = int(storewide.shape[0] * 0.85)
train_data = storewide[0:len_train] # 121 weeks
test_data = storewide[len_train:] # 22 weeks
[train_data.shape, test_data.shape]
时间序列有 143 周。我们将 85% 作为训练数据,其余作为测试数据。训练数据有 121 周,测试数据有 22 周。
任何时间序列数据都应包含三个基本要素:开始日期、目标数据和数据频率。GluonTS 要求数据格式包含这三个要素。下面的代码将数据集转换为与 gluonTS 兼容的格式,通过计算最小日期获得起始日期,并将列作为目标。
# Prepare the data for deepAR format
from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName
def to_deepar_format(dataframe, freq):
start_index = dataframe.index.min()
data = [{
FieldName.START: start_index,
FieldName.TARGET: dataframe[c].values,
}
for c in dataframe.columns]
print(data[0])
return ListDataset(data, freq=freq)
train_data_lds = to_deepar_format(train_data, 'W')
test_data_lds = to_deepar_format(test_data, 'W')
同样的处理方法也适用于其他时间序列数据。加载完成后,我们可以开始建模过程。GluonTS要求在训练过程中使用上下文数据的长度以及在预测时使用的长度。在这里,我们将指定训练数据的长度作为上下文数据,并将指定时间外数据的长度作为预测数据。
context_length = len(train_data) # 121
prediction_length = len(test_data) # 22
使用以下代码加载并打开 Lag-Llama 模型
ckpt = torch.load("lag-llama.ckpt", map_location=torch.device('cuda:0'))
estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
在这里,我们将介绍.ckpt
和 torch.load()
的新概念。CKPT 文件是由 PyTorch Lightning 创建的检查点文件,其中包含了一个 PyTorch Lightning 模型的转储。这个文件包含了加载模型所需的所有内容。在模型训练过程中,开发人员可以保存模型的当前状态为检查点文件,以便在未来继续开发。另一个重要概念是 PyTorch Lightning 库,它实际上是 PyTorch 的一个接口。PyTorch 作为一个强大的深度学习框架而闻名,后来也被用于生产可扩展的深度学习模型。
我们将使用 LagLlamaEstimator()
声明模型。它需要.ckpt
文件、上下文和预测长度,以及下面的一些其他参数。
estimator = LagLlamaEstimator(
ckpt_path="lag-llama.ckpt",
prediction_length=prediction_length,
context_length=context_length,
# estimator args
input_size=estimator_args["input_size"],
n_layer=estimator_args["n_layer"],
n_embd_per_head=estimator_args["n_embd_per_head"],
n_head=estimator_args["n_head"],
scaling=estimator_args["scaling"],
time_feat=estimator_args["time_feat"],
)
lightning_module = estimator.create_lightning_module()
transformation = estimator.create_transformation()
predictor = estimator.create_predictor(transformation, lightning_module)
通过predictor()
函数,我们可以生成预测结果。需要注意的是,Lag-Llama会进行零点学习(ZSL),以对沃尔玛数据进行预测,这可能包括之前未曾见过的情况。
forecast_it, ts_it = make_evaluation_predictions(
dataset=train_data_lds,
predictor=predictor,
)
forecasts = list(forecast_it)
tss = list(ts_it)
可视化输出值:
plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})
for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
ax = plt.subplot(3, 3, idx+1)
plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target", )
forecast.plot( color='g')
plt.xticks(rotation=60)
ax.xaxis.set_major_formatter(date_formater)
ax.set_title(forecast.item_id)
plt.gcf().tight_layout()
plt.legend()
plt.show()
Lag-Llama 为每个时间序列提供概率预测
在文本末尾,将介绍连续排序概率得分(CRPS),它是一种常用的评估指标,特别适用于概率预测。当预测涉及一系列概率值时,我们应如何评估性能?对于点估计,可以使用MSE、MAE或MAPE。但对于概率预测,我们关注预测分布的扩散和中心倾向。如果预测分布的扩散极大,导致任何预测都有可能,则该模型不可被视为优秀模型。
CRPS范围从0到正无穷大,当预测的累积分布函数(CDF)与观测结果完全吻合时,CRPS为0,我们希望CRPS越低越好。连续排序概率得分(CRPS)的计算公式:
公式中的整合意味着评分考虑了整个潜在结果范围及其相关概率。
evaluator = Evaluator()
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))
print("CRPS:", agg_metrics['mean_wQuantileLoss'])
CRPS: 0.07203626685077341
结果虽然不是零,但可以接受。如果你还有其他改进的方法,欢迎评论区讨论!
在本文中,云朵君和大家一起学习了使用 Lag-Llama 进行数据零样本预测的方法,包括 Lag-Llama 的架构和零样本学习的概念。我们还探讨了概率预测的评估指标,即连续排序概率得分(CRPS)。
长按👇关注- 数据STUDIO -设为星标,干货速递