【机器学习】优化预测速度 部署机器学习模型的7个要点

在模型部署时,模型的性能和耗时都非常重要。但是我们在构建模型时,往往没有考虑模型的预测速度。虽然性能优化会损害预测准确性,但更简单的模型通常运行得更快,也不容易过拟合。

预测延迟被测量为进行预测所需的经过时间。延迟通常被视为一个分布,而运维工程师通常关注此分布的给定百分位数的延迟,如50%或99%情况下的耗时。

要点1:影响延迟的因素

对于机器学习模型,影响预测延迟的主要因素是:

  • 特征个数
  • 数据的稀疏性
  • 模型复杂度
  • 特征提取耗时

其中特征是否能够并行提取 & 模型能够批量预测,对预测耗时影响非常大。

要点2:关注模型本身

由于不同的模型原理上存在区别,本质由多种原因(分支可预测性、CPU 缓存、线性代数库优化等)导致模型速度存在差异。

模型速度存在以下规律:

  • 树模型速度比线性模型的速度慢
  • SVM速度比线性模型慢
  • 线性模型和贝叶斯模型速度相当
  • KNN速度会受到训练数据量影响

要点3:保留有效的特征

当特征数量增加时每个示例的内存消耗也会增加,此外也会影响运算速度。当特征数量增加时,模型计算速度增加的速度不同。一般情况下,模型预测速度与特征个数成正比。

要点4:存储稀疏数据

Scipy中包含稀疏矩阵结构,只会存储非0的数据,这样占用的内存会少很少。在线性模型上使用稀疏输入可以大大加快预测速度,因为只有非零值特征会影响点积,从而影响模型预测。

使用如下代码可以统计数据的稀疏比例:

def sparsity_ratio(X):
    return 1.0 - np.count_nonzero(X) / float(X.shape[0] * X.shape[1])
print("input sparsity ratio:", sparsity_ratio(X))

但对于非稀疏的数据,如果强行存储为的稀疏矩阵,反而会增加模型的预测速度。因此需要稀疏度通常非常高(非零数据只占比10%),这样才会有速度的增益。

要点5:限制模型复杂度

当模型复杂性增加时,模型精度和预测延迟应该会增加。对于sklearn.linear_model中很多模型,例如LassoElasticNetSGDClassifier/RegressorRidge & RidgeClassifierPassiveAgressiveClassifier/RegressorLinearSVCLogisticRegression在预测时应用的决策函数是相同的,因此预测延迟应该差不多。

但在线性模型中我们可以调节模型的惩罚因子,然后控制模型的稀疏性,然后进一步可以减少模型的复杂度。这里我们需要在精度和预测延迟之间进行一个折中,可以参考下面的统计逻辑。

要点6:注意底层实现

scikit-learn 依赖 Numpy/Scipy 底层函数,因此明确关注这些库的版本是有意义的。首先需要确保 Numpy 是使用优化的 BLAS / LAPACK库构建的。

可以使用以下命令显示 NumPy/SciPy/scikit-learn底层的 BLAS/LAPACK支持:

from numpy.distutils.system_info import get_info
print(get_info('blas_opt'))
print(get_info('lapack_opt'))

BLAS / LAPACK实现包括:

  • Atlas
  • OpenBLAS
  • MKL

当然如果你的CPU支持scikit-learn-intelex,你也可以获得更多的加速比:

要点7:线性模型稀疏转换

scikit-learn中的线性模型支持将系数矩阵转换为稀疏格式,其内存和存储效率比Numpy高得多。

clf = SGDRegressor(penalty='elasticnet', l1_ratio=0.25)
clf.fit(X_train, y_train).sparsify()
clf.predict(X_test)

当模型和输入都是稀疏的,上述操作可以加速30%的速度,还可以对内容更加友好。

支持sparsify的模型包括:

  • LogisticRegression
  • LogisticRegressionCV
  • PassiveAggressiveClassifier
  • Perceptron
  • SGDClassifier
  • SGDOneClassSVM
  • SGDRegressor
  • LinearSVC

参考文献

  • https://scikit-learn.org/0.15/modules/computational_performance.html
  • https://scikit-learn.org/0.15/developers/performance.html
  • https://github.com/scikit-learn/scikit-learn/blob/main/benchmarks/bench_sparsify.py


往期精彩回顾



  • 交流群

欢迎加入机器学习爱好者微信群一起和同行交流,目前有机器学习交流群、博士群、博士申报交流、CV、NLP等微信群,请扫描下面的微信号加群,备注:”昵称-学校/公司-研究方向“,例如:”张小明-浙大-CV“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~(也可以加入机器学习交流qq群772479961


相关推荐

  • 大模型+蒙特卡洛树搜索,一招让LLaMa-3 8B奥数水平直逼GPT-4
  • 公司现在只发50%的工资,我就出去面试!没想到碰上了领导,他说这公司不行,你不用面了!结果我面试后,HR给我涨薪30%。这是为啥
  • AI研究的主要推动力会是什么?ChatGPT团队研究科学家:算力成本下降
  • 网传南方医科大学老师为抢救患儿迟到29分钟,被举报扣款2000元?
  • RAG落地中的文档智能处理经验及6月份半月度大模型等进展分享回顾
  • 摸鱼网站精选分享第三番
  • 17岁中专女生姜萍拿下数学竞赛全球第12名!我试着做了这套题,给跪了...
  • 腾讯混元、北大发现Scaling law「浪涌现象」,解决学习率调参难题
  • KDD2024-WhoIsWho-Top3开源方案
  • VSCode无限画布模式(可能会惊艳到你的一个小功能)
  • 管理员如何踢掉登录用户?
  • 3D 版 SORA 来了!DreamTech 推出全球首个原生 3D-DiT 大模型 Direct3D
  • 2024阿里巴巴全球数学竞赛试题&答案
  • 65W!确实可以封神了!
  • Spring Boot集成vaadin快速入门demo
  • 全网最佳websocket封装:完美支持断网重连、自动心跳!
  • 实用技巧,用lsof命令监控tar文件解压进度,简单有效!
  • 10个非常炫酷的 JavaScript 动画库!!!
  • 同事给我介绍了个私活儿,说1万报酬全给我,昨天快要交片之前,我私下问了下甲方,结果甲方说你同事白拿了很多。
  • 超简单下载网站视频,两个日常生活中极为实用的开源高星项目