Code Monkey home page Code Monkey logo

crnn_attention_ocr_chinese's Issues

训练结果总是 '的'

您好, 我按照train.txt生成了数据,然后修改了一下代码,但是发现预测结果总是 '的', 是因为训练的次数不够多 还是我改的代码有问题?

[+] epoch:0, batch:371, loss:3.7565720081329346, acc:0.0,
 train_decode:[',的的的的的的的的的', ',,的的的的的的的的', ',,的的的的的的的的', ',,的的的的的的的的', ',的的的的的的的的的'], 
 val_decode:[',的的的的的的的的的', ',的的的的的的的的的', ',的的的的的的的的的', ',的的的的的', ',,的的的的的的的的'], 
 val_truth:['管理我们要质疑!当然', '那大仙才能保持人体健', '种色彩因其波长不同而', '命和**共和制的产生', '盟者将按照年营业额的']

config.py 主要是加了一个数据批量读入的功能 还把图片尺寸改成280*32了 不知道有没有影响

import numpy as np
import cv2
import os

learning_rate = 0.001
momentum = 0.9
START_TOKEN = 0
END_TOKEN = 1
UNK_TOKEN = 2
VOCAB = {'<GO>': 0, '<EOS>': 1, '<UNK>': 2, '<PAD>': 3}  # 分别表示开始,结束,未出现的字符
VOC_IND = {}


# charset='0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
def get_class(path):
    """

    :param path:
    :return: 文字到int , int到文字
    """
    f = open(path, 'r', encoding='UTF-8')
    line = f.readline().strip()
    class2int = {}
    int2class = {}
    i = 0
    while line != '':
        class2int[line] = i
        int2class[i] = line
        line = f.readline().strip()
        i = i + 1
    f.close()
    return class2int, int2class


_, charset = get_class('char_std_5990.txt')

for i in range(len(charset)):
    VOCAB[charset[i]] = i + 4
for key in VOCAB:
    VOC_IND[VOCAB[key]] = key

NUM_BATCHES = 0
MAX_LEN_WORD = 20  # 标签的最大长度,以PAD
VOCAB_SIZE = len(VOCAB)
BATCH_SIZE = 320
RNN_UNITS = 256
EPOCH = 10000
IMAGE_WIDTH = 280
IMAGE_HEIGHT = 32
MAXIMUM__DECODE_ITERATIONS = 20
DISPLAY_STEPS = 20
LOGS_PATH = 'log'
CKPT_DIR = 'save_model'
train_dir = '/media/lhj/0C3313300C331330/Images'
val_dir = 'data/data'
is_restore = True


def label2int(label):  # label shape (num,len)
    # seq_len=[]
    target_input = np.ones((len(label), MAX_LEN_WORD), dtype=np.float32) + 2  # 初始化为全为PAD
    target_out = np.ones((len(label), MAX_LEN_WORD), dtype=np.float32) + 2  # 初始化为全为PAD
    for i in range(len(label)):
        # seq_len.append(len(label[i]))
        target_input[i][0] = 0  # 第一个为GO
        for j in range(len(label[i])):
            target_input[i][j + 1] = VOCAB[label[i][j]]
            target_out[i][j] = VOCAB[label[i][j]]
        target_out[i][len(label[i])] = 1
    return target_input, target_out


def int2label(decode_label):
    label = []
    for i in range(decode_label.shape[0]):
        temp = ''
        for j in range(decode_label.shape[1]):
            if VOC_IND[decode_label[i][j]] == '<EOS>':
                break
            elif decode_label[i][j] == 3:
                continue
            else:
                temp += VOC_IND[decode_label[i][j]]
        label.append(temp)
    return label


def batch_iter(data_dir, file):
    """生成批次数据"""
    global NUM_BATCHES
    f = open(file, 'r', encoding='UTF-8')
    lines = f.read().strip().split('\n')

    data_len = len(lines)
    NUM_BATCHES = int((data_len - 1) / BATCH_SIZE) + 1

    print("[+] You have {} batches!".format(NUM_BATCHES))

    for i in range(NUM_BATCHES):
        image = []
        labels = []
        start_id = i * BATCH_SIZE
        end_id = min((i + 1) * BATCH_SIZE, data_len)
        iter_lines = lines[start_id:end_id]
        for line in iter_lines:
            s = line.strip().split(' ')
            label = ''
            image_name = os.path.join(data_dir, s[0])
            im = cv2.imread(image_name, 0)  # /255.#read the gray image
            if im.shape != [32, 280]:
                im = cv2.resize(im, (IMAGE_WIDTH, IMAGE_HEIGHT))
            img = im.swapaxes(0, 1)
            image.append(np.array(img[:, :, np.newaxis]))
            for i in range(len(s) - 1):
                label += charset[int(s[i + 1])]
            labels.append(label)
        yield np.array(image), labels

def cal_acc(pred, label):
    num = 0
    for i in range(len(pred)):
        if pred[i] == label[i]:
            num += 1
    return num * 1.0 / len(pred)

train.py 调了一下代码结构, 把批量读入的函数加进去了

from model import *
import config as cfg
import time
import os


init = tf.global_variables_initializer()  # 变量初始化

with tf.name_scope("optimizer") as scope:
    loss, train_decode_result, pred_decode_result = build_network(is_training=True)  # TODO
    optimizer = tf.train.MomentumOptimizer(learning_rate=cfg.learning_rate, momentum=cfg.momentum, use_nesterov=True)
    train_op = optimizer.minimize(loss)


with tf.name_scope('summaries'):
    saver = tf.train.Saver(max_to_keep=5)
    tf.summary.scalar("cost", loss)
    summary_op = tf.summary.merge_all()
    writer = tf.summary.FileWriter(cfg.LOGS_PATH)


with tf.Session() as sess:
    tf.initialize_all_variables().run()
    sess.run(init)
    if cfg.is_restore:
        ckpt = tf.train.latest_checkpoint(cfg.CKPT_DIR)
        if ckpt:
            saver.restore(sess, ckpt)
            print('[*] restore from the checkpoint{0}'.format(ckpt))

    num_batches_per_epoch = cfg.NUM_BATCHES
    for cur_epoch in range(cfg.EPOCH):
        train_dataloader = cfg.batch_iter(cfg.train_dir, 'data/train.txt')
        test_dataloader = cfg.batch_iter(cfg.train_dir, 'data/my_test')

        train_cost = 0
        start_time = time.time()
        batch_time = time.time()
        val_img, val_label = next(test_dataloader)
        # the tracing part
        for cur_batch in range(10249):
            batch_time = time.time()
            batch_inputs, batch_label = next(train_dataloader)
            print("Batch Label: ", batch_label)
            batch_target_in, batch_target_out = cfg.label2int(batch_label)
            sess.run(train_op,
                     feed_dict={image: batch_inputs, train_output: batch_target_in, target_output: batch_target_out,
                                sample_rate: np.min([1., 0.2 * cur_epoch + 0.2])})

            if cur_batch % 1 == 0:
                summary_loss, loss_result = sess.run([summary_op, loss],
                                                     feed_dict={image: batch_inputs, train_output: batch_target_in,
                                                                target_output: batch_target_out,
                                                                sample_rate: np.min([1., 1.])})
                writer.add_summary(summary_loss, cur_epoch * num_batches_per_epoch + cur_batch)
                val_predict = sess.run(pred_decode_result, feed_dict={image: val_img[0:cfg.BATCH_SIZE]})
                train_predict = sess.run(pred_decode_result, feed_dict={image: batch_inputs,
                                                                        train_output: batch_target_in,
                                                                        target_output: batch_target_out,
                                                                        sample_rate: np.min([1., 1.])})
                predit = cfg.int2label(np.argmax(val_predict, axis=2))
                train_pre = cfg.int2label(np.argmax(train_predict, axis=2))
                gt = val_label[0:cfg.BATCH_SIZE]
                acc = cfg.cal_acc(predit, gt)

                print("[+] epoch:{}, batch:{}, loss:{}, acc:{},\n train_decode:{}, \n val_decode:{}, \n val_truth:{}".
                      format(cur_epoch, cur_batch,
                             loss_result, acc,
                             train_pre[0:5],
                             predit[0:5],
                             gt[0:5]))

                if not os.path.exists(cfg.CKPT_DIR):
                    os.makedirs(cfg.CKPT_DIR)
                saver.save(sess, os.path.join(cfg.CKPT_DIR, 'attention_ocr.model'),
                           global_step=cur_epoch * num_batches_per_epoch + cur_batch)

关于单张图片inference的问题

你好,我发现在使用infer.py时,输入图片的数量必须和BATCH_SIZE一样。
1)当BATCH_SIZE为40,我用如下代码预测时:
for img_pd in val_img:
val_predict = sess.run(pred_decode_result,feed_dict={image: np.array([img_pd])})
predit = cfg.int2label(np.argmax(val_predict, axis=2))
print(predit)
会出现这样的错误:ConcatOp : Dimensions of inputs should match: shape[0] = [40,326] vs. shape[1] = [1,256]
也就是输入图片的数量得和BATCH_SIZE一样。

2)当我把BATCH_SIZE改为2,我用如下代码预测时:
for img_pd in val_img:
val_predict = sess.run(pred_decode_result,feed_dict={image: np.array([img_pd,img_pd])})
predit = cfg.int2label(np.argmax(val_predict, axis=2))
print(predit)
虽然能输出结果,但是结果却是没那么准确的。

3)我把BATCH_SIZE改为1之后,在build_network时会报错:0 is not an invariant for the loop.
定位是tf.contrib.seq2seq.dynamic_decode这个函数里面出错了。
所以没办法通过更改BATCH_SIZE为1来实现单张图片预测。

那么,请问一下,如何进行单张图片的预测呢?
期待你的回复,谢谢!

How to increase training rate.

Hi, wu
I found changing this line(in model line 81) loss = tf.reduce_mean(att_loss) to loss = tf.reduce_sum(att_loss) will greatly increase the training speed.
The reason, I thought, is reduce_sum return a bigger loss, which will directly cause that every trainable parameter will be changed greatly.

Brs
SJHB

synthetic Chinese string dataset

I downloaded the synthetic chinese string data from baiduyun, but can not open it.

I want to know is this file corrupted or not.

please give me some advice.

模型

楼主,没有开放训练好的模型吗

an error

Hi, wu
I confront this error when I run train.py. Do you know how to solved this problem?

Traceback (most recent call last):
File "train.py", line 50, in
sess.run( train_op,feed_dict={image: batch_inputs,train_output: batch_target_in,target_output: batch_target_out,sample_rate:np.min([1.,0.2*cur_e
poch+0.2])})

File "/home/sjhbxs/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 789, in run
run_metadata_ptr)
File "/home/sjhbxs/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 997, in _run
feed_dict_string, options, run_metadata)
File "/home/sjhbxs/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1132, in _do_run
target_list, options, run_metadata)
File "/home/sjhbxs/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1152, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: The node 'gradients/decode/decoder/while/BasicDecoderStep/ScheduledEmbeddingTrainingHe
lperSample_1/cond/Gather_1_grad/ExpandDims/StackPush' has inputs from different frames. The input 'gradients/decode/decoder/while/BasicDecoderStep/S
cheduledEmbeddingTrainingHelperSample_1/cond/Gather_1_grad/Size' is in frame ''. The input 'gradients/decode/decoder/while/BasicDecoderStep/Schedule
dEmbeddingTrainingHelperSample_1/cond/Gather_1_grad/ExpandDims/StackPush/Switch' is in frame 'decode/decoder/while/decode/decoder/while/'.

dic.txt

中文的字典的用处是什么

Do I need to modify charset in config.py in order to recog Chinese?

It seems to me that in order to have K Chinese characters recoginized , I have to initialize K logits in the compute graph.With the attention mechanism,there will be K^2 attention nodes in the graph.
Considering there're many Chinese characters(20000+),the compute graph is very large and fails to allocate memory on a dual GTX1080Ti machine (22GB of VRAM)
Is my understanding correct?

loss下降问题

您好,我用这个代码训练synth90k的英文数据,数据量比较大,训练的时候一开始loss是正常下降的,但是后来经过每个epoch后loss竟然又上升了,这个不适合大数据量吗,无法收敛啊。我还试过小数据集进行训练,loss是正常下降了,但是到0.2就不下降了,不知道您有没有遇到类似情况。

2018-11-28 09-35-28

转pb过程中的问题

请问转pb的时候,输出的名字是这个么output_node_names = "target_output",怎么我转出来的就是84byte?

关于生成.pb文件

谢谢之前的解答,我想问问就是模型保存到了save_model目录下后,怎么将它们转成.pb文件

转成中文识别的错误

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [584,256] rhs shape= [578,256]
[[Node: save/Assign_4 = Assign[T=DT_FLOAT, _class=["loc:@decode/decoder/attention_wrapper/gru_cell/candidate/kernel"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/gpu:0"](decode/decoder/attention_wrapper/gru_cell/candidate/kernel, save/RestoreV2_4/_81)]]

报错 Can not squeeze dim[2]

ValueError: Can not squeeze dim[2], expected a dimension of 1, got 5 for 'encode_features/Squeeze' (op: 'Squeeze') with input shapes: [?,98,5,512].
我输入的图片大小为400*80,cnn特征提取后输入到rnn时会报错,想问一下这个tf.squeeze(net, axis=2)是做了什么操作,卷积后net的维度不是[batch_size, h, w , channel]吗,对第二维做压缩是为什么?

训练结果全部变成<GO><GO><GO>

你好,使用模型训练,项目中的训练集,当loss=0.18,精确度71%,learning_rate由0.001调到0.0001的时候,模型突然所有的预测结果都变成一连串的....
请问一下,这个是什么原因导致的?谢谢啦

MAX_LEN_WORD issue

Hello sir,

Nice job!
Here i have a question, i need to recognition long number sequence, the max length might be 20, so i just change the input W to 240(2*120=240) and change the stride of the last pooling layer from 1.2 to 2,2.
however the training will not converge whenever i add more samples(>100), and the acc is always 0.0.
Do you have any advice?

怎么预测

训练完成后生成attention_ocr.ckpt-239998.data-00000-of-00001,attention_ocr.ckpt-239998.index,attention_ocr.ckpt-239998.meta,checkpoint,怎么用来对字符图片进行预测

Embedding Dimension

model.py 文件里的这一句可能有问题
embeddings = tf.get_variable(name='embed_matrix',shape=[cfg.VOCAB_SIZE, cfg.VOCAB_SIZE])
embeding 被定义成 VOCAB_SIZE * VOCAB_SIZE 的矩阵
如果训练识别中文用 char_std_5990 做字典的话, 这个矩阵就是 5994 * 5994 (还有 , , , 4个token)
还会影响后面 decoder gru 的参数量
这么巨大的参数量会导致模型无法收敛
训练中文需要调整 embedding dimension

Is this trainning process right?

I use the default parameters set in original code, and got the following training process. The "train_decode" and "val_decode" both are null till "epoch:0, batch:1700", but the loss is still decreasing. And I changed the output format as following. So the acc keeps as 0.0, and is there something wrong with the training process?

train_decode: [,,,的的的的的的的] | val_decode: [,,,的的的的的的的] | ground_truth: [职顺带休息网友奥斯摩]
train_decode: [,,的的的的的的的的] | val_decode: [,,的的的的的的的的] | ground_truth: [中距空对空导弹实施攻]
train_decode: [,,的的的的的的的的] | val_decode: [,,,,的的的的的的] | ground_truth: [参见第十四章“无状之]
train_decode: [,,,的的的的的的的] | val_decode: [,,,,的的的的的的] | ground_truth: [面冒个泡,得三公之教]
epoch:0, batch:19300, loss:3.55256605148, acc:0.0
train_decode: [,,,的的的的的的的] | val_decode: [,,,的的的的的的的] | ground_truth: [也许你通过跟人力资源]
train_decode: [,,,的的的的的的的] | val_decode: [,,,的的的的的的的] | ground_truth: [狼或熊则是一个好兆头]
train_decode: [,,的的的的的的的的] | val_decode: [,,,,的的的的的的] | ground_truth: [网来说是一个不小的挑]
train_decode: [,,的的的的的的的的] | val_decode: [,,,,的的的的的的] | ground_truth: [无神论科学家的说法是]
train_decode: [,,的的的的的的的的] | val_decode: [,,,,的的的的的的] | ground_truth: [闻玉帝求又破右胁而出]
epoch:0, batch:19400, loss:3.53783345222, acc:0.0
train_decode: [,,的的的的的的的的] | val_decode: [,,的的的的的的的的] | ground_truth: [机械学院→华北工学院]
train_decode: [,的的的的的的的的的] | val_decode: [,的的的的的的的的的] | ground_truth: [中。同时身体向下弯曲]
train_decode: [,的的的的的的的的的] | val_decode: [,,,,的的的的的的] | ground_truth: [到了德佩罗(Desp]
train_decode: [,,,的的的的的的的] | val_decode: [,,,的的的的的的的] | ground_truth: [这完全是一个主权国家]
train_decode: [,的的的的的的的的的] | val_decode: [,,的的的的的的的的] | ground_truth: [时候他会永远离不开老]

准确率acc一直是0.0

您好,我用了你的代码,使用自己的数据集,训练的时候loss=0.0几,但是val_Decoder一直是错的,acc也一直是0.0.您知道是什么问题么?

Error

当我执行此代码时会显示错误消息,请有人可以帮助我
when I execute this code an error message is displayed, please someone can help me
capture du 2018-04-23 13-38-10

acc一直为0.0

您好,我使用的是您推荐的数据集,训练了3个epoch,acc一直为0,您知道是什么问题吗?

不能训练问题

程序在运行 sess.run(tf.global_variables_initializer()) 这句时卡主了 ,过一会就强制退出来了,请问这是什么原因呢

代码出问题,怎么解决

img = cv2.resize(im, (IMAGE_WIDTH,IMAGE_HEIGHT))

cv2.error: C:\bld\opencv_1498171314629\work\opencv-3.2.0\modules\imgproc\src\imgwarp.cpp:3492: error: (-215) ssize.width > 0 && ssize.height > 0 in function cv::resize

batchnormolization 没有更新

你好你在训练过程中BN的参数没有更新,导致在测试的时候必须使用和训练过程中一样的batch测试数据才能显示正确结果,如果在测试过程中把is_training改成false会出现错误

run error

Traceback (most recent call last):
File "/home/jxf/CRNN_Attention_OCR-master/train.py", line 24, in
img,label=cfg.read_data(cfg.train_dir)
File "/home/jxf/CRNN_Attention_OCR-master/config.py", line 64, in read_data
img = cv2.resize(im, (IMAGE_WIDTH,IMAGE_HEIGHT))
cv2.error: /io/opencv/modules/imgproc/src/imgwarp.cpp:3483: error: (-215) ssize.width > 0 && ssize.height > 0 in function resize

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.