bubbliiiing / classification-pytorch Goto Github PK
View Code? Open in Web Editor NEW这是各个主干网络分类模型的源码,可以用于训练自己的分类模型。
License: MIT License
这是各个主干网络分类模型的源码,可以用于训练自己的分类模型。
License: MIT License
我看到mobilenet网络中加载预训练的权值参数时是通过‘https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'这个地址,
然后使用
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], model_dir='./model_data', progress=progress)
model.load_state_dict(state_dict)
这个代码加载的,
,请问怎么在本地加载从网盘中下载好的权值数据呀?谢谢啦
比如说输入img/cat.jpg的时候,预测结果是Class: cat Probability: 0.998,希望可以得到的预测结果是Class: cat Probability: 0.998,dog Probability: 0.002。
#---------------------------------------------------#
#---------------------------------------------------#
class_name = self.class_names[np.argmax(preds)]
probability = np.max(preds)
我觉得应该是需要修改这块的代码,但是不知道该如何下手...初学者还请多多包涵~
如果能回复我,将不胜感激~~~支持大佬!!!
https://download.pytorch.org/models/resnet50-19c8e357.pt
https://download.pytorch.org/models/resnet50-0676ba61.pth
我想问下这俩权重为啥不一样,我找到torchvision下的resnet.py文件。发现他们定义的网络结构一样的。为啥加载模型会不一样呢。
要是加载不对应的权重会报错:
RuntimeError: Error(s) in loading state_dict for ResNet:
size mismatch for layer1.0.downsample.1.weight: copying a param with shape torch.Size([256, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for layer2.0.downsample.1.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for layer3.0.downsample.1.weight: copying a param with shape torch.Size([1024, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for layer4.0.downsample.1.weight: copying a param with shape torch.Size([2048, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([2048]).
求大佬指导。
个人感觉是不是应该把letterbox加到Dataloader里,或者将其从predict或其他validation代码中移除,以保证训练和验证图片的一致性?
代码中的数据预处理是resie到 [224,224],之后CenterCrop [224,244],CenterCrop前后图像大小一样是不是意味着CenterCrop不起作用?我看很多其他仓库的代码是训练时候直接Randcrop到 [224,224],而在预测或评估时先resize到 [256,256] 再CenterCrop到 [224,224],想请问一下这个问题。
另外,仓库中的归一化方式是除以127.5,之后-1,把所有像素点归纳到 [-1,1]之间,而其他仓库是除以255,之后用imagenet1K数据集的mean和std作normalize,官方论文中是用哪种方式呢?
学习率一直是0.0001,不会变动,这个有办法解决嘛
不知道怎么解决,不知道哪里出了问题。具体的提示是:
RuntimeWarning: invalid value encountered in true_divide
F1 = (2 * Recall * Precision) / (Recall + Precision)
Save Recall out to metrics_out\Recall.png
Save Precision out to metrics_out\Precision.png
Traceback (most recent call last):
File "F:\ImgClassification\eval.py", line 58, in
top1, top5, Recall, Precision, F1= evaluteTop1_5(classfication, lines, metrics_out_path)
SSD网络改了后面的网络,训练了100epochs,map只有61.9%正常吗,训练前用原始权重有79.3%
出错信息:
There is no pretrained model for vit_b_16
Traceback (most recent call last):
File "train.py", line 225, in <module>
model = get_model_from_name[backbone](input_shape = input_shape, num_classes = num_classes, pretrained = pretrained)
File "/data/xly/hhs/classification-pytorch-1/nets/vision_transformer.py", line 223, in vit_b_16
model.load_state_dict(torch.load("/data/xly/hhs/classification-pytorch-1/model_data/vit-patch_16.pth"))
File "/root/anaconda3/envs/torch38/lib/python3.8/site-packages/torch/serialization.py", line 595, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/root/anaconda3/envs/torch38/lib/python3.8/site-packages/torch/serialization.py", line 749, in _legacy_load
return legacy_load(f)
File "/root/anaconda3/envs/torch38/lib/python3.8/site-packages/torch/serialization.py", line 674, in legacy_load
tar.extract('storages', path=tmpdir)
File "/root/anaconda3/envs/torch38/lib/python3.8/tarfile.py", line 2272, in extract
tarinfo = self._get_extract_tarinfo(member, filter_function, path)
File "/root/anaconda3/envs/torch38/lib/python3.8/tarfile.py", line 2279, in _get_extract_tarinfo
tarinfo = self.getmember(member)
File "/root/anaconda3/envs/torch38/lib/python3.8/tarfile.py", line 1962, in getmember
raise KeyError("filename %r not found" % name)
KeyError: "filename 'storages' not found"
想要数据集!
torchvision版本过高会使得使用此程序时出现此问题,可以将将“from torchvision.models.utils import load_state_dict_from_url”改为:“from torch.hub import load_state_dict_from_url”
如题
导师 分类的pytorch这个版本是没有tensorboard的log吗,如果想再tensorboard记录学习率和acc曲线的话,需要在哪里添加代码的
您公开的模型文件.pth是哪个数据集上训练的?
代码段如下:
if not random:
image = self.resize(image)
image = self.center_crop(image)
return image
请问验证集图片预处理的目的是什么,验证集的结果是否影响了训练集的超参数更新?
谢谢!
百度云盘没会员太慢了,求个Google Drive
Traceback (most recent call last): File "train.py", line 452, in
fit_one_epoch(model_train, model, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, save_period, save_dir, local_rank)
File "/mnt/class_master/utils/utils_fit.py", line 116, in fit_one_epoch if (math.floor(epoch) + 1) % save_period == 0 or math.floor(epoch)+1 == Epoch:
TypeError: unsupported operand type(s) for %: 'int' and 'str'
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.