Code Monkey home page Code Monkey logo

ptr's Introduction

PTR: Prompt Tuning with Rules for Text Classification

The code and datasets of our paper "PTR: Prompt Tuning with Rules for Text Classification"

To clone the repository, please run the following command:

git clone https://github.com/thunlp/PTR.git --depth 1

If you use the code, please cite the following paper:

@article{han2021ptr,
  title={PTR: Prompt Tuning with Rules for Text Classification},
  author={Han, Xu and Zhao, Weilin and Ding, Ning and Liu, Zhiyuan and Sun, Maosong},
  journal={arXiv preprint arXiv:2105.11259},
  year={2021}
}

Quick Links

Overview

In this work, we propose prompt tuning with rules (PTR) for many-class text classification and apply logic rules to construct prompts with several sub-prompts. In this way, PTR is able to encode prior knowledge of each class into prompt tuning. You can find more details in our [paper](https://arxiv.org/pdf/2105.11259.pdf).

Requirements

The model is implemented using PyTorch. The versions of packages used are shown below.

  • numpy==1.18.0
  • scikit-learn==0.22.1
  • scipy==1.4.1
  • torch==1.4.0
  • tqdm==4.41.1
  • transformers==4.0.0

To set up the dependencies, you can run the following command:

pip install -r requirements.txt

Data Preparation

We have provided a scripts to download all the datasets we used in our paper. You can run the following command to download the datasets:

bash data/download.sh all

The above command will download all the datasets including

  • Retacred
  • Tacred
  • Tacrev
  • Semeval

If you only want to download a specific dataset, you can run the following command:

bash data/download.sh $dataset_name1 $dataset_name2 ...

where $dataset_nameX can be one or multiple of retacred, tacred, tacrev, semeval.

Experiments

Baselines

Some baselines, especially the baselines using entity markers, come from the project [RE_improved_baseline].

Reproduce Results in Our Work

1. For TACRED

bash scipts/run_large_tacred.sh

2. For TACREV

bash scripts/run_large_tacrev.sh

3. For RETACRED

bash scripts/run_large_retacred.sh

4. For Semeval

bash scripts/run_large_semeval.sh

ptr's People

Contributors

chenweize1998 avatar thucsthanxu13 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ptr's Issues

有个疑问,关于fine tuning

用中文可以充分表达,恕用中文提问。

看到论文中提到本文的优势是不用fine tuning, 但是在model中,其实是分两条线进行tuning的,对bertmodel本身,和对另一个映射网络。以我初浅的理解,它应该还是算fine tuning, 一般叫prompt tuning, 对否?

我理解的,如果不用tuing, 应该只是用bertmodel本身的查询 或隐层输出 表示而不能去反向调整bertmodel , 但是代码实际上做了这个工作。
请明示。

F1 is 0.0 for tacred and tacrev?

I tested the script with tacred, retacred, tacrev. I found tacred and tacrev gets F1=0 after five epoches.
I mainly did not change anything inside the code except the batch sizes and num of GPUs:

for mode in train val test; do
    if [ ! -d "data/tacred/$mode" ]; then
        mkdir -p results/tacred/$mode
    fi
done

export CUDA_VISIBLE_DEVICES=0,1

python3 src/run_prompt.py \
--data_dir data/tacred \
--output_dir results/tacred \
--model_type roberta \
--model_name_or_path roberta-large \
--per_gpu_train_batch_size 2 \
--gradient_accumulation_steps 1 \
--max_seq_length 512 \
--warmup_steps 500 \
--learning_rate 3e-5 \
--learning_rate_for_new_token 1e-5 \
--num_train_epochs 5 \
--weight_decay 1e-2 \
--adam_epsilon 1e-6 \
--temps temp.txt \

out of memory error

I am already using --max_seq_length 16 with run_large_tacrev.sh
Could anyone help?

  File "/shared/home/yerong/local/Conda/envs/ptr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 639, in _apply
    module._apply(fn)
  File "/shared/home/yerong/local/Conda/envs/ptr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 639, in _apply
    module._apply(fn)
  [Previous line repeated 4 more times]
  File "/shared/home/yerong/local/Conda/envs/ptr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 662, in _apply
    param_applied = fn(param)
  File "/shared/home/yerong/local/Conda/envs/ptr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 747, in <lambda>
    return self._apply(lambda t: t.cuda(device))
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 23.69 GiB total capacity; 729.76 MiB already allocated; 3.31 MiB free; 736.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
for mode in train val test; do
    if [ ! -d "data/tacrev/$mode" ]; then
        mkdir -p results/tacrev/$mode
    fi
done

export CUDA_VISIBLE_DEVICES=0,1,2,3

python3 src/run_prompt.py \
--data_dir data/tacrev \
--output_dir results/tacrev \
--model_type roberta \
--model_name_or_path roberta-large \
--per_gpu_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--max_seq_length 16 \
--warmup_steps 1 \
--learning_rate 3e-5 \
--learning_rate_for_new_token 1e-5 \
--num_train_epochs 2 \
--weight_decay 1e-2 \
--adam_epsilon 1e-6 \
--temps temp.txt

关于如何调用BioBERT模型

您好,请问我想使用ptr模型测试一下在生物医学数据集上的关系分类任务,这里的transformer包里没有封装好的biobert,请问我该如何调用biobert呢?感谢🙏

关于modeling.py

您好,我有一些问题想向您请教
我在阅读modeling.py部分的代码时发现,您的代码(个人理解)
Roberta生成原始输入x的嵌入,又用随机嵌入和线性层生成prompt部分的嵌入
利用torch.where进行拼接(原始输入的嵌入+prompt部分嵌入)
再输入Roberta生成隐藏状态

请问为什么要这么做呢?

Question regarding the output

Hi,

Thanks for your solid work and for sharing the code!

May I ask why do you choose to predict the label index (like if the masked token has three possible values, then you will output the index 0 to 2 instead of outputting the actual word id corresponding to the label ) when you generate the output? Have you tried to predict the actual word instead of the index?

Thank you!

CUBLAS_STATUS_EXECUTION_FAILED error

When running bash scripts/run_large_tacred.sh I met an issue CUBLAS_STATUS_EXECUTION_FAILED, is this because my GPU memory is too small

    new_token_embeddings = self.mlp(self.extra_token_embeddings.weight)                                                                           
  File "/shared/home/yerong/local/Conda/envs/ptr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__                   
    result = self.forward(*input, **kwargs)                                                                                                       
  File "/shared/home/yerong/local/Conda/envs/ptr/lib/python3.7/site-packages/torch/nn/modules/container.py", line 100, in forward                 
    input = module(input)                                                                                                                         
  File "/shared/home/yerong/local/Conda/envs/ptr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__                   
    result = self.forward(*input, **kwargs)                                                                                                       
  File "/shared/home/yerong/local/Conda/envs/ptr/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward                     
    return F.linear(input, self.weight, self.bias)                                                                                                
  File "/shared/home/yerong/local/Conda/envs/ptr/lib/python3.7/site-packages/torch/nn/functional.py", line 1370, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, l
dc)`

Question Regarding Reverse

Hi Han Xu,

May I ask in which part of the code did you implement the reverse operation? It seems to me that the code does not use reverse.

Thank you!

关于tokenizer

想请教一下,为什么在data_prompt.py中get_labels()里面tokenizer.encode() 的输出的结果是没有分词的结果

some questions about paper

Hi Xu,
I have some questions about this paper. And i am looking forward to your reply.

  1. I notice that this paper focuses on relation extraction (a classification problem). Thus, why entity classification task is needed (e.g. equation (1))?
  2. I also notice that you use REVERSED operation to reverse a part of relations. What is the standard for REVERSED?
  3. I also notice that ENTITY MARKER also use REVERSED operation. How do they combine? Exchange the position of [E1] and [E2]?
  4. There is also an implementation detail issue (common problem about prompt-based learning). How many [MASK] do we need in equation (3)? Does this require the same number of tokens in the label words (label words in Equation 4)? E.g., all words in V_{[MASK]_1} have 2 tokens after BPE?

OOM错误

为什么我使用7张24G的3090都会报OOM错误呀(一开始运行run_large_tacred.sh报错找不到PTR/datasets/tacred/temp_dd.txt文件,我将temp_dd.txt改成了temp.txt),是我哪里使用的不对吗?

关于使用bert模型

哈喽,我在跑bert模型时出现了assert错误,看样子是 's在编码的时候bert将‘s拆分成'和s?代码中好像不能解决这个问题?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.