Comments (5)
你是说你用脚本加载模型预训练是可以的,但是保存出来的ckpt想用于微调发现加载不进去,是这个意思吗?
如果是的话,加载finetune的模型使用model.load_weights(ckpt_path)
来加载就可以了
from bert4torch.
预训练脚本
#! -*- coding: utf-8 -*-
# 预训练脚本,单GPU版方便测试
# 改DDP需几行代码,参考https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick/task_distributed_data_parallel.py
from bert4torch.models import build_transformer_model
from bert4torch.snippets import sequence_padding
from bert4torch.callbacks import Callback
from bert4torch.optimizers import get_linear_schedule_with_warmup
from torch.utils.data import Dataset
import torch.nn as nn
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import json
import os
import shelve
import random
import time
# 语料路径和模型保存路径
model_saved_path = './saved_model/bert_pretrain_model.ckpt'
dir_training_data = 'pretrain_data' # dir_training_data
task_name = 'bert'
# 其他配置
maxlen = 256
batch_size = 20
config_path = '../../seq2seq/simbert_model/simbert-chinese-base/config.json'
checkpoint_path = '../../seq2seq/simbert_model/simbert-chinese-base/pytorch_model.bin' # 如果从零训练,就设为None
learning_rate = 0.00176
weight_decay_rate = 0.01 # 权重衰减
num_warmup_steps = 3125
num_train_steps = 10000
steps_per_epoch = 1000
grad_accum_steps = 1 # 大于1即表明使用梯度累积
epochs = num_train_steps * grad_accum_steps // steps_per_epoch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 读取数据集,构建数据张量
class MyDataset(Dataset):
def __init__(self, file):
super(MyDataset, self).__init__()
self.file = file
self.len = self._get_dataset_length()
self.db = self._load_data()
def __getitem__(self, index):
return self.db[str(index)]
def __len__(self):
return self.len
def _get_dataset_length(self):
file_record_info = self.file + ".json"
record_info = json.load(open(file_record_info, "r", encoding="utf-8"))
return record_info["samples_num"]
def _load_data(self):
return shelve.open(self.file)
def collate_fn(batch):
batch_token_ids, batch_labels = [], []
for item in batch:
batch_token_ids.append(item['input_ids'])
batch_labels.append(item['masked_lm_labels'])
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
batch_labels = torch.tensor(batch_labels, dtype=torch.long, device=device)
return [batch_token_ids], batch_labels
# 从语料文件夹中随机选取一个文件,生成dataloader
def get_train_dataloader():
while True:
# prepare dataset
files_training_data = os.listdir(dir_training_data)
files_training_data = [file.split(".")[0] for file in files_training_data if "train" in file]
# 防止使用到正在生成的文件
files_training_data = [i for i in set(files_training_data) if files_training_data.count(i)==4]
if files_training_data:
file_train = random.choice(files_training_data)
for suffix in [".bak", ".dat", ".dir", ".json"]:
file_old = os.path.join(dir_training_data, file_train + suffix)
file_new = os.path.join(dir_training_data, task_name + suffix)
os.renames(file_old, file_new)
cur_load_file = file_new.split(".")[0]
train_dataloader = DataLoader(MyDataset(cur_load_file), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
break
else:
sleep_seconds = 300
print(f"No training data! Sleep {sleep_seconds}s!")
time.sleep(sleep_seconds)
continue
return train_dataloader
train_dataloader = get_train_dataloader()
model = build_transformer_model(config_path, checkpoint_path, segment_vocab_size=0, with_mlm=True, add_trainer=True).to(device)
# weight decay
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay_rate},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
class MyLoss(nn.CrossEntropyLoss):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, output, batch_labels):
y_preds = output[-1]
y_preds = y_preds.reshape(-1, y_preds.shape[-1])
return super().forward(y_preds, batch_labels.flatten())
# 定义使用的loss和optimizer,这里支持自定义
optimizer = optim.Adam(optimizer_grouped_parameters, lr=learning_rate, weight_decay=weight_decay_rate)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)
model.compile(loss=MyLoss(ignore_index=0), optimizer=optimizer, scheduler=scheduler, grad_accumulation_steps=grad_accum_steps)
class ModelCheckpoint(Callback):
"""自动保存最新模型
"""
def on_dataloader_end(self, logs=None):
# 在dataloader结束的时候,关闭db并且删除训练的文件
model.train_dataloader.dataset.db.close()
for suffix in [".bak", ".dat", ".dir", ".json"]:
file_remove = os.path.join(dir_training_data, task_name + suffix)
try:
os.remove(file_remove)
except:
print(f"Failed to remove training data {file_remove}.")
# 重新生成dataloader
model.train_dataloader = get_train_dataloader()
def on_epoch_end(self, global_step, epoch, logs=None):
model.save_weights(model_saved_path)
if __name__ == '__main__':
# 保存模型
checkpoint = ModelCheckpoint()
# 模型训练
model.fit(
train_dataloader,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
callbacks=[checkpoint],
)
微调脚本
#! -*- coding: utf-8 -*-
# SimBERT预训练代码,也可用于微调,微调方式用其他方式比如sentence_bert的可能更好
# 官方项目:https://github.com/ZhuiyiTechnology/simbert
import json
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from bert4torch.models import build_transformer_model, BaseModel
from bert4torch.snippets import sequence_padding, ListDataset, text_segmentate, get_pool_emb
from bert4torch.generation import AutoRegressiveDecoder
from bert4torch.callbacks import Callback
from bert4torch.tokenizers import Tokenizer, load_vocab
import os
# 基本信息
maxlen = 22
batch_size = 64
# 这里加载的是simbert权重,在此基础上用自己的数据继续pretrain/finetune
# 自己从头预训练也可以直接加载bert/roberta等checkpoint
config_path = './simbert_model/simbert-chinese-base/config.json'
checkpoint_path = './simbert_model/simbert-chinese-base/pytorch_model.bin'
# checkpoint_path = '../pretrain/roberta_pretrain/saved_model/bert_pretrain_model.ckpt'
dict_path = './simbert_model/simbert-chinese-base/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
# 加载并精简词表,建立分词器(先不简化)
token_dict = load_vocab(
dict_path=dict_path,
simplified=False,
startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
)
tokenizer = Tokenizer(token_dict, do_lower_case=True)
filename = 'data/ais0824.json'
class MyDataset(ListDataset):
@staticmethod
def load_data(filename):
"""读取语料,每行一个json
示例:{"text": "懂英语的来!", "synonyms": ["懂英语的来!!!", "懂英语的来", "一句英语翻译 懂英语的来"]}
"""
D = []
with open(filename, encoding='utf-8') as f:
for l in f:
D.append(json.loads(l))
return D
def truncate(text):
"""截断句子
"""
seps, strips = u'\n。!?!?;;,, ', u';;,, '
return text_segmentate(text, maxlen - 2, seps, strips)[0]
def collate_fn(batch):
batch_token_ids, batch_segment_ids = [], []
for d in batch:
text, synonyms = d['text'], d['synonyms']
synonyms = [text] + synonyms
np.random.shuffle(synonyms)
text, synonym = synonyms[:2]
text, synonym = truncate(text), truncate(synonym)
token_ids, segment_ids = tokenizer.encode(text, synonym, maxlen=maxlen * 2)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
token_ids, segment_ids = tokenizer.encode(synonym, text, maxlen=maxlen * 2)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device)
return [batch_token_ids, batch_segment_ids], [batch_token_ids, batch_segment_ids]
train_dataloader = DataLoader(MyDataset(filename), batch_size=batch_size, shuffle=True,
collate_fn=collate_fn)
# 建立加载模型
class Model(BaseModel):
def __init__(self, pool_method='cls'):
super().__init__()
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path,
with_pool='linear',
with_mlm='linear', application='unilm')
self.pool_method = pool_method
def forward(self, token_ids, segment_ids):
hidden_state, pool_cls, seq_logit = self.bert([token_ids, segment_ids])
sen_emb = get_pool_emb(hidden_state, pool_cls, token_ids.gt(0).long(), self.pool_method)
return seq_logit, sen_emb
model = Model(pool_method='cls').to(device)
class TotalLoss(nn.Module):
"""loss分两部分,一是seq2seq的交叉熵,二是相似度的交叉熵。
"""
def forward(self, outputs, target):
seq_logit, sen_emb = outputs
seq_label, seq_mask = target
seq2seq_loss = self.compute_loss_of_seq2seq(seq_logit, seq_label, seq_mask)
similarity_loss = self.compute_loss_of_similarity(sen_emb)
return {'loss': seq2seq_loss + similarity_loss, 'seq2seq_loss': seq2seq_loss,
'similarity_loss': similarity_loss}
def compute_loss_of_seq2seq(self, y_pred, y_true, y_mask):
'''
y_pred: [btz, seq_len, hdsz]
y_true: [btz, seq_len]
y_mask: [btz, seq_len]
'''
y_true = y_true[:, 1:] # 目标token_ids
y_mask = y_mask[:, 1:] # 指示了要预测的部分
y_pred = y_pred[:, :-1, :] # 预测序列,错开一位
y_pred = y_pred.reshape(-1, y_pred.shape[-1])
y_true = (y_true * y_mask).flatten()
return F.cross_entropy(y_pred, y_true, ignore_index=0)
def compute_loss_of_similarity(self, y_pred):
y_true = self.get_labels_of_similarity(y_pred) # 构建标签
y_pred = F.normalize(y_pred, p=2, dim=-1) # 句向量归一化
similarities = torch.matmul(y_pred, y_pred.T) # 相似度矩阵
similarities = similarities - torch.eye(y_pred.shape[0], device=device) * 1e12 # 排除对角线
similarities = similarities * 30 # scale
loss = F.cross_entropy(similarities, y_true)
return loss
def get_labels_of_similarity(self, y_pred):
idxs = torch.arange(0, y_pred.shape[0], device=device)
idxs_1 = idxs[None, :]
idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
labels = idxs_1.eq(idxs_2).float()
return labels
model.compile(loss=TotalLoss(), optimizer=optim.Adam(model.parameters(), 1e-5),
metrics=['seq2seq_loss', 'similarity_loss'])
class SynonymsGenerator(AutoRegressiveDecoder):
"""seq2seq解码器
"""
@AutoRegressiveDecoder.wraps('logits')
def predict(self, inputs, output_ids, states):
token_ids, segment_ids = inputs
token_ids = torch.cat([token_ids, output_ids], 1)
segment_ids = torch.cat([segment_ids, torch.ones_like(output_ids, device=device)], 1)
seq_logit, _ = model.predict([token_ids, segment_ids])
return seq_logit[:, -1, :]
def generate(self, text, n=1, topk=5):
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
output_ids = self.random_sample([token_ids, segment_ids], n=n, topk=topk) # 基于随机采样
return [tokenizer.decode(ids.cpu().numpy()) for ids in output_ids]
synonyms_generator = SynonymsGenerator(start_id=None, end_id=tokenizer._token_end_id, maxlen=maxlen, device=device)
def cal_sen_emb(text_list):
'''输入text的list,计算sentence的embedding
'''
X, S = [], []
for t in text_list:
x, s = tokenizer.encode(t)
X.append(x)
S.append(s)
X = torch.tensor(sequence_padding(X), dtype=torch.long, device=device)
S = torch.tensor(sequence_padding(S), dtype=torch.long, device=device)
_, Z = model.predict([X, S])
return Z
def gen_synonyms(text, n=100, k=20):
""""含义: 产生sent的n个相似句,然后返回最相似的k个。
做法:用seq2seq生成,并用encoder算相似度并排序。
效果:
>>> gen_synonyms(u'微信和支付宝哪个好?')
[
u'微信和支付宝,哪个好?',
u'微信和支付宝哪个好',
u'支付宝和微信哪个好',
u'支付宝和微信哪个好啊',
u'微信和支付宝那个好用?',
u'微信和支付宝哪个好用',
u'支付宝和微信那个更好',
u'支付宝和微信哪个好用',
u'微信和支付宝用起来哪个好?',
u'微信和支付宝选哪个好',
]
"""
r = synonyms_generator.generate(text, n)
r = [i for i in set(r) if i != text] # 不和原文相同
r = [text] + r
Z = cal_sen_emb(r)
Z /= (Z ** 2).sum(dim=1, keepdims=True) ** 0.5
argsort = torch.matmul(Z[1:], -Z[0]).argsort()
return [r[i + 1] for i in argsort[:k]]
def just_show(some_samples):
"""随机观察一些样本的效果
"""
S = [np.random.choice(some_samples) for _ in range(3)]
for s in some_samples:
try:
print(u'原句子:%s' % s)
print(u'同义句子:', gen_synonyms(s, 10, 10))
print()
except:
pass
class Evaluator(Callback):
"""评估模型
"""
def __init__(self):
self.lowest = 1e10
def on_epoch_end(self, global_step, epoch, logs=None):
# 保存最优
if logs['loss'] <= self.lowest:
self.lowest = logs['loss']
print('saved epoch: {}'.format(epoch))
model.save_weights('./saved_model/best_model0828.bin')
# 演示效果
just_show(['微信和支付宝拿个好用?',
'微信和支付宝,哪个好?',
'微信和支付宝哪个好',
'支付宝和微信哪个好',
'支付宝和微信哪个好啊',
'微信和支付宝那个好用?',
'微信和支付宝哪个好用',
'支付宝和微信那个更好',
'支付宝和微信哪个好用',
'微信和支付宝用起来哪个好?',
'微信和支付宝选哪个好'
])
if __name__ == '__main__':
choice = 'generate' # train generate similarity
if choice == 'train':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=50, steps_per_epoch=None, callbacks=[evaluator])
elif choice == 'generate':
model.load_weights('./saved_model/bert_pretrain_model.ckpt') #######会出现问题。
param = list(model.named_parameters())
just_show(['微信和支付宝拿个好用?',
'微信和支付宝,哪个好?',
'微信和支付宝哪个好',
'支付宝和微信哪个好',
'支付宝和微信哪个好啊',
'微信和支付宝那个好用?',
'微信和支付宝哪个好用',
'支付宝和微信那个更好',
'支付宝和微信哪个好用',
'微信和支付宝用起来哪个好?',
'微信和支付宝选哪个好'
])
elif choice == 'similarity':
target_text = '我想去首都北京玩玩'
text_list = ['我想去北京玩', '北京有啥好玩的吗?我想去看看', '好渴望去北京游玩啊']
Z = cal_sen_emb([target_text] + text_list)
Z /= (Z ** 2).sum(dim=1, keepdims=True) ** 0.5
similarity = torch.matmul(Z[1:], Z[0])
for i, line in enumerate(text_list):
print(f'cos_sim: {similarity[i].item():.4f}, tgt_text: "{target_text}", cal_text: "{line}"')
直接尝试generate会出现如下问题,(好像不能直接加载huggingface上的模型,需要经过转化脚本转化后再继续微调?我现在在尝试将苏神给的tf版本模型转为本项目支持的格式再进行预训练试试):
Missing key(s) in state_dict: "bert.mlmBias", "bert.embeddings.word_embeddings.weight", "bert.embeddings.position_embeddings.weight", "bert.embeddings.segment_embeddings.weight", "bert.embeddings.layerNorm.weight", "bert.embeddings.layerNorm.bias", "bert.encoderLayer.0.multiHeadAttention.q.weight", "bert.encoderLayer.0.multiHeadAttention.q.bias", "bert.encoderLayer.0.multiHeadAttention.k.weight", "bert.encoderLayer.0.multiHeadAttention.k.bias", "bert.encoderLayer.0.multiHeadAttention.v.weight", "bert.encoderLayer.0.multiHeadAttention.v.bias", "bert.encoderLayer.0.multiHeadAttention.o.weight", "bert.encoderLayer.0.multiHeadAttention.o.bias", "bert.encoderLayer.0.layerNorm1.weight", "bert.encoderLayer.0.layerNorm1.bias", "bert.encoderLayer.0.feedForward.intermediateDense.weight", "bert.encoderLayer.0.feedForward.intermediateDense.bias", "bert.encoderLayer.0.feedForward.outputDense.weight", "bert.encoderLayer.0.feedForward.outputDense.bias", "bert.encoderLayer.0.layerNorm2.weight", "bert.encoderLayer.0.layerNorm2.bias", "bert.encoderLayer.1.multiHeadAttention.q.weight", "bert.encoderLayer.1.multiHeadAttention.q.bias", "bert.encoderLayer.1.multiHeadAttention.k.weight", "bert.encoderLayer.1.multiHeadAttention.k.bias", "bert.encoderLayer.1.multiHeadAttention.v.weight", "bert.encoderLayer.1.multiHeadAttention.v.bias", "bert.encoderLayer.1.multiHeadAttention.o.weight", "bert.encoderLayer.1.multiHeadAttention.o.bias", "bert.encoderLayer.1.layerNorm1.weight", "bert.encoderLayer.1.layerNorm1.bias", "bert.encoderLayer.1.feedForward.intermediateDense.weight", "bert.encoderLayer.1.feedForward.intermediateDense.bias", "bert.encoderLayer.1.feedForward.outputDense.weight", "bert.encoderLayer.1.feedForward.outputDense.bias", "bert.encoderLayer.1.layerNorm2.weight", "bert.encoderLayer.1.layerNorm2.bias", "bert.encoderLayer.2.multiHeadAttention.q.weight", "bert.encoderLayer.2.multiHeadAttention.q.bias", "bert.encoderLayer.2.multiHeadAttention.k.weight", "bert.encoderLayer.2.multiHeadAttention.k.bias", "bert.encoderLayer.2.multiHeadAttention.v.weight", "bert.encoderLayer.2.multiHeadAttention.v.bias", "bert.encoderLayer.2.multiHeadAttention.o.weight", "bert.encoderLayer.2.multiHeadAttention.o.bias", "bert.encoderLayer.2.layerNorm1.weight", "bert.encoderLayer.2.layerNorm1.bias", "bert.encoderLayer.2.feedForward.intermediateDense.weight", "bert.encoderLayer.2.feedForward.intermediateDense.bias", "bert.encoderLayer.2.feedForward.outputDense.weight", "bert.encoderLayer.2.feedForward.outputDense.bias", "bert.encoderLayer.2.layerNorm2.weight", "bert.encoderLayer.2.layerNorm2.bias", "bert.encoderLayer.3.multiHeadAttention.q.weight", "bert.encoderLayer.3.multiHeadAttention.q.bias", "bert.encoderLayer.3.multiHeadAttention.k.weight", "bert.encoderLayer.3.multiHeadAttention.k.bias", "bert.encoderLayer.3.multiHeadAttention.v.weight", "bert.encoderLayer.3.multiHeadAttention.v.bias", "bert.encoderLayer.3.multiHeadAttention.o.weight", "bert.encoderLayer.3.multiHeadAttention.o.bias", "bert.encoderLayer.3.layerNorm1.weight", "bert.encoderLayer.3.layerNorm1.bias", "bert.encoderLayer.3.feedForward.intermediateDense.weight", "bert.encoderLayer.3.feedForward.intermediateDense.bias", "bert.encoderLayer.3.feedForward.outputDense.weight", "bert.encoderLayer.3.feedForward.outputDense.bias", "bert.encoderLayer.3.layerNorm2.weight", "bert.encoderLayer.3.layerNorm2.bias", "bert.encoderLayer.4.multiHeadAttention.q.weight", "bert.encoderLayer.4.multiHeadAttention.q.bias", "bert.encoderLayer.4.multiHeadAttention.k.weight", "bert.encoderLayer.4.multiHeadAttention.k.bias", "bert.encoderLayer.4.multiHeadAttention.v.weight", "bert.encoderLayer.4.multiHeadAttention.v.bias", "bert.encoderLayer.4.multiHeadAttention.o.weight", "bert.encoderLayer.4.multiHeadAttention.o.bias", "bert.encoderLayer.4.layerNorm1.weight", "bert.encoderLayer.4.layerNorm1.bias", "bert.encoderLayer.4.feedForward.intermediateDense.weight", "bert.encoderLayer.4.feedForward.intermediateDense.bias", "bert.encoderLayer.4.feedForward.outputDense.weight", "bert.encoderLayer.4.feedForward.outputDense.bias", "bert.encoderLayer.4.layerNorm2.weight", "bert.encoderLayer.4.layerNorm2.bias", "bert.encoderLayer.5.multiHeadAttention.q.weight", "bert.encoderLayer.5.multiHeadAttention.q.bias", "bert.encoderLayer.5.multiHeadAttention.k.weight", "bert.encoderLayer.5.multiHeadAttention.k.bias", "bert.encoderLayer.5.multiHeadAttention.v.weight", "bert.encoderLayer.5.multiHeadAttention.v.bias", "bert.encoderLayer.5.multiHeadAttention.o.weight", "bert.encoderLayer.5.multiHeadAttention.o.bias", "bert.encoderLayer.5.layerNorm1.weight", "bert.encoderLayer.5.layerNorm1.bias", "bert.encoderLayer.5.feedForward.intermediateDense.weight", "bert.encoderLayer.5.feedForward.intermediateDense.bias", "bert.encoderLayer.5.feedForward.outputDense.weight", "bert.encoderLayer.5.feedForward.outputDense.bias", "bert.encoderLayer.5.layerNorm2.weight", "bert.encoderLayer.5.layerNorm2.bias", "bert.encoderLayer.6.multiHeadAttention.q.weight", "bert.encoderLayer.6.multiHeadAttention.q.bias", "bert.encoderLayer.6.multiHeadAttention.k.weight", "bert.encoderLayer.6.multiHeadAttention.k.bias", "bert.encoderLayer.6.multiHeadAttention.v.weight", "bert.encoderLayer.6.multiHeadAttention.v.bias", "bert.encoderLayer.6.multiHeadAttention.o.weight", "bert.encoderLayer.6.multiHeadAttention.o.bias", "bert.encoderLayer.6.layerNorm1.weight", "bert.encoderLayer.6.layerNorm1.bias", "bert.encoderLayer.6.feedForward.intermediateDense.weight", "bert.encoderLayer.6.feedForward.intermediateDense.bias", "bert.encoderLayer.6.feedForward.outputDense.weight", "bert.encoderLayer.6.feedForward.outputDense.bias", "bert.encoderLayer.6.layerNorm2.weight", "bert.encoderLayer.6.layerNorm2.bias", "bert.encoderLayer.7.multiHeadAttention.q.weight", "bert.encoderLayer.7.multiHeadAttention.q.bias", "bert.encoderLayer.7.multiHeadAttention.k.weight", "bert.encoderLayer.7.multiHeadAttention.k.bias", "bert.encoderLayer.7.multiHeadAttention.v.weight", "bert.encoderLayer.7.multiHeadAttention.v.bias", "bert.encoderLayer.7.multiHeadAttention.o.weight", "bert.encoderLayer.7.multiHeadAttention.o.bias", "bert.encoderLayer.7.layerNorm1.weight", "bert.encoderLayer.7.layerNorm1.bias", "bert.encoderLayer.7.feedForward.intermediateDense.weight", "bert.encoderLayer.7.feedForward.intermediateDense.bias", "bert.encoderLayer.7.feedForward.outputDense.weight", "bert.encoderLayer.7.feedForward.outputDense.bias", "bert.encoderLayer.7.layerNorm2.weight", "bert.encoderLayer.7.layerNorm2.bias", "bert.encoderLayer.8.multiHeadAttention.q.weight", "bert.encoderLayer.8.multiHeadAttention.q.bias", "bert.encoderLayer.8.multiHeadAttention.k.weight", "bert.encoderLayer.8.multiHeadAttention.k.bias", "bert.encoderLayer.8.multiHeadAttention.v.weight", "bert.encoderLayer.8.multiHeadAttention.v.bias", "bert.encoderLayer.8.multiHeadAttention.o.weight", "bert.encoderLayer.8.multiHeadAttention.o.bias", "bert.encoderLayer.8.layerNorm1.weight", "bert.encoderLayer.8.layerNorm1.bias", "bert.encoderLayer.8.feedForward.intermediateDense.weight", "bert.encoderLayer.8.feedForward.intermediateDense.bias", "bert.encoderLayer.8.feedForward.outputDense.weight", "bert.encoderLayer.8.feedForward.outputDense.bias", "bert.encoderLayer.8.layerNorm2.weight", "bert.encoderLayer.8.layerNorm2.bias", "bert.encoderLayer.9.multiHeadAttention.q.weight", "bert.encoderLayer.9.multiHeadAttention.q.bias", "bert.encoderLayer.9.multiHeadAttention.k.weight", "bert.encoderLayer.9.multiHeadAttention.k.bias", "bert.encoderLayer.9.multiHeadAttention.v.weight", "bert.encoderLayer.9.multiHeadAttention.v.bias", "bert.encoderLayer.9.multiHeadAttention.o.weight", "bert.encoderLayer.9.multiHeadAttention.o.bias", "bert.encoderLayer.9.layerNorm1.weight", "bert.encoderLayer.9.layerNorm1.bias", "bert.encoderLayer.9.feedForward.intermediateDense.weight", "bert.encoderLayer.9.feedForward.intermediateDense.bias", "bert.encoderLayer.9.feedForward.outputDense.weight", "bert.encoderLayer.9.feedForward.outputDense.bias", "bert.encoderLayer.9.layerNorm2.weight", "bert.encoderLayer.9.layerNorm2.bias", "bert.encoderLayer.10.multiHeadAttention.q.weight", "bert.encoderLayer.10.multiHeadAttention.q.bias", "bert.encoderLayer.10.multiHeadAttention.k.weight", "bert.encoderLayer.10.multiHeadAttention.k.bias", "bert.encoderLayer.10.multiHeadAttention.v.weight", "bert.encoderLayer.10.multiHeadAttention.v.bias", "bert.encoderLayer.10.multiHeadAttention.o.weight", "bert.encoderLayer.10.multiHeadAttention.o.bias", "bert.encoderLayer.10.layerNorm1.weight", "bert.encoderLayer.10.layerNorm1.bias", "bert.encoderLayer.10.feedForward.intermediateDense.weight", "bert.encoderLayer.10.feedForward.intermediateDense.bias", "bert.encoderLayer.10.feedForward.outputDense.weight", "bert.encoderLayer.10.feedForward.outputDense.bias", "bert.encoderLayer.10.layerNorm2.weight", "bert.encoderLayer.10.layerNorm2.bias", "bert.encoderLayer.11.multiHeadAttention.q.weight", "bert.encoderLayer.11.multiHeadAttention.q.bias", "bert.encoderLayer.11.multiHeadAttention.k.weight", "bert.encoderLayer.11.multiHeadAttention.k.bias", "bert.encoderLayer.11.multiHeadAttention.v.weight", "bert.encoderLayer.11.multiHeadAttention.v.bias", "bert.encoderLayer.11.multiHeadAttention.o.weight", "bert.encoderLayer.11.multiHeadAttention.o.bias", "bert.encoderLayer.11.layerNorm1.weight", "bert.encoderLayer.11.layerNorm1.bias", "bert.encoderLayer.11.feedForward.intermediateDense.weight", "bert.encoderLayer.11.feedForward.intermediateDense.bias", "bert.encoderLayer.11.feedForward.outputDense.weight", "bert.encoderLayer.11.feedForward.outputDense.bias", "bert.encoderLayer.11.layerNorm2.weight", "bert.encoderLayer.11.layerNorm2.bias", "bert.pooler.weight", "bert.pooler.bias", "bert.mlmDense.weight", "bert.mlmDense.bias", "bert.mlmLayerNorm.weight", "bert.mlmLayerNorm.bias", "bert.mlmDecoder.weight", "bert.mlmDecoder.bias".
Unexpected key(s) in state_dict: "mlmBias", "embeddings.word_embeddings.weight", "embeddings.position_embeddings.weight", "embeddings.layerNorm.weight", "embeddings.layerNorm.bias", "encoderLayer.0.multiHeadAttention.q.weight", "encoderLayer.0.multiHeadAttention.q.bias", "encoderLayer.0.multiHeadAttention.k.weight", "encoderLayer.0.multiHeadAttention.k.bias", "encoderLayer.0.multiHeadAttention.v.weight", "encoderLayer.0.multiHeadAttention.v.bias", "encoderLayer.0.multiHeadAttention.o.weight", "encoderLayer.0.multiHeadAttention.o.bias", "encoderLayer.0.layerNorm1.weight", "encoderLayer.0.layerNorm1.bias", "encoderLayer.0.feedForward.intermediateDense.weight", "encoderLayer.0.feedForward.intermediateDense.bias", "encoderLayer.0.feedForward.outputDense.weight", "encoderLayer.0.feedForward.outputDense.bias", "encoderLayer.0.layerNorm2.weight", "encoderLayer.0.layerNorm2.bias", "encoderLayer.1.multiHeadAttention.q.weight", "encoderLayer.1.multiHeadAttention.q.bias", "encoderLayer.1.multiHeadAttention.k.weight", "encoderLayer.1.multiHeadAttention.k.bias", "encoderLayer.1.multiHeadAttention.v.weight", "encoderLayer.1.multiHeadAttention.v.bias", "encoderLayer.1.multiHeadAttention.o.weight", "encoderLayer.1.multiHeadAttention.o.bias", "encoderLayer.1.layerNorm1.weight", "encoderLayer.1.layerNorm1.bias", "encoderLayer.1.feedForward.intermediateDense.weight", "encoderLayer.1.feedForward.intermediateDense.bias", "encoderLayer.1.feedForward.outputDense.weight", "encoderLayer.1.feedForward.outputDense.bias", "encoderLayer.1.layerNorm2.weight", "encoderLayer.1.layerNorm2.bias", "encoderLayer.2.multiHeadAttention.q.weight", "encoderLayer.2.multiHeadAttention.q.bias", "encoderLayer.2.multiHeadAttention.k.weight", "encoderLayer.2.multiHeadAttention.k.bias", "encoderLayer.2.multiHeadAttention.v.weight", "encoderLayer.2.multiHeadAttention.v.bias", "encoderLayer.2.multiHeadAttention.o.weight", "encoderLayer.2.multiHeadAttention.o.bias", "encoderLayer.2.layerNorm1.weight", "encoderLayer.2.layerNorm1.bias", "encoderLayer.2.feedForward.intermediateDense.weight", "encoderLayer.2.feedForward.intermediateDense.bias", "encoderLayer.2.feedForward.outputDense.weight", "encoderLayer.2.feedForward.outputDense.bias", "encoderLayer.2.layerNorm2.weight", "encoderLayer.2.layerNorm2.bias", "encoderLayer.3.multiHeadAttention.q.weight", "encoderLayer.3.multiHeadAttention.q.bias", "encoderLayer.3.multiHeadAttention.k.weight", "encoderLayer.3.multiHeadAttention.k.bias", "encoderLayer.3.multiHeadAttention.v.weight", "encoderLayer.3.multiHeadAttention.v.bias", "encoderLayer.3.multiHeadAttention.o.weight", "encoderLayer.3.multiHeadAttention.o.bias", "encoderLayer.3.layerNorm1.weight", "encoderLayer.3.layerNorm1.bias", "encoderLayer.3.feedForward.intermediateDense.weight", "encoderLayer.3.feedForward.intermediateDense.bias", "encoderLayer.3.feedForward.outputDense.weight", "encoderLayer.3.feedForward.outputDense.bias", "encoderLayer.3.layerNorm2.weight", "encoderLayer.3.layerNorm2.bias", "encoderLayer.4.multiHeadAttention.q.weight", "encoderLayer.4.multiHeadAttention.q.bias", "encoderLayer.4.multiHeadAttention.k.weight", "encoderLayer.4.multiHeadAttention.k.bias", "encoderLayer.4.multiHeadAttention.v.weight", "encoderLayer.4.multiHeadAttention.v.bias", "encoderLayer.4.multiHeadAttention.o.weight", "encoderLayer.4.multiHeadAttention.o.bias", "encoderLayer.4.layerNorm1.weight", "encoderLayer.4.layerNorm1.bias", "encoderLayer.4.feedForward.intermediateDense.weight", "encoderLayer.4.feedForward.intermediateDense.bias", "encoderLayer.4.feedForward.outputDense.weight", "encoderLayer.4.feedForward.outputDense.bias", "encoderLayer.4.layerNorm2.weight", "encoderLayer.4.layerNorm2.bias", "encoderLayer.5.multiHeadAttention.q.weight", "encoderLayer.5.multiHeadAttention.q.bias", "encoderLayer.5.multiHeadAttention.k.weight", "encoderLayer.5.multiHeadAttention.k.bias", "encoderLayer.5.multiHeadAttention.v.weight", "encoderLayer.5.multiHeadAttention.v.bias", "encoderLayer.5.multiHeadAttention.o.weight", "encoderLayer.5.multiHeadAttention.o.bias", "encoderLayer.5.layerNorm1.weight", "encoderLayer.5.layerNorm1.bias", "encoderLayer.5.feedForward.intermediateDense.weight", "encoderLayer.5.feedForward.intermediateDense.bias", "encoderLayer.5.feedForward.outputDense.weight", "encoderLayer.5.feedForward.outputDense.bias", "encoderLayer.5.layerNorm2.weight", "encoderLayer.5.layerNorm2.bias", "encoderLayer.6.multiHeadAttention.q.weight", "encoderLayer.6.multiHeadAttention.q.bias", "encoderLayer.6.multiHeadAttention.k.weight", "encoderLayer.6.multiHeadAttention.k.bias", "encoderLayer.6.multiHeadAttention.v.weight", "encoderLayer.6.multiHeadAttention.v.bias", "encoderLayer.6.multiHeadAttention.o.weight", "encoderLayer.6.multiHeadAttention.o.bias", "encoderLayer.6.layerNorm1.weight", "encoderLayer.6.layerNorm1.bias", "encoderLayer.6.feedForward.intermediateDense.weight", "encoderLayer.6.feedForward.intermediateDense.bias", "encoderLayer.6.feedForward.outputDense.weight", "encoderLayer.6.feedForward.outputDense.bias", "encoderLayer.6.layerNorm2.weight", "encoderLayer.6.layerNorm2.bias", "encoderLayer.7.multiHeadAttention.q.weight", "encoderLayer.7.multiHeadAttention.q.bias", "encoderLayer.7.multiHeadAttention.k.weight", "encoderLayer.7.multiHeadAttention.k.bias", "encoderLayer.7.multiHeadAttention.v.weight", "encoderLayer.7.multiHeadAttention.v.bias", "encoderLayer.7.multiHeadAttention.o.weight", "encoderLayer.7.multiHeadAttention.o.bias", "encoderLayer.7.layerNorm1.weight", "encoderLayer.7.layerNorm1.bias", "encoderLayer.7.feedForward.intermediateDense.weight", "encoderLayer.7.feedForward.intermediateDense.bias", "encoderLayer.7.feedForward.outputDense.weight", "encoderLayer.7.feedForward.outputDense.bias", "encoderLayer.7.layerNorm2.weight", "encoderLayer.7.layerNorm2.bias", "encoderLayer.8.multiHeadAttention.q.weight", "encoderLayer.8.multiHeadAttention.q.bias", "encoderLayer.8.multiHeadAttention.k.weight", "encoderLayer.8.multiHeadAttention.k.bias", "encoderLayer.8.multiHeadAttention.v.weight", "encoderLayer.8.multiHeadAttention.v.bias", "encoderLayer.8.multiHeadAttention.o.weight", "encoderLayer.8.multiHeadAttention.o.bias", "encoderLayer.8.layerNorm1.weight", "encoderLayer.8.layerNorm1.bias", "encoderLayer.8.feedForward.intermediateDense.weight", "encoderLayer.8.feedForward.intermediateDense.bias", "encoderLayer.8.feedForward.outputDense.weight", "encoderLayer.8.feedForward.outputDense.bias", "encoderLayer.8.layerNorm2.weight", "encoderLayer.8.layerNorm2.bias", "encoderLayer.9.multiHeadAttention.q.weight", "encoderLayer.9.multiHeadAttention.q.bias", "encoderLayer.9.multiHeadAttention.k.weight", "encoderLayer.9.multiHeadAttention.k.bias", "encoderLayer.9.multiHeadAttention.v.weight", "encoderLayer.9.multiHeadAttention.v.bias", "encoderLayer.9.multiHeadAttention.o.weight", "encoderLayer.9.multiHeadAttention.o.bias", "encoderLayer.9.layerNorm1.weight", "encoderLayer.9.layerNorm1.bias", "encoderLayer.9.feedForward.intermediateDense.weight", "encoderLayer.9.feedForward.intermediateDense.bias", "encoderLayer.9.feedForward.outputDense.weight", "encoderLayer.9.feedForward.outputDense.bias", "encoderLayer.9.layerNorm2.weight", "encoderLayer.9.layerNorm2.bias", "encoderLayer.10.multiHeadAttention.q.weight", "encoderLayer.10.multiHeadAttention.q.bias", "encoderLayer.10.multiHeadAttention.k.weight", "encoderLayer.10.multiHeadAttention.k.bias", "encoderLayer.10.multiHeadAttention.v.weight", "encoderLayer.10.multiHeadAttention.v.bias", "encoderLayer.10.multiHeadAttention.o.weight", "encoderLayer.10.multiHeadAttention.o.bias", "encoderLayer.10.layerNorm1.weight", "encoderLayer.10.layerNorm1.bias", "encoderLayer.10.feedForward.intermediateDense.weight", "encoderLayer.10.feedForward.intermediateDense.bias", "encoderLayer.10.feedForward.outputDense.weight", "encoderLayer.10.feedForward.outputDense.bias", "encoderLayer.10.layerNorm2.weight", "encoderLayer.10.layerNorm2.bias", "encoderLayer.11.multiHeadAttention.q.weight", "encoderLayer.11.multiHeadAttention.q.bias", "encoderLayer.11.multiHeadAttention.k.weight", "encoderLayer.11.multiHeadAttention.k.bias", "encoderLayer.11.multiHeadAttention.v.weight", "encoderLayer.11.multiHeadAttention.v.bias", "encoderLayer.11.multiHeadAttention.o.weight", "encoderLayer.11.multiHeadAttention.o.bias", "encoderLayer.11.layerNorm1.weight", "encoderLayer.11.layerNorm1.bias", "encoderLayer.11.feedForward.intermediateDense.weight", "encoderLayer.11.feedForward.intermediateDense.bias", "encoderLayer.11.feedForward.outputDense.weight", "encoderLayer.11.feedForward.outputDense.bias", "encoderLayer.11.layerNorm2.weight", "encoderLayer.11.layerNorm2.bias", "mlmDense.weight", "mlmDense.bias", "mlmLayerNorm.weight", "mlmLayerNorm.bias", "mlmDecoder.weight", "mlmDecoder.bias".
from bert4torch.
我晚上看一下
from bert4torch.
在model.load_weights()前加个映射似乎是可以跑的,参数似乎确实缺少了一个bert类型。
mapping = {'mlmDecoder.bias': 'bert.mlmDecoder.bias'}
param = list(model.named_parameters())
for p in param:
mapping[p[0][5:]] = p[0]
model.load_weights('../pretrain/roberta_pretrain/saved_model/bert_pretrain_model0831.ckpt', mapping=mapping)
这种方式当不加strict=False时出现如下问题是正常的,预训练阶段没有simbert具体的poor层等,所以预测的结果也很差。
RuntimeError: Error(s) in loading state_dict for Model:
Missing key(s) in state_dict: "bert.embeddings.segment_embeddings.weight", "bert.pooler.weight", "bert.pooler.bias",
from bert4torch.
你好,看了你的代码我感觉应该是这样加载进去,你看看是否可行
# 建立加载模型
class Model(BaseModel):
def __init__(self, pool_method='cls'):
super().__init__()
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path,
with_pool='linear', with_mlm='linear', application='unilm', add_trainer=True)
self.bert.load_weights('./bert_pretrain_model.ckpt', strict=False) # 这里加载进去
self.pool_method = pool_method
def forward(self, token_ids, segment_ids):
hidden_state, pool_cls, seq_logit = self.bert([token_ids, segment_ids])
sen_emb = get_pool_emb(hidden_state, pool_cls, token_ids.gt(0).long(), self.pool_method)
return seq_logit, sen_emb
from bert4torch.
Related Issues (20)
- tokenizer_encode_config和tokenizer_decode_config问题 HOT 3
- DataLoader的num_workers设置为大于0时出错 HOT 3
- 请问怎么调整最后的输出结果的小数点位数,目前最后的测试结果是两位,在哪里可以调整一下啊? HOT 7
- 请问在用bart bert等模型生成标题时,如果修改标题的最大长度? HOT 3
- 显示每个epoch的已用时间和预估剩余时间 HOT 3
- chinese-xlnet-mid载入 HOT 1
- XLnet分类报错 HOT 4
- simbert微调问题 HOT 2
- LLM:chatGLM2推理加速 HOT 1
- bert4torch版本0.2.8升级到0.3.4问题 HOT 2
- 分类算法sentence句子编码的时候,没理解到mask处理逻辑 HOT 1
- Trainer.compile与torch2.0新加的model.compile重名 HOT 3
- bert4torch中,使用单卡多GPU训练,使用accelerate后,不能使用梯度剪裁。 HOT 3
- bert4torch/examples/llm /task_chatglm2_lora.py中的config_path = dir_path + 'bert4torch_config.json' 在哪里? HOT 3
- callback变红 HOT 3
- self.random_sample和 self.beam_search返回值tensor未统一转化numpy,导致两函数同时调用报错 HOT 2
- 加载Yi-6B进行lora微调报错 HOT 1
- 请求大佬写个NER指针网络的预测代码 HOT 4
- Text2Vec加载BAAI_bge-large-zh-v1.5进行相似度计算得到的结果有误 HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from bert4torch.