Code Monkey home page Code Monkey logo

Comments (6)

jakeret avatar jakeret commented on July 17, 2024

My understanding for your including gradients and self.net.gradients_nodes is for the debugging usages.

If Unet.summaries is set to true you get some histograms in the Tensorboard displaying the gradient values for each epoch. If the values remain very large for instance it might indicate that the training is not working.

In other words, if I am only trying to run the training process, I can remove these two items and re-write your code as.

Yes in principle, but I still don't understand why you want to make changes in the Trainer class

from tf_unet.

surfreta avatar surfreta commented on July 17, 2024

Hi Joel,

No, I do not plan to change the trainer class. The reason that I am asking this question is because, I saw some other tensorflow implementations, and found that most of them normally do not include the gradient part explicitly in the optimization( or back-propagation) part as you did. I am just curious the reason that you want to do this way, or any trick(or special considerations) for this? In accordance with your explanation, the purpose is more for debugging and performance tracking purposes.

from tf_unet.

jakeret avatar jakeret commented on July 17, 2024

Out of curiosity I downloaded the kaggle ultrasound dataset. To start the training I reused most of the script/rfi_launcher.py" andscript/radio_util.py` code. I more or less just replaced the loading of the data with a call to the PIL library:

data = np.array(Image.open(self.data_files[self.file_idx]), dtype=np.float32)

I don't get black predictions and the learning rate is decreasing as expected

from tf_unet.

surfreta avatar surfreta commented on July 17, 2024

Hi Joel,
Thanks for trying this data set as well. I really want to fully understand your code and get it worked because I really like your code, which is very well designed.

Do you use the same parameter, such as three layers, feature_root=16, Have you normalized the image?
I don’t change the optimization step. The only thing I change is the code for loading the image files. I am quite confused why my ones does not work.

(a quick update, i just got the figure after finishing 12 epochs, looks like the prediction one is not full black. But the learning rate does not change as well. )
epoch_12

In the demo problem, generally, you randomly generate training data set on the fly within each epoch. For kaggle case, I need to iterate over the training set consisting of 5635 pairs of images and mask images

This is what I did during training step, I use batch_x, batch_y = data_provider(ij, self.batch_size) to replace batch_x, batch_y = data_provider(self.batch_size). Here ij represents the starting index because I need to iterate over the training set.

The following is new_get_image_gen, which is similar to get_image_gen. The only difference is that I introduce one input parameter to control the starting point for each iteration

def new_get_image_gen(raw_image, mask_image):
def new_create_batch(starting_point, n_image):
batch_x = raw_image[starting_point:starting_point+n_image,:,:,:]
batch_y = mask_image[starting_point:starting_point+n_image,:,:,:]
return batch_x,batch_y
new_create_batch.channels = 1
new_create_batch.n_class = 2
return new_create_batch

This is how I launch the training process

image_gen.create_train_data(raw_data_path,train_data_path)  # only the first time needs to include this operation
    imgs_train = np.load(os.path.join(train_data_path,"imgs_train.npy"))
    imgs_mask_train = np.load(os.path.join(train_data_path,"imgs_mask_train.npy"))      
    generator = image_gen.new_get_image_gen(imgs_train,imgs_mask_train)
    net = unet.Unet(channels=generator.channels, n_class=generator.n_class, layers=3, features_root=16)
    trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2))
    path = trainer.train(generator, "/data/unet_trained", training_iters=2816, epochs=100, display_step=2)

The function of create_train_data is attached. Generally, it will return two arrays: one is imgs_train of shape (5635,420,580,1)and imgs_mask_train of shape(5635,420,580,2).

create_train_data.txt

from tf_unet.

jakeret avatar jakeret commented on July 17, 2024

Here is my implementation: https://gist.github.com/jakeret/cd1e8b51009b995c38b44101c38ad831

To create the instance I use something like this

from scripts.ultrasound_util import DataProvider

data_provider = DataProvider(data_files=glob.glob(data_root+"/[0-9]?_[0-9]?.tif"),
                             mask_files=glob.glob(data_root+"/[0-9]?_[0-9]?_mask.tif"),)

Just like in the rfi example

from tf_unet.

surfreta avatar surfreta commented on July 17, 2024

Hi Joel,

Thanks.

from tf_unet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.