Code Monkey home page Code Monkey logo

Comments (7)

MaybeShewill-CV avatar MaybeShewill-CV commented on May 28, 2024 1

@dmsql0816 If you train the model on cityscapes datase. You need to adjust some field value in

TRAIN:
MODEL_SAVE_DIR: 'model/cityscapes/'
TBOARD_SAVE_DIR: 'tboard/cityscapes/'
MODEL_PARAMS_CONFIG_FILE_NAME: "model_train_config.json"
RESTORE_FROM_SNAPSHOT:
ENABLE: False
SNAPSHOT_PATH: ''
SNAPSHOT_EPOCH: 8
BATCH_SIZE: 16
VAL_BATCH_SIZE: 4
EPOCH_NUMS: 905
WARM_UP:
ENABLE: True
EPOCH_NUMS: 8
FREEZE_BN:
ENABLE: False
COMPUTE_MIOU:
ENABLE: True
EPOCH: 1
MULTI_GPU:
ENABLE: True
GPU_DEVICES: ['0', '1', '2', '3']
CHIEF_DEVICE_INDEX: 0
. Other section (such as OHEM in solver) do not need to be adjusted:)

from bisenetv2-tensorflow.

MaybeShewill-CV avatar MaybeShewill-CV commented on May 28, 2024

@dmsql0816 The min sample counts in OHEM process. Check code

def _compute_ohem_cross_entropy_loss(cls, seg_logits, labels, class_nums, name, thresh, n_min):
"""
:param seg_logits:
:param labels:
:param class_nums:
:param name:
:return:
"""
with tf.variable_scope(name_or_scope=name):
# first check if the logits' shape is matched with the labels'
seg_logits_shape = seg_logits.shape[1:3]
labels_shape = labels.shape[1:3]
seg_logits = tf.cond(
tf.reduce_all(tf.equal(seg_logits_shape, labels_shape)),
true_fn=lambda: seg_logits,
false_fn=lambda: tf.image.resize_bilinear(seg_logits, labels_shape)
)
seg_logits = tf.reshape(seg_logits, [-1, class_nums])
labels = tf.reshape(labels, [-1, ])
indices = tf.squeeze(tf.where(tf.less_equal(labels, class_nums - 1)), 1)
seg_logits = tf.gather(seg_logits, indices)
labels = tf.cast(tf.gather(labels, indices), tf.int32)
# compute cross entropy loss
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels,
logits=seg_logits
)
loss, _ = tf.nn.top_k(loss, tf.size(loss), sorted=True)
# apply ohem
ohem_thresh = tf.multiply(-1.0, tf.math.log(thresh), name='ohem_score_thresh')
ohem_cond = tf.greater(loss[n_min], ohem_thresh)
loss_select = tf.cond(
pred=ohem_cond,
true_fn=lambda: tf.gather(loss, tf.squeeze(tf.where(tf.greater(loss, ohem_thresh)), 1)),
false_fn=lambda: loss[:n_min]
)
loss_value = tf.reduce_mean(loss_select, name='ohem_cross_entropy_loss')
return loss_value
.

from bisenetv2-tensorflow.

dmsql0816 avatar dmsql0816 commented on May 28, 2024

To fix that error, should I have to fix somewhere in bisenet_v2.py?

from bisenetv2-tensorflow.

MaybeShewill-CV avatar MaybeShewill-CV commented on May 28, 2024

@dmsql0816 No. If you not familiar with ohem you may reduce MIN_SAMPLE_NUMS in config file or disable ohem in config file.

from bisenetv2-tensorflow.

dmsql0816 avatar dmsql0816 commented on May 28, 2024

In paper, they use ohem, so I want to use it too.
I want to proceed as you training BiseNetv2-tensorflow, just change batch size into 8, and training with 2 gpus.
what should I have to do to fix that error?
sorry for bothering you, and thank you for your fast reply.

from bisenetv2-tensorflow.

dmsql0816 avatar dmsql0816 commented on May 28, 2024

It WORKS!
Thank you so much!!

from bisenetv2-tensorflow.

MaybeShewill-CV avatar MaybeShewill-CV commented on May 28, 2024

@dmsql0816 No problem:)

from bisenetv2-tensorflow.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.