リスト4-1に主プログラムを示します。
処理の流れは以下のようになります。
リスト4-1 主プログラムのソースコード(抜粋)
1 def main():
2 # データセット名
3 DATA = 'MNIST'
4
5 # 計算パラメーター
6 num_epochs = 10 # 繰り返し回数
7 batch_size = 50 # ミニバッチサイズ
8
9 # device
10 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
11
12 # dataset, 画像サイズ, ラベル文字列
13 if DATA == 'MNIST':
14 train_set, test_set, test0_set = MNIST.dataset()
15 in_size = (1, 28, 28)
16 strclasses = MNIST.strclasses()
17
18 # dataloader
19 train_loader, test_loader = utils.dataloader(train_set, test_set, batch_size=batch_size)
20
21 # model
22 net = myCNN.CNN6(in_size, (32, 32, 64, 64, 128, 128), (0.2, 0.3, 0.4, 0.5), len(strclasses))
23
24 # 損失関数: 交差エントロピー関数
25 criterion = nn.CrossEntropyLoss()
26
27 # 最適化関数
28 #optimizer = optim.SGD(net.parameters(), lr=0.001)
29 #optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
30 optimizer = optim.Adam(net.parameters())
31
32 # 学習
33 history, test_labels = fit.fit(net, device, num_epochs, train_loader, test_loader, optimizer, criterion)
データセットからミニバッチ用のデータローダーを作成するプログラムをリスト3-2に示します。
引数でミニバッチサイズを指定します。
これはデータセットによらない処理です。
リスト4-2 データローダープログラムのソースコード
1 def dataloader(train_set, test_set, batch_size=50):
2 train_loader, test_loader = None, None
3 if train_set is not None:
4 train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
5 if test_set is not None:
6 test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
7
8 return train_loader, test_loader
リスト4-3に学習プログラムを示します。
エポックごとに訓練とテストを行い、損失と正解率を出力します。
これはデータセットによらない処理です。
リスト4-3 学習プログラムのソースコード
1 # 訓練
2 def _train(model, device, train_loader, optimizer, criterion):
3 model.train()
4 for data, target in train_loader:
5 data = data.to(device)
6 target = target.to(device)
7 optimizer.zero_grad()
8 output = model(data)
9 loss = criterion(output, target)
10 loss.backward()
11 optimizer.step()
12
13 # テスト
14 def _test(model, device, test_loader, criterion):
15 model.eval()
16 test_loss = 0
17 test_correct = 0
18 labels = np.zeros(0)
19 with torch.no_grad():
20 for data, target in test_loader:
21 data = data.to(device)
22 target = target.to(device)
23 output = model(data)
24 loss = criterion(output, target)
25 test_loss += loss.item() * len(data) # sum up batch loss
26 pred = output.argmax(dim=1, keepdim=True) # get the index
27 test_correct += pred.eq(target.view_as(pred)).sum().item()
28 labels = np.append(labels, pred.data.to("cpu").numpy())
29
30 return test_loss / len(test_loader.dataset), test_correct, labels
31
32 # 学習
33 def fit(model, device, num_epochs, train_loader, test_loader, optimizer, criterion):
34
35 # 開始時刻
36 t0 = time.time()
37 t1 = t0
38
39 # GPU転送
40 model = model.to(device)
41
42 # 損失と正解率を保存する配列
43 history = np.zeros((0, 2))
44
45 # エポックに関するループ
46 for epoch in range(num_epochs):
47
48 # 訓練
49 _train(model, device, train_loader, optimizer, criterion)
50
51 # テスト
52 test_loss, test_correct, test_labels = _test(model, device, test_loader, criterion)
53 test_accuracy = test_correct / len(test_loader.dataset) # 正解率
54
55 # 損失と正解率を出力する
56 t2 = time.time()
57 print('%3d %.5f %.5f(%d/%d)%8.1f(%5.1f)[sec]'
58 % (epoch, test_loss, test_accuracy, test_correct, len(test_loader.dataset), t2 - t0, t2 - t1))
59 t1 = t2
60
61 # 損失と正解率を配列に代入する
62 item = np.array([test_loss, test_accuracy])
63 history = np.vstack((history, item))
64
65 return history, test_labels
NVIDIAの最近のGPU[10]は半精度演算(16ビット浮動小数点演算:float16)
を用いてテンソル演算を高速化することができます。
計算時間の主要部を半精度で計算し、
その他を単精度で計算する方法を混合精度(Mixed Precision)と呼びます。
リスト4-4に訓練部を混合精度で計算するプログラムを示します[11][12]。
テストした結果によると、
画像のサイズが大きいとき(目安として64x64ピクセル以上)約2倍速くなり、
画像のサイズが小さいときは計算時間は変わりません。
なお、混合演算の有無によって結果は少し変わりますが正解率はほぼ同じです。
リスト4-4 半精度版学習プログラム(訓練部のみ, GPUのとき)
1 from torch.amp import autocast, GradScaler
2 def _train_mixed(model, device, train_loader, optimizer, criterion):
3 scaler = GradScaler('cuda')
4 model.train()
5 for data, target in train_loader:
6 data = data.cuda()
7 target = target.cuda()
8 optimizer.zero_grad()
9 with autocast(device_type='cuda', dtype=torch.float16):
10 output = model(data)
11 loss = criterion(output, target)
12 scaler.scale(loss).backward()
13 scaler.step(optimizer)
14 scaler.update()