Code Monkey home page Code Monkey logo

arallm's Introduction

Introduction

This is part of dataset and code for paper Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. This repo aims to display the main prompt content, dataset content, evaluation methods, and fine-tuning code mentioned in our paper. Considering the security of enterprise data, we will release more data content after completing data anonymization in the future. Our project is built on LLaMA Factory, thanks to their awesome work.

Special Note: Considering the importance of data security, the tags in the data have been anonymized and rewritten. There might be slight differences compared to the tags that are actually deployed online, but this does not hinder understanding. We hope for your understanding in this matter.

Test Prompt

In the row_data directory, we provide some examples of reasoning library, tag table, train data and test data. Considering the data security, they are anonnymous now.

test_prompt.py provide the different test prompts. The basic instruction comes from the file test_instruction.txt. After obtaining the prompt, you can test it by calling the gpt-3.5 turbo API. Please refer to the documentation on the openai official website (https://platform.openai.com/docs/api-reference/chat/object) for the testing script.

How to finetune

Training data

In sft_data/sft_train_data.json, we provide two simple train samples corresponding to two training tasks (i.e. predict the answers or predict the reasoning steps). For analogical reasoning based finetune, we will add analogical examples into input, which is similar to ara_prompt function in test_prompt.py.

Environment

git clone https://github.com/wjj0122/ARALLM.git
conda create -n llama_factory python=3.10
conda activate llama_factory
cd ARALLM
pip install -r requirements.txt

Finetune command

ChatGLM2-6B-32K

deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
    --deepspeed ds_config.json \
    --stage sft \
    --model_name_or_path YOUR_MODEL_PATH \
    --do_train \
    --dataset DATASET_NAME \
    --template chatglm2 \
    --cutoff_len 4096 \
    --finetuning_type lora \
    --lora_target c_attn \
    --lora_rank 8 \
    --output_dir ${OUTPUT_DIR} \
    --overwrite_cache \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_strategy "epoch" \
    --learning_rate 5e-5 \
    --num_train_epochs 1 \
    --plot_loss \
    --report_to tensorboard \
    --bf16

Baichuan2-13B-Chat

deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
    --deepspeed ds_config.json \
    --stage sft \
    --model_name_or_path YOUR_MODEL_PATH \
    --do_train \
    --dataset DATASET_NAME \
    --template baichuan2 \
    --cutoff_len 4096 \
    --finetuning_type lora \
    --lora_target W_pack \
    --lora_rank 8 \
    --output_dir OUTPUT_DIR \
    --overwrite_cache \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_strategy "epoch" \
    --learning_rate 5e-5 \
    --num_train_epochs 1 \
    --plot_loss \
    --report_to tensorboard \
    --bf16

Export checkpoint

python src/export_model.py \
    --model_name_or_path path_to_llama_model \
    --template default \
    --finetuning_type lora \
    --checkpoint_dir path_to_checkpoint \
    --export_dir path_to_export

Predict

ChatGLM2-6B-32K

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --model_name_or_path YOUR_EXPORT_MODEL_PATH \
    --do_predict \
    --dataset TEST_DATASET \
    --template chatglm2 \
    --finetuning_type lora \
    --checkpoint_dir path_to_checkpoint \
    --output_dir path_to_predict_result \
    --per_device_eval_batch_size 8 \
    --max_samples 100 \
    --predict_with_generate

Baichuan2-13B-Chat

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --model_name_or_path YOUR_EXPORT_MODEL_PATH \
    --do_predict \
    --dataset TEST_DATASET \
    --template baichuan2 \
    --finetuning_type lora \
    --checkpoint_dir path_to_checkpoint \
    --output_dir path_to_predict_result \
    --per_device_eval_batch_size 8 \
    --max_samples 100 \
    --predict_with_generate

Evaluate

In eval directory, we provide the evaluation scripts we used for evaluating the result generated from LLMs(ChatGPT or finetuned LLMs). The metrics of structural accuracy and overall accuracy are obtained from struc_and_overall_eval.py, while GPTEval are obtained from gpt4_eval_prompt.py. The gpt4_eval_instruction.txt file provides the scoring example for GPT4 evaluation.

License

This repository is licensed under the Apache-2.0 License.

Please follow the model licenses to use the corresponding model weights: Baichuan / Baichuan2 / BLOOM / ChatGLM3 / Falcon / InternLM / LLaMA / LLaMA-2 / Mistral / Phi-1.5 / Qwen / XVERSE

arallm's People

Contributors

wjj0122 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.