Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
.vscode/

# 模型文件
model/*

# ckpt
ckpt*

# 临时目录
tmp/*

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/


# Debug
Debug

# Compiled Object files
*.slo
*.lo
*.o
*.obj

# Precompiled Headers
*.gch
*.pch

# Compiled Dynamic libraries
*.so
*.dylib
*.dll

# Fortran module files
*.mod
*.smod

# Compiled Static libraries
*.lai
*.la
*.a
*.lib

# Executables
*.exe
*.out
*.app

# Temp files
*.dump
3 changes: 0 additions & 3 deletions .idea/dictionaries/liuchong.xml

This file was deleted.

7 changes: 0 additions & 7 deletions .idea/misc.xml

This file was deleted.

8 changes: 0 additions & 8 deletions .idea/modules.xml

This file was deleted.

12 changes: 0 additions & 12 deletions .idea/seq2seq_chatbot.iml

This file was deleted.

5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"python.pythonPath": "C:\\Program Files\\Python37\\python.exe",
"python.linting.pylintEnabled": true,
"python.linting.enabled": true
}
22 changes: 9 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,41 +1,37 @@
=================================================更新===========================================================
训练好的模型已经上传到百度云网盘,如果大家有需要可以前去下载。模型训练速度的话,CPU,16G内存,一天即刻训练完成~~~

链接:https://pan.baidu.com/s/1hrNxaSk 密码:d2sn
# seq2seq_chatbot

=================================================分割线,下面是正文===============================================

本文是一个简单的基于seq2seq模型的chatbot对话系统的tensorflow实现
本文是一个简单的基于 seq2seq 模型的 chatbot 对话系统的 tensorflow 实现

代码的讲解可以参考我的知乎专栏文章:

[从头实现深度学习的对话系统--简单chatbot代码实现](https://zhuanlan.zhihu.com/p/32455898)
[从头实现深度学习的对话系统--简单 chatbot 代码实现](https://zhuanlan.zhihu.com/p/32455898)

代码参考了DeepQA,在其基础上添加了beam search的功能和attention的机制
代码参考了 DeepQA,在其基础上添加了 beam search 的功能和 attention 的机制

最终的效果如下图所示:

![](https://i.imgur.com/pN7AfAB.png)

![](https://i.imgur.com/RnvBDwO.png)

测试效果,根据用户输入回复概率最大的前beam_size个句子
测试效果,根据用户输入回复概率最大的前 beam_size 个句子

![](https://i.imgur.com/EdsQ5FE.png)

#使用方法

1,下载代码到本地(data文件夹下已经包含了处理好的数据集,所以无需额外下载数据集)
1,下载代码到本地(data 文件夹下已经包含了处理好的数据集,所以无需额外下载数据集)

2,训练模型,将chatbot.py文件第34行的decode参数修改为False,进行训练模型
2,训练模型,将 chatbot.py 文件第 34 行的 decode 参数修改为 False,进行训练模型

(之后我会把我这里训练好的模型上传到网上方便大家使用)

3,训练完之后(大概要一天左右的时间,30个epoches),再将decode参数修改为True
3,训练完之后(大概要一天左右的时间,30 个 epoches),再将 decode 参数修改为 True

就可以进行测试了。输入你想问的话看他回复什么吧==

这里还需要注意的就是要记得修改数据集和最后模型文件的绝对路径,不然可能会报错。

分别在44行,57行,82行三处。好了,接下来就可以愉快的玩耍了~~

分别在 44 行,57 行,82 行三处。好了,接下来就可以愉快的玩耍了~~
1 change: 1 addition & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.0.1"
67 changes: 67 additions & 0 deletions beam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Date : Apr-09-20 21:03
# @Author : Your Name (you@example.org)
# @Link : http://example.org

import os
import sys
import math
import time
import tensorflow as tf
from data_utils import *
from model import *
from tqdm import tqdm


def beam_search(sess, sentence, word2id, id2word, model, beam_size=5):
if sentence:
batch = sentence2enco(sentence, word2id, model.en_de_seq_len)
beam_path, beam_symbol = model.step(sess, batch.encoderSeqs, batch.decoderSeqs, batch.targetSeqs,
batch.weights, goToken)
paths = [[] for _ in range(beam_size)]
indices = [i for i in range(beam_size)]
num_steps = len(beam_path)
for i in reversed(range(num_steps)):
for kk in range(beam_size):
paths[kk].append(beam_symbol[i][indices[kk]])
indices[kk] = beam_path[i][indices[kk]]

recos = []
for kk in range(beam_size):
foutputs = [int(logit) for logit in paths[kk][::-1]]
if eosToken in foutputs:
foutputs = foutputs[:foutputs.index(eosToken)]
rec = " ".join([tf.compat.as_str(id2word[output])
for output in foutputs if output in id2word])
if rec not in recos:
recos.append(rec)
return recos


def main():
pass
# with tf.Session() as sess:
# beam_size = 5
# if_beam_search = True
# model = create_model(
# sess, True, beam_search=if_beam_search, beam_size=beam_size)
# model.batch_size = 1
# data_path = DATA_PATH
# word2id, id2word, trainingSamples = load_dataset(data_path)

# sys.stdout.write("> ")
# sys.stdout.flush()
# sentence = sys.stdin.readline()
# while sentence:
# recos = beam_search(sess, sentence=sentence, word2id=word2id,
# id2word=id2word, model=model)
# print("Replies --------------------------------------->")
# print(recos)
# sys.stdout.write("> ")
# sys.stdout.flush()
# sentence = sys.stdin.readline()


if __name__ == "__main__":
main()
Binary file not shown.
Loading