Code Monkey home page Code Monkey logo

Comments (6)

YassineYousfi avatar YassineYousfi commented on June 18, 2024

The final trained models were shared as releases here, you can try them out without training.
https://github.com/YassineYousfi/comma10k-baseline/releases

from comma10k-baseline.

Zeleni9 avatar Zeleni9 commented on June 18, 2024

Wow, thank you for fast response, lovely.

from comma10k-baseline.

Zeleni9 avatar Zeleni9 commented on June 18, 2024

Hey I ran inference on custom dataset, but it gives out something weird. The release model is trained on half images or full images resolution? I will add inference script and output, also tried the dataset images and they ouput the same problematic mask. Maybe the problem is in the preprocessing?

semantic_map

import os
from argparse import ArgumentParser
from LitModel import *
import cv2
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from PIL import Image
from torchvision import transforms as T
import torch

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    torch.device('cpu')
   
def preprocess(img: np.ndarray):
    img = np.moveaxis(img, -1, 0)  # from [H, W, C] to [C, H, W]
    img = img.astype(np.float32)  # typecasting to float32
    img = Image.fromarray(img, mode='RGB')  # load into Image to apply transforms
    transforms = get_valid_transforms()
    img = transforms(img)   # Apply transforms
    img = img.unsqueeze(0)
    return img


def get_valid_transforms(height: int = 896, 
                         width: int = 1184): 
    return T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            T.Resize((height, width),interpolation=Image.NEAREST)])



# def get_colormap(five=True):
#   f32 = lambda x: (x % 256, x//256 % 256, x//(256*256) % 256)
#   if five:
#     key = [2105408, 255, 0x608080, 6749952, 16711884]
#   else:
#     key = [0, 0xc4c4e2, 2105408, 255, 0x608080, 6749952, 16737792, 16711884]
#   return {i: f32(key[i]) for i in range(len(key))}
colormap = np.zeros((256, 3), dtype=np.uint8)
colormap[0] = [0, 255, 0]
colormap[1] = [244, 35, 232]
colormap[2] = [70, 255, 70]
colormap[3] = [125, 125, 0]
colormap[4] = [0, 0, 255]
colormap[5] = [255, 0, 0]

def decode_segmentation_masks(mask, colormap, n_classes):
    r = np.zeros_like(mask).astype(np.uint8)
    g = np.zeros_like(mask).astype(np.uint8)
    b = np.zeros_like(mask).astype(np.uint8)
    for l in range(0, n_classes):
        idx = mask == l
        r[idx] = colormap[l, 0]
        g[idx] = colormap[l, 1]
        b[idx] = colormap[l, 2]
    rgb = np.stack([r, g, b], axis=2)
    return rgb


def get_overlay(image, colored_mask):
    image = tf.keras.preprocessing.image.array_to_img(image)
    image = np.array(image).astype(np.uint8)
    image = cv2.resize(image, (1184, 896))
    overlay = cv2.addWeighted(image, 0.35, colored_mask, 0.65, 0)
    return overlay


def plot_samples_matplotlib(display_list, figsize=(5, 3)):
    _, axes = plt.subplots(nrows=1, ncols=len(display_list), figsize=figsize)
    for i in range(len(display_list)):
        if display_list[i].shape[-1] == 3:
            axes[i].imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        else:
            axes[i].imshow(display_list[i])
    plt.show()


def inference_model(args):
    """
    :param args:
    :return:
    """    
    video_path = args.video_path
    pretrained_checkpoint = 'epoch.28_val_loss.0.0439.ckpt'
    model = LitModel.load_from_checkpoint(pretrained_checkpoint, **vars(args))
    model.eval()
    model.to(device)

    # Prepare video cap 
    cap = cv2.VideoCapture(video_path)
    if (cap.isOpened() == False):
        print("Error opening video stream or file")

    # Run inference on images from the video
    while (cap.isOpened()):
         ret, image = cap.read()
         if ret == True:            
            image_vis = image.copy()
            image = preprocess(image)
            image=image.to(device)
            
            with torch.no_grad():
                output = model.forward(image)        
                output = output.squeeze(0) # [1, 6, 896, 1184] -> [6, 896, 1184]
                output_predictions = output.argmax(0)

                output_predictions = output_predictions.cpu().numpy()
                prediction_colormap = decode_segmentation_masks(output_predictions, colormap, 5)

                print(prediction_colormap.shape) # [896, 1184]
                overlay = get_overlay(image_vis, prediction_colormap)
                plot_samples_matplotlib(
                            [image_vis, overlay, prediction_colormap], figsize=(18, 14))

    return




if __name__ == '__main__':
    """
    Inference on custom images
    Run with:
        python inference.py --backbone efficientnet-b4 --video_path .\center_video.avi  --gpus 1 --batch-size 1 --learning-rate 5e-5 --epochs 30 --height 874 --width 1164 --augmentation-level hard    
    """
    root_dir = os.path.dirname(os.path.realpath(__file__))
    
    # Init arguments
    parent_parser = ArgumentParser(add_help=False)
    parser = LitModel.add_model_specific_args(parent_parser)
    parser.add_argument('--video_path', type=str, help='The video path or the src image save dir')
    parser.add_argument('--weights_path', type=str, help='The model weights path')
    args = parser.parse_args()
    
    inference_model(args)

from comma10k-baseline.

YassineYousfi avatar YassineYousfi commented on June 18, 2024

The first thing I see is that you are using different transforms from those used in training.
c.f. https://github.com/YassineYousfi/comma10k-baseline/blob/main/retriever.py#L137
You can also use this https://github.com/YassineYousfi/comma10k-baseline/blob/predict-folder/predict_folder.py to predict on your own images in a folder.
The models on the release are trained with --height 874 --width 1164

from comma10k-baseline.

Zeleni9 avatar Zeleni9 commented on June 18, 2024

Tested it now, works pretty good. It is also interesting that models works with 10k images only to this extent. Is it possible to not use Pytorch Lighting for inference, since it seems I am loading all the training data and training workflow just to test folder images. Tried exporting the model to ONNX but Swish activation function is causing problems for export. Thanks

from comma10k-baseline.

YassineYousfi avatar YassineYousfi commented on June 18, 2024

Normalizing the pixels makes a big difference... So I am not surprised.
Yes, you can use vanilla pytorch as you did in your previous script. I heard the new torch inference_mode is pretty good.
The script does not load the training data, it only lists the files in your training directories.
Swish will cause some ONNX export problems, it is a known issue.

from comma10k-baseline.

Related Issues (6)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.