Comments (6)
My understanding for your including gradients and self.net.gradients_nodes is for the debugging usages.
If Unet.summaries
is set to true you get some histograms in the Tensorboard displaying the gradient values for each epoch. If the values remain very large for instance it might indicate that the training is not working.
In other words, if I am only trying to run the training process, I can remove these two items and re-write your code as.
Yes in principle, but I still don't understand why you want to make changes in the Trainer
class
from tf_unet.
Hi Joel,
No, I do not plan to change the trainer class. The reason that I am asking this question is because, I saw some other tensorflow implementations, and found that most of them normally do not include the gradient part explicitly in the optimization( or back-propagation) part as you did. I am just curious the reason that you want to do this way, or any trick(or special considerations) for this? In accordance with your explanation, the purpose is more for debugging and performance tracking purposes.
from tf_unet.
Out of curiosity I downloaded the kaggle ultrasound dataset. To start the training I reused most of the script/rfi_launcher.py" and
script/radio_util.py` code. I more or less just replaced the loading of the data with a call to the PIL library:
data = np.array(Image.open(self.data_files[self.file_idx]), dtype=np.float32)
I don't get black predictions and the learning rate is decreasing as expected
from tf_unet.
Hi Joel,
Thanks for trying this data set as well. I really want to fully understand your code and get it worked because I really like your code, which is very well designed.
Do you use the same parameter, such as three layers, feature_root=16, Have you normalized the image?
I donβt change the optimization step. The only thing I change is the code for loading the image files. I am quite confused why my ones does not work.
(a quick update, i just got the figure after finishing 12 epochs, looks like the prediction one is not full black. But the learning rate does not change as well. )
In the demo problem, generally, you randomly generate training data set on the fly within each epoch. For kaggle case, I need to iterate over the training set consisting of 5635 pairs of images and mask images
This is what I did during training step, I use batch_x, batch_y = data_provider(ij, self.batch_size)
to replace batch_x, batch_y = data_provider(self.batch_size)
. Here ij represents the starting index because I need to iterate over the training set.
The following is new_get_image_gen
, which is similar to get_image_gen
. The only difference is that I introduce one input parameter to control the starting point for each iteration
def new_get_image_gen(raw_image, mask_image):
def new_create_batch(starting_point, n_image):
batch_x = raw_image[starting_point:starting_point+n_image,:,:,:]
batch_y = mask_image[starting_point:starting_point+n_image,:,:,:]
return batch_x,batch_y
new_create_batch.channels = 1
new_create_batch.n_class = 2
return new_create_batch
This is how I launch the training process
image_gen.create_train_data(raw_data_path,train_data_path) # only the first time needs to include this operation
imgs_train = np.load(os.path.join(train_data_path,"imgs_train.npy"))
imgs_mask_train = np.load(os.path.join(train_data_path,"imgs_mask_train.npy"))
generator = image_gen.new_get_image_gen(imgs_train,imgs_mask_train)
net = unet.Unet(channels=generator.channels, n_class=generator.n_class, layers=3, features_root=16)
trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2))
path = trainer.train(generator, "/data/unet_trained", training_iters=2816, epochs=100, display_step=2)
The function of create_train_data
is attached. Generally, it will return two arrays: one is imgs_train of shape (5635,420,580,1)
and imgs_mask_train of shape(5635,420,580,2)
.
from tf_unet.
Here is my implementation: https://gist.github.com/jakeret/cd1e8b51009b995c38b44101c38ad831
To create the instance I use something like this
from scripts.ultrasound_util import DataProvider
data_provider = DataProvider(data_files=glob.glob(data_root+"/[0-9]?_[0-9]?.tif"),
mask_files=glob.glob(data_root+"/[0-9]?_[0-9]?_mask.tif"),)
Just like in the rfi example
from tf_unet.
Hi Joel,
Thanks.
from tf_unet.
Related Issues (20)
- How to use nii.gz image format in my own dataset?
- Cannot access the folder of checkpoints to predict.
- 2 class segmentation prediction values don't range from [0, 1] HOT 1
- converting to ONNX, need names of input and output nodes. HOT 1
- black screen regarding the predictor for demo_toy_problem Notebook HOT 1
- how i can use method "evaluate", to get a accuracy and precision of the neural network ? HOT 1
- Why my .meta file is growing bigger while training? HOT 1
- How to train data HOT 1
- Training Accuracy is always 1.00 and the Minibatch error is always 0.0% HOT 3
- what is the difference between jaccard similarity and Intersection over union ?
- how to detect galaxies and star in wide field imaging data. HOT 2
- Issues with import: Colab HOT 1
- How to get the requirement.txt? HOT 1
- Could you share your data set?
- TypeError: Fetch argument None has invalid type <class 'NoneType'>
- TypeError: Fetch argument None has invalid type <class 'NoneType'>
- TypeError: Fetch argument None has invalid type <class 'NoneType'> HOT 1
- ConcatOp error when changing layer_depth=3 to a larger value in circle demo. HOT 1
- UnsupportedPluginTypeException: Coordinate frame barycentricmeanecliptic not in allowed values
- Install tf_unet
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tf_unet.