Code Monkey home page Code Monkey logo

Comments (4)

fighting41love avatar fighting41love commented on July 20, 2024 2

@alexander-rakhlin Thanks for your help.
It is reported as a bug of Keras with python3 and keras 1.2.0.
keras-team/keras#4904
Thanks!

from cnn-for-sentence-classification-in-keras.

alexander-rakhlin avatar alexander-rakhlin commented on July 20, 2024

Hi,

Not sure what you call test set here. We perform validation during training, and validation set is being generated randomly according to val_split (=0.1) and random seed.

from cnn-for-sentence-classification-in-keras.

fighting41love avatar fighting41love commented on July 20, 2024

@alexander-rakhlin Thanks for your response.
Training set: a set of examples used for learning: to fit the parameters of the classifier
Validation set: a set of examples used to tune the parameters of a classifier
Test set: a set of examples used only to assess the performance of a fully-trained classifier
Yes, you perform validation during training. We have a test set to evaluate the overall performance of the classifier. The problem is that:
train + test = good performance
load the classifier + test = not very good performance.

Here is the code that can reproduce the problem.
Run this code, we can get a classifier with good performance.
Then comment the following codes:
model.fit(x_train, y_train, batch_size=batch_size,
nb_epoch=num_epochs, validation_split=val_split, verbose=2)
model.save('test.h5')

The performance becomes not very good.

`
from future import print_function
import numpy as np
import data_helpers
from w2v import train_word2vec
from keras.models import load_model
from keras.models import Sequential, Model
from keras.layers import Activation, Dense, Dropout, Embedding, Flatten, Input, Merge, Convolution1D, MaxPooling1D
from keras.optimizers import SGD
from keras.models import model_from_json
import data_helpers as dh
np.random.seed(20)

model_variation = 'CNN-non-static'
print('Model variation is %s' % model_variation)

sequence_length = 56
embedding_dim = 20
filter_sizes = (3, 4)
num_filters = 3
dropout_prob = (0.25, 0.5)
#dropout_prob = (0, 0)
hidden_dims = 100

batch_size = 32
num_epochs = 30
val_split = 0.1

min_word_count = 1
context = 10

print("Loading data...")
x, y, vocabulary, vocabulary_inv = data_helpers.load_data()

if model_variation=='CNN-non-static' or model_variation=='CNN-static':
embedding_weights = train_word2vec(x, vocabulary_inv, embedding_dim, min_word_count, context)
if model_variation=='CNN-static':
x = embedding_weights[0][x]
elif model_variation=='CNN-rand':
embedding_weights = None
else:
raise ValueError('Unknown model variation')

shuffle_indices = np.random.permutation(np.arange(len(y)))
x_shuffled = x[shuffle_indices]
y_shuffled = y[shuffle_indices].argmax(axis=1)

x_train = x_shuffled[:int(len(y)*2/3)]
y_train = y_shuffled[:int(len(y)*2/3)]
x_test = x_shuffled[int(len(y)*2/3):]
y_test = y_shuffled[int(len(y)*2/3):]

graph_in = Input(shape=(sequence_length, embedding_dim))
convs = []
for fsz in filter_sizes:
conv = Convolution1D(nb_filter=num_filters,
filter_length=fsz,
border_mode='valid',
activation='relu',
subsample_length=1)(graph_in)
pool = MaxPooling1D(pool_length=2)(conv)
flatten = Flatten()(pool)
convs.append(flatten)

if len(filter_sizes)>1:
out = Merge(mode='concat')(convs)
else:
out = convs[0]

graph = Model(input=graph_in, output=out)

model = Sequential()
if not model_variation=='CNN-static':
model.add(Embedding(len(vocabulary), embedding_dim, input_length=sequence_length,
weights=embedding_weights))
model.add(Dropout(dropout_prob[0], input_shape=(sequence_length, embedding_dim)))
model.add(graph)
model.add(Dense(hidden_dims))
model.add(Dropout(dropout_prob[1]))
model.add(Activation('relu'))
model.add(Dense(1))
model.add(Activation('sigmoid'))
opt = SGD(lr=0.01, momentum=0.80, decay=1e-6, nesterov=True)

model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy','precision','recall','fbeta_score'])
model.summary()
model.fit(x_train, y_train, batch_size=batch_size,
nb_epoch=num_epochs, validation_split=val_split, verbose=2)
model.save('test.h5')
loaded_model = load_model('test.h5')
loaded_model.summary()
score = loaded_model.evaluate(x_test,y_test, verbose=2)
print("%s: %.2f%%" % (loaded_model.metrics_names[1], score[1] * 100))
print("%s: %.2f%%" % (loaded_model.metrics_names[2], score[2] * 100))
print("%s: %.2f%%" % (loaded_model.metrics_names[3], score[3] * 100))
print("%s: %.2f%%" % (loaded_model.metrics_names[4], score[4] * 100))
`

from cnn-for-sentence-classification-in-keras.

alexander-rakhlin avatar alexander-rakhlin commented on July 20, 2024

Provided random seed is the same, both runs should give you the same test set. And model you use on test stage seems the same too. So I don't know how it gives different results. Try to debug and in the first place verify test set is the same. As an ultimate measure you can compare model weights on both stages.

from cnn-for-sentence-classification-in-keras.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.