Comments (4)
Sorry for the late reply. I just pushed a generic data provider for numpy arrays.
import numpy as np
from tf_unet import image_util
data = np.zeros((5, 10, 10, 3))
label = np.zeros((5, 10, 10), dtype=np.bool)
gen = image_util.SimpleDataProvider(data, label, channels=3, n_class=2)
Hope it helps
from tf_unet.
Thanks for your reply. It helps me a lot.
I copy the SimpleDataProvider to image_util.py, however, an error message "SimpleDataProvider is not defined.
I download the latest version and setup again. Then the above error is solved.
I tested my data by following code, Unfortunately a new error occurred.
Here is the nifti data&code
I can't solve this problem. Would it be possible for you to help me solve the error?
I'd be happy to provide more information as needed.
Look forward to your reply.
from tf_unet import unet
from tf_unet import util
from tf_unet import image_util
import numpy as np
import nibabel as nb
import matplotlib.pyplot as plt
# Read nii data
TrainFilePath = "G:/Tensorflow/tf_unet-master/demo/"
# TrainFilePath = "X:/tf_unet-master/demo/"
Train_Case = 1
Train_nii = nb.load(TrainFilePath + 'data-' + str(Train_Case) + '.nii')
Train_Label_nii = nb.load(TrainFilePath + 'data-label-' + str(Train_Case) + '.nii') # read label
# get data
Train_data = Train_nii.get_data()
Train_Label_data = Train_Label_nii.get_data() # 0: background; 1&2 ground truth
Train_data = Train_data.astype(np.uint8)
Train_Label_data = Train_Label_data.astype(np.uint8)
Train_Label_data[Train_Label_data == 2] = 1 # combine label 2 with label 1
# reshape data and label to [n, X, Y, channels], label [n, X, Y, classes].
nx = Train_data.shape[0]
ny = Train_data.shape[1]
nz = Train_data.shape[2]
Train_data = Train_data.reshape(nx, ny, nz, 1)
Train_data_label = Train_Label_data.reshape(nx, ny, nz, 1)
Train_data = Train_data.transpose(2, 0, 1, 3)
Train_data_label = Train_data_label.transpose(2, 0, 1, 3)
# provide data for the unet
data_provider = image_util.SimpleDataProvider(Train_data, Train_data_label)
# set parameter
net = unet.Unet(channels=1, n_class=3, layers=3, features_root=16)
trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2))
# train
path = trainer.train(data_provider, "./seg_3dim_data_trained_0509", training_iters=5, epochs=1, display_step=2)
# predict
Test_Slice = Train_data[40]
Test_Slice = Test_Slice.reshape(1, nx, ny, 1)
Test_label = Train_data_label[40]
Test_label = Test_label.reshape(1, nx, ny, 1)
prediction = net.predict("./seg_3dim_data_trained_0509", Test_Slice)
# show the result
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(12,6))
ax[0].imshow(Test_Slice[0,...,0], cmap = 'gray', aspect="auto")
ax[1].imshow(Test_label[0,...,0], cmap = 'gray', aspect="auto")
mask = prediction[0, ..., 1] > 0.9
ax[2].imshow(mask, aspect="auto")
ax[0].set_title("Input")
ax[1].set_title("Ground truth")
ax[2].set_title("Prediction")
fig.tight_layout()
plt.show()
The main error produced is this:
Traceback (most recent call last):
File "E:\Program Files (x86)\JetBrains\PyCharm Community Edition 2017.1\helpers\pydev\pydev_run_in_console.py", line 78, in <module>
globals = run_file(file, None, None)
File "E:\Program Files (x86)\JetBrains\PyCharm Community Edition 2017.1\helpers\pydev\pydev_run_in_console.py", line 35, in run_file
pydev_imports.execfile(file, globals, locals) # execute the script
File "E:\Program Files (x86)\JetBrains\PyCharm Community Edition 2017.1\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "G:/Tensorflow/tf_unet-master/Seg_3_dim_Nii_data.py", line 36, in <module>
path = trainer.train(data_provider, "./seg_3dim_data_trained_0509", training_iters=5, epochs=1, display_step=2)
File "G:/Tensorflow/tf_unet-master\tf_unet\unet.py", line 403, in train
test_x, test_y = data_provider(self.verification_batch_size)
File "G:/Tensorflow/tf_unet-master\tf_unet\image_util.py", line 98, in __call__
train_data, labels = self._load_data_and_label()
File "G:/Tensorflow/tf_unet-master\tf_unet\image_util.py", line 58, in _load_data_and_label
return train_data.reshape(1, ny, nx, self.channels), labels.reshape(1, ny, nx, self.n_class),
ValueError: cannot reshape array of size 262144 into shape (1,512,512,2)
from tf_unet.
If I understand your code correctly, then you are passing a tensor with 3 classes into the network but the label seems to only have 2 classes.
from tf_unet.
Could you give an example to use image_util.SimpleDataProvider as input load and successfully run the whole training?
from tf_unet.
Related Issues (20)
- How to use nii.gz image format in my own dataset?
- Cannot access the folder of checkpoints to predict.
- 2 class segmentation prediction values don't range from [0, 1] HOT 1
- converting to ONNX, need names of input and output nodes. HOT 1
- black screen regarding the predictor for demo_toy_problem Notebook HOT 1
- how i can use method "evaluate", to get a accuracy and precision of the neural network ? HOT 1
- Why my .meta file is growing bigger while training? HOT 1
- How to train data HOT 1
- Training Accuracy is always 1.00 and the Minibatch error is always 0.0% HOT 3
- what is the difference between jaccard similarity and Intersection over union ?
- how to detect galaxies and star in wide field imaging data. HOT 2
- Issues with import: Colab HOT 1
- How to get the requirement.txt? HOT 1
- Could you share your data set?
- TypeError: Fetch argument None has invalid type <class 'NoneType'>
- TypeError: Fetch argument None has invalid type <class 'NoneType'>
- TypeError: Fetch argument None has invalid type <class 'NoneType'> HOT 1
- ConcatOp error when changing layer_depth=3 to a larger value in circle demo. HOT 1
- UnsupportedPluginTypeException: Coordinate frame barycentricmeanecliptic not in allowed values
- Install tf_unet
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tf_unet.