Code Monkey home page Code Monkey logo

ssr_net_pytorch's Introduction

author : oukohou
time : 2019-09-26 16:44:48
email : [email protected]

A pytorch reimplementation of SSR-Net.

the official keras version is here: SSR-Net

results on MegaAge_Asian datasets:

- train valid test
version_v1^[1] train Loss: 22.0870 CA_3: 0.5108, CA_5: 0.7329 val Loss: 44.7439 CA_3: 0.4268, CA_5: 0.6225 test Loss: 35.6759 CA_3: 0.4935, CA_5: 0.6902
original paper ** ** CA_3: 0.549, CA_5: 0.741
version_v2^[2] train Loss: 2.9401 CA_3: 0.6326, CA_5: 0.8123 val Loss: 4.7221 CA_3: 0.4438, CA_5: 0.6295 test Loss: 3.9311 CA_3: 0.5151, CA_5: 0.7163

Note:

  • This SSR-Net model can't fit big learning rate, learning rate should be smaller than 0.002. otherwise the model will very likely always output 0, me myself suspects this is because of the utilizing Tanh as activation function.
  • And also: Batchsize could severely affect the results. A set of tested params can be :
    batch_size = 50
    input_size = 64
    num_epochs = 90
    learning_rate = 0.001 # originally 0.001
    weight_decay = 1e-4 # originally 1e-4
    augment = False
    optimizer_ft = optim.Adam(params_to_update, lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.L1Loss()
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1)
    
  • The dataset preprocess is quite easy. For MegaAsian datasets, you can use the ./datasets/read_megaasina_data.py directly; for other datasets, just generate a pandas csv file in format like:
    filename,age
    1.jpg,23
    ...
    

is OK. But also, remember to change the ./datasets/read_imdb_data.py accordingly.

onnxruntime C++ implementation

thanks to DefTruth 's implementation here: How to convert SSRNet to ONNX and implements with onnxruntime c++.

another small note:

my reading understanding of SSRNet can be found:

which was written in Chinese.

ssr_net_pytorch's People

Contributors

oukohou avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

ssr_net_pytorch's Issues

咨询疑问

您好,如果我使用输入的不是图片,是3维矩阵(大脑的3维矩阵.mat)来做这样的年龄预测估计,您感觉如果使用您的方法和代码修改后可行吗?

使用处理过后的IMDB数据集运行train_SSR-Net.py报错,初学者不太会,望作者解答,不胜感激

C:\ProgramData\Anaconda3\python.exe D:/PycharmProjects/SSR_Net_Pytorch-master/train_SSR-Net.py
Traceback (most recent call last):
File "D:/PycharmProjects/SSR_Net_Pytorch-master/train_SSR-Net.py", line 214, in
num_epochs_=num_epochs,
File "D:/PycharmProjects/SSR_Net_Pytorch-master/train_SSR-Net.py", line 63, in train_model
for i, (inputs, labels) in enumerate(dataloaders_[phase]):
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 615, in next
batch = self.collate_fn([self.dataset[i] for i in indices])
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 615, in
batch = self.collate_fn([self.dataset[i] for i in indices])
File "D:\PycharmProjects\SSR_Net_Pytorch-master\datasets\read_imdb_data.py", line 35, in getitem
image_, image_path_ = self.read_images(index)
File "D:\PycharmProjects\SSR_Net_Pytorch-master\datasets\read_imdb_data.py", line 52, in read_images
filename = self.images_df.iloc[index_].Filename
File "C:\ProgramData\Anaconda3\lib\site-packages\pandas\core\generic.py", line 4372, in getattr
return object.getattribute(self, name)
AttributeError: 'Series' object has no attribute 'Filename'

read_imdb_data.py

34 def getitem(self, index):
35 image_, image_path_ = self.read_images(index)
if self.mode in ['train', ]:
label = self.images_df.iloc[index].age
else:
label = image_path_
if self.augment:
image_ = self.augmentor(image_)
image_ = T.Compose([
T.ToPILImage(),
# T.RandomResizedCrop(self.input_size),
# T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(image_)
return image_.float(), label

51def read_images(self, index_):
52 filename = self.images_df.iloc[index_].Filename
image_path_ = os.path.join(self.base_path, filename)
image = cv2.imread(image_path_)
return image, image_path_

data augment

inference_images.py中的方法inference_single_image有如下代码:

image_ = cv2.imread(image_path_)
image_ = T.Compose([
T.ToPILImage(),
T.Resize((input_size_, input_size_)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(image_)

opencv读取的图像通道顺序为BGR,而[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]是Imagenet的均值和标准差,RGB通道顺序,transforms.ToPILImage() 只是单纯的将格式变为PIL,没有转换通道顺序的功能。此处写法是否有问题?

人脸检测

不好意思,想请教一下模型输入是经过什么人脸检测啊,大概要裁剪到什么样子。

Data preprocessing

Hi there,

The official version needs to preprocess the image dataset, but I can not find this part in your code. Doesn't your code need this process?

Regards,

About the used datasets of the pretrained model

Hi, oukohou!

Thanks for the pytorch version of SSR net.

I have ran the inference_image file for the prediction, pretty easy to use, and I have some questions.

1. code error

in the inference_images.py file, line44, the code is

image_ = image_.cuda()

but actually I don't have a gpu, so there will be an error occurred. I think

image_ = image_.to(device)

is better in my opinion.

2. datasets

I have read your read.me file. If I didn't misunderstand you, the only one pretrained model was trained on IMDB and Mega datasets. However, in the train_SSR-net.py file, I saw

from datasets.read_face_age_data import FaceAgeDatasets

Now I'm confused. Did you use face age dataset for training? this one?
Or did you just use IMDB and Mega only?

And if possible, could you tell me which solution(datasets) has the best performance?

Looking forward to hear from you soon.

How to improve the accuracy of a specific age group?

Hello, thanks for your work!I trained based on megaage and got the accuracy of CA5: 75%, but I found that for a certain age group, such as 5060, 6070 elderly people, 010, 1020 young people, the accuracy of this age group is not ideal. Is it because there are relatively few age data on both sides of the dataset?
I need to invest more datasets in my estimated age?

training speed

你好,你这个版本的迭代速度大概是什么样子的。之前训练的感觉数据处理比较慢~

train model on the IMDB_WIKI?

Hi, I have some questions to consult:

  1. Did you train the model on the IMDB_WIKI? In the doc, the version1 are trained from scratch on the MegaAge_Asian and version2 are fine-tuned on the MegaAge_Asian.
  2. Did you calculate the MAE of the model on the MegaAge_Asian test dataset? my result is 11.09. Maybe is it a little bigger?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.