Code Monkey home page Code Monkey logo

bit's Introduction

BiT

This repository contains the training code of BiT introduced in our work: "BiT: Robustly Binarized Multi-distilled Transformer"

In this work, we identify a series of improvements which enables binary transformers at a much higher accuracy than what was possible previously. These include a two-set binarization scheme, a novel elastic binary activation function with learned parameters, and a multi-step distilation method. These approaches allow for the first time, fully binarized transformer models that are at a practical level of accuracy, approaching a full-precision BERT baseline on the GLUE language understanding benchmark within as little as 5.9%.

Citation

If you find our code useful for your research, please consider citing:

@article{liu2022bit,
title={BiT: Robustly Binarized Multi-distilled Transformer},
author={Liu, Zechun and Oguz, Barlas and Pappu, Aasish and Xiao, Lin and Yih, Scott and Li, Meng and Krishnamoorthi, Raghuraman and Mehdad, Yashar},
journal={arXiv preprint arXiv:2205.13016},
year={2022}
}

Run

1. Requirements:

  • python 3.6, pytorch 1.7.1

2. Data:

3. Pretrained models:

4. Steps to run:

  • Specify the num_bits, data path and the pre-trained model path in scrips/run.sh file.

  • Run bash scripts/run_glue.sh GLUE_dataset or Run bash scrips/run_squad.sh .

    E.g., bash scripts/run_glue.sh MNLI for running the MNLI dataset in GLUE dataset.

Models

1. GLUE dataset

(1) Without data augmentation

Method #Bits Size (M) FLOPs (G) MNLI m/mm QQP QNLI SST-2 CoLA STS-B MRPC RTE Avg
BERT 32-32-32 418 22.5 84.9/85.5 91.4 92.1 93.2 59.7 90.1 86.3 72.2 83.9
BinaryBert 1-1-4 16.5 1.5 83.9/84.2 91.2 90.9 92.3 44.4 87.2 83.3 65.3 79.9
BinaryBert 1-1-2 16.5 0.8 62.7/63.9 79.9 52.6 82.5 14.6 6.5 68.3 52.7 53.7
BinaryBert 1-1-1 16.5 0.4 35.6/35.3 66.2 51.5 53.2 0 6.1 68.3 52.7 41.0
BiBert 1-1-1 13.4 0.4 66.1/67.5 84.8 72.6 88.7 25.4 33.6 72.5 57.4 63.2
BiT * 1-1-4 13.4 1.5 83.6/84.4 87.8 91.3 91.5 42.0 86.3 86.8 66.4 79.5
BiT * 1-1-2 13.4 0.8 82.1/82.5 87.1 89.3 90.8 32.1 82.2 78.4 58.1 75.0
BiT * 1-1-1 13.4 0.4 77.1/77.5 82.9 85.7 87.7 25.1 71.1 79.7 58.8 71.0
BiT 1-1-1 13.4 0.4 79.5/79.4 85.4 86.4 89.9 32.9 72 79.9 62.1 73.5

(2) With data augmentation

Method #Bits Size (M) FLOPs (G) MNLI m/mm QQP QNLI SST-2 CoLA STS-B MRPC RTE Avg
BinaryBert 1-1-2 16.5 0.8 62.7/63.9* 79.9* 51.0 89.6 33.0 11.4 71.0 55.9 57.6
BinaryBert 1-1-1 16.5 0.4 35.6/35.3* 66.2* 66.1 78.3 7.3 22.1 69.3 57.7 48.7
BiBert 1-1-1 13.4 0.4 66.1/67.5* 84.8* 76.0 90.9 37.8 56.7 78.8 61.0 68.8
BiT * 1-1-2 13.4 0.8 82.1/82.5* 87.1* 88.8 92.5 43.2 86.3 90.4 72.9 80.4
BiT * 1-1-1 13.4 0.4 77.1/77.5* 82.9* 85.0 91.5 32.0 84.1 88.0 67.5 76.0
BiT 1-1-1 13.4 0.4 79.5/79.4* 85.4* 86.5 92.3 38.2 84.2 88 69.7 78.0

2. SQuAD dataset

Method #Bits SQuADv1.1 em/f1
BERT 32-32-32 82.6/89.7
BinaryBert 1-1-4 77.9/85.8
BinaryBert 1-1-2 72.3/81.8
BinaryBert 1-1-1 1.5/8.2
BiBert 1-1-1 8.5/18.9
BiT 1-1-1 63.1/74.9

Acknowledgement

The original code is borrowed from BinaryBERT.

Contact

Zechun Liu, Reality Labs, Meta Inc (liuzechun0216 at gmail.com)

License

BiT is CC-BY-NC 4.0 licensed as of now.

bit's People

Contributors

liuzechun avatar

Stargazers

 avatar  avatar  avatar Nizar Islah avatar Chenlin avatar Kaijie Yin avatar 郑佳宁 avatar Xinquan Chen avatar flywwwfly avatar Keishin N avatar XXZH avatar Kosmas Alexandridis avatar Pingcheng DONG avatar Ryan avatar Ghulam Jilani Raza avatar Zheng Qu avatar Vitaliy Kinakh avatar Rongjie Yi avatar Minseok Yang avatar Tomáš Pazdiora avatar Shareef Ifthekhar avatar 艾梦 avatar Phuoc-Hoan Charles Le avatar Zicong Hu avatar  avatar lizhan avatar Kentaro Iizuka avatar  avatar  avatar Xijie Huang avatar  avatar  avatar  avatar Faris Hijazi avatar Yota Toyama avatar Little_Vellichor avatar Look.AI Labs avatar bruce_zheng avatar Artemi Krymski avatar Tao BAI avatar Winston Hu avatar  avatar  avatar  avatar  avatar  avatar  avatar Party4Bread avatar wuwenjie avatar lixc avatar Shashank Pachava avatar Zach Stoebner avatar He Xiao avatar huxiang15b avatar LIU, Shih-Yang avatar andrew98 avatar  avatar Yiqian He avatar  avatar Bing Han avatar Ninnart Fuengfusin avatar Angie avatar Kentaro Yoshioka avatar Vignesh Venkatesh avatar Hiroki Taniai avatar Ziyang Chen avatar Zhihao Lin avatar  avatar Jiayi Tian avatar  avatar kumisaki avatar  avatar BlueRum avatar 云销雨霁 avatar Jiahao Wang avatar Hui Zhang avatar  avatar Alexey Golyshev avatar potter  avatar  avatar Jason Yang avatar Qingqing Cao avatar Joonsun avatar  avatar Nikolaos Stylianou avatar  avatar  avatar 爱可可-爱生活 avatar  avatar Johnny avatar  avatar  avatar Yuan-Man avatar inarikami avatar Xiaoyu Xiang avatar Tim Kersey avatar  avatar  avatar

Watchers

Stan Peshterliev avatar Aasish Pappu avatar Dan Bikel avatar  avatar piaozi avatar Anna avatar Cami Williams avatar  avatar Hakan Inan avatar  avatar Shubham Modi avatar  avatar Rashed Talukder avatar Anchit Gupta avatar  avatar  avatar  avatar

bit's Issues

Model not binary

When the final model is saved, I checked the weights and found out that they were not binary. Also, the models that were attached in last row of each table have size greater than 13 MB.
Can you please clarify if you've combined both the non-binary and binary model, or have you not attached the binary model.

Problem in reproduce multi-distillation approach

Hello, Thank you for providing code.

I can get the right results of W1A1 with bash scripts/run_glue.sh MNLI (around 77 accuracy on MNLI)

But when i reproduce the W1A1 with multi-distillation approach following (W32A32->W1A2->W1A1), I cannot reproduce the results of W1A2 in paper by simply change abits=1 to abits=2 in scripts/run_glue.sh (The result of W1A2 i get is 80.96/81.36).

Can you share the detail settings of multi-disitillation approach?

How many epochs do you train with data augmentation?

In the paper, you mentioned how many epochs you trained without data augmentation. However, I am not sure if you use the same number of epochs when training with data augmentation.

How many epochs do you train with data augmentation?

Reproduction Issue of BiT on STS-B

Great work and thanks a lot for opening up the great work!

While reusing the code released, I found some issues below:

I can not reproduce the W1A1 version BiT accuracy on the STS-B dataset as reported in the paper (68.7% vs 71.1%).

I have basically followed the setting in the code and paper, can you share some suggestions for the issue?

"sts-b": {"num_train_epochs": 20, "max_seq_length": 128, "batch_size": 8, "learning_rate": 5e-4 },

StopIteration encountered running MNLI

Hi there! I got an StopIteration when I was trying to follow the steps to run your code, scripts/run_glue.sh:

... previous messages hidden...
2024-05-13 16:17:00,172 [INFO]: module.classifier: Linear(in_features=768, out_features=3, bias=True)
Evaluating:   0%|                                                                                                                         | 0/614 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/home/cy/bit/quant_task_distill_glue.py", line 251, in <module>
    main()
  File "/home/cy/bit/quant_task_distill_glue.py", line 242, in main
    learner.train(train_examples, task_name, output_mode, eval_labels,
  File "/home/cy/bit/kd_learner_glue.py", line 152, in train
    teacher_results = self._do_eval(self.teacher_model, task_name, eval_dataloader, output_mode, eval_labels, num_labels)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/bit/kd_learner_glue.py", line 95, in _do_eval
    logits, _, _ = model(input_ids, segment_ids, input_mask)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply
    output.reraise()
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/_utils.py", line 722, in reraise
    raise exception
StopIteration: Caught StopIteration in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/bit/transformer/modeling_bert.py", line 498, in forward
    sequence_output, att_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
                                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/bit/transformer/modeling_bert.py", line 471, in forward
    dtype=next(self.parameters()).dtype)  # fp16 compatibility
          ^^^^^^^^^^^^^^^^^^^^^^^
StopIteration

It seems there is an error in the training routine. I used the provided pretrained full precision bert_base for MNLI and modified the paths for models and dataset accordingly. I suspect this might be due to a library version conflict since there are no more pytorch_model.bin file by default (instead, model.save_pretrained() gives a model.safetensors file.) Can the environment configuration file be provided to address this issue?

And since I'm using W1A1 as the config, I guess I can set it to binary to make it run, but I also want to get a clarification of what would be the reason to use the next generator value as the data type? If environment conflict is not the cause, it is crucial to resolve my question on this to make it to work.

Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.