ry / tensorflow-vgg16 Goto Github PK
View Code? Open in Web Editor NEWconversation of caffe vgg16 model to tensorflow
conversation of caffe vgg16 model to tensorflow
Hello, it's really nice of you to release the source code.
I wonder that do implement the training phases? Or do you have any idea of implement VGG net, including training /testing phase?
Thx!
Andrew
I want to extract features with vgg16. I can use "get_tensor_by_name" to get the features of certain layer. Now I don't the exact name of tensor. So how can I get the all the name of tensors in network?
First, thank you so much for this library. It is exactly what I was looking for and is working well. I do have a question about inputs.
On the caffe project page (https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-md), linked in the https://github.com/ry/tensorflow-vgg16/blob/master/README.md it states the following:
The input images should be zero-centered by mean pixel (rather than mean image) subtraction. Namely, the following BGR values should be subtracted: [103.939, 116.779, 123.68].
Which suggests that, for the caffe model, images should be BGR with the specified vector subtracted. This is different from utils.load_image(), which provides RGB images with pixels scaled by 255.0. I just want to confirm that the format we should use is RGB, img /= 255.0
, and not BGR, img -= [103.939, 116.779, 123.68]
as specified on the caffe project page.
Note: I confirmed that skimage.io.imread() returns an RGB image by creating a blue jpg, loading it, and noting [0,0,255] for all pixels.
Thank you,
Mark
I want to use only part of the model(discard all the information after the last convolution), can I do it with your model???
Thank you very much for the code. I got the following error message. Any idea?
Thank you!
caffe session
prob shape (1000,)
Top1: n02123045 tabby, tabby cat
Top5: ['n02123045 tabby, tabby cat', 'n02124075 Egyptian cat', 'n02123159 tiger cat', 'n02119789 kit fox, Vulpes macrotis', 'n02119022 red fox, Vulpes vulpes']
tensorflow session
F tensorflow/stream_executor/cuda/cuda_driver.cc:302] current context was not created by the StreamExecutor cuda_driver API: 0x2c3a910; a CUDA runtime call was likely performed without using a StreamExecutor context
Aborted (core dumped)
Hello, thank you for this great source code.
I would like to use the FC layers below prob layer, but I am not sure how I should do it.
I tried something like "graph.get_tensor_by_name("import/fc7:0") but it does not work.
Thank you
-Taeksoo
This project lacks a license. The resnet-converter has the MIT license but this one does not. This prevents people from uploading the converted model on their own site (with due credit of course) if they want to err on the site of caution.
It would be nice to get a license for the converter code since the original VGG 16 model is released under the Creative Commons Attribution License
Running the following after download gives error at ParseFromString:
import tensorflow as tf
import sys
import skimage
import skimage.io
import skimage.transform
import numpy as np
synset = [l.strip() for l in open('/home/ubuntu/tensorflow-vgg16/synset.txt').readlines()]
VGG_MEAN = [103.939, 116.779, 123.68]
# returns image of shape [224, 224, 3]
# [height, width, depth]
def load_image(path):
# load image
img = skimage.io.imread(path)
img = img / 255.0
assert (0 <= img).all() and (img <= 1.0).all()
#print "Original Image Shape: ", img.shape
# we crop image from center
short_edge = min(img.shape[:2])
yy = int((img.shape[0] - short_edge) / 2)
xx = int((img.shape[1] - short_edge) / 2)
crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
# resize to 224, 224
resized_img = skimage.transform.resize(crop_img, (224, 224))
return resized_img
# returns the top1 string
def print_prob(prob):
#print prob
print "prob shape", prob.shape
pred = np.argsort(prob)[::-1]
# Get top1 label
top1 = synset[pred[0]]
print "Top1: ", top1
# Get top5 label
top5 = [synset[pred[i]] for i in range(5)]
print "Top5: ", top5
return top1
with open("/home/ubuntu/vgg16-v4.tfmodel", mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
images = tf.placeholder("float", [None, 224, 224, 3])
tf.import_graph_def(graph_def, input_map={ "images": images })
print "graph loaded from disk"
graph = tf.get_default_graph()
cat = load_image("cat.jpg")
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
print "variables initialized"
batch = cat.reshape((1, 224, 224, 3))
assert batch.shape == (1, 224, 224, 3)
feed_dict = { images: batch }
prob_tensor = graph.get_tensor_by_name("import/prob:0")
prob = sess.run(prob_tensor, feed_dict=feed_dict)
print_prob(prob[0])
Error
---------------------------------------------------------------------------
DecodeError Traceback (most recent call last)
<ipython-input-1-c8f1d9f927de> in <module>()
48
49 graph_def = tf.GraphDef()
---> 50 graph_def.ParseFromString(fileContent)
51
52 images = tf.placeholder("float", [None, 224, 224, 3])
/usr/local/lib/python2.7/dist-packages/google/protobuf/message.py in ParseFromString(self, serialized)
183 """
184 self.Clear()
--> 185 self.MergeFromString(serialized)
186
187 def SerializeToString(self):
/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py in MergeFromString(self, serialized)
1006 length = len(serialized)
1007 try:
-> 1008 if self._InternalParse(serialized, 0, length) != length:
1009 # The only reason _InternalParse would return early is if it
1010 # encountered an end-group tag.
/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py in InternalParse(self, buffer, pos, end)
1042 pos = new_pos
1043 else:
-> 1044 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1045 if field_desc:
1046 self._UpdateOneofState(field_desc)
/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py in DecodeRepeatedField(buffer, pos, end, message, field_dict)
626 new_pos = pos + size
627 if new_pos > end:
--> 628 raise _DecodeError('Truncated message.')
629 # Read sub-message.
630 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
DecodeError: Truncated message.
Hi,
the VGG-16 model you referenced (here) requires that we subtract the mean image from the current image. Your code does not seem to do that, or?
Have you checked what image classification performance you achieve with this model?
caffe session
prob shape (1000,)
Top1: n02123045 tabby, tabby cat
Top5: ['n02123045 tabby, tabby cat', 'n02124075 Egyptian cat', 'n02123159 tiger cat', 'n02119789 kit fox, Vulpes macrotis', 'n02119022 red fox, Vulpes vulpes']
tensorflow session
Top1: n02123159 tiger cat
Top5: ['n02123159 tiger cat', 'n02124075 Egyptian cat', 'n02123045 tabby, tabby cat', 'n02113023 Pembroke, Pembroke Welsh corgi', 'n02094258 Norwich terrier']
Any idea why this happens? Awesome work by the way!
Hi,
Thanks for sharing your work.
The Tensorflow model for VGG16 torrent download file is not working. Can you please share that. I am in need of that.
Thanks,
Shahid
[email protected]
Hello,
I try to run the file tf_forward.py, but I get the error:
ValueError: Attempted to map inputs that were not found in graph_def: [images:0]
The code is:
import tensorflow as tf
import utils
with open("vgg16-v4.tfmodel", mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
images = tf.placeholder("float", [None, 224, 224, 3])
tf.import_graph_def(graph_def, input_map={ "images": images })
tf.import_graph_def(graph_def, input_map={ "images": images })
Traceback (most recent call last):
File "", line 1, in
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/importer.py", line 283, in import_graph_def
% ', '.join(unused_input_keys))
ValueError: Attempted to map inputs that were not found in graph_def: [images:0]
I downloaded the result (cuz I don't wanna install Caffe) and tried to read it in code with:
with open(FLAGS.model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
And I got
Traceback (most recent call last):
File "synthesis.py", line 131, in <module>
main()
File "synthesis.py", line 56, in main
X = load_vgg16() #load_inception()
File "synthesis.py", line 35, in load_vgg16
graph_def.ParseFromString(s)
File "/Users/taiyuanz/tensorflow/lib/python2.7/site-packages/google/protobuf/message.py", line 185, in ParseFromString
self.MergeFromString(serialized)
File "/Users/taiyuanz/tensorflow/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1090, in MergeFromString
raise message_mod.DecodeError('Unexpected end-group tag.')
google.protobuf.message.DecodeError: Unexpected end-group tag.
I don't see anything particularly wrong so I can only guess it was due to incompatibility of protobuf versions. Can you help take a look?
"conversation of caffe vgg16 model to tensorflow"
I guess you meant "conversion" ๐
Thanks for loading up the model in TensorFlow. I have the torrent file but have no idea how to get the model at this point. How do I do this?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.