Code Monkey home page Code Monkey logo

astormer's Introduction

ASTormer: AST Structure-aware Transformer Decoder for Text-to-SQL

This is the project containing source code for the paper ASTormer: An AST Structure-aware Transformer Decoder for Text-to-SQL. If you find it useful, please cite our work.

@misc{cao2023astormer,
      title={ASTormer: An AST Structure-aware Transformer Decoder for Text-to-SQL}, 
      author={Ruisheng Cao and Hanchong Zhang and Hongshen Xu and Jieyu Li and Da Ma and Lu Chen and Kai Yu},
      year={2023},
      eprint={2310.18662},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Note that: This work focuses on leveraging small-sized pre-trained models and labeled training data to train a specialized, interpretable and efficient local text-to-SQL parser in low-resource scenarios, instead of chasing SOTA performances. For better results, please try LLM with in-context learning (such as DINSQL and ACTSQL), or resort to larger encoder-decoder architectures containing billion parameters (such as Picard-3B and RESDSQL-3B). Due to a shift in the author's research focus in the LLM era, this project will no longer be maintained.

Create environment

The following commands are also provided in setup.sh.

  1. Firstly, create conda environment astormer:
$ conda create -n astormer python=3.8
$ conda activate astormer
$ pip install torch==1.8.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
$ pip install -r requirements.txt
  1. Next, download thrird-party dependencies:
$ python -c "import stanza; stanza.download('en')"
$ python -c "import nltk; nltk.download('stopwords'); nltk.download('punkt');"
  1. Download the required pre-trained language models from Hugging Face Model Hub, such as electra-small-discriminator and chinese-electra-180g-small-discriminator, into the pretrained_models directory: (please ensure that Git LFS is installed)
$ mkdir -p pretrained_models && cd pretrained_models
$ git lfs install
$ git clone https://huggingface.co/google/electra-small-discriminator

Download and preprocess datasets

  1. Create a new directory data to store all prevalent cross-domain multi-table text-to-SQL data, including Spider, SParC, CoSQL, DuSQL and Chase. Next, download, unzip and rename the spider.zip, sparc.zip, cosql_dataset.zip, DuSQL.zip, Chase.zip as well as their databases (Spider-testsuite-database and Chase-database) into the directory data.
  • For variants of dev dataset on Spider, e.g., SpiderSyn, SpiderDK, SpiderRealistic, they can also be downloaded and included at the evaluation stage.
  • These default paths can be changed by modifying the dict CONFIG_PATHS in nsts/transition_system.py.
  • The directory data should be organized as follows:
- data/
    - spider/
        - database/ # all databases, one directory for each db_id
        - database-testsuite/ # test-stuite databases
        - *.json # datasets or tables, variants of dev set such as dev_syn.json are also downloaded and placed here
    - sparc/
        - database/
        - *.json
    - cosql/
        - database/
        - sql_state_tracking/
            - *.json # train and dev datasets
        - [other directories]/
        - tables.json
    - dusql/
        - *.json
    - chase/
        - database/
        - *.json
  1. Datasets preprocessing, including:
  • Merge data/spider/train_spider.json and data/spider/train_others.json into one single dataset data/spider/train.json
  • Dataset and database format transformation for Chinese benchmarks DuSQL and Chase
  • Fix some annotation errors in SQLs and type errors in database schema
  • Re-parse the SQL query into a unified JSON format for all benchmarks. We modify and unify the format of sql field, including: (see nsts/parse_sql_to_json.py for details)
    • For a single condition, the parsed tuple is changed from (not_op, op_id, val_unit, val1, val2) into (agg_id, op_id, val_unit, val1, val2). The not_op is removed and integrated into op_id, such as not in and not like
    • For FROM conditions where the value is a column id, the target val1 must be a column list (agg_id, col_id, isDistinct(bool)) to distinguish from integer values
    • For ORDER BY clause, the parsed tuple is changed from ('asc'/'desc', [val_unit1, val_unit2, ...]) to ('asc'/'desc', [(agg_id, val_unit1), (agg_id, val_unit2), ...])
  • It takes less than 10 minutes to preprocess each dataset (tokenization, schema linking and value linking). We use the PLM tokenizer to tokenize questions and schema items; Schema linking is performed at the word level instead of BPE/Subword token-level.
$ ./run/run_preprocessing.sh

Training

To train ASTormer with small/base/large series pre-trained language models respectively:

  • dataset can be chosen from ['spider', 'sparc', 'cosql', 'dusql', 'chase']
  • plm is the name of pre-trained language models under the directory pretrained_models. Please conform to the choice of PLMs in preprocessing script (run/run_preprocessing.sh).
# swv means utilizing static word embeddings, extracted from small-series models such as electra-small-discriminator
$ ./run/run_train_and_eval_swv.sh [dataset] [plm]
        
# DDP is not needed, a single 2080Ti GPU is enough
$ ./run/run_train_and_eval_small.sh [dataset] [plm]

# if DDP used, please specify the environment variables below, e.g., one machine with two GPUs
$ GPU_PER_NODE=2 NUM_NODES=1 NODE_RANK=0 MASTER_ADDR="127.0.0.1" MASTER_PORT=23456 ./run/run_train_and_eval_base.sh [dataset] [plm]

# if DDP used, please specify the environment variables below, e.g., two machines each with two GPUs
$ GPU_PER_NODE=2 NUM_NODES=2 NODE_RANK=0 MASTER_ADDR=[node0_ip] MASTER_PORT=23456 ./run/run_train_and_eval_large.sh [dataset] [plm]
$ GPU_PER_NODE=2 NUM_NODES=2 NODE_RANK=1 MASTER_ADDR=[node0_ip] MASTER_PORT=23456 ./run/run_train_and_eval_large.sh [dataset] [plm]

Inference and Submission

For inference, see run/run_eval.sh (evaluation on the preprocessed dev dataset) and run/run_eval_from_scratch.sh (only SQL prediction, for testset submission):

  • saved_model_dir is the directory to saved arguments (params.json) and model parameters (model.bin)
$ ./run/run_eval.sh [saved_model_dir]

$ ./run/run_eval_from_scratch.sh [saved_model_dir]

For both training and inference, you can also use the prepared Docker environment from rhythmcao/astormer:v0.3:

$ docker pull rhythmcao/astormer:v0.3
$ docker run -it -v $PWD:/workspace rhythmcao/astormer:v0.3 /bin/bash

Acknowledgements

We are grateful to the flexible semantic parser TranX that inspires our works.

astormer's People

Contributors

rhythmcao avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

saivignesh-05

astormer's Issues

Pretrained on spider data

Hello @rhythmcao, Do you have a previously trained model on the spider dataset, so that I can directly use it for inference. I'm only interested in using different metrics to check how different models perform. So it will be of real help if u can add the pretrained model directly useful for inference. Thanks in advance

Help Wanted!Why did you choose 'word-level schema linking' instead of 'token-level schema linking'?

image

Hello~ rhythm~ I want to ask why word level schema linking is used instead of using an subword level schema linking process.
I found a curious problem while reading the code. If PLM is used for tokenizing, would schema linking at the subword level be better than at the word level? Have you ever tried some tests about it? Thank you very much for your reply.

image

Because I found in LGESQL that there is indeed aggregation and pooling function at the subword token-level. So I am curious why RAT-SQL only needs word level. In addition, your lgesql is written very well. It is a rare and foundational article for the text-to-SQL task. Thank you very much for your contribution to the community.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.