Pytorch實驗代碼的億些小細節
轉載請註明出處
(兄弟們點點贊吧,收藏都是點讚的兩倍了T_T)
序
你是否有過這樣的經歷:煉了一大堆的丹,但過了一周回來看結果,忘記了每個模型對應的配置;改了模型中的一個組件,跑起來一個新的訓練,這時候測試舊模型卻發現結果跟原來不一樣了;把所有的訓練測試代碼寫在一個文件裡,加入各種if else,最後一個文件上千行,一個週末沒看,回來改一個邏輯要找半天……其實這些情況除了深度學習相關的開發,在別的軟件開發中也是很常見的,為了解決這些問題,軟件行業的開發者形成了很多套路,比如設計模式,提高代碼復用性,或者各種最佳實踐,比如穀歌、阿里都有一套Java開發最佳實踐,各種框架比如客戶端的Android,後端的spring,也有各種最佳實踐,讓開發者的代碼更加簡潔,更專注於核心的業務實現。在煉丹領域,從2016年至今,各大訓練框架互相競爭,互相學習,學術界基於這些框架產出了很多論文,質量越高的論文,往往代碼寫得也越有條理,最新的論文代碼也漸漸形成了一些固定的範式。這幾年看了許多論文的開源代碼,基於別人的代碼做過不少的改進,煉丹代碼也漸漸形成了一些風格,今天就講講自己煉丹代碼中的一些能讓實驗更有條理的小習慣。
先上代碼,歡迎star收藏(別光放進zhihu收藏夾吃灰呀):
代碼結構
torch_base torch_base ├── checkpoints # 存放模型的地方├── data # 定义各种用于训练测试的dataset ├── eval.py # 测试代码├── loss.py # 定义各种花里胡哨的loss ├── metrics.py # 定义各种约定俗成的评估指标├── model # 定义各种实验中的模型├── options.py # 定义各种实验参数,以命令行形式传入├── README.md # 介绍一下自己的repo ├── scripts # 各种训练,测试脚本├── train.py # 训练代码└── utils # 各种工具代码
├── checkpoints # 存放模型的地方├── data # 定義各種用於訓練測試的dataset torch_base ├── checkpoints # 存放模型的地方├── data # 定义各种用于训练测试的dataset ├── eval.py # 测试代码├── loss.py # 定义各种花里胡哨的loss ├── metrics.py # 定义各种约定俗成的评估指标├── model # 定义各种实验中的模型├── options.py # 定义各种实验参数,以命令行形式传入├── README.md # 介绍一下自己的repo ├── scripts # 各种训练,测试脚本├── train.py # 训练代码└── utils # 各种工具代码
├── eval.py # 測試代碼├── loss.py # 定義各種花里胡哨的loss torch_base ├── checkpoints # 存放模型的地方├── data # 定义各种用于训练测试的dataset ├── eval.py # 测试代码├── loss.py # 定义各种花里胡哨的loss ├── metrics.py # 定义各种约定俗成的评估指标├── model # 定义各种实验中的模型├── options.py # 定义各种实验参数,以命令行形式传入├── README.md # 介绍一下自己的repo ├── scripts # 各种训练,测试脚本├── train.py # 训练代码└── utils # 各种工具代码
├── metrics.py # 定義各種約定俗成的評估指標├── model # 定義各種實驗中的模型├── options.py # 定義各種實驗參數,以命令行形式傳入├── README.md # 介紹一下自己的repo torch_base ├── checkpoints # 存放模型的地方├── data # 定义各种用于训练测试的dataset ├── eval.py # 测试代码├── loss.py # 定义各种花里胡哨的loss ├── metrics.py # 定义各种约定俗成的评估指标├── model # 定义各种实验中的模型├── options.py # 定义各种实验参数,以命令行形式传入├── README.md # 介绍一下自己的repo ├── scripts # 各种训练,测试脚本├── train.py # 训练代码└── utils # 各种工具代码
checkpoints比較簡單,每次訓練的模型各自放在一個目錄裡,scripts目錄可以放每次訓練或測試用的命令腳本,README.md往往是這個repo的門面,可以放一些介紹性的內容;其他都是代碼目錄,下面會逐一講解。
options
首先要介紹的是options.py這個文件,因為這裡定義了各種實驗參數,其他模塊多多少少都會與它有關,受它控制;通常我們需要把各種參數通過某種方式傳給程序,比如命令行參數,或者yaml配置文件,我比較習慣用命令行參數,配合pycharm的configuration使用,或者寫在scripts目錄的腳本里邊,都很方便清晰。命令行傳參用到了
defparse_common_args(parser):parser.add_argument('--model_type',type=str,default='base_model',help='used in model_entry.py')parser.add_argument('--data_type',type=str,default='base_dataset',help='used in data_entry.py')parser.add_argument('--save_prefix',type=str,default='pref',help='some comment for model or test result dir')parser.add_argument('--load_model_path',type=str,default='checkpoints/base_model_pref/0.pth',help='model path for pretrain or test')parser.add_argument('--load_not_strict',action='store_true',help='allow to load only common state dicts')parser.add_argument('--val_list',type=str,default='/data/dataset1/list/base/val.txt',help='val list in train, test list path in test')parser.add_argument('--gpus',nargs='+',type=int)returnparserdefparse_train_args(parser):parser=parse_common_args(parser)...returnparserdefparse_test_args(parser):parser=parse_common_args(parser)...returnparser
我會在外面初始化一個parser,先用parse_common_args添加訓練測試共用的一些參數,在parse_train_args和parse_test_args中調用這個公共的函數,這樣可以避免有些參數在訓練時寫了,測試時忘了寫,一跑就報錯。 parse_train_args解析訓練相關的參數,parse_test_args解析測試相關的參數;具體參數和用途如下:
- parse_common_args
- model_type
- data_type
- save_prefix
- load_model_path
- load_not_strict如果關閉,就會用torch原本的加載邏輯,要求比較嚴格的參數匹配;
- val_list
- gpus
- parse_train_args
- lr
- model_dir
- train_list
- batch_size主要是出於測試時的可視化需求,往往測試需要一張一張forward,所以我習慣將測試batch size為1
- epochs
- parse_test_args
- save_viz
- result_dir
使用時,調用保存參數這一步十分重要,能夠避免模型訓練完成之後,腳本或命令找不到,忘記自己訓練的模型配置這種尷尬局面。
測試時也類似,調用
data
接下來是data package,在這裡,可以為每種數據集定義一個dataset,最好是每個dataset各自形成一個文件,比如
這裡我們還有一個有了dataset,再用pytorch的dataloader接口包一下,可以支持shuffle,多線程加載數據,非常方便。
通常我們還會在data package裡放一個augment.py,可以把數據擴增操作都放進去,因為往往多個dataset都需要調用相同的augmentor,所以最好獨立出來,在dataset文件中分別調用。
model
這裡放的就是各種花里胡哨的模型啦,也是煉丹工作最主要的部分。建議每個模型創建一個package,比如
與data_entry類似,我們有一個如果大家想看這方面教程,可以留言,我可以補一下對應的代碼。
utils
存放各種可複用的util函數或者類,比如一些通用的可視化代碼放到
Recoder
一個數據統計工具,在循環裡record每次迭代的數據(比如各種評價指標這個工具在訓練時嵌入到Logger中使用,在測試時由於不需要調用tensorboard,所以直接被eval.py調用。
Logger
將tensorboard的SummaryWritter包了一層,包含一個recorder,還有一個SummaryWritter;在訓練或驗證的每個step以name-value的形式record一下對應的曲線數據,name最好用
train.py
終於來到核心的訓練代碼環節,這裡我整了一個trainer,將訓練中固定的操作封裝成一些函數,需要按實際情況修改的操作封裝成另外的函數,這樣有新任務來了,只需要修改這些函數就行。現在依次介紹這些函數:
- __init__
- train
- train_per_epoch最後根據print_freq,每隔一段時間打印日誌方便觀察。
- val_per_epoch
eval.py
最後介紹的是測試代碼,我把測試的過程包成了一個Evaluator,和trainer也比較類似:
- __init__
- eval
總結
至此,如果有什麼建議也歡迎在