Code Monkey home page Code Monkey logo

bert4torch's Introduction

Hi there 👋

  • 🔭 I’m currently working on bert4torch.
  • 🌱 I’m currently learning NLP and REC.
  • 📫 How to reach me:

bert4torch's People

Contributors

nejweka407 avatar nuass avatar skykiseki avatar tongjilibo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bert4torch's Issues

运行t5-pegasus报错

image
您好,通过您给定convert_t5_pegasus转化成立pytorch_model.bin,此处config.json我沿用了base版本的config.json,这个错把
"hidden_act": ["gelu", "linear"]改成"hidden_act": "gelu"后出这个错
image
望解答,谢谢!

模型最后几层平均输出

想要使用模型最后几层的输出作为输出结果,现在bert4torch能实现吗
看了模型定义没找到方法

csl数据集文本摘要

你好,波哥,请问你做文本摘要的csl数据集是10K样本还是3K样本的那个呢,我跑的时候F1值到60%左右就不动了,跑不到68%

RoPE细节问题

class RoPEPositionEncoding(nn.Module):
    """旋转式位置编码: https://kexue.fm/archives/8265
    """
    def __init__(self, max_position, embedding_size):
        super(RoPEPositionEncoding, self).__init__()
        position_embeddings = get_sinusoid_encoding_table(max_position, embedding_size)  # [seq_len, hdsz]
        # cos_position = position_embeddings[:, 1::2].repeat(1, 2) 
        # sin_position = position_embeddings[:, ::2].repeat(1, 2)
        cos_position = position_embeddings[:, 1::2].repeat_interleave(2, dim=-1)  # 修改后
        sin_position = position_embeddings[:, ::2].repeat_interleave(2, dim=-1)  # 修改后
        self.register_buffer('cos_position', cos_position)
        self.register_buffer('sin_position', sin_position)
    
    def forward(self, qw, seq_len_dim=1):
        dim = len(qw.shape)
        assert (dim >= 2) and (dim <= 4), 'Input units should >= 2 dims(seq_len and hdsz) and usually <= 4 dims'
        seq_len = qw.shape[seq_len_dim]
        # qw2 = torch.cat([-qw[..., 1::2], qw[..., ::2]], dim=-1)
        qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], dim=-1).reshape_as(qw)  # 修改后

        if dim == 2:
            return qw * self.cos_position[:seq_len] + qw2 * self.sin_position[:seq_len]
        if dim == 3:
            return qw * self.cos_position[:seq_len].unsqueeze(0) + qw2 * self.sin_position[:seq_len].unsqueeze(0)
        else:
            return qw * self.cos_position[:seq_len].unsqueeze(0).unsqueeze(2) + qw2 * self.sin_position[:seq_len].unsqueeze(0).unsqueeze(2)

大佬你好,bert4torch的RoPE在实现上是不是有点问题,按照苏神的博客应该是上面修改后的代码吧

tensorrt 转onnx为fp16 trt

提问时请尽可能提供如下信息:
请教下,tensorrt trtexec把onnx转换成fp32的trt时,是没有问题的,但是转换成fp16,误差很大,基本不可用;不知道你这边有没有转成fp16的trt,有没有问题?

基本信息

  • 你使用的操作系统:
  • 你使用的Python版本:
  • 你使用的Pytorch版本:
  • 你使用的bert4torch版本:
  • 你加载的预训练模型:

核心代码

# 请在此处贴上你的核心代码

输出信息

# 请在此处贴上你的调试输出

自我尝试

此处请贴上你的自我尝试过程

请问class CRF(nn.Module):是从哪里移植过来的?

麻烦您给个链接,我想看看移植过来之前的源码。
主要是想知道转移矩阵这里为什么要+2:
init_transitions = torch.zeros(self.num_labels + 2, self.num_labels + 2)
因为看过其他实现,头一次看到+2的情况,您在代码中也注释了是要加首尾,但我不知道加首尾是要解决什么问题。

basic_masked_language_model.py

这个example中
输入:[CLS]科学[MASK][MASK]是第一生产力[SEP]
预测出来的结果是,,两个逗号,而不是技术
使用的模型是hugging face 模型库中的bert-base-chinese。
模型加载过程中出现大量警告:
image
请问只是啥问题?

使用task_sequence_labeling_ner_global_pointer.py训练结果为0

使用task_sequence_labeling_ner_global_pointer.py脚本做尝试
修改位置及代码如下
1、加载bert模型为huggingface上面的模型权重
self.bert_dir = "/home/BERT/bert_torch/bert-base-chinese/"
self.config_path = self.bert_dir + 'config.json'
self.checkpoint_path = self.bert_dir + 'pytorch_model.bin'
self.dict_path = self.bert_dir + 'vocab.txt'
2、修改model.fit参数
model.fit(train_dataloader, epochs=20, steps_per_epoch=5, callbacks=[evaluator])
3、完全运行结果,部分截图如下
1/5 [=====>........................] - ETA: 0s - loss: 2.5862
2/5 [===========>..................] - ETA: 0s - loss: 2.5127
3/5 [=================>............] - ETA: 0s - loss: 2.7961
4/5 [=======================>......] - ETA: 0s - loss: 2.8367
5/5 [==============================] - 1s 125ms/step - loss: 2.4887
[val] f1: 0.00000, p: 0.00000 r: 0.00000 best_f1: 0.00000
============Finish Training=============
Process finished with exit code 0
系统:ubuntu 20.0.4
pytorch版本:1.11.0+cu113
python: 3.7

想请教下是哪里的问题,导致f1结果一直为0

模型训练中的收敛问题

咨询波哥一个问题哦。
对比测试了几个模型,比如分类,序列标注,文本生成等。使用bert4torch和hugging face中的tokenizer和model load,

hugging face版本的会在五六轮左右出现一个比较好的效果
bert4torch需要20轮以上效果才可以
而最终的模型评估效果是hugging face略高1~2个点
对比代码,暂时没找到原因。比较疑惑

预训练模型加载出错的请进

不少issue是关于预训练模型加载出错的,包含报warning和config参数不对,解释如下

  • 原因:报warning是因为模型文件中的key和bert4torch的key没有完全对齐,config参数不对是笔者对原config文件做了修改(方便参数名统一)
  • 解决方案:可以直接查看README文件结尾,部分预训练权重提供了convert文件,config参数提供了config说明

遇到的几个问题分享一下,以及博主波哥的帮助解决记录

1、basic_language_model_CDial_GPT.py 文件测试的时候显示生成的文字杂七杂八的。
image
解决方法: @AutoRegressiveDecoder.wraps(default_rtype='probas') 中的probas改为logits
2、basic_language_model_nezha_gpt_dialog.py 文件测试的时候报错
image
解决:主要问题是model的转换,之前是自己写的,有些层写的不对,参看convert中的转换就没问题。另外相对距离的计算,波哥在配置文件中添加了,bert4keras则是在代码里实现。

你好,你能提供预测代码吗?比如给没有情感的文本进行情感预测,或者没有实体的文本进行实体预测。

比如这个代码case:
你能提供预测新的文本的inference代码吗?

#! -- coding:utf-8 --

情感分类任务, 加载bert权重

valid_acc: 94.72, test_acc: 94.11

from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from bert4torch.snippets import sequence_padding, Callback, text_segmentate, ListDataset
import torch.nn as nn
import torch
import torch.optim as optim
import random, os, numpy as np
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter

maxlen = 256
batch_size = 16
config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
writer = SummaryWriter(log_dir='./summary') # prepare summary writer

固定seed

seed = 42
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

建立分词器

tokenizer = Tokenizer(dict_path, do_lower_case=True)

加载数据集

class MyDataset(ListDataset):
@staticmethod
def load_data(filenames):
"""加载数据,并尽量划分为不超过maxlen的句子
"""
D = []
seps, strips = u'\n。!?!?;;,, ', u';;,, '
for filename in filenames:
with open(filename, encoding='utf-8') as f:
for l in f:
text, label = l.strip().split('\t')
for t in text_segmentate(text, maxlen - 2, seps, strips):
D.append((t, int(label)))
return D

def collate_fn(batch):
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
for text, label in batch:
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_labels.append([label])

batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device)
batch_labels = torch.tensor(batch_labels, dtype=torch.long, device=device)
return [batch_token_ids, batch_segment_ids], batch_labels.flatten()

加载数据集

train_dataloader = DataLoader(MyDataset(['E:/Github/bert4torch/examples/datasets/sentiment/sentiment.train.data']), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset(['E:/Github/bert4torch/examples/datasets/sentiment/sentiment.valid.data']), batch_size=batch_size, collate_fn=collate_fn)
test_dataloader = DataLoader(MyDataset(['E:/Github/bert4torch/examples/datasets/sentiment/sentiment.test.data']), batch_size=batch_size, collate_fn=collate_fn)

定义bert上的模型结构

class Model(BaseModel):
def init(self) -> None:
super().init()
self.bert, self.config = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, with_pool=True, return_model_config=True)
self.dropout = nn.Dropout(0.1)
self.dense = nn.Linear(self.config['hidden_size'], 2)

def forward(self, token_ids, segment_ids):
    _, pooled_output = self.bert([token_ids, segment_ids])
    output = self.dropout(pooled_output)
    output = self.dense(output)
    return output

model = Model().to(device)

定义使用的loss和optimizer,这里支持自定义

model.compile(
loss=nn.CrossEntropyLoss(),
optimizer=optim.Adam(model.parameters(), lr=2e-5), # 用足够小的学习率
metrics=['accuracy']
)

定义评价函数

def evaluate(data):
total, right = 0., 0.
for x_true, y_true in data:
y_pred = model.predict(x_true).argmax(axis=1)
total += len(y_true)
right += (y_true == y_pred).sum().item()
return right / total

class Evaluator(Callback):
"""评估与保存
"""
def init(self):
self.best_val_acc = 0.

# def on_batch_end(self, global_step, batch, logs=None):
#     if global_step % 10 == 0:
#         writer.add_scalar(f"train/loss", logs['loss'], global_step)
#         val_acc = evaluate(valid_dataloader)
#         writer.add_scalar(f"valid/acc", val_acc, global_step)

def on_epoch_end(self, global_step, epoch, logs=None):
    val_acc = evaluate(valid_dataloader)
    test_acc = evaluate(test_dataloader)
    if val_acc > self.best_val_acc:
        self.best_val_acc = val_acc
        # model.save_weights('best_model.pt')
    print(f'val_acc: {val_acc:.5f}, test_acc: {test_acc:.5f}, best_val_acc: {self.best_val_acc:.5f}\n')

if name == 'main':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=10, steps_per_epoch=None, callbacks=[evaluator])
else:
model.load_weights('best_model.pt')

BERT-MRC跑的P,R,F1一直为0

Epoch 1/50
2000/2000 [==============================] - 534s 267ms/step - loss: 0.0394
Evaluation: 100%|██████████| 1159/1159 [01:34<00:00, 12.25it/s]
[val] f1: 0.00000, p: 0.00000 r: 0.00000

虚拟对抗训练

您好呀,尝试在模型上添加对抗训练,添加完成后,代码的f1值一直是0

预测时候 object has no attribute 'copy'

这样预测时候 model.load_weights('../../output/best_model.pt'),报错

Traceback (most recent call last):
File "/opt/pycharm-2021.1.3/plugins/python/helpers/pydev/pydevd.py", line 1483, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/opt/pycharm-2021.1.3/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/home/gallup/study/search/kuafu/kuafu/matching/run_poly_encoders.py", line 149, in
model.load_weights('../../output/' + args.model_name + '_best_model.pt')
File "/home/gallup/anaconda3/envs/tf2-torch1/lib/python3.6/site-packages/bert4torch/models.py", line 272, in load_weights
self.load_state_dict(state_dict, strict=strict)
File "/home/gallup/anaconda3/envs/tf2-torch1/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1455, in load_state_dict
state_dict = state_dict.copy()
File "/home/gallup/anaconda3/envs/tf2-torch1/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1178, in getattr
type(self).name, name))
AttributeError: 'BiEncoder' object has no attribute 'copy'
python-BaseException

basic_language_model_nezha_gen_gpt.py

这个模块中输入的pytorch_model.bin文件是使用convert_nezha_gpt_dialog.py这个脚本对苏神提供的tf版的chinese_nezha_gpt_L-12_H-768_A-12模型进行转换而来的吗?

如果是自定义的acc,loss怎么用

例如:
想象keras一样直接将自定义函数作为模型的评估函数该怎么做?

该如何修改
def accuracy(y_pred, y_true):
y_pred = torch.where(y_pred>0.5,torch.ones_like(y_pred,dtype = torch.float32),
torch.zeros_like(y_pred,dtype = torch.float32))
acc = torch.mean(1-torch.abs(y_true-y_pred))
return acc

model.compile(
loss=nn.CrossEntropyLoss(),
optimizer=optim.Adam(model.parameters(), lr=2e-5),
metrics={LogLoss}
)
image

一个小问题

在关系抽取模型casrel中,如果在模型结构中加入了其他的网络结构,在def forward(self, inputs)中添加了使用过程。那么需要在class Model(BaseModel)下的def predict_subject(self, inputs)和def predict_object(self, inputs)下添加这个网络结构的使用嘛吗
image
image
image

怎么加载hugface的Bert模型啊?波哥帮帮我

我跑的是task_sequence_labeling_ner_global_pointer.py
我的代码里就改了我下载的bert的路径(我用的是绝对路径)和在定义bert这一行(加了个model="bert")

config_path = '/mnt/hdd0/lsn/bert/bert-base-uncased/config.json'
checkpoint_path = '/mnt/hdd0/lsn/bert/bert-base-uncased/pytorch_model.bin'
dict_path = '/mnt/hdd0/lsn/bert/bert-base-uncased/vocab.txt'

self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, model='bert', segment_vocab_size=0)

然后运行后报一堆[warning],类似下面这样的
[WARNIMG] bert.embeddings.LayerNorm.weight not found in pretrain models
[WARNIMG] bert.embeddings.LayerNorm.bias not found in pretrain models
[WARNIMG] bert.encoder.layer.0.attention.output.LayerNorm.weight not found in pretrain models
[WARNIMG] bert.encoder.layer.0.attention.output.LayerNorm.bias not found in pretrain models
[WARNIMG] bert.encoder.layer.0.output.LayerNorm.weight not found in pretrain models
[WARNIMG] bert.encoder.layer.0.output.LayerNorm.bias not found in pretrain models

global_pointer的示例

global_pointer示例在分词后没有添加cls和sep, token加了之后指标会提高,不知是故意去掉还是疏忽了。

代码问题

115 行被注释掉了, 后面inference要用到这个文件

def on_epoch_end(self, global_step, epoch, logs=None):
val_acc = self.evaluate(valid_dataloader)
test_acc = self.evaluate(test_dataloader)
logs['val/acc'] = val_acc
logs['test/acc'] = test_acc
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
# model.save_weights('best_model.pt')
print(f'val_acc: {val_acc:.5f}, test_acc: {test_acc:.5f}, best_val_acc: {self.best_val_acc:.5f}\n')

关于gradient-checkpointing的支持

你好!

非常感谢作者编写的这套torch框架,gradient-checkpointing是种可以节省显存的训练方法,对于资源紧张下训练大模型有比较大的帮助作用,在苏神的博客上也有介绍,huggingface的transformers也内置了相关支持,是否能在后期加上这个功能?

t5_pegasus_small 支持问题

你好,我想用t5_pegasus_small做一个seq2seq任务,config.json文件如下:
{
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 512,
"initializer_range": 0.02,
"intermediate_size": 1024,
"num_attention_heads": 6,
"attention_head_size": 64,
"num_hidden_layers": 8,
"vocab_size": 50000,
"hidden_act": "gelu",
"relative_attention_num_buckets": 32
}
,但是运行时在layer层会抛assert hidden_size % num_attention_heads == 0 错误,请问是配置文件哪里不对吗

您好,请问W2NER的预测怎么写呢

提问时请尽可能提供如下信息:

基本信息

  • 你使用的操作系统:
  • 你使用的Python版本:
  • 你使用的Pytorch版本:
  • 你使用的bert4torch版本:
  • 你加载的预训练模型:

核心代码

# 请在此处贴上你的核心代码

输出信息

# 请在此处贴上你的调试输出

自我尝试

此处请贴上你的自我尝试过程

RuntimeError: masked_select: expected BoolTensor or ByteTensor for mask

Traceback (most recent call last):
File "C:/Users/Administrator/Desktop/bert_crf/train.py", line 191, in
model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
File "D:\python36\lib\site-packages\bert4torch\models.py", line 213, in fit
output, loss, loss_detail = self.train_step(train_X, train_y, grad_accumulation_steps)
File "D:\python36\lib\site-packages\bert4torch\models.py", line 131, in train_step
loss_detail = self.criterion(output, train_y)
File "D:\python36\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:/Users/Administrator/Desktop/bert_crf/train.py", line 122, in forward
return model.crf.neg_log_likelihood_loss(*outputs, labels)
File "D:\python36\lib\site-packages\bert4torch\layers.py", line 913, in neg_log_likelihood_loss
forward_score, scores = self._forward_alg(feats, mask)
File "D:\python36\lib\site-packages\bert4torch\layers.py", line 862, in _forward_alg
masked_cur_partition = cur_partition.masked_select(mask_idx) # [x * tag_size]

原始数据

您好,请问能否发一下您用于训练和测试的原始数据集呢?主要是想了解一下您是如何跑出如此高的准确率的,谢谢!

Traceback (most recent call last): File "task_sentiment_classification_hierarchical_position.py", line 117, in <module> model.fit(train_dataloader, epochs=10, steps_per_epoch=None, callbacks=[evaluator]) File "/home/huangjiaheng/.local/lib/python3.8/site-packages/bert4torch/models.py", line 274, in fit loss.backward(retain_graph=retain_graph) File "/opt/conda/lib/python3.8/site-packages/torch/tensor.py", line 245, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File "/opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward Variable._execution_engine.run_backward( RuntimeError: cuda runtime error (710) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1616554793803/work/aten/src/THC/generic/THCTensorMath.cu:29

一个可能的小问题

在脚本layers.py中55行,cond = cond.unsqueeze(dim=1),这里的dim应该是等于0吧,调试发现是不太对的,老哥确认下

模型结果不是特别好,不知道是什么问题

大佬,我用你examples中的人民日报数据跑,在task_sequence_labeling_ner_W2NER 上的结果并不好,不知道是哪里出问题了,在验证集上F1值只有90一点。我用的是windows系统3090显卡,torch=1.11.3,bert4torch=0.2.2,Python=3.8

[val-token  level] f1: 0.91029, p: 0.89145 r: 0.93052
[val-entity level] f1: 0.90059, p: 0.91598 r: 0.88571 best_f1: 0.90154
============Finish Training=============

验证evaluate推理速度第一次快,其他时候非常慢的问题

你好,波哥,请问为啥self.evaluate(valid_dataset.data[:valid_len]),第一次推理验证集时速度快,第二次推理同样的验证集时速度只有前一次的二十分之一左右,evaluate代码如下
def evaluate(self, data):
total = 0
rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0
for title, content in tqdm(data):
total += 1
title = ' '.join(title).lower()
# with torch.no_grad():
pred_title = ' '.join(autosumm.generate(content)).lower()
if pred_title.strip():
scores = self.rouge.get_scores(hyps=pred_title, refs=title)
rouge_1 += scores[0]['rouge-1']['f']
rouge_2 += scores[0]['rouge-2']['f']
rouge_l += scores[0]['rouge-l']['f']
bleu += sentence_bleu(references=[title.split(' ')], hypothesis=pred_title.split(' '),
smoothing_function=self.smooth)
rouge_1, rouge_2, rouge_l, bleu = rouge_1 / total, rouge_2 / total, rouge_l / total, bleu / total
return {'rouge-1': rouge_1, 'rouge-2': rouge_2, 'rouge-l': rouge_l, 'bleu': bleu}
微信截图_20220919220951

你好,波哥,请问T5-PEGASUS的generate策略是和transformers库的mt5不同吗?

我使用同样的权重文件,同样的测试数据,使用如下代码可以生成摘要:
model = MT5ForConditionalGeneration.from_pretrained(args.pretrain_model).to(device)
gen = model.generate(max_length=args.max_len_generate,
eos_token_id=tokenizer.sep_token_id,
decoder_start_token_id=tokenizer.cls_token_id,
**content)

使用Autotitle 示例中的generate输出为空,后来发现是beam_search中预测的第一个值就是end_id [SEP]
class AutoSummarize(AutoRegressiveDecoder):
"""seq2seq解码器
"""
@AutoRegressiveDecoder.wraps(default_rtype='logits')
def predict(self, inputs, output_ids, states):
# inputs中包含了[decoder_ids, encoder_hidden_state, encoder_attention_mask]
return model.decoder.predict([output_ids] + inputs)[-1][:, -1, :] # 保留最后一位

def generate(self, text):
    gc.collect()
    torch.cuda.empty_cache()
    token_ids, _ = tokenizer.encode(text, maxlen=args.max_c_len)
    token_ids = torch.tensor([token_ids], device=device)
    encoder_output = model.encoder.predict([token_ids])
    output_ids = self.beam_search(encoder_output, topk=1)  # 基于beam search
    return tokenizer.decode([int(i) for i in output_ids.cpu().numpy()]

不知道是不是因为states为空呢,我发现encoder_output里并没有states

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.