文章主题:WizardMath, 大模型, 数学能力, 基准测试
编辑:陈萍
有了这项研究,大模型的数学能力更强了。
上周,微软与中国科学院联合发布的 WizardMath 大模型火了。
该模型有 70B、13B、7B 三个参数规模,研究者在两个数学推理基准 GSM8k 和 MATH 上的测试表明,WizardMath 优于所有其他开源 LLM,达到 SOTA。
在GSM8K平台上,WizardMath-70B-V1.0模型的表现相较于部分开源的LLM(如ChatGPT 3.5、Claude Instant 1以及PaLM 2 540B)而言,略有优势。
WizardMath-70B-V1.0模型在GSM8k基准测试中取得了令人瞩目的成绩,达到了81.6 pass@1的水平,这使得它比SOTA开源LLM高出24.8分。这一表现充分展示了该模型的卓越性能,值得进一步的研究和应用。
WizardMath-70B-V1.0模型在MATH基准测试中取得了22.7 pass@1的成绩,比SOTA开源LLM高出9.2分。
GSM8k数据集是大型数学问题数据集之一,它包含了约7500个训练样本和1319个测试数据。这个数据集主要针对小学生的数学教育,其题目涵盖了基本的算术运算(加、减、乘、除)等方面,通常需要2到8步的计算过程才能得出正确答案。另一方面,MATH数据集则来自于一些著名的数学竞赛,如AMC10、AMC12以及AIME等。这个数据集包含了7500个训练样本和5000个具有挑战性的测试数据。这些题目涵盖了初等代数、代数、数论、几何、微积分等多个数学领域。
如下图所示,WizardMath在GSM8k基准测试中名列第五,超越了Claude Instant 1(81.6 vs. 80.9)、ChatGPT(81.6 vs. 80.8)以及PaLM 2 540B(81.6 vs. 80.7)。值得关注的是,相较于这些模型,WizardMath模型的体积要小得多。
Hugging Face 现已发布三个版本,分别为 7B、13B 和 70B 参数。与此同时,相应的论文已对外公开。
方法介绍
这项研究提出了一种名为Reinforced Evol-Instruct的方法,如图1所示,该方法主要分为三个步骤:首先,进行监督微调;其次,训练指令奖励模型和过程监督奖励模型;最后,采用Active Evol-Instruct和PPO进行训练。
监督微调:继 InstructGPT 之后,该研究还使用了监督指令 – 响应对进行微调,其中包含:
为确保各步骤解析更为直观,本研究采用Alpha版WizardLM 70B(经微调的LLaMA模型)对GSM8k与MATH数据集进行了15k个答案的生成,采取分步方式呈现解决方案,进而确定正确答案。基于这些数据,我们对基础Llama模型进行了微调。
该研究还从 WizardLM 的训练数据中采样了 1.5k 个开放域对话,然后将其与上述数学语料库合并作为最终的 SFT ( supervised fine-tuning )训练数据。
Evol-Instruct 原则:受 WiazrdLM 提出的 Evol-Instruct 方法及其在 WizardCoder 上有效应用的启发,该研究试图制作具有各种复杂性和多样性的数学指令,以增强预训练 LLM。具体来说:
向下进化:首先是增强指令,通过使问题变得更加容易来实现。例如,i):将高难度问题转化为较低难度,或 ii) 用另一个不同主题制作一个新的更简单的问题。
向上进化:源自原始的 Evol-Instruct 方法,通过 i)添加更多约束,ii)具体化,iii)增加推理来深化并产生新的更难的问题。
Reinforced Evol-Instruct :受 InstructGPT 和 PRMs 的启发,该研究训练了两个奖励模型,分别用来预测指令的质量和答案中每一步的正确性。
实验及结果
该研究主要在 GSM8k 和 MATH 这两个常见的数学基准上测试了模型的性能,并使用大量基线模型,包括闭源模型:OpenAI 的 GPT-3、GPT-3.5、ChatGPT、GPT-4,谷歌的 PaLM 2、PaLM、 Minerva,Anthropic 的 Claude Instant、Claude 1.3、Claude 2, DeepMind 的 Chinchilla;开源模型:Llama 1、Llama 2、GAL、GPT-J、GPT-Neo、Vicuna、MPT、Falcon、Baichuan、ChatGLM、Qwen 和 RFT。
与闭源模型的比较。在表 1 中,WizardMath 70B 稍微优于 GSM8k 上的一些闭源 LLM,包括 ChatGPT、Claude Instant 和 PaLM 2 540B。
如图 2 所示(见上文),WizardMath 目前在所有模型上排名前五。同时,WizardMath 70B 在 MATH 上也超越了 Text-davinci-002。详细结果如下:
WizardMath 13B 在 GSM8k 上优于 PaLM 1 540B(63.9 vs 56.5)、Minerva 540B(63.9 vs 58.8)和 GPT-3.5(63.9 vs 57.1)。同时,它在 MATH 上超越了 PaLM 1 540B(14.0 vs. 8.8)、GPT-3 175B(14.0 vs. 5.2)。
WizardMath 70B 在 GSM8k 上实现了与 Claude Instant(81.6 vs 80.9)、ChatGPT(81.6 vs 80.8)和 PaLM 2(81.6 vs 80.7)更好或相当的性能。同时,WizardMath 70B 在 MATH 基准测试中也超过了 Text-davinci-002(22.7 比 19.1)。
与开源模型的比较。表 1 中所示的结果表明,WizardMath 70B 在 GSM8k 和 MATH 基准测试中明显优于所有开源模型。详细结果如下:
WizardMath 7B 超越了大多数开源模型,这些模型的参数数量约为 7B 到 40B 不等,包括 MPT、Falcon、Baichuan-chat、Vicuna v1.3、ChatGLM 2、Qwen、Llama 1 和 Llama 2 。尽管它的参数数量要少得多。
WizardMath 13B 在 GSM8k 上明显优于 Llama 1 65B(63.9 vs. 50.9)和 Llama 2 70B(63.9 vs. 56.8)。此外,它在 MATH 上的表现远远优于 Llama 1 65B(14.0 vs. 10.6)和 Llama 2 70B(14.0 vs. 13.5)。
WizardMath 70B 在 GSM8k 上超越了 Llama 2 70B(81.6 比 56.8),提升达到 24.8%。同时,它在数学方面也比 Llama 2 70B(22.7 比 13.5)高出 9.2%。
表 2 显示了 WizardMath 70B 模型在 MATH Subtopics上的结果。
© THE END
WizardMath, 大模型, 数学能力, 基准测试