I dont know what is the reason behind this error, I tried to debug it but couldnt get to the core of it
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class depthwise_conv(nn.Module):
def __init__(self, nin, kernel_size, padding, stride=1, dilation=1):
super(depthwise_conv, self).__init__()
self.depthwise = nn.Conv2d(nin, nin, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=nin)
def forward(self, x):
out = self.depthwise(x)
return out
class dw_block(nn.Module):
def __init__(self, nin, kernel_size, padding=1, stride=1, dilation=1):
super(dw_block, self).__init__()
self.dw_block = nn.Sequential(
depthwise_conv(nin, kernel_size, stride, padding, dilation),
)
def forward(self, x):
out = self.dw_block(x)
return out
class pointwise_conv(nn.Module):
def __init__(self, nin, nout, padding=0, stride=1):
super(pointwise_conv, ####qweqwe##self).__init__()
self.pointwise_block = nn.Sequential(
nn.Conv2d(nin, nout, kernel_size=1, stride=stride, padding=padding),
)
def forward(self, x):
out = self.pointwise_block(x)
return out
class SuperRes(nn.Module):
def __init__(self, scale_factor=3, num_channels=1, d=32, s=12, m=4):
super(SuperRes, self).__init__()
self.first_part = nn.Sequential(
nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),
nn.PReLU(d)
)
self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]
for _ in range(m):
self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=2, dilation=2), nn.PReLU(s)])
self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])
self.mid_part = nn.Sequential(*self.mid_part)
#self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,
#output_padding=scale_factor-1)
self.dp1 = nn.Sequential(
dw_block(32, kernel_size=3, dilation=2),
nn.PReLU(32),
pointwise_conv(nin = 32, nout = 24),
nn.PReLU(24),
dw_block(24, kernel_size=3, dilation=2),
nn.PReLU(24),
pointwise_conv(nin = 24, nout = 16),
nn.PReLU(16),
dw_block(16, kernel_size=3),
nn.PReLU(16),
pointwise_conv(nin = 16, nout = 8),
nn.PReLU(8),
dw_block(8, kernel_size=3),
nn.PReLU(8),
pointwise_conv(nin = 8, nout = 16),
nn.PReLU(16),
dw_block(16, kernel_size=5),
nn.PReLU(16),
pointwise_conv(nin = 16, nout = 24, padding=2),
nn.PReLU(24),
dw_block(24, kernel_size=5),
nn.PReLU(24),
pointwise_conv(nin = 24, nout = 32, padding=2), # PADDING = 2 here
nn.PReLU(32),
)
self.conv = nn.Conv2d(32, 9, 3, 1, 1)
self.last_part = nn.PixelShuffle(scale_factor)
self._initialize_weights()
def _initialize_weights(self):
for m in self.first_part:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
nn.init.zeros_(m.bias.data)
for m in self.mid_part:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
nn.init.zeros_(m.bias.data)
for m in self.dp1:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
nn.init.zeros_(m.bias.data)
def forward(self, x):
global_residual = x
x1 = self.first_part(x)
x2 = self.mid_part(x1)
x3 = self.dp1(x2)
x4 = x3 + x1
x = self.conv(x4)
x = x + global_residual
x = self.last_part(x)
return x
if __name__ == "__main__":
model = SuperRes()
print(model)