Code Monkey home page Code Monkey logo

pytorch-enhance's Introduction

pytorch-enhance's People

Contributors

isaaccorley avatar lawrence880301 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

pytorch-enhance's Issues

Some hardcoded channels in test_models.py

If a user runes the tests with CHANNELS != 3 they will fail as there are some hardcoded channels, e.g. assert sr.shape == (1, 3, 64, 64). These could be updated to assert sr.shape == (1, CHANNELS, 64, 64)

Historical dataset error

On running the pytorch lightning example but swapping in the Historical dataset, the following error is raised:

RuntimeError: Given groups=1, weight of size [64, 3, 9, 9], expected input[1, 1, 256, 256] to have 3 channels, but got 1 channels instead

I assume there is a trivial fix in the dataset?

Inference example

I assumed I could do the following:

# Training with pl

# save for use in production environment
torch.save(model.state_dict(), 'ESPCN.pth')

# Load and evaluate model:
model = ESPCN(scale_factor, channels)
model.load_state_dict(torch.load("ESPCN.pth"))
model.eval()

test_images, test_labels = next(iter(test_dataloader)
lr_img = test_images[0].squeeze()
hr_img = test_labels[0].squeeze()

sr_img = model.enhance(test_images[0])

but this results in RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 1, 5, 5], but got 3-dimensional input of size [1, 200, 200] instead - I'm using enhance incorrectly?

Strange results.

Thanks for your good work.

But, when I try to use your library, this makes
sr_test
strange results.

I loaded the image using openCV and made it pytorch tensor which range is [0.,1.]

Is the default is un-trained model?

I followed this example below:

import torch
import torch_enhance

increase resolution by factor of 2 (e.g. 128x128 -> 256x256)
model = torch_enhance.models.SRResNet(scale_factor=2, channels=3)

lr = torch.randn(1, 3, 128, 128)
sr = model(x) # [1, 3, 256, 256] -> model(my_img)

Modified dataset to load from a local dir

Hi @isaaccorley
I modified the dataset to load from a local dir (in my case a mounted google drive) - might be one to add to the wiki if not to add to the codebase

class BaseDataset(torch.utils.data.Dataset):
    """Base Super Resolution Dataset Class
    """
    color_space: str = "RGB"
    lr_transform: T.Compose = None
    hr_transform: T.Compose = None

    def get_lr_transforms(self):
        """Returns HR to LR image transformations
        """
        return Compose([
            Resize(size=(
                    self.image_size//self.scale_factor,
                    self.image_size//self.scale_factor
                ),
                interpolation=Image.BICUBIC
            ),
            ToTensor(),
        ])

    def get_hr_transforms(self):
        """Returns HR image transformations
        """
        return Compose([
            Resize((self.image_size, self.image_size), Image.BICUBIC),
            ToTensor(),
        ])

    def get_files(self, data_dir: str) -> List[str]:
        """Returns  a list of valid image files in a directory
        Parameters
        ----------
        root_dir : str
            Path to directory of images.
        Returns
        -------
        List[str]
            List of valid images in `root_dir` directory.
        """
        return glob.glob(data_dir + '*.jpg')

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns  a tuple of and lr and hr torch tensors
        Parameters
        ----------
        idx : int
            Index value to index the list of images
        Returns
        -------
        lr: torch.Tensor
            Low Resolution transformed indexed image.
        hr: torch.Tensor
            High Resolution transformed indexed image.
        """
        lr = self.load_img(self.file_names[idx])
        hr = lr.copy()
        if self.lr_transform:
            lr = self.lr_transform(lr)
        if self.hr_transform:
            hr = self.hr_transform(hr)

        return lr, hr

    def __len__(self) -> int:
        """Return number of images in dataset
        Returns
        -------
        int
            Number of images in dataset file_names list
        """
        return len(self.file_names)


    def load_img(self, file_path: str) -> Image.Image:
        """Returns a PIL Image of the image located at `file_path`
        Parameters
        ----------
        file_path : str
            Path to image file to be loaded
        Returns
        -------
        PIL.Image.Image
            Loaded image as PIL Image
        """
        return Image.open(file_path).convert(self.color_space)

Attempts to calc psnr independently result in differences

I am using greyscale pngs which have been contrast stretched to have pixel values in range 0-255. I have a value for test_psnr, but when I independently calculate this value I get a different result. However I note that in your implementation of PSNR the max_val is set to 1, whereas elsewhere implementations use 255. Also in the enhance method you multiply by 255 then clip to 0,255, which is at odds with using max_val is set to 1? My functions are below: can you advise what is the issue resulting in the differences?

def contrast_stretch(img: np.array) -> np.array:
    """Contrast stretch an image
    Return pixels in range 0 to 255.
    Parameters
    ----------
    img : np.array
        Image to be contrast stretched
    Returns
    -------
    np.array
        Contrast stretched image
    """
    img_min = img.min()
    img_max = img.max()
    return (img - img_min) / (img_max - img_min)*255

def psnr(img1: np.array, img2: np.array, max_value: float = 255.0) -> float:
    """
    Compute the PSNR between two images.
    For 8bit greyscale max value is 255 and min is zero.
    """
    mse = np.mean((img1 - img2) ** 2)
    return round(20 * math.log10(max_value / math.sqrt(mse)), 1)

About the conv_index

The crucial code are as follows:

if conv_index == '22':
self.vgg = nn.Sequential(*modules[:8])
elif conv_index == '54':
self.vgg = nn.Sequential(*modules[:35])

According to some explanations like in https://paperswithcode.com/method/vgg-loss, we usually use the feature maps activated by a function like 'ReLU' to compute the perceptual similarity. More specifically, the conv_index 'i,j' are regarded to take the j-th convolution (after activation) before the i-th maxpooling layer. If so, back to the code, '22' will correspond to (*modules[:9], which refers to the layer after ReLU activation), and similarly, '54' will correspond to (*modules[:36], which refers to the layer after ReLU activation).

I am not sure whether my understanding is correct or not. Your help would be highly appreciated.

Best.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.