将 PyTorch 版 bin 模型转换成 Tensorflow 版 ckpt

Pytorch bin to Tensorflow checkpoint

Posted by Xusheng on April 12, 2021

最近由于工作上的需求,需要使用Tensorflow加载语言模型 SpanBERT(Facebook 发布的 BERT 模型的变体),但是作者只发布了 Pytorch 版的预训练权重,因此需要将其转换为 Tensorflow 可以加载的 checkpoint。

在 Pytorch 框架下,大多数开发者使用 Huggingface 发布的 Transformers 工具来加载语言模型,它同时支持加载 Pytorch 和 Tensorflow 版的模型。但是,目前基于 Tensorflow(或 Keras)的工具基本上都不支持加载 Pytorch 版的 bin 模型,转换代码在网上也很难找到,这带来了很多不便。

article_186_1

通过搜索,目前能够找到的有以下几个转换代码片段可供参考:

通过分析可以看到,将 PyTorch 版 bin 模型转换成 Tensorflow 版 ckpt 的过程并不复杂,可以分为以下几步:

  1. 读取出模型中每一层网络结构的名称和参数;
  2. 针对 PyTorch 和 Tensorflow 的模型格式差异对参数做一些调整;
  3. 按照 Tensorflow 的格式保存模型。

读取和保存模型

PyTorch 和 Tensorflow 框架都提供了模型的读取和保存功能,因此读取和保存语言模型的过程非常简单。

读取模型直接使用 PyTorch 自带函数 torch.load() 或者 Transformers 提供的对应模型包的 from_pretrained() 函数就可以了;而保存模型则使用 Tensorflow 自带的模型保存器 tf.train.Saver 来完成。

以 BERT 模型为例,读取模型的过程就是:

model = BertModel.from_pretrained(
    pretrained_model_name_or_path=pytorch_bin_path,
    state_dict=torch.load(os.path.join(pytorch_bin_path, pytorch_bin_model), map_location='cpu')
)

或者

model = torch.load(os.path.join(pytorch_bin_path, pytorch_bin_model), map_location='cpu')

模型的保存过程则通过 Tensorflow 提供的保存器 tf.train.Saver 来完成:

tf.reset_default_graph()
with tf.Session() as session:
    for var_name in state_dict:
        tf_name = to_tf_var_name(var_name) # 将层名称改为Tensorflow模型格式
        torch_tensor = state_dict[var_name].numpy()
        # 将参数矩阵改为Tensorflow模型格式
        if any([x in var_name for x in tensors_to_transpose]):
            torch_tensor = torch_tensor.T
        tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
        tf.keras.backend.set_value(tf_var, torch_tensor)
        tf_weight = session.run(tf_var)
        print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))

    saver = tf.train.Saver(tf.trainable_variables())
    saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_").replace(".ckpt", "") + ".ckpt"))

整个过程就是先逐层读取模型的层名称和对应参数,然后将格式调整为 Tensorflow 模型的格式,再一次性写入到 checkpoint 文件中。

注意:部分转换脚本忽略了 reset_default_graph() 这一操作,会导致生成的 meta 文件不仅保存网络结构,还会保存完整的网络参数,从而体积庞大。

调整模型格式

由于 PyTorch 和 Tensorflow 的模型格式定义有所差异,因此转换的关键就是对部分层的名称和参数矩阵进行调整。具体来说,首先需要构建名称映射字典,对部分层的名称进行调整:

var_map = (
    ("layer.", "layer_"),
    ("word_embeddings.weight", "word_embeddings"),
    ("position_embeddings.weight", "position_embeddings"),
    ("token_type_embeddings.weight", "token_type_embeddings"),
    (".", "/"),
    ("LayerNorm/weight", "LayerNorm/gamma"),
    ("LayerNorm/bias", "LayerNorm/beta"),
    ("weight", "kernel"),
)

def to_tf_var_name(name: str):
    for patt, repl in iter(var_map):
        name = name.replace(patt, repl)
    return "bert/{}".format(name)

注意:这里演示的是转换 BERT 模型,所以转换后的层名以 bert/ 开头。如果转换的是其他模型,需要做相应的修改。

然后,由于 PyTorch 和 Tensorflow 模型中 dense/kernelattention/self/queryattention/self/keyattention/self/value 层的参数矩阵互为转置,因此还需要对模型中的对应层的参数进行调整:

tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")

if any([x in var_name for x in tensors_to_transpose]):
    torch_tensor = torch_tensor.T

至此,转换过程就全部完成了。

完整的代码

综上所述,将 PyTorch 版 bin 模型转换成 Tensorflow 版 ckpt 的过程还是比较清晰的。本文对 VoidOc 编写的脚本进行了进一步的简化,以转换 BERT 模型为例,完整的代码如下(Github):

# coding=utf-8

"""
Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.
"""

import numpy as np
import tensorflow.compat.v1 as tf
import torch
from transformers import BertModel
import os

def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):

    """
    :param model:BertModel Pytorch model instance to be converted
    :param ckpt_dir: Tensorflow model directory
    :param model_name: model name
    :return:

    Currently supported Huggingface models:
        Y BertModel
        N BertForMaskedLM
        N BertForPreTraining
        N BertForMultipleChoice
        N BertForNextSentencePrediction
        N BertForSequenceClassification
        N BertForQuestionAnswering
    """

    tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")

    var_map = (
        ("layer.", "layer_"),
        ("word_embeddings.weight", "word_embeddings"),
        ("position_embeddings.weight", "position_embeddings"),
        ("token_type_embeddings.weight", "token_type_embeddings"),
        (".", "/"),
        ("LayerNorm/weight", "LayerNorm/gamma"),
        ("LayerNorm/bias", "LayerNorm/beta"),
        ("weight", "kernel"),
    )

    if not os.path.isdir(ckpt_dir):
        os.makedirs(ckpt_dir)

    state_dict = model.state_dict()

    def to_tf_var_name(name: str):
        for patt, repl in iter(var_map):
            name = name.replace(patt, repl)
        return "bert/{}".format(name)

    def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
        tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
        tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
        session.run(tf.variables_initializer([tf_var]))
        session.run(tf_var)
        return tf_var

    tf.reset_default_graph()
    with tf.Session() as session:
        for var_name in state_dict:
            tf_name = to_tf_var_name(var_name)
            torch_tensor = state_dict[var_name].numpy()
            if any([x in var_name for x in tensors_to_transpose]):
                torch_tensor = torch_tensor.T
            tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
            tf.keras.backend.set_value(tf_var, torch_tensor)
            tf_weight = session.run(tf_var)
            print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))

        saver = tf.train.Saver(tf.trainable_variables())
        saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_").replace(".ckpt", "") + ".ckpt"))

def convert(pytorch_bin_path: str, pytorch_bin_model: str, tf_ckpt_path: str, tf_ckpt_model: str):

    model = BertModel.from_pretrained(
        pretrained_model_name_or_path=pytorch_bin_path,
        state_dict=torch.load(os.path.join(pytorch_bin_path, pytorch_bin_model), map_location='cpu')
    )

    convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=tf_ckpt_path, model_name=tf_ckpt_model)

if __name__ == '__main__':
    bin_path = './pretrained_model/pytorch_model/'
    bin_model = 'pytorch_model.bin'
    ckpt_path = './pretrained_model/tensorflow_model/'
    ckpt_model = 'bert_model.ckpt'

    convert(bin_path, bin_model, ckpt_path, ckpt_model)

转换过程被包装为 convert() 函数,输入 PyTorch 版 bin 模型的路径和名称,以及 Tensorflow 版 ckpt 的保存路径和名称即可。

再次提醒一下,由于本文转换的 SpanBERT 只是 BERT 的一个变体,因此模型的层名称是与 BERT 模型完全一致的,如果需要转换其他模型,请自行修改 to_tf_var_name() 函数和 tensors_to_transpose 变量。

SpanBERT 目前已被实体同指在内的很多任务作为基准模型使用,因此 Facebook 没有提供 Tensorflow 版权重会给很多研究者和公司带来不便,尤其是目前很多商业项目是使用 Tensorflow 而不是 PyTorch 框架实现的。下面提供了转换好的 SpanBERT 模型权重给大家使用(包括 base 和 large 版):

  • SpanBERT (base & cased): 12-layer, 768-hidden, 12-heads , 110M parameters
  • SpanBERT (large & cased): 24-layer, 1024-hidden, 16-heads, 340M parameters

下载地址:百度盘 (提取码: wtyr) | 城通网盘 | GoogleDrive