isaaccorley / pytorch-enhance Goto Github PK
View Code? Open in Web Editor NEWOpen-source Library of Image Super-Resolution Models, Datasets, and Metrics for Benchmarking or Pretrained Use
License: Apache License 2.0
Open-source Library of Image Super-Resolution Models, Datasets, and Metrics for Benchmarking or Pretrained Use
License: Apache License 2.0
If a user runes the tests with CHANNELS != 3
they will fail as there are some hardcoded channels, e.g. assert sr.shape == (1, 3, 64, 64)
. These could be updated to assert sr.shape == (1, CHANNELS, 64, 64)
On running the pytorch lightning example but swapping in the Historical dataset, the following error is raised:
RuntimeError: Given groups=1, weight of size [64, 3, 9, 9], expected input[1, 1, 256, 256] to have 3 channels, but got 1 channels instead
I assume there is a trivial fix in the dataset?
I assumed I could do the following:
# Training with pl
# save for use in production environment
torch.save(model.state_dict(), 'ESPCN.pth')
# Load and evaluate model:
model = ESPCN(scale_factor, channels)
model.load_state_dict(torch.load("ESPCN.pth"))
model.eval()
test_images, test_labels = next(iter(test_dataloader)
lr_img = test_images[0].squeeze()
hr_img = test_labels[0].squeeze()
sr_img = model.enhance(test_images[0])
but this results in RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 1, 5, 5], but got 3-dimensional input of size [1, 200, 200] instead
- I'm using enhance
incorrectly?
I don't see links anywhere. is there a hub for these?
Thanks for your good work.
But, when I try to use your library, this makes
strange results.
I loaded the image using openCV and made it pytorch tensor which range is [0.,1.]
Is the default is un-trained model?
I followed this example below:
import torch
import torch_enhance
increase resolution by factor of 2 (e.g. 128x128 -> 256x256)
model = torch_enhance.models.SRResNet(scale_factor=2, channels=3)
lr = torch.randn(1, 3, 128, 128)
sr = model(x) # [1, 3, 256, 256] -> model(my_img)
Hi @isaaccorley
I modified the dataset to load from a local dir (in my case a mounted google drive) - might be one to add to the wiki if not to add to the codebase
class BaseDataset(torch.utils.data.Dataset):
"""Base Super Resolution Dataset Class
"""
color_space: str = "RGB"
lr_transform: T.Compose = None
hr_transform: T.Compose = None
def get_lr_transforms(self):
"""Returns HR to LR image transformations
"""
return Compose([
Resize(size=(
self.image_size//self.scale_factor,
self.image_size//self.scale_factor
),
interpolation=Image.BICUBIC
),
ToTensor(),
])
def get_hr_transforms(self):
"""Returns HR image transformations
"""
return Compose([
Resize((self.image_size, self.image_size), Image.BICUBIC),
ToTensor(),
])
def get_files(self, data_dir: str) -> List[str]:
"""Returns a list of valid image files in a directory
Parameters
----------
root_dir : str
Path to directory of images.
Returns
-------
List[str]
List of valid images in `root_dir` directory.
"""
return glob.glob(data_dir + '*.jpg')
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns a tuple of and lr and hr torch tensors
Parameters
----------
idx : int
Index value to index the list of images
Returns
-------
lr: torch.Tensor
Low Resolution transformed indexed image.
hr: torch.Tensor
High Resolution transformed indexed image.
"""
lr = self.load_img(self.file_names[idx])
hr = lr.copy()
if self.lr_transform:
lr = self.lr_transform(lr)
if self.hr_transform:
hr = self.hr_transform(hr)
return lr, hr
def __len__(self) -> int:
"""Return number of images in dataset
Returns
-------
int
Number of images in dataset file_names list
"""
return len(self.file_names)
def load_img(self, file_path: str) -> Image.Image:
"""Returns a PIL Image of the image located at `file_path`
Parameters
----------
file_path : str
Path to image file to be loaded
Returns
-------
PIL.Image.Image
Loaded image as PIL Image
"""
return Image.open(file_path).convert(self.color_space)
Possibly a mistake? In master branch setup.py __version__ = "0.1.4"
but pypi has 0.1.5 https://pypi.org/project/torch-enhance/
I am using greyscale pngs which have been contrast stretched to have pixel values in range 0-255. I have a value for test_psnr, but when I independently calculate this value I get a different result. However I note that in your implementation of PSNR the max_val is set to 1, whereas elsewhere implementations use 255. Also in the enhance
method you multiply by 255 then clip to 0,255, which is at odds with using max_val is set to 1? My functions are below: can you advise what is the issue resulting in the differences?
def contrast_stretch(img: np.array) -> np.array:
"""Contrast stretch an image
Return pixels in range 0 to 255.
Parameters
----------
img : np.array
Image to be contrast stretched
Returns
-------
np.array
Contrast stretched image
"""
img_min = img.min()
img_max = img.max()
return (img - img_min) / (img_max - img_min)*255
def psnr(img1: np.array, img2: np.array, max_value: float = 255.0) -> float:
"""
Compute the PSNR between two images.
For 8bit greyscale max value is 255 and min is zero.
"""
mse = np.mean((img1 - img2) ** 2)
return round(20 * math.log10(max_value / math.sqrt(mse)), 1)
The crucial code are as follows:
if conv_index == '22':
self.vgg = nn.Sequential(*modules[:8])
elif conv_index == '54':
self.vgg = nn.Sequential(*modules[:35])
According to some explanations like in https://paperswithcode.com/method/vgg-loss, we usually use the feature maps activated by a function like 'ReLU' to compute the perceptual similarity. More specifically, the conv_index 'i,j' are regarded to take the j-th convolution (after activation) before the i-th maxpooling layer. If so, back to the code, '22' will correspond to (*modules[:9], which refers to the layer after ReLU activation), and similarly, '54' will correspond to (*modules[:36], which refers to the layer after ReLU activation).
I am not sure whether my understanding is correct or not. Your help would be highly appreciated.
Best.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.