torchvision分類介紹
Torchvision高版本支持各種SOTA的圖像分類模型,同時還支持不同數據集分類模型的預訓練模型的切換。使用起來十分方便快捷,Pytroch中支持兩種遷移學習方式,分別是:
- Finetune模式 基于預訓練模型,全鏈路調優參數 - 凍結特征層模式 這種方式只修改輸出層的參數,CNN部分的參數凍結上述兩種遷移方式,分別適合大量數據跟少量數據,前一種方式計算跟訓練時間會比第二種方式要長點,但是針對大量自定義分類數據效果會比較好。
自定義分類模型修改與訓練
加載模型之后,feature_extracting 為true表示凍結模式,否則為finetune模式,相關的代碼如下:
def set_parameter_requires_grad(model, feature_extracting): if feature_extracting: for param in model.parameters(): param.requires_grad = False以resnet18為例,修改之后的自定義訓練代碼如下:
model_ft=models.resnet18(pretrained=True) num_ftrs=model_ft.fc.in_features #Herethesizeofeachoutputsampleissetto5. #Alternatively,itcanbegeneralizedtonn.Linear(num_ftrs,len(class_names)). model_ft.fc=nn.Linear(num_ftrs,5) model_ft=model_ft.to(device) criterion=nn.CrossEntropyLoss() #Observethatallparametersarebeingoptimized optimizer_ft=optim.SGD(model_ft.parameters(),lr=0.001,momentum=0.9) #DecayLRbyafactorof0.1every7epochs exp_lr_scheduler=lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1) model_ft=train_model(model_ft,criterion,optimizer_ft,exp_lr_scheduler, num_epochs=25)
數據集是flowers-dataset,有五個分類分別是:
daisy dandelion roses sunflowers tulips
全鏈路調優,遷移學習訓練CNN部分的權重參數
Epoch0/24 ---------- trainLoss:1.3993Acc:0.5597 validLoss:1.8571Acc:0.7073 Epoch1/24 ---------- trainLoss:1.0903Acc:0.6580 validLoss:0.6150Acc:0.7805 Epoch2/24 ---------- trainLoss:0.9095Acc:0.6991 validLoss:0.4386Acc:0.8049 Epoch3/24 ---------- trainLoss:0.7628Acc:0.7349 validLoss:0.9111Acc:0.7317 Epoch4/24 ---------- trainLoss:0.7107Acc:0.7669 validLoss:0.4854Acc:0.8049 Epoch5/24 ---------- trainLoss:0.6231Acc:0.7793 validLoss:0.6822Acc:0.8049 Epoch6/24 ---------- trainLoss:0.5768Acc:0.8033 validLoss:0.2748Acc:0.8780 Epoch7/24 ---------- trainLoss:0.5448Acc:0.8110 validLoss:0.4440Acc:0.7561 Epoch8/24 ---------- trainLoss:0.5037Acc:0.8170 validLoss:0.2900Acc:0.9268 Epoch9/24 ---------- trainLoss:0.4836Acc:0.8360 validLoss:0.7108Acc:0.7805 Epoch10/24 ---------- trainLoss:0.4663Acc:0.8369 validLoss:0.5868Acc:0.8049 Epoch11/24 ---------- trainLoss:0.4276Acc:0.8504 validLoss:0.6998Acc:0.8293 Epoch12/24 ---------- trainLoss:0.4299Acc:0.8529 validLoss:0.6449Acc:0.8049 Epoch13/24 ---------- trainLoss:0.4256Acc:0.8567 validLoss:0.7897Acc:0.7805 Epoch14/24 ---------- trainLoss:0.4062Acc:0.8559 validLoss:0.5855Acc:0.8293 Epoch15/24 ---------- trainLoss:0.4030Acc:0.8545 validLoss:0.7336Acc:0.7805 Epoch16/24 ---------- trainLoss:0.3786Acc:0.8730 validLoss:1.0429Acc:0.7561 Epoch17/24 ---------- trainLoss:0.3699Acc:0.8763 validLoss:0.4549Acc:0.8293 Epoch18/24 ---------- trainLoss:0.3394Acc:0.8788 validLoss:0.2828Acc:0.9024 Epoch19/24 ---------- trainLoss:0.3300Acc:0.8834 validLoss:0.6766Acc:0.8537 Epoch20/24 ---------- trainLoss:0.3136Acc:0.8906 validLoss:0.5893Acc:0.8537 Epoch21/24 ---------- trainLoss:0.3110Acc:0.8901 validLoss:0.4909Acc:0.8537 Epoch22/24 ---------- trainLoss:0.3141Acc:0.8931 validLoss:0.3930Acc:0.9024 Epoch23/24 ---------- trainLoss:0.3106Acc:0.8887 validLoss:0.3079Acc:0.9024 Epoch24/24 ---------- trainLoss:0.3143Acc:0.8923 validLoss:0.5122Acc:0.8049 Trainingcompletein25m34s BestvalAcc:0.926829
凍結CNN部分,只訓練全連接分類權重
Paramstolearn: fc.weight fc.bias Epoch0/24 ---------- trainLoss:1.0217Acc:0.6465 validLoss:1.5317Acc:0.8049 Epoch1/24 ---------- trainLoss:0.9569Acc:0.6947 validLoss:1.2450Acc:0.6829 Epoch2/24 ---------- trainLoss:1.0280Acc:0.6999 validLoss:1.5677Acc:0.7805 Epoch3/24 ---------- trainLoss:0.8344Acc:0.7426 validLoss:1.1053Acc:0.7317 Epoch4/24 ---------- trainLoss:0.9110Acc:0.7250 validLoss:1.1148Acc:0.7561 Epoch5/24 ---------- trainLoss:0.9049Acc:0.7346 validLoss:1.1541Acc:0.6341 Epoch6/24 ---------- trainLoss:0.8538Acc:0.7465 validLoss:1.4098Acc:0.8293 Epoch7/24 ---------- trainLoss:0.9041Acc:0.7349 validLoss:0.9604Acc:0.7561 Epoch8/24 ---------- trainLoss:0.8885Acc:0.7468 validLoss:1.2603Acc:0.7561 Epoch9/24 ---------- trainLoss:0.9257Acc:0.7333 validLoss:1.0751Acc:0.7561 Epoch10/24 ---------- trainLoss:0.8637Acc:0.7492 validLoss:0.9748Acc:0.7317 Epoch11/24 ---------- trainLoss:0.8686Acc:0.7517 validLoss:1.0194Acc:0.8049 Epoch12/24 ---------- trainLoss:0.8492Acc:0.7572 validLoss:1.0378Acc:0.7317 Epoch13/24 ---------- trainLoss:0.8773Acc:0.7432 validLoss:0.7224Acc:0.8049 Epoch14/24 ---------- trainLoss:0.8919Acc:0.7473 validLoss:1.3564Acc:0.7805 Epoch15/24 ---------- trainLoss:0.8634Acc:0.7490 validLoss:0.7822Acc:0.7805 Epoch16/24 ---------- trainLoss:0.8069Acc:0.7644 validLoss:1.4132Acc:0.7561 Epoch17/24 ---------- trainLoss:0.8589Acc:0.7492 validLoss:0.9812Acc:0.8049 Epoch18/24 ---------- trainLoss:0.7677Acc:0.7688 validLoss:0.7176Acc:0.8293 Epoch19/24 ---------- trainLoss:0.8044Acc:0.7514 validLoss:1.4486Acc:0.7561 Epoch20/24 ---------- trainLoss:0.7916Acc:0.7564 validLoss:1.0575Acc:0.8049 Epoch21/24 ---------- trainLoss:0.7922Acc:0.7647 validLoss:1.0406Acc:0.7805 Epoch22/24 ---------- trainLoss:0.8187Acc:0.7647 validLoss:1.0965Acc:0.7561 Epoch23/24 ---------- trainLoss:0.8443Acc:0.7503 validLoss:1.6163Acc:0.7317 Epoch24/24 ---------- trainLoss:0.8165Acc:0.7583 validLoss:1.1680Acc:0.7561 Trainingcompletein20m7s BestvalAcc:0.829268
測試結果:
零代碼訓練演示
我已經完成torchvision中分類模型自定義數據集遷移學習的代碼封裝與開發,支持基于收集到的數據集,零代碼訓練,生成模型。圖示如下:

審核編輯:彭靜
聲明:本文內容及配圖由入駐作者撰寫或者入駐合作網站授權轉載。文章觀點僅代表作者本人,不代表電子發燒友網立場。文章及其配圖僅供工程師學習之用,如有內容侵權或者其他違規問題,請聯系本站處理。
舉報投訴
-
數據
+關注
關注
8文章
7335瀏覽量
94756 -
模型
+關注
關注
1文章
3751瀏覽量
52099 -
遷移學習
+關注
關注
0文章
74瀏覽量
5850
原文標題:tochvision輕松支持十種圖像分類模型遷移學習
文章出處:【微信號:CVSCHOOL,微信公眾號:OpenCV學堂】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
熱點推薦
linux配置mysql的兩種方式
兩種方式:a、$ find / -name mysql–print 查看是否有mysql文件夾b、$ netstat -a –n 查看是否打開3306端口
發表于 07-26 07:46
SQL語言的兩種使用方式
SQL語言的兩種使用方式在終端交互方式下使用,稱為交互式SQL嵌入在高級語言的程序中使用,稱為嵌入式SQL―高級語言如C、Java等,稱為宿主語言嵌入式SQL的實現方式源程序(用主語言
發表于 12-20 06:51
WiMAX系統中兩種多天線技術的原理和特點詳述
802.16e協議中支持MIMO(多入多出)和AAS(自適應天線系統)兩種不同的多天線實現方式。本文在介紹MIMO和AAS原理的基礎上。分析了各自的特點和性能,并且進行了比較。802.16e協議
發表于 12-13 07:43
?4548次閱讀
兩種UVLED封裝方式COB和DOB的區別
目前市面上,UVLED常見的封裝方式是COB和DOB兩種,這兩種封裝方式的區別主要體現在封裝物料、生產工藝、光性能、電性能以及熱性能這幾方面。昀通科技作為UVLED固化機廠家,在之前的
發表于 10-12 08:44
?7236次閱讀
在MATLAB/simulink中建模時的兩種不同實現方式
導讀:本期文章主要介紹在MATLAB/simulink中建模時的兩種不同實現方式,一種是直接用現成的文件庫中的模塊進行搭建,一種是用Sfunction代碼實現。接下來以電壓型磁鏈觀測器
MATLAB/simulink中兩種實現建模方式的優勢
導讀:本期文章主要介紹在MATLAB/simulink中建模時的兩種不同實現方式,一種是直接用現成的文件庫中的模塊進行搭建,一種是用Sfunction代碼實現。接下來以電壓型磁鏈觀測器
MIMXRT并口連接外圍器件的兩種方式
MIMXRT 有類似Kinetis FlexBUS的接口用于外接FPGA或者并口的液晶屏或者并口采集芯片。可以參考如下的應用筆記,有兩種方式: Flexio方式以及SEMC的DBI總線并口連接
Pytroch中支持的兩種遷移學習方式
評論