キカベン
機械学習でより便利な世の中へ
G検定対策
お問い合わせ
   

Deeplabv3-ResNet101でセマンティック・セグメンテーションをやってみる

thumb image

ResNetで転移学習の方法を試してみる」でResNetを扱ったので、今回はResNetをバックボーン(特徴量を取り出す部分のネットワーク)として使っているDeeplabv3-ResNet101を紹介します。

もともとの論文は「Rethinking Atrous Convolution for Semantic Image Segmentation」です。

1. セマンティック・セグメンテーションとは🔝

画像にある全てのピクセルを分類するセマンティック・セグメンテーションを行うモデルです。

セグメンテーションとは区分けすることで、例えば道路の部分と車の部分と建物や空などを区別することを指します。

また、セマンティックとは種類の異なるものは区別するけど、同じ種類のものが複数個あっても区別しないということす。

例えば、車が複数台あっても区別はしません。つまり同じクラスの物体は全て同じ色で表現されています。

以下はPASCAL-Context Datasetから引用したイメージとそのセマンティック・セグメンテーションの例になります。

セマンティック・セグメンテーションとは別にインスタンス・セグメンテーションというものもあって、それではクラス区分だけでなく、同じクラスに属する物体でもインスタンスごとに区別します。

例えば、上記の画像では複数の人間がおのおの別の物体(インスタンス)として区別されるのです。

この記事で取り扱うDeeplabv3-ResNet101はセマンティック・セグメンテーションのモデルになります。

2. モデルのダウンロード🔝

Mini Condaなどを使ってPython環境を作りPyTorchが使えるようにしておいてください。

今回もTorch Hubのサンプルコードを参考にしています。

まず、モデルをダウンロードしましょう。

import torch

model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)

model.eval() # 評価モード

モデルは、~/.cache/torch/hub/checkpoints/にダウンロードされるので次回からは素早くロードできるようになります。

次に、テスト用の画像を取り込みます。

import urllib

url = "https://github.com/pytorch/hub/raw/master/images/dog.jpg"
filename = "dog.jpg"

urllib.request.urlretrieve(url, filename)

これで犬の画像がダウンロードされました。

この画像を読み込んでおきます。

from PIL import Image

input_image = Image.open(filename)

input_image.show()
犬

3. 画像の前処理をしてバッチを作る🔝

画像の前処理を行うトランスフォームを作ります。

from torchvision import transforms

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

バックボーンがResNetなので、ResNetで使われる前処理と同じです。

ImageNetの画像の平均のRGB値とその標準偏差で、全てのピクセル値を正規化しています。

この前処理を施された画像データを使って、推論のためのミニバッチを作ります。

input_batch = input_tensor.unsqueeze(0)

print('input_batch', input_batch.shape)

これは、推論では入力画像が一つであってもバッチの形式を求められるからです。

なお、プリントしてバッチのシェイプを確認すると、torch.Size([1, 3, 1213, 1546])と表示され、画像が一つだけのバッチになっているのがわかります。

ちなみに、unsqueeze(0)は次元を追加するときの使います。

もともとの次元は(3, 1213, 1546)でしたが、unsqueeze(0)によって先頭の次元が追加されました。

4. セマンティック・セグメンテーションを行う🔝

推論を実行する前に、GPUがあるならバッチとモデルをcudaデバイスに移動します。

if torch.cuda.is_available():
    device = torch.device('cuda')
    input_batch = input_batch.to(device)
    model.to(device)

GPUがあれば実行が速くなります。

推論を実行します。

with torch.no_grad():
    outputs = model(input_batch)

torch.no_grad()を使うのは勾配(グラディエント)の計算が必要ないからです(スピードが速くなります)。

このoutputsOrderedDictで、キーにoutauxがありますが、auxはトレーニングでのみ使われるロス値に関するものなのでここでは無視します。

よってoutの1番目の出力を取り出します。バッチサイズが1なので出力も1つしかないですね。

out = outputs['out'][0]

print('out', out.shape)

outのシェイプはtorch.Size([21, 1213, 1546])で21はクラスの数です。

また、1213, 1546はサイズに相当します。

つまり、各ピクセルに対して21の値が予測値として計算されていることになります。

ピクセルごとにどのクラスが最大値になっているのかを求めます。

prediction = out.argmax(0)

これによってピクセルごとにクラスの値が決まりました。

Torch Hubのサンプルコードと同様の方法で色を付け表示します。

物体検出(Object Detection)とは違ってBounding Boxではなくピクセルごとに分類が行われています。

5. 全部まとめたソース🔝

import torch
import urllib
from PIL import Image
from torchvision import transforms


# モデルのダウンロード
model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)
model.eval()

# サンプル画像のダウンロード
url = "https://github.com/pytorch/hub/raw/master/images/dog.jpg"
filename = "dog.jpg"

urllib.request.urlretrieve(url, filename)

# 画像を読み込んでおく
input_image = Image.open(filename)
input_image.show()

# 前処理のトランスフォームを作る

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# サンプル画像に前処理を施す
input_tensor = preprocess(input_image)

# ミニバッチを作る(中身は一つの画像)
input_batch = input_tensor.unsqueeze(0)

print('input_batch', input_batch.shape)

# GPUがあるなら、バッチとモデルをcudaデバイスに移動する
if torch.cuda.is_available():
    device = torch.device('cuda')
    input_batch = input_batch.to(device)
    model.to(device)

# 推論の実行
with torch.no_grad():
    outputs = model(input_batch)

# 結果を取り出す
out = outputs['out'][0]
print('out', out.shape)

prediction = out.argmax(0)

# 21のクラスに固有の色を指定する
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")

prediction_numpy = prediction.byte().cpu().numpy()
r = Image.fromarray(prediction_numpy)
r.putpalette(colors)
r.show()


コメントを残す

メールアドレスは公開されません。