Comments (4)
@alexander-rakhlin Thanks for your help.
It is reported as a bug of Keras with python3 and keras 1.2.0.
keras-team/keras#4904
Thanks!
from cnn-for-sentence-classification-in-keras.
Hi,
Not sure what you call test set here. We perform validation during training, and validation set is being generated randomly according to val_split (=0.1) and random seed.
from cnn-for-sentence-classification-in-keras.
@alexander-rakhlin Thanks for your response.
Training set: a set of examples used for learning: to fit the parameters of the classifier
Validation set: a set of examples used to tune the parameters of a classifier
Test set: a set of examples used only to assess the performance of a fully-trained classifier
Yes, you perform validation during training. We have a test set to evaluate the overall performance of the classifier. The problem is that:
train + test = good performance
load the classifier + test = not very good performance.
Here is the code that can reproduce the problem.
Run this code, we can get a classifier with good performance.
Then comment the following codes:
model.fit(x_train, y_train, batch_size=batch_size,
nb_epoch=num_epochs, validation_split=val_split, verbose=2)
model.save('test.h5')
The performance becomes not very good.
`
from future import print_function
import numpy as np
import data_helpers
from w2v import train_word2vec
from keras.models import load_model
from keras.models import Sequential, Model
from keras.layers import Activation, Dense, Dropout, Embedding, Flatten, Input, Merge, Convolution1D, MaxPooling1D
from keras.optimizers import SGD
from keras.models import model_from_json
import data_helpers as dh
np.random.seed(20)
model_variation = 'CNN-non-static'
print('Model variation is %s' % model_variation)
sequence_length = 56
embedding_dim = 20
filter_sizes = (3, 4)
num_filters = 3
dropout_prob = (0.25, 0.5)
#dropout_prob = (0, 0)
hidden_dims = 100
batch_size = 32
num_epochs = 30
val_split = 0.1
min_word_count = 1
context = 10
print("Loading data...")
x, y, vocabulary, vocabulary_inv = data_helpers.load_data()
if model_variation=='CNN-non-static' or model_variation=='CNN-static':
embedding_weights = train_word2vec(x, vocabulary_inv, embedding_dim, min_word_count, context)
if model_variation=='CNN-static':
x = embedding_weights[0][x]
elif model_variation=='CNN-rand':
embedding_weights = None
else:
raise ValueError('Unknown model variation')
shuffle_indices = np.random.permutation(np.arange(len(y)))
x_shuffled = x[shuffle_indices]
y_shuffled = y[shuffle_indices].argmax(axis=1)
x_train = x_shuffled[:int(len(y)*2/3)]
y_train = y_shuffled[:int(len(y)*2/3)]
x_test = x_shuffled[int(len(y)*2/3):]
y_test = y_shuffled[int(len(y)*2/3):]
graph_in = Input(shape=(sequence_length, embedding_dim))
convs = []
for fsz in filter_sizes:
conv = Convolution1D(nb_filter=num_filters,
filter_length=fsz,
border_mode='valid',
activation='relu',
subsample_length=1)(graph_in)
pool = MaxPooling1D(pool_length=2)(conv)
flatten = Flatten()(pool)
convs.append(flatten)
if len(filter_sizes)>1:
out = Merge(mode='concat')(convs)
else:
out = convs[0]
graph = Model(input=graph_in, output=out)
model = Sequential()
if not model_variation=='CNN-static':
model.add(Embedding(len(vocabulary), embedding_dim, input_length=sequence_length,
weights=embedding_weights))
model.add(Dropout(dropout_prob[0], input_shape=(sequence_length, embedding_dim)))
model.add(graph)
model.add(Dense(hidden_dims))
model.add(Dropout(dropout_prob[1]))
model.add(Activation('relu'))
model.add(Dense(1))
model.add(Activation('sigmoid'))
opt = SGD(lr=0.01, momentum=0.80, decay=1e-6, nesterov=True)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy','precision','recall','fbeta_score'])
model.summary()
model.fit(x_train, y_train, batch_size=batch_size,
nb_epoch=num_epochs, validation_split=val_split, verbose=2)
model.save('test.h5')
loaded_model = load_model('test.h5')
loaded_model.summary()
score = loaded_model.evaluate(x_test,y_test, verbose=2)
print("%s: %.2f%%" % (loaded_model.metrics_names[1], score[1] * 100))
print("%s: %.2f%%" % (loaded_model.metrics_names[2], score[2] * 100))
print("%s: %.2f%%" % (loaded_model.metrics_names[3], score[3] * 100))
print("%s: %.2f%%" % (loaded_model.metrics_names[4], score[4] * 100))
`
from cnn-for-sentence-classification-in-keras.
Provided random seed is the same, both runs should give you the same test set. And model you use on test stage seems the same too. So I don't know how it gives different results. Try to debug and in the first place verify test set is the same. As an ultimate measure you can compare model weights on both stages.
from cnn-for-sentence-classification-in-keras.
Related Issues (20)
- error when retraining word vector HOT 3
- Error in w2v.py line 52 HOT 4
- TypeError: __init__() takes at least 3 arguments (2 given) HOT 1
- Running instructions HOT 1
- expected input_4 to have shape (None, 185) but got array with shape (1665, 35) HOT 1
- how to run trained model on sample sentence HOT 1
- Trying to replicate the results obtained with denny brtiz's code HOT 8
- Using local directory dataset does not yield the marked results HOT 1
- Using Glove or GoogleNews? HOT 1
- Wrong model for Y.Kim's TextCNN HOT 1
- Using pre-trained google word embeddings HOT 3
- Negative dimension size caused by subtracting 3 from 1
- Only words, no sentences HOT 1
- 问题咨询 HOT 8
- How to train the model with multi-class dataset HOT 2
- accuracy HOT 3
- The model always predicts the same label HOT 1
- Two fully-connected layers after convolutions HOT 1
- as for the CNN-non-static model initialization issue
- Multiple Dropouts different from Original Paper and Denny Britz
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cnn-for-sentence-classification-in-keras.