Code Monkey home page Code Monkey logo

Comments (5)

Tongjilibo avatar Tongjilibo commented on June 7, 2024

你是说你用脚本加载模型预训练是可以的,但是保存出来的ckpt想用于微调发现加载不进去,是这个意思吗?

如果是的话,加载finetune的模型使用model.load_weights(ckpt_path)来加载就可以了

from bert4torch.

Smile-L-up avatar Smile-L-up commented on June 7, 2024

预训练脚本

#! -*- coding: utf-8 -*-
# 预训练脚本,单GPU版方便测试
# 改DDP需几行代码,参考https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick/task_distributed_data_parallel.py

from bert4torch.models import build_transformer_model
from bert4torch.snippets import sequence_padding
from bert4torch.callbacks import Callback
from bert4torch.optimizers import get_linear_schedule_with_warmup
from torch.utils.data import Dataset
import torch.nn as nn
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import json
import os
import shelve
import random
import time


# 语料路径和模型保存路径
model_saved_path = './saved_model/bert_pretrain_model.ckpt'
dir_training_data = 'pretrain_data'  # dir_training_data
task_name = 'bert'

# 其他配置
maxlen = 256
batch_size = 20
config_path = '../../seq2seq/simbert_model/simbert-chinese-base/config.json'
checkpoint_path = '../../seq2seq/simbert_model/simbert-chinese-base/pytorch_model.bin'  # 如果从零训练,就设为None

learning_rate = 0.00176
weight_decay_rate = 0.01  # 权重衰减
num_warmup_steps = 3125
num_train_steps = 10000
steps_per_epoch = 1000
grad_accum_steps = 1  # 大于1即表明使用梯度累积
epochs = num_train_steps * grad_accum_steps // steps_per_epoch
device = 'cuda' if torch.cuda.is_available() else 'cpu'


# 读取数据集,构建数据张量
class MyDataset(Dataset):
    def __init__(self, file):
        super(MyDataset, self).__init__()
        self.file = file
        self.len = self._get_dataset_length()
        self.db = self._load_data()

    def __getitem__(self, index):
        return self.db[str(index)]

    def __len__(self):
        return self.len

    def _get_dataset_length(self):
        file_record_info = self.file + ".json"
        record_info = json.load(open(file_record_info, "r", encoding="utf-8"))
        return record_info["samples_num"]

    def _load_data(self):
        return shelve.open(self.file)

def collate_fn(batch):
    batch_token_ids, batch_labels = [], []
    for item in batch:
        batch_token_ids.append(item['input_ids'])
        batch_labels.append(item['masked_lm_labels'])

    batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
    batch_labels = torch.tensor(batch_labels, dtype=torch.long, device=device)
    return [batch_token_ids], batch_labels


# 从语料文件夹中随机选取一个文件,生成dataloader
def get_train_dataloader():
    while True:
        # prepare dataset
        files_training_data = os.listdir(dir_training_data)
        files_training_data = [file.split(".")[0] for file in files_training_data if "train" in file]
        # 防止使用到正在生成的文件
        files_training_data = [i for i in set(files_training_data) if files_training_data.count(i)==4]
        if files_training_data:
            file_train = random.choice(files_training_data)
            for suffix in [".bak", ".dat", ".dir", ".json"]:
                file_old = os.path.join(dir_training_data, file_train + suffix)
                file_new = os.path.join(dir_training_data, task_name + suffix)
                os.renames(file_old, file_new)
            cur_load_file = file_new.split(".")[0]
            train_dataloader = DataLoader(MyDataset(cur_load_file), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
            break
        else:
            sleep_seconds = 300
            print(f"No training data! Sleep {sleep_seconds}s!")
            time.sleep(sleep_seconds)
            continue
    return train_dataloader
train_dataloader = get_train_dataloader()

model = build_transformer_model(config_path, checkpoint_path, segment_vocab_size=0, with_mlm=True, add_trainer=True).to(device)

# weight decay
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay_rate},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

class MyLoss(nn.CrossEntropyLoss):
    def __init__(self, **kwargs): 
        super().__init__(**kwargs)
    def forward(self, output, batch_labels):
        y_preds = output[-1]
        y_preds = y_preds.reshape(-1, y_preds.shape[-1])
        return super().forward(y_preds, batch_labels.flatten())

# 定义使用的loss和optimizer,这里支持自定义
optimizer = optim.Adam(optimizer_grouped_parameters, lr=learning_rate, weight_decay=weight_decay_rate)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)
model.compile(loss=MyLoss(ignore_index=0), optimizer=optimizer, scheduler=scheduler, grad_accumulation_steps=grad_accum_steps)


class ModelCheckpoint(Callback):
    """自动保存最新模型
    """
    def on_dataloader_end(self, logs=None):
        # 在dataloader结束的时候,关闭db并且删除训练的文件
        model.train_dataloader.dataset.db.close()
        for suffix in [".bak", ".dat", ".dir", ".json"]:
            file_remove = os.path.join(dir_training_data, task_name + suffix)
            try:
                os.remove(file_remove)
            except:
                print(f"Failed to remove training data {file_remove}.")

        # 重新生成dataloader
        model.train_dataloader = get_train_dataloader()

    def on_epoch_end(self, global_step, epoch, logs=None):
        model.save_weights(model_saved_path)

if __name__ == '__main__':
    # 保存模型
    checkpoint = ModelCheckpoint()

    # 模型训练
    model.fit(
        train_dataloader,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        callbacks=[checkpoint],
    )

微调脚本

#! -*- coding: utf-8 -*-
# SimBERT预训练代码,也可用于微调,微调方式用其他方式比如sentence_bert的可能更好
# 官方项目:https://github.com/ZhuiyiTechnology/simbert

import json
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from bert4torch.models import build_transformer_model, BaseModel
from bert4torch.snippets import sequence_padding, ListDataset, text_segmentate, get_pool_emb
from bert4torch.generation import AutoRegressiveDecoder
from bert4torch.callbacks import Callback
from bert4torch.tokenizers import Tokenizer, load_vocab
import os


# 基本信息
maxlen = 22
batch_size = 64

# 这里加载的是simbert权重,在此基础上用自己的数据继续pretrain/finetune
# 自己从头预训练也可以直接加载bert/roberta等checkpoint
config_path = './simbert_model/simbert-chinese-base/config.json'
checkpoint_path = './simbert_model/simbert-chinese-base/pytorch_model.bin'
# checkpoint_path = '../pretrain/roberta_pretrain/saved_model/bert_pretrain_model.ckpt'
dict_path = './simbert_model/simbert-chinese-base/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
# 加载并精简词表,建立分词器(先不简化)
token_dict = load_vocab(
    dict_path=dict_path,
    simplified=False,
    startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
)
tokenizer = Tokenizer(token_dict, do_lower_case=True)

filename = 'data/ais0824.json'


class MyDataset(ListDataset):
    @staticmethod
    def load_data(filename):
        """读取语料,每行一个json
        示例:{"text": "懂英语的来!", "synonyms": ["懂英语的来!!!", "懂英语的来", "一句英语翻译  懂英语的来"]}
        """
        D = []
        with open(filename, encoding='utf-8') as f:
            for l in f:
                D.append(json.loads(l))
        return D


def truncate(text):
    """截断句子
    """
    seps, strips = u'\n。!?!?;;,, ', u';;,, '
    return text_segmentate(text, maxlen - 2, seps, strips)[0]


def collate_fn(batch):
    batch_token_ids, batch_segment_ids = [], []
    for d in batch:
        text, synonyms = d['text'], d['synonyms']
        synonyms = [text] + synonyms
        np.random.shuffle(synonyms)
        text, synonym = synonyms[:2]
        text, synonym = truncate(text), truncate(synonym)
        token_ids, segment_ids = tokenizer.encode(text, synonym, maxlen=maxlen * 2)
        batch_token_ids.append(token_ids)
        batch_segment_ids.append(segment_ids)
        token_ids, segment_ids = tokenizer.encode(synonym, text, maxlen=maxlen * 2)
        batch_token_ids.append(token_ids)
        batch_segment_ids.append(segment_ids)

    batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
    batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device)
    return [batch_token_ids, batch_segment_ids], [batch_token_ids, batch_segment_ids]


train_dataloader = DataLoader(MyDataset(filename), batch_size=batch_size, shuffle=True,
                              collate_fn=collate_fn)


# 建立加载模型
class Model(BaseModel):
    def __init__(self, pool_method='cls'):
        super().__init__()
        self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path,
                                            with_pool='linear',
                                            with_mlm='linear', application='unilm')
        self.pool_method = pool_method

    def forward(self, token_ids, segment_ids):
        hidden_state, pool_cls, seq_logit = self.bert([token_ids, segment_ids])
        sen_emb = get_pool_emb(hidden_state, pool_cls, token_ids.gt(0).long(), self.pool_method)
        return seq_logit, sen_emb


model = Model(pool_method='cls').to(device)


class TotalLoss(nn.Module):
    """loss分两部分,一是seq2seq的交叉熵,二是相似度的交叉熵。
    """

    def forward(self, outputs, target):
        seq_logit, sen_emb = outputs
        seq_label, seq_mask = target

        seq2seq_loss = self.compute_loss_of_seq2seq(seq_logit, seq_label, seq_mask)
        similarity_loss = self.compute_loss_of_similarity(sen_emb)
        return {'loss': seq2seq_loss + similarity_loss, 'seq2seq_loss': seq2seq_loss,
                'similarity_loss': similarity_loss}

    def compute_loss_of_seq2seq(self, y_pred, y_true, y_mask):
        '''
        y_pred: [btz, seq_len, hdsz]
        y_true: [btz, seq_len]
        y_mask: [btz, seq_len]
        '''
        y_true = y_true[:, 1:]  # 目标token_ids
        y_mask = y_mask[:, 1:]  # 指示了要预测的部分
        y_pred = y_pred[:, :-1, :]  # 预测序列,错开一位

        y_pred = y_pred.reshape(-1, y_pred.shape[-1])
        y_true = (y_true * y_mask).flatten()
        return F.cross_entropy(y_pred, y_true, ignore_index=0)

    def compute_loss_of_similarity(self, y_pred):
        y_true = self.get_labels_of_similarity(y_pred)  # 构建标签
        y_pred = F.normalize(y_pred, p=2, dim=-1)  # 句向量归一化
        similarities = torch.matmul(y_pred, y_pred.T)  # 相似度矩阵
        similarities = similarities - torch.eye(y_pred.shape[0], device=device) * 1e12  # 排除对角线
        similarities = similarities * 30  # scale

        loss = F.cross_entropy(similarities, y_true)
        return loss

    def get_labels_of_similarity(self, y_pred):
        idxs = torch.arange(0, y_pred.shape[0], device=device)
        idxs_1 = idxs[None, :]
        idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
        labels = idxs_1.eq(idxs_2).float()
        return labels


model.compile(loss=TotalLoss(), optimizer=optim.Adam(model.parameters(), 1e-5),
              metrics=['seq2seq_loss', 'similarity_loss'])


class SynonymsGenerator(AutoRegressiveDecoder):
    """seq2seq解码器
    """

    @AutoRegressiveDecoder.wraps('logits')
    def predict(self, inputs, output_ids, states):
        token_ids, segment_ids = inputs
        token_ids = torch.cat([token_ids, output_ids], 1)
        segment_ids = torch.cat([segment_ids, torch.ones_like(output_ids, device=device)], 1)
        seq_logit, _ = model.predict([token_ids, segment_ids])
        return seq_logit[:, -1, :]

    def generate(self, text, n=1, topk=5):
        token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
        output_ids = self.random_sample([token_ids, segment_ids], n=n, topk=topk)  # 基于随机采样
        return [tokenizer.decode(ids.cpu().numpy()) for ids in output_ids]


synonyms_generator = SynonymsGenerator(start_id=None, end_id=tokenizer._token_end_id, maxlen=maxlen, device=device)


def cal_sen_emb(text_list):
    '''输入text的list,计算sentence的embedding
    '''
    X, S = [], []
    for t in text_list:
        x, s = tokenizer.encode(t)
        X.append(x)
        S.append(s)
    X = torch.tensor(sequence_padding(X), dtype=torch.long, device=device)
    S = torch.tensor(sequence_padding(S), dtype=torch.long, device=device)
    _, Z = model.predict([X, S])
    return Z


def gen_synonyms(text, n=100, k=20):
    """"含义: 产生sent的n个相似句,然后返回最相似的k个。
    做法:用seq2seq生成,并用encoder算相似度并排序。
    效果:
        >>> gen_synonyms(u'微信和支付宝哪个好?')
        [
            u'微信和支付宝,哪个好?',
            u'微信和支付宝哪个好',
            u'支付宝和微信哪个好',
            u'支付宝和微信哪个好啊',
            u'微信和支付宝那个好用?',
            u'微信和支付宝哪个好用',
            u'支付宝和微信那个更好',
            u'支付宝和微信哪个好用',
            u'微信和支付宝用起来哪个好?',
            u'微信和支付宝选哪个好',
        ]
    """
    r = synonyms_generator.generate(text, n)
    r = [i for i in set(r) if i != text]  # 不和原文相同
    r = [text] + r
    Z = cal_sen_emb(r)
    Z /= (Z ** 2).sum(dim=1, keepdims=True) ** 0.5
    argsort = torch.matmul(Z[1:], -Z[0]).argsort()
    return [r[i + 1] for i in argsort[:k]]


def just_show(some_samples):
    """随机观察一些样本的效果
    """
    S = [np.random.choice(some_samples) for _ in range(3)]
    for s in some_samples:
        try:
            print(u'原句子:%s' % s)
            print(u'同义句子:', gen_synonyms(s, 10, 10))
            print()
        except:
            pass


class Evaluator(Callback):
    """评估模型
    """

    def __init__(self):
        self.lowest = 1e10

    def on_epoch_end(self, global_step, epoch, logs=None):
        # 保存最优
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
        print('saved epoch: {}'.format(epoch))
        model.save_weights('./saved_model/best_model0828.bin')
        # 演示效果
        just_show(['微信和支付宝拿个好用?',
                   '微信和支付宝,哪个好?',
                   '微信和支付宝哪个好',
                   '支付宝和微信哪个好',
                   '支付宝和微信哪个好啊',
                   '微信和支付宝那个好用?',
                   '微信和支付宝哪个好用',
                   '支付宝和微信那个更好',
                   '支付宝和微信哪个好用',
                   '微信和支付宝用起来哪个好?',
                   '微信和支付宝选哪个好'
                   ])


if __name__ == '__main__':
    choice = 'generate'  # train  generate  similarity

    if choice == 'train':
        evaluator = Evaluator()
        model.fit(train_dataloader, epochs=50, steps_per_epoch=None, callbacks=[evaluator])

    elif choice == 'generate':
        model.load_weights('./saved_model/bert_pretrain_model.ckpt')   #######会出现问题。
        param = list(model.named_parameters())
        just_show(['微信和支付宝拿个好用?',
                   '微信和支付宝,哪个好?',
                   '微信和支付宝哪个好',
                   '支付宝和微信哪个好',
                   '支付宝和微信哪个好啊',
                   '微信和支付宝那个好用?',
                   '微信和支付宝哪个好用',
                   '支付宝和微信那个更好',
                   '支付宝和微信哪个好用',
                   '微信和支付宝用起来哪个好?',
                   '微信和支付宝选哪个好'
                   ])

    elif choice == 'similarity':
        target_text = '我想去首都北京玩玩'
        text_list = ['我想去北京玩', '北京有啥好玩的吗?我想去看看', '好渴望去北京游玩啊']
        Z = cal_sen_emb([target_text] + text_list)
        Z /= (Z ** 2).sum(dim=1, keepdims=True) ** 0.5
        similarity = torch.matmul(Z[1:], Z[0])
        for i, line in enumerate(text_list):
            print(f'cos_sim: {similarity[i].item():.4f}, tgt_text: "{target_text}", cal_text: "{line}"')

直接尝试generate会出现如下问题,(好像不能直接加载huggingface上的模型,需要经过转化脚本转化后再继续微调?我现在在尝试将苏神给的tf版本模型转为本项目支持的格式再进行预训练试试):

Missing key(s) in state_dict: "bert.mlmBias", "bert.embeddings.word_embeddings.weight", "bert.embeddings.position_embeddings.weight", "bert.embeddings.segment_embeddings.weight", "bert.embeddings.layerNorm.weight", "bert.embeddings.layerNorm.bias", "bert.encoderLayer.0.multiHeadAttention.q.weight", "bert.encoderLayer.0.multiHeadAttention.q.bias", "bert.encoderLayer.0.multiHeadAttention.k.weight", "bert.encoderLayer.0.multiHeadAttention.k.bias", "bert.encoderLayer.0.multiHeadAttention.v.weight", "bert.encoderLayer.0.multiHeadAttention.v.bias", "bert.encoderLayer.0.multiHeadAttention.o.weight", "bert.encoderLayer.0.multiHeadAttention.o.bias", "bert.encoderLayer.0.layerNorm1.weight", "bert.encoderLayer.0.layerNorm1.bias", "bert.encoderLayer.0.feedForward.intermediateDense.weight", "bert.encoderLayer.0.feedForward.intermediateDense.bias", "bert.encoderLayer.0.feedForward.outputDense.weight", "bert.encoderLayer.0.feedForward.outputDense.bias", "bert.encoderLayer.0.layerNorm2.weight", "bert.encoderLayer.0.layerNorm2.bias", "bert.encoderLayer.1.multiHeadAttention.q.weight", "bert.encoderLayer.1.multiHeadAttention.q.bias", "bert.encoderLayer.1.multiHeadAttention.k.weight", "bert.encoderLayer.1.multiHeadAttention.k.bias", "bert.encoderLayer.1.multiHeadAttention.v.weight", "bert.encoderLayer.1.multiHeadAttention.v.bias", "bert.encoderLayer.1.multiHeadAttention.o.weight", "bert.encoderLayer.1.multiHeadAttention.o.bias", "bert.encoderLayer.1.layerNorm1.weight", "bert.encoderLayer.1.layerNorm1.bias", "bert.encoderLayer.1.feedForward.intermediateDense.weight", "bert.encoderLayer.1.feedForward.intermediateDense.bias", "bert.encoderLayer.1.feedForward.outputDense.weight", "bert.encoderLayer.1.feedForward.outputDense.bias", "bert.encoderLayer.1.layerNorm2.weight", "bert.encoderLayer.1.layerNorm2.bias", "bert.encoderLayer.2.multiHeadAttention.q.weight", "bert.encoderLayer.2.multiHeadAttention.q.bias", "bert.encoderLayer.2.multiHeadAttention.k.weight", "bert.encoderLayer.2.multiHeadAttention.k.bias", "bert.encoderLayer.2.multiHeadAttention.v.weight", "bert.encoderLayer.2.multiHeadAttention.v.bias", "bert.encoderLayer.2.multiHeadAttention.o.weight", "bert.encoderLayer.2.multiHeadAttention.o.bias", "bert.encoderLayer.2.layerNorm1.weight", "bert.encoderLayer.2.layerNorm1.bias", "bert.encoderLayer.2.feedForward.intermediateDense.weight", "bert.encoderLayer.2.feedForward.intermediateDense.bias", "bert.encoderLayer.2.feedForward.outputDense.weight", "bert.encoderLayer.2.feedForward.outputDense.bias", "bert.encoderLayer.2.layerNorm2.weight", "bert.encoderLayer.2.layerNorm2.bias", "bert.encoderLayer.3.multiHeadAttention.q.weight", "bert.encoderLayer.3.multiHeadAttention.q.bias", "bert.encoderLayer.3.multiHeadAttention.k.weight", "bert.encoderLayer.3.multiHeadAttention.k.bias", "bert.encoderLayer.3.multiHeadAttention.v.weight", "bert.encoderLayer.3.multiHeadAttention.v.bias", "bert.encoderLayer.3.multiHeadAttention.o.weight", "bert.encoderLayer.3.multiHeadAttention.o.bias", "bert.encoderLayer.3.layerNorm1.weight", "bert.encoderLayer.3.layerNorm1.bias", "bert.encoderLayer.3.feedForward.intermediateDense.weight", "bert.encoderLayer.3.feedForward.intermediateDense.bias", "bert.encoderLayer.3.feedForward.outputDense.weight", "bert.encoderLayer.3.feedForward.outputDense.bias", "bert.encoderLayer.3.layerNorm2.weight", "bert.encoderLayer.3.layerNorm2.bias", "bert.encoderLayer.4.multiHeadAttention.q.weight", "bert.encoderLayer.4.multiHeadAttention.q.bias", "bert.encoderLayer.4.multiHeadAttention.k.weight", "bert.encoderLayer.4.multiHeadAttention.k.bias", "bert.encoderLayer.4.multiHeadAttention.v.weight", "bert.encoderLayer.4.multiHeadAttention.v.bias", "bert.encoderLayer.4.multiHeadAttention.o.weight", "bert.encoderLayer.4.multiHeadAttention.o.bias", "bert.encoderLayer.4.layerNorm1.weight", "bert.encoderLayer.4.layerNorm1.bias", "bert.encoderLayer.4.feedForward.intermediateDense.weight", "bert.encoderLayer.4.feedForward.intermediateDense.bias", "bert.encoderLayer.4.feedForward.outputDense.weight", "bert.encoderLayer.4.feedForward.outputDense.bias", "bert.encoderLayer.4.layerNorm2.weight", "bert.encoderLayer.4.layerNorm2.bias", "bert.encoderLayer.5.multiHeadAttention.q.weight", "bert.encoderLayer.5.multiHeadAttention.q.bias", "bert.encoderLayer.5.multiHeadAttention.k.weight", "bert.encoderLayer.5.multiHeadAttention.k.bias", "bert.encoderLayer.5.multiHeadAttention.v.weight", "bert.encoderLayer.5.multiHeadAttention.v.bias", "bert.encoderLayer.5.multiHeadAttention.o.weight", "bert.encoderLayer.5.multiHeadAttention.o.bias", "bert.encoderLayer.5.layerNorm1.weight", "bert.encoderLayer.5.layerNorm1.bias", "bert.encoderLayer.5.feedForward.intermediateDense.weight", "bert.encoderLayer.5.feedForward.intermediateDense.bias", "bert.encoderLayer.5.feedForward.outputDense.weight", "bert.encoderLayer.5.feedForward.outputDense.bias", "bert.encoderLayer.5.layerNorm2.weight", "bert.encoderLayer.5.layerNorm2.bias", "bert.encoderLayer.6.multiHeadAttention.q.weight", "bert.encoderLayer.6.multiHeadAttention.q.bias", "bert.encoderLayer.6.multiHeadAttention.k.weight", "bert.encoderLayer.6.multiHeadAttention.k.bias", "bert.encoderLayer.6.multiHeadAttention.v.weight", "bert.encoderLayer.6.multiHeadAttention.v.bias", "bert.encoderLayer.6.multiHeadAttention.o.weight", "bert.encoderLayer.6.multiHeadAttention.o.bias", "bert.encoderLayer.6.layerNorm1.weight", "bert.encoderLayer.6.layerNorm1.bias", "bert.encoderLayer.6.feedForward.intermediateDense.weight", "bert.encoderLayer.6.feedForward.intermediateDense.bias", "bert.encoderLayer.6.feedForward.outputDense.weight", "bert.encoderLayer.6.feedForward.outputDense.bias", "bert.encoderLayer.6.layerNorm2.weight", "bert.encoderLayer.6.layerNorm2.bias", "bert.encoderLayer.7.multiHeadAttention.q.weight", "bert.encoderLayer.7.multiHeadAttention.q.bias", "bert.encoderLayer.7.multiHeadAttention.k.weight", "bert.encoderLayer.7.multiHeadAttention.k.bias", "bert.encoderLayer.7.multiHeadAttention.v.weight", "bert.encoderLayer.7.multiHeadAttention.v.bias", "bert.encoderLayer.7.multiHeadAttention.o.weight", "bert.encoderLayer.7.multiHeadAttention.o.bias", "bert.encoderLayer.7.layerNorm1.weight", "bert.encoderLayer.7.layerNorm1.bias", "bert.encoderLayer.7.feedForward.intermediateDense.weight", "bert.encoderLayer.7.feedForward.intermediateDense.bias", "bert.encoderLayer.7.feedForward.outputDense.weight", "bert.encoderLayer.7.feedForward.outputDense.bias", "bert.encoderLayer.7.layerNorm2.weight", "bert.encoderLayer.7.layerNorm2.bias", "bert.encoderLayer.8.multiHeadAttention.q.weight", "bert.encoderLayer.8.multiHeadAttention.q.bias", "bert.encoderLayer.8.multiHeadAttention.k.weight", "bert.encoderLayer.8.multiHeadAttention.k.bias", "bert.encoderLayer.8.multiHeadAttention.v.weight", "bert.encoderLayer.8.multiHeadAttention.v.bias", "bert.encoderLayer.8.multiHeadAttention.o.weight", "bert.encoderLayer.8.multiHeadAttention.o.bias", "bert.encoderLayer.8.layerNorm1.weight", "bert.encoderLayer.8.layerNorm1.bias", "bert.encoderLayer.8.feedForward.intermediateDense.weight", "bert.encoderLayer.8.feedForward.intermediateDense.bias", "bert.encoderLayer.8.feedForward.outputDense.weight", "bert.encoderLayer.8.feedForward.outputDense.bias", "bert.encoderLayer.8.layerNorm2.weight", "bert.encoderLayer.8.layerNorm2.bias", "bert.encoderLayer.9.multiHeadAttention.q.weight", "bert.encoderLayer.9.multiHeadAttention.q.bias", "bert.encoderLayer.9.multiHeadAttention.k.weight", "bert.encoderLayer.9.multiHeadAttention.k.bias", "bert.encoderLayer.9.multiHeadAttention.v.weight", "bert.encoderLayer.9.multiHeadAttention.v.bias", "bert.encoderLayer.9.multiHeadAttention.o.weight", "bert.encoderLayer.9.multiHeadAttention.o.bias", "bert.encoderLayer.9.layerNorm1.weight", "bert.encoderLayer.9.layerNorm1.bias", "bert.encoderLayer.9.feedForward.intermediateDense.weight", "bert.encoderLayer.9.feedForward.intermediateDense.bias", "bert.encoderLayer.9.feedForward.outputDense.weight", "bert.encoderLayer.9.feedForward.outputDense.bias", "bert.encoderLayer.9.layerNorm2.weight", "bert.encoderLayer.9.layerNorm2.bias", "bert.encoderLayer.10.multiHeadAttention.q.weight", "bert.encoderLayer.10.multiHeadAttention.q.bias", "bert.encoderLayer.10.multiHeadAttention.k.weight", "bert.encoderLayer.10.multiHeadAttention.k.bias", "bert.encoderLayer.10.multiHeadAttention.v.weight", "bert.encoderLayer.10.multiHeadAttention.v.bias", "bert.encoderLayer.10.multiHeadAttention.o.weight", "bert.encoderLayer.10.multiHeadAttention.o.bias", "bert.encoderLayer.10.layerNorm1.weight", "bert.encoderLayer.10.layerNorm1.bias", "bert.encoderLayer.10.feedForward.intermediateDense.weight", "bert.encoderLayer.10.feedForward.intermediateDense.bias", "bert.encoderLayer.10.feedForward.outputDense.weight", "bert.encoderLayer.10.feedForward.outputDense.bias", "bert.encoderLayer.10.layerNorm2.weight", "bert.encoderLayer.10.layerNorm2.bias", "bert.encoderLayer.11.multiHeadAttention.q.weight", "bert.encoderLayer.11.multiHeadAttention.q.bias", "bert.encoderLayer.11.multiHeadAttention.k.weight", "bert.encoderLayer.11.multiHeadAttention.k.bias", "bert.encoderLayer.11.multiHeadAttention.v.weight", "bert.encoderLayer.11.multiHeadAttention.v.bias", "bert.encoderLayer.11.multiHeadAttention.o.weight", "bert.encoderLayer.11.multiHeadAttention.o.bias", "bert.encoderLayer.11.layerNorm1.weight", "bert.encoderLayer.11.layerNorm1.bias", "bert.encoderLayer.11.feedForward.intermediateDense.weight", "bert.encoderLayer.11.feedForward.intermediateDense.bias", "bert.encoderLayer.11.feedForward.outputDense.weight", "bert.encoderLayer.11.feedForward.outputDense.bias", "bert.encoderLayer.11.layerNorm2.weight", "bert.encoderLayer.11.layerNorm2.bias", "bert.pooler.weight", "bert.pooler.bias", "bert.mlmDense.weight", "bert.mlmDense.bias", "bert.mlmLayerNorm.weight", "bert.mlmLayerNorm.bias", "bert.mlmDecoder.weight", "bert.mlmDecoder.bias". 
	Unexpected key(s) in state_dict: "mlmBias", "embeddings.word_embeddings.weight", "embeddings.position_embeddings.weight", "embeddings.layerNorm.weight", "embeddings.layerNorm.bias", "encoderLayer.0.multiHeadAttention.q.weight", "encoderLayer.0.multiHeadAttention.q.bias", "encoderLayer.0.multiHeadAttention.k.weight", "encoderLayer.0.multiHeadAttention.k.bias", "encoderLayer.0.multiHeadAttention.v.weight", "encoderLayer.0.multiHeadAttention.v.bias", "encoderLayer.0.multiHeadAttention.o.weight", "encoderLayer.0.multiHeadAttention.o.bias", "encoderLayer.0.layerNorm1.weight", "encoderLayer.0.layerNorm1.bias", "encoderLayer.0.feedForward.intermediateDense.weight", "encoderLayer.0.feedForward.intermediateDense.bias", "encoderLayer.0.feedForward.outputDense.weight", "encoderLayer.0.feedForward.outputDense.bias", "encoderLayer.0.layerNorm2.weight", "encoderLayer.0.layerNorm2.bias", "encoderLayer.1.multiHeadAttention.q.weight", "encoderLayer.1.multiHeadAttention.q.bias", "encoderLayer.1.multiHeadAttention.k.weight", "encoderLayer.1.multiHeadAttention.k.bias", "encoderLayer.1.multiHeadAttention.v.weight", "encoderLayer.1.multiHeadAttention.v.bias", "encoderLayer.1.multiHeadAttention.o.weight", "encoderLayer.1.multiHeadAttention.o.bias", "encoderLayer.1.layerNorm1.weight", "encoderLayer.1.layerNorm1.bias", "encoderLayer.1.feedForward.intermediateDense.weight", "encoderLayer.1.feedForward.intermediateDense.bias", "encoderLayer.1.feedForward.outputDense.weight", "encoderLayer.1.feedForward.outputDense.bias", "encoderLayer.1.layerNorm2.weight", "encoderLayer.1.layerNorm2.bias", "encoderLayer.2.multiHeadAttention.q.weight", "encoderLayer.2.multiHeadAttention.q.bias", "encoderLayer.2.multiHeadAttention.k.weight", "encoderLayer.2.multiHeadAttention.k.bias", "encoderLayer.2.multiHeadAttention.v.weight", "encoderLayer.2.multiHeadAttention.v.bias", "encoderLayer.2.multiHeadAttention.o.weight", "encoderLayer.2.multiHeadAttention.o.bias", "encoderLayer.2.layerNorm1.weight", "encoderLayer.2.layerNorm1.bias", "encoderLayer.2.feedForward.intermediateDense.weight", "encoderLayer.2.feedForward.intermediateDense.bias", "encoderLayer.2.feedForward.outputDense.weight", "encoderLayer.2.feedForward.outputDense.bias", "encoderLayer.2.layerNorm2.weight", "encoderLayer.2.layerNorm2.bias", "encoderLayer.3.multiHeadAttention.q.weight", "encoderLayer.3.multiHeadAttention.q.bias", "encoderLayer.3.multiHeadAttention.k.weight", "encoderLayer.3.multiHeadAttention.k.bias", "encoderLayer.3.multiHeadAttention.v.weight", "encoderLayer.3.multiHeadAttention.v.bias", "encoderLayer.3.multiHeadAttention.o.weight", "encoderLayer.3.multiHeadAttention.o.bias", "encoderLayer.3.layerNorm1.weight", "encoderLayer.3.layerNorm1.bias", "encoderLayer.3.feedForward.intermediateDense.weight", "encoderLayer.3.feedForward.intermediateDense.bias", "encoderLayer.3.feedForward.outputDense.weight", "encoderLayer.3.feedForward.outputDense.bias", "encoderLayer.3.layerNorm2.weight", "encoderLayer.3.layerNorm2.bias", "encoderLayer.4.multiHeadAttention.q.weight", "encoderLayer.4.multiHeadAttention.q.bias", "encoderLayer.4.multiHeadAttention.k.weight", "encoderLayer.4.multiHeadAttention.k.bias", "encoderLayer.4.multiHeadAttention.v.weight", "encoderLayer.4.multiHeadAttention.v.bias", "encoderLayer.4.multiHeadAttention.o.weight", "encoderLayer.4.multiHeadAttention.o.bias", "encoderLayer.4.layerNorm1.weight", "encoderLayer.4.layerNorm1.bias", "encoderLayer.4.feedForward.intermediateDense.weight", "encoderLayer.4.feedForward.intermediateDense.bias", "encoderLayer.4.feedForward.outputDense.weight", "encoderLayer.4.feedForward.outputDense.bias", "encoderLayer.4.layerNorm2.weight", "encoderLayer.4.layerNorm2.bias", "encoderLayer.5.multiHeadAttention.q.weight", "encoderLayer.5.multiHeadAttention.q.bias", "encoderLayer.5.multiHeadAttention.k.weight", "encoderLayer.5.multiHeadAttention.k.bias", "encoderLayer.5.multiHeadAttention.v.weight", "encoderLayer.5.multiHeadAttention.v.bias", "encoderLayer.5.multiHeadAttention.o.weight", "encoderLayer.5.multiHeadAttention.o.bias", "encoderLayer.5.layerNorm1.weight", "encoderLayer.5.layerNorm1.bias", "encoderLayer.5.feedForward.intermediateDense.weight", "encoderLayer.5.feedForward.intermediateDense.bias", "encoderLayer.5.feedForward.outputDense.weight", "encoderLayer.5.feedForward.outputDense.bias", "encoderLayer.5.layerNorm2.weight", "encoderLayer.5.layerNorm2.bias", "encoderLayer.6.multiHeadAttention.q.weight", "encoderLayer.6.multiHeadAttention.q.bias", "encoderLayer.6.multiHeadAttention.k.weight", "encoderLayer.6.multiHeadAttention.k.bias", "encoderLayer.6.multiHeadAttention.v.weight", "encoderLayer.6.multiHeadAttention.v.bias", "encoderLayer.6.multiHeadAttention.o.weight", "encoderLayer.6.multiHeadAttention.o.bias", "encoderLayer.6.layerNorm1.weight", "encoderLayer.6.layerNorm1.bias", "encoderLayer.6.feedForward.intermediateDense.weight", "encoderLayer.6.feedForward.intermediateDense.bias", "encoderLayer.6.feedForward.outputDense.weight", "encoderLayer.6.feedForward.outputDense.bias", "encoderLayer.6.layerNorm2.weight", "encoderLayer.6.layerNorm2.bias", "encoderLayer.7.multiHeadAttention.q.weight", "encoderLayer.7.multiHeadAttention.q.bias", "encoderLayer.7.multiHeadAttention.k.weight", "encoderLayer.7.multiHeadAttention.k.bias", "encoderLayer.7.multiHeadAttention.v.weight", "encoderLayer.7.multiHeadAttention.v.bias", "encoderLayer.7.multiHeadAttention.o.weight", "encoderLayer.7.multiHeadAttention.o.bias", "encoderLayer.7.layerNorm1.weight", "encoderLayer.7.layerNorm1.bias", "encoderLayer.7.feedForward.intermediateDense.weight", "encoderLayer.7.feedForward.intermediateDense.bias", "encoderLayer.7.feedForward.outputDense.weight", "encoderLayer.7.feedForward.outputDense.bias", "encoderLayer.7.layerNorm2.weight", "encoderLayer.7.layerNorm2.bias", "encoderLayer.8.multiHeadAttention.q.weight", "encoderLayer.8.multiHeadAttention.q.bias", "encoderLayer.8.multiHeadAttention.k.weight", "encoderLayer.8.multiHeadAttention.k.bias", "encoderLayer.8.multiHeadAttention.v.weight", "encoderLayer.8.multiHeadAttention.v.bias", "encoderLayer.8.multiHeadAttention.o.weight", "encoderLayer.8.multiHeadAttention.o.bias", "encoderLayer.8.layerNorm1.weight", "encoderLayer.8.layerNorm1.bias", "encoderLayer.8.feedForward.intermediateDense.weight", "encoderLayer.8.feedForward.intermediateDense.bias", "encoderLayer.8.feedForward.outputDense.weight", "encoderLayer.8.feedForward.outputDense.bias", "encoderLayer.8.layerNorm2.weight", "encoderLayer.8.layerNorm2.bias", "encoderLayer.9.multiHeadAttention.q.weight", "encoderLayer.9.multiHeadAttention.q.bias", "encoderLayer.9.multiHeadAttention.k.weight", "encoderLayer.9.multiHeadAttention.k.bias", "encoderLayer.9.multiHeadAttention.v.weight", "encoderLayer.9.multiHeadAttention.v.bias", "encoderLayer.9.multiHeadAttention.o.weight", "encoderLayer.9.multiHeadAttention.o.bias", "encoderLayer.9.layerNorm1.weight", "encoderLayer.9.layerNorm1.bias", "encoderLayer.9.feedForward.intermediateDense.weight", "encoderLayer.9.feedForward.intermediateDense.bias", "encoderLayer.9.feedForward.outputDense.weight", "encoderLayer.9.feedForward.outputDense.bias", "encoderLayer.9.layerNorm2.weight", "encoderLayer.9.layerNorm2.bias", "encoderLayer.10.multiHeadAttention.q.weight", "encoderLayer.10.multiHeadAttention.q.bias", "encoderLayer.10.multiHeadAttention.k.weight", "encoderLayer.10.multiHeadAttention.k.bias", "encoderLayer.10.multiHeadAttention.v.weight", "encoderLayer.10.multiHeadAttention.v.bias", "encoderLayer.10.multiHeadAttention.o.weight", "encoderLayer.10.multiHeadAttention.o.bias", "encoderLayer.10.layerNorm1.weight", "encoderLayer.10.layerNorm1.bias", "encoderLayer.10.feedForward.intermediateDense.weight", "encoderLayer.10.feedForward.intermediateDense.bias", "encoderLayer.10.feedForward.outputDense.weight", "encoderLayer.10.feedForward.outputDense.bias", "encoderLayer.10.layerNorm2.weight", "encoderLayer.10.layerNorm2.bias", "encoderLayer.11.multiHeadAttention.q.weight", "encoderLayer.11.multiHeadAttention.q.bias", "encoderLayer.11.multiHeadAttention.k.weight", "encoderLayer.11.multiHeadAttention.k.bias", "encoderLayer.11.multiHeadAttention.v.weight", "encoderLayer.11.multiHeadAttention.v.bias", "encoderLayer.11.multiHeadAttention.o.weight", "encoderLayer.11.multiHeadAttention.o.bias", "encoderLayer.11.layerNorm1.weight", "encoderLayer.11.layerNorm1.bias", "encoderLayer.11.feedForward.intermediateDense.weight", "encoderLayer.11.feedForward.intermediateDense.bias", "encoderLayer.11.feedForward.outputDense.weight", "encoderLayer.11.feedForward.outputDense.bias", "encoderLayer.11.layerNorm2.weight", "encoderLayer.11.layerNorm2.bias", "mlmDense.weight", "mlmDense.bias", "mlmLayerNorm.weight", "mlmLayerNorm.bias", "mlmDecoder.weight", "mlmDecoder.bias". 

from bert4torch.

Tongjilibo avatar Tongjilibo commented on June 7, 2024

我晚上看一下

from bert4torch.

Smile-L-up avatar Smile-L-up commented on June 7, 2024

在model.load_weights()前加个映射似乎是可以跑的,参数似乎确实缺少了一个bert类型。

        mapping = {'mlmDecoder.bias': 'bert.mlmDecoder.bias'}
        param = list(model.named_parameters())
        for p in param:
            mapping[p[0][5:]] =  p[0]
        model.load_weights('../pretrain/roberta_pretrain/saved_model/bert_pretrain_model0831.ckpt', mapping=mapping)

这种方式当不加strict=False时出现如下问题是正常的,预训练阶段没有simbert具体的poor层等,所以预测的结果也很差。

RuntimeError: Error(s) in loading state_dict for Model:
	Missing key(s) in state_dict: "bert.embeddings.segment_embeddings.weight", "bert.pooler.weight", "bert.pooler.bias",

from bert4torch.

Tongjilibo avatar Tongjilibo commented on June 7, 2024

你好,看了你的代码我感觉应该是这样加载进去,你看看是否可行

# 建立加载模型
class Model(BaseModel):
    def __init__(self, pool_method='cls'):
        super().__init__()
        self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path,
                                            with_pool='linear', with_mlm='linear', application='unilm', add_trainer=True)
        self.bert.load_weights('./bert_pretrain_model.ckpt', strict=False)   # 这里加载进去
        self.pool_method = pool_method

    def forward(self, token_ids, segment_ids):
        hidden_state, pool_cls, seq_logit = self.bert([token_ids, segment_ids])
        sen_emb = get_pool_emb(hidden_state, pool_cls, token_ids.gt(0).long(), self.pool_method)
        return seq_logit, sen_emb

from bert4torch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.