Code Monkey home page Code Monkey logo

chatsql's Introduction

ChatSQL

基于ChatGLM-6B/MOSS,实现nl2sql,直连数据库并返回查询结果 目前仅支持MYSQL语法,后续支持多数据库语法查询

  • 目前观察下来,ChatGLM-6B多轮交互上限在3-4轮,之后就无法正确生成准确的SQL语句

🚀 HuggingFace初体验

HuggingFace/ChatSQL❤️感谢HuggingFace提供免费CPU资源

目前配置为:v2CPU-16GRAM 进行部署,基于ChatGLM-int4模型,推理时间感人,如果本地有资源的同学还是下载在本地进行尝试😘

✨ 整体思路

整体思路如上,目前采用yaml文件代替Table_info表结构

🎬 开始

git clone [email protected]:yysirs/ChatSQL.git
cd ChatSQL
conda create -n chatsql python=3.9
conda activate chatsql
pip install -r requirements.txt
# 新建文件夹
mkdir DB
mkdir logs
# 生成本地数据库+插入数据
python local_database.py
# 基于GLM生成SQL
python main_gui.py
或者 基于MOSS生成SQL
python main_gui_moss.py

😁 效果演示

👍 特性

  • 🛒 支持多表联查
  • 🖼️ 2023/04/24 支持web前端
  • 🎉 2023/04/24 支持yaml自定义数据库schema
  • 😁 2023/04/25 支持yaml自定义数据
  • 🎗️ 2023/04/25 支持直连本地数据库查询,验证SQL是否正确
  • 👌 2023/04/30 支持MOSS大模型

🔍 各种类型的查询

# 单表多条件查询
请帮我查询在2019年的货物销售的净收益率大于10的货物名称

# 两表联查
请帮我查询在2019年的净收益率大于10并且销售量大于100的销售负责人名字

# 两表多条件联查
请帮我查询在2019年的货物的净收益率大于10并且销售量大于100并且销售负责人业绩大于1000的销售负责人名字

# max
请帮我查询货物销售量最大的货物名称

# min
请帮我查询货物销售量最小的货物名称

# COUNT
请帮我查询在2019年的货物销售的净收益率大于10的货物名称的数目

# AVG
请帮我查询2019年以及2020年货物销售量的平均值

# GROUP BY
请根据年份进行分组查询货物销售量和年份

# ORDER BY
请帮我按照数量大小对货物名称进行排序

# SUM	
请帮我查询货物的销售量总和是多少

🔨 TODO LIST

  • 增加web前端
  • yaml可配置数据库schema
  • 采用sqlite本地数据库操作,验证SQL语句是否正确
  • In-Context Prompt作为yaml文件外挂
  • 优化各类查询语句,如:ORDER BY、GROUP BY / HAVING 等复杂查询
  • 优化相似度查询模块
  • 其他SQL语法查询,如:ORACLE(关系型数据库)、Cypher(图数据库)
  • Docker部署
  • SQL领域微调ChatGLM/MOSS

❤️ 致谢

  • ChatGLM-6B:ChatGLM-6B模型提供大语言模型能力
  • MOSS:MOSS模型提供大语言模型能力

chatsql's People

Contributors

yysirs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

chatsql's Issues

python main_gui.py 报连不上HF

(chatsql) zhanghui@ubuntu:/home1/zhanghui/ChatSQL$ python main_gui.py
Traceback (most recent call last):
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/urllib3/connectionpool.py", line 467, in _make_request
self._validate_conn(conn)
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/urllib3/connectionpool.py", line 1092, in _validate_conn
conn.connect()
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/urllib3/connection.py", line 642, in connect
sock_and_verified = _ssl_wrap_socket_and_match_hostname(
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/urllib3/connection.py", line 783, in ssl_wrap_socket_and_match_hostname
ssl_sock = ssl_wrap_socket(
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/urllib3/util/ssl
.py", line 469, in ssl_wrap_socket
ssl_sock = ssl_wrap_socket_impl(sock, context, tls_in_tls, server_hostname)
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/urllib3/util/ssl
.py", line 513, in _ssl_wrap_socket_impl
return ssl_context.wrap_socket(sock, server_hostname=server_hostname)
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/ssl.py", line 500, in wrap_socket
return self.sslsocket_class._create(
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/ssl.py", line 1040, in _create
self.do_handshake()
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/ssl.py", line 1309, in do_handshake
self._sslobj.do_handshake()
ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: Hostname mismatch, certificate is not valid for 'huggingface.co'. (_ssl.c:1131)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/urllib3/connectionpool.py", line 790, in urlopen
response = self._make_request(
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/urllib3/connectionpool.py", line 491, in _make_request
raise new_e
urllib3.exceptions.SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: Hostname mismatch, certificate is not valid for 'huggingface.co'. (_ssl.c:1131)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/requests/adapters.py", line 486, in send
resp = conn.urlopen(
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/urllib3/connectionpool.py", line 844, in urlopen
retries = retries.increment(
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/urllib3/util/retry.py", line 515, in increment
raise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type]
urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/models/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 (Caused by SSLError(SSLCertVerificationError(1, "[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: Hostname mismatch, certificate is not valid for 'huggingface.co'. (_ssl.c:1131)")))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "main_gui.py", line 15, in
from utils import obtain_sql, retrieval_related_table, execute_sql
File "/home1/zhanghui/ChatSQL/utils.py", line 8, in
from prompt import embedder, corpus_embeddings, table_schema, corpus, In_context_prompt
File "/home1/zhanghui/ChatSQL/prompt.py", line 9, in
embedder = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/sentence_transformers/SentenceTransformer.py", line 87, in init
snapshot_download(model_name_or_path,
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/sentence_transformers/util.py", line 442, in snapshot_download
model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token)
File "/home/zhanghui/.local/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn
return fn(*args, **kwargs)
File "/home/zhanghui/.local/lib/python3.8/site-packages/huggingface_hub/hf_api.py", line 1675, in model_info
r = get_session().get(path, headers=headers, timeout=timeout, params=params)
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/requests/sessions.py", line 602, in get
return self.request("GET", url, **kwargs)
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/requests/sessions.py", line 589, in request
resp = self.send(prep, **send_kwargs)
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/requests/sessions.py", line 703, in send
r = adapter.send(request, **kwargs)
File "/home/zhanghui/archiconda3/envs/chatsql/lib/python3.8/site-packages/requests/adapters.py", line 517, in send
raise SSLError(e, request=request)
requests.exceptions.SSLError: HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/models/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 (Caused by SSLError(SSLCertVerificationError(1, "[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: Hostname mismatch, certificate is not valid for 'huggingface.co'. (_ssl.c:1131)")))
(chatsql) zhanghui@ubuntu:/home1/zhanghui/ChatSQL$

能不能做成离线版本,不用连HF的?

python main_gui.py报错

huggingface_hub.utils.validators.HFValidationError: Repo id must use alphanumeric chars or '-', '', '.', '--' and '..' are forbidden, '-' and '.' cannot start or end the name, max length is 96: './ChatGlm-6b'.

关注下

之前有用NL2SQL、DuSQL数据集Ptuning过ChatGLM,泛化效果并不好。后来我也自己建模板生成领域内训练集测试,也不行。

几个原因

1是通常我们建表都会做拆分,比如净收益率、损失率、进球次数这类值基本都是多表实时计算来的,像数据集那样统计类型的表很少

2是比如你代码中prompt.py#L44的这个例子
用户在问"问题"时通常不会知道表里都有哪些表和字段。比如"矿泉水是从那进的货",了解sql的话应该都知道是查cargo、supply_company,但llm没法把矿泉水和cargo_name关联(训练集足够的话可以,但如果在换成“干脆面”就又不行了),也没法把"从那进"和supply_company关联

main_gui.py 报错 local variable 'In_context_prompt' referenced before assignment

python F:\chat\ChatSQL-main\main_gui.py
Loading checkpoint shards: 14%|████████▏ | 1/7 [00:02<00:12, 2Loading checkpoint shards: 29%|████████████████▎ | 2/7 [00:03<0Loading checkpoint shards: 43%|████████████████████████▍ | 3/7 Loading checkpoint shards: 57%|████████████████████████████████▌ Loading checkpoint shards: 71%|████████████████████████████████████████▋ Loading checkpoint shards: 86%|████████████████████████████████████████████Loading checkpoint shards: 100%|████████████████████████████████████████████Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 7/7 [00:13<00:00, 1.89s/it]
F:\chat\ChatSQL-main\main_gui.py:107: GradioDeprecationWarning: The style method is deprecated. Please set these arguments in the constructor instead.
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
Running on local URL: http://127.0.0.1:7860

To create a public link, set share=True in launch().
Traceback (most recent call last):
File "D:\work\miniconda3\envs\chatsql\lib\site-packages\gradio\routes.py", line 442, in run_predict
output = await app.get_blocks().process_api(
File "D:\work\miniconda3\envs\chatsql\lib\site-packages\gradio\blocks.py", line 1392, in process_api
result = await self.call_function(
File "D:\work\miniconda3\envs\chatsql\lib\site-packages\gradio\blocks.py", line 1097, in call_function
prediction = await anyio.to_thread.run_sync(
File "D:\work\miniconda3\envs\chatsql\lib\site-packages\anyio\to_thread.py", line 56, in run_sync
return await get_async_backend().run_sync_in_worker_thread(
File "D:\work\miniconda3\envs\chatsql\lib\site-packages\anyio_backends_asyncio.py", line 2134, in run_sync_in_worker_thread
return await future
File "D:\work\miniconda3\envs\chatsql\lib\site-packages\anyio_backends_asyncio.py", line 851, in run
result = context.run(func, *args)
File "D:\work\miniconda3\envs\chatsql\lib\site-packages\gradio\utils.py", line 703, in wrapper
response = f(*args, **kwargs)
File "F:\chat\ChatSQL-main\main_gui.py", line 76, in predict
input_prompt = retrieval_related_table(input_prompt, input, history, top_k=3)
File "F:\chat\ChatSQL-main\utils.py", line 28, in retrieval_related_table
input_prompt += In_context_prompt
UnboundLocalError: local variable 'In_context_prompt' referenced before assignment

python main_gui.py 后卡住不动

本地部署LLM为chatglm3-6b
python=3.9
Ubuntu20.04

修改了部分代码尝试让模型本地加载
tokenizer = AutoTokenizer.from_pretrained("~/THUDM/chatglm3-6b", trust_remote_code=True)

model = AutoModel.from_pretrained("~/THUDM/chatglm3-6b", trust_remote_code=True).half().cuda()

已填写config.cfg

python local_database.py没问题,正常执行

调用python main_gui.py后整个console卡住不动,等了至少5分钟也没有反应
debug在这个文件的第一行打断点也没用

image

ctrl+C强停后的报错

^CTraceback (most recent call last):
  File "/data/bch/LLM/forSQL/ChatSQL/main_gui.py", line 14, in <module>
    from utils import obtain_sql, retrieval_related_table, execute_sql
  File "/data/bch/LLM/forSQL/ChatSQL/utils.py", line 8, in <module>
    from prompt import embedder, corpus_embeddings, table_schema, corpus, In_context_prompt
  File "/data/bch/LLM/forSQL/ChatSQL/prompt.py", line 9, in <module>
    embedder = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/sentence_transformers/SentenceTransformer.py", line 87, in __init__
    snapshot_download(model_name_or_path,
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/sentence_transformers/util.py", line 442, in snapshot_download
    model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token)
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/huggingface_hub/hf_api.py", line 2227, in model_info
    r = get_session().get(path, headers=headers, timeout=timeout, params=params)
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/requests/sessions.py", line 602, in get
    return self.request("GET", url, **kwargs)
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/requests/sessions.py", line 589, in request
    resp = self.send(prep, **send_kwargs)
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/requests/sessions.py", line 703, in send
    r = adapter.send(request, **kwargs)
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/huggingface_hub/utils/_http.py", line 68, in send
    return super().send(request, *args, **kwargs)
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/requests/adapters.py", line 486, in send
    resp = conn.urlopen(
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/urllib3/connectionpool.py", line 793, in urlopen
    response = self._make_request(
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/urllib3/connectionpool.py", line 467, in _make_request
    self._validate_conn(conn)
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/urllib3/connectionpool.py", line 1099, in _validate_conn
    conn.connect()
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/urllib3/connection.py", line 616, in connect
    self.sock = sock = self._new_conn()
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/urllib3/connection.py", line 198, in _new_conn
    sock = connection.create_connection(
  File "/data/bch/miniconda3/envs/chatsql/lib/python3.9/site-packages/urllib3/util/connection.py", line 73, in create_connection
    sock.connect(sa)
KeyboardInterrupt

<\details>

训练chatGLM的prompt如何构建

我目前使用的prompt的格式类似这样的:

`
_我们要将用户的问题翻译为一个mysql的sql语句。
问题涉及的数据表的信息是:

表名:a_activity_instance
表结构:
activity_instance_id 活动实例ID
activity_type_id 活动类型ID
activity_code 活动编码
busi_category 业务类型,参见"活动业务分类"表的配置
用户的问题是:

get all activity_instance_id of activity?
将用户的问题翻译为mysql的sql语句,sql语句是:

`

想请教一下这样正确不? 如果这样的话,pre_seq_len 就要设置的比较大,会不会对最后的效果有影响? 麻烦大佬给一些建议,跪谢。

返回sql由于空格问题导致执行失败

好的,让我们开始查询:

SELECTcargo_name
FROMcargo
WHEREnet_yield>10

这将返回在2019年销售的货物名称。

sql语句执行失败

near "SELECTcargo_name": syntax error

加油!

最近也在关注相关text2sql功能,先关注,希望以后完善该项目!大力支持

main.py 报错 KeyError: 'database'

(chatsql) F:\chat\ChatSQL-main>python F:\chat\ChatSQL-main\main.py
Loading checkpoint shards: 14%|████████▏ | 1/7 [00:01<00:11, 1Loading checkpoint shards: 29%|████████████████▎ | 2/7 [00:03<0Loading checkpoint shards: 43%|████████████████████████▍ | 3/7 Loading checkpoint shards: 57%|████████████████████████████████▌ Loading checkpoint shards: 71%|████████████████████████████████████████▋ Loading checkpoint shards: 86%|████████████████████████████████████████████Loading checkpoint shards: 100%|████████████████████████████████████████████Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 7/7 [00:12<00:00, 1.75s/it]
Traceback (most recent call last):
File "F:\chat\ChatSQL-main\main.py", line 99, in
main()
File "F:\chat\ChatSQL-main\main.py", line 31, in main
db_con = Cur_db()
File "F:\chat\ChatSQL-main\utility\db_tools.py", line 14, in init
self.db_name = self.config['database']['DB']
File "D:\work\miniconda3\envs\chatsql\lib\configparser.py", line 963, in getitem
raise KeyError(key)
KeyError: 'database'

(chatsql) F:\chat\ChatSQL-main>a

关于模型本地化部署

大神,请问可以不连接抱抱脸,将模型下载到本地完全私有化运行吗?

如果可以,请问需要下载哪些模型? 以及如何修改配置、代码文件。

万分感谢!

config.cfg文件为空

在服务器上测试的时候,好像是由于这个文件为空导致了与数据库的联动查询不成功,能演示下如何与数据库联动吗?
我自己尝试是,db_tools这个文件一直无法正常运行

sql语句执行失败

按照示例提问:“请帮我查询货物销售量最大的货物名称”,显示:“sql语句执行失败 near "sql": syntax error”
看起来是生成的sql语句不正确,查询了一下yaml里面没有的字段,请问这种问题如何解决。

详情见下面:

请帮我查询货物销售量最大的货物名称

根据你提供的表格信息,我假设你想要查询的是2019年销售量最大的货物名称。你可以使用以下SQL语句来查询这个信息:

SELECT c.cargo_name
FROM cargo c
JOIN cargo_info i ON c.cargo_id = i.cargo_id
JOIN source_cargo s ON i.cargo_id = s.cargo_id
JOIN storage_warehouse w ON s.cargo_id = w.cargo_id
JOIN sales_department d ON w.cargo_id = d.cargo_id
JOIN sales_person s_per ON d.sales_person_id = s_per.sales_person_id
JOIN sales_person s_市场 ON s_per.sales_person_id = s_市场.sales_person_id
JOIN cargo_category c_cat ON c.cargo_category = c_cat.cargo_category
JOIN source_cargo s_src ON c_cat.source_cargo_id = s_src.cargo_id
JOIN storage_warehouse s_ware on s_src.storage_warehouse_id = s_ware.storage_warehouse_id
JOIN sales_volume s_volume ON s_ware.cargo_id = s_volume.cargo_id
JOIN month_on_month_growth_rate m_g on s_volume.month_on_month_growth_rate = m_g.month_on_month_growth_rate
JOIN loss_rate l_r on m_g.

sql语句执行失败

near "sql": syntax error

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.