1. เกริ่นนำ

จากความสามารถของ CNN ที่เคยเขียนไปใน อธิบายพื้นฐาน CNN พร้อมโค้ด PyTorch🔥 เห็นได้ว่า output สุดท้ายที่ต้องการจากโมเดลคือ class ของรูปภาพนั้น แต่มีอีกหลายงานที่ต้องการทราบตำแหน่งของ class ที่สนใจในรูปภาพด้วย ซึ่งหมายความว่าโมเดลต้องสามารถระบุ class ของแต่ละ pixel ได้ เราเรียกสิ่งนี้ว่าการทำ “segmentation” จึงมีการคิดค้น “fully convolutional network” (FCN) โดยสิ่งที่แตกต่างจาก CNN คือนำ convolutional layer มาแทนที่ FC layer ในช่วงท้ายของโมเดล เพื่อทำ upsampling โดยใช้ deconvolution หรือ up convolution หรือ transposed convolution หรือ fractional stride convolution (มีหลายชื่อ) เพื่อให้ output มีขนาดใหญ่ขึ้นจนกลับมาเป็นรูปภาพอีกครั้ง ตามภาพด้านล่าง

เปรียบเทียบโครงสร้าง CNN (บน) กับ FCN (ล่าง)

ต่อมาได้มีการทำโครงสร้างของ FCN มาปรับปรุงเพื่อให้สามารถทำงานได้แม่นยำขึ้นถึงแม้จะมีรูปภาพสำหรับ train น้อย ซึ่งโมเดลนั้นก็คือ “U-Net”

2. โครสร้างของ U-Net

จากภาพด้านล่าง เห็นได้ว่าโครงสร้างของโมเดลเป็นรูปตัว U ดังนั้นจึงเป็นที่มาของชื่อ “U-Net”, U-Net ประกอบด้วย 2 ส่วนคือ

  1. Contracting Path (encoder) คือส่วนโค้งลง
  2. Expanding Path (decoder) คือส่วนโค้งขึ้น
โครงสร้าง U-Net

2.1 Contracting Path (encoder)

หรือ downsampling หลักการเหมือนกับใน CNN คือดึง feature ที่สำคัญผ่าน convolution layer และทำให้ข้อมูลเล็กลงโดยใช้ pooling layer วิธีการคำนวณขนาด, ค่าที่อยู่ใน output รวมถึงการเขียนออกมาโดยใช้ PyTorch เหมือนกับที่อธิบายไว้ในส่วนของ CNN

encoder สามารถแบ่งออกมาได้ 5 blocks ตามจำนวน filter ได้แก่ 64, 128, 256, 512 และ 1024 ตามภาพด้านล่าง

2.2 Expanding Path (decoder)

หรือ upsampling คือการสร้างรูปภาพใหม่ขึ้นมาอีกครั้ง (reconstruction) จากรูปภาพขนาดเล็ก มี 3 ขั้นตอนย่อยคือ

  1. ทำให้ข้อมูลขนาดใหญ่ขึ้น (จำนวน channel เท่าเดิม แต่เพิ่ม height กับ width)
  2. นำไปรวมกับข้อมูลฝั่ง encoder ผ่าน skip connection
  3. นำข้อมูลเข้า convolutional layer

1. ทำให้ข้อมูลขนาดใหญ่ขึ้น

ถ้าเป็น U-Net แบบดั้งเดิม (ที่ถูกเผยแพร่เมื่อปี 2015) ใช้วิธี up convolution หรือ transpose convolution แต่ถ้าเป็น U-Net สมัยใหม่มักใช้วิธีอื่นเพื่อลดข้อผิดพลาดในการ upsampling เช่น bilinear interpolation ด้านล่างต่อไปนี้จะอธิบายทั้งวิธี up convolution และ bilinear interpolation

Up convolution

คือการนำ kernel คำนวณ output ออกมา, parameter ที่สำคัญก็เช่นเดียวกับของ convolutional layer ที่อยู่ใน CNN สูตรหาขนาดของ output ที่ออกมาคือ

output size = (input size -1)×stride - 2×padding + kernel size

ซึ่งตาม paper ใช้ kernel ขนาด 2 × 2

ตัวอย่างการคำนวณ output โดยใช้ up convolution

ตัวอย่างที่ 1

กำหนด input ขนาด 2 × 2 มีค่าด้านในคือ

กำหนด kernel ขนาด 2 × 2 มีค่าด้านในคือ

กำหนด stride = 1, padding = 0

ได้ว่าขนาดของ output ที่จะออกมาคือ (2-1)×1- 2×0 + 2 = 3

Step1

เริ่มต้นกำหนดให้ค่าใน output คือ 0 ทั้งหมด

Step2

นำค่าใน input มุมซ้ายสุดคูณกับ kernel และนำผลลัพธ์ไปรวมกับค่าใน output โดยเริ่มจากมุมบนซ้ายสุด

Step3

นำ input ตัวถัดมาคูณกับ kernel และนำผลลัพธ์ไปรวมกับค่าใน output โดยตำแหน่งของ output ที่บวกด้วยเลื่อนไปทางซ้ายตตามค่า stride ซึ่งในที่นี้คือ 1 ดังนั้นเลื่อนไปทางขวา 1 ช่อง

Step4

ทำเหมือนกับ step3 ในกรณีที่ใน output ไม่สามารถเลื่อนไปทางขวาตามขนาด stride ได้แล้วก็เลื่อนลงด้านล่างตามขนาด stride ซึ่งในที่นี้คือ 1 ดังนั้นเลื่อนลงล่าง 1 ช่อง

Step5

ทำแบบเดิมไปเรื่อยๆจน ทุกค่าใน input ถูกคำนวณ

จากตัวอย่างด้านบนสามารถใช้ PyTorch เขียนออกมาได้ว่า

import torch
import torch.nn as nn

# ใช้ nn.ConvTranspose2d สร้าง up convolution สำหรับ 2D CNN
up_conv = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0)

# กำหนดค่าใน input
input_tensor = torch.tensor([[[[0., 3.],
[1., 2.]]]])

# กำหนดค่าที่อยู่ใน kernel
up_conv.weight.data = torch.tensor([[[[0., 2.],
[1., 1.]]]])

up_conv.bias.data = torch.tensor([0.0])

output_tensor = up_conv(input_tensor)

ได้ผลลัพธ์ออกมาคือ

tensor([[[[0., 0., 6.],
[0., 5., 7.],
[1., 3., 2.]]]], grad_fn=<ConvolutionBackward0>)

ตัวอย่างที่ 2

กำหนด input ขนาด 3 × 3 มีค่าด้านในคือ

กำหนด kernel ขนาด 2 × 2 มีค่าด้านในคือ

กำหนด stride = 1, padding = 1

ได้ว่าขนาดของ output ที่จะออกมาคือ (3–1)×1- 2×1 + 2 = 2

Step1

เหมือนกับตัวอย่างที่ 1 กำหนดให้ค่าใน output คือ 0 ทั้งหมด

Step2

ใช้วิธีเดียวกับตัวอย่าง1 คำนวณผลลัพธ์ระหว่าง input กับ kernel แต่เมื่อไปรวมกับ output เวลาทาบต้องเหลือขอบตามค่า padding ด้วย ในที่นี้ padding เท่ากับ 1 ดังนั้นเหลือขอบไว้ 1 ช่อง

Step3

นำ input ตัวถัดมาคูณกับ kernel และนำผลลัพธ์ไปรวมกับค่าใน output โดยตำแหน่งของ output ที่บวกด้วยเลื่อนไปทางซ้ายตตามค่า stride ซึ่งในที่นี้คือ 1 ดังนั้นเลื่อนไปทางขวา 1 ช่อง

Step4

ทำแบบเดิมไปเรื่อยๆจน ทุกค่าใน input ถูกคำนวณ

จากตัวอย่างด้านบนสามารถใช้ PyTorch เขียนออกมาได้ว่า

import torch
import torch.nn as nn

up_conv = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=1)

input_tensor = torch.tensor([[[[1., 2., 0.],
[2., 1., 0.],
[2., 0., 1.]]]])


up_conv.weight.data = torch.tensor([[[[2., 0.],
[1., 2.]]]])

up_conv.bias.data = torch.tensor([0.0])

output_tensor = up_conv(input_tensor)

ได้ผลลัพธ์ออกมาคือ

tensor([[[[6., 4.],
[5., 4.]]]], grad_fn=<ConvolutionBackward0>)

เช่นเดียวกับ CNN ค่าใน kernel ที่เรากำหนดใน 2 ตัวอย่างด้านบนนั้น แท้จริงแล้วคือ weight ของโมเดลซึ่งจะมีการปรับเองระหว่าง train model

Bilinear interpolation

คือการหาค่าที่ไม่รู้จากค่าที่รู้ในตำแหน่งที่ใกล้กัน มีสูตรคือ กำหนดให้ จุด Q₁₁ อยู่ที่ (x₁, y₁), จุด Q₂₁ อยู่ที่ (x₂, y₁) ค่าที่ต้องการหาอยู่ที่จุด R₁ ซึ่งอยู่ตำแหน่ง (x, y₁) ตามภาพด้านล่าง

กำหดให้ f(Q₁₁) คือค่าที่อยู่บนจุด Q₁₁ และ f(Q₂₁) คือค่าที่อยู่บนจุด Q₂₁ ดังนั้ได้ว่าค่าที่อยู่บนจุด R₁ คือ

จากสูตรด้านบนเป็นกรณีที่ค่า y คงที่ แต่ถ้าในกรณีที่ x คงที่ก็สลับตำแหน่งระหว่าง x กับ y ตามสูตร

จาก paper, upsampling ของ U-Net ต้องการให้ output มีขนาดเป็น 2 เท่าของ input (factor =2) ดังนั้นตัวอย่างด้านล่างต่อไปนี้จะเสนองกรณีที่ได้ output มีขนาด 2 เท่าจาก input

ตัวอย่างการคำนวณ output โดยใช้ bilinear interpolation

ตัวอย่างที่ 1

กำหนด input ขนาด 2 × 2 มีค่าด้านในคือ

ดังนั้น output ที่ออกมามีขนาด 4 × 4

Step1

นำค่าทั้ง 4 ของ input อยู่ที่มุมทั้ง 4 ด้านของ output กำหนดเลข row และ column แทนพิกัดของและค่า

Step2

ใช้สูตรด้านบนคำนวณหาค่าในช่องที่ไม่ทราบค่า เช่นช่อง (C, R) = (2, 1) มีตำแหน่งที่ทราบค่าอยู่แล้วที่ (1, 1) คือ 1 และพิกัด (4, 1) คือ 2 ดังนั้นได้ [(4–2) / (4–1)]×1 + [(2–1) / (4–1)]×2 = 2/3 + 2/3 = 4/3 ≈ 1.33

Step3

ในการณีที่อยู่ column เดียวกัน แต่คนละ row เช่นต้องการทราบค่าของ (1, 2) ต้องใช้ค่าของ (1, 1) กับ (1, 4) ก็สลับค่า x, y ของสูตรด้านบน ได้คำตอบคือ [(4–2) / (4–1)]×1 + [(2–1) / (4–1)]×3 = 2/3 + 3/3 = 5/3 ≈ 1.67

Step4

ค่าที่ไม่ได้อยู่บริเวณขอบก็ใช้หลักการเดียวกับ step 2 กับ 3 ในการค่า เช่นต้องการหาค่าที่ (2, 2) ก็สามารถใช้ค่าจาก (1, 2) กับ (4, 2) มาคำนวณได้ [(4–2) / (4–1)]×5/3 + [(2–1) / (4–1)]×8/3 = 10/9 + 8/9 = 18/9 = 2

Step5

คำนวณไปเรื่อยๆจนหาค่าครบทุกช่อง

จากตัวอย่างด้านบนสามารถใช้ PyTorch เขียนออกมาได้ว่า

import torch
import torch.nn.functional as F

input_tensor = torch.tensor([[[[1., 2.],
[3., 4.]]]])
# ใช้ torch.nn.functional เพื่อคำนวณ bilinear interpolation เมื่อ factor=2
output_tensor = F.interpolate(input=input_tensor, scale_factor=2, mode="bilinear", align_corners=True)

ได้ผลลัพธ์ออกมาคือ

tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
[1.6667, 2.0000, 2.3333, 2.6667],
[2.3333, 2.6667, 3.0000, 3.3333],
[3.0000, 3.3333, 3.6667, 4.0000]]]])

ตัวอย่างที่ 2

กำหนด input ขนาด 3 × 3 มีค่าด้านในคือ

ดังนั้น output ที่ออกมามีขนาด 6 × 6

Step1

คล้ายกับตัวอย่างที่ 1 นำค่าของ input อยู่ที่มุมของ output แต่เรื่องจาก output ที่ออกมามีขนาดเป็นเลขคู่ ดังนั้นค่าที่อยู่ตรงกลางของ input จึงไม่มีช่องให้อยู่ใน output ดังนั้นกำหนดให้ค่า input เหล่านั้นอยู่ในตำแหน่งกึ่งกลางระหว่างมุม ทั้ง 2 ด้าน

Step2

ใช้วิธีเดียวกับตัวอย่าง1, หาค่าในตำแหน่งที่ยังไม่ทราบค่า เช่นต้องการทราบค่าของพิกัด (2, 1) เพราะว่าตำแหน่งที่ทราบค่าที่อยู่ใกล้ (2, 1) คือ (1, 1) และ (3.5, 1) ดังนั้นคำนวณออกมาได้ [(3.5–2) / (3.5–1)]×1 + [(2–1) / (3.5–1)]×0 = 1.5/2.5 = 0.6

Step3

คำนวณไปเรื่อยๆจนหาค่าครบทุกช่อง

จากตัวอย่างด้านบนสามารถใช้ PyTorch เขียนออกมาได้ว่า

import torch
import torch.nn.functional as F

input_tensor = torch.tensor([[[[1., 0., 2.],
[3., 1., 2.],
[0., 2., 1.]]]])

output_tensor = F.interpolate(input=input_tensor, scale_factor=2, mode="bilinear", align_corners=True)

ได้ผลลัพธ์ออกมาคือ

tensor([[[[1.0000, 0.6000, 0.2000, 0.4000, 1.2000, 2.0000],
[1.8000, 1.2400, 0.6800, 0.7200, 1.3600, 2.0000],
[2.6000, 1.8800, 1.1600, 1.0400, 1.5200, 2.0000],
[2.4000, 1.9200, 1.4400, 1.3200, 1.5600, 1.8000],
[1.2000, 1.3600, 1.5200, 1.5600, 1.4800, 1.4000],
[0.0000, 0.8000, 1.6000, 1.8000, 1.4000, 1.0000]]]])

2. นำไปรวมกับข้อมูลฝั่ง encoder ผ่าน skip connection

เนื่องจากระหว่าง encoder อาจมีบาง feature ที่หายไป ดังนั้นจึงจำเป็นต้องใช้ skip connection เพื่อรื้อฟื้นมาอีกครั้ง โดย skip connection ถูกกล่าวถึงครั้งแรกในโมเดล Residual Network (ResNet)

U-Net ฝั่ง decoder สามารถแบ่งเป็น 5 block เช่นเดียวกับ encoder ตามภาพด้านล่าง

โดย encoder แต่ละ block จะเชื่อมกับ decoder ที่อยู่ block เดียวกัน ผ่าน skip connection

จากภาพด้านล่างเห็นได้ว่าขนาดของรูปภาพฝั่ง encoder ใหญ่กว่าฝั่ง decoder เพราะว่า U-Net ใน paper ซึ่งคือ U-Net แบบดั้งเดิม ไม่มีการทำ padding ทั้งใน encoder และใน decoder แต่ถ้าเป็น U-Net สมัยใหม่บางโมเดลก็จะมีการทำ padding ทั้งสองส่วน เพื่อควบคุมให้รูปภาพ 2 ฝั่งมีขนาดเท่ากัน

วิธีนำข้อมูลฝั่ง encoder มาเชื่อมกับ decoder ทั้งที่ขนาด 2 ฝั่งไม่เท่ากับคือตัดเอาเฉพาะภาพตรงกลาง (center crop) ของ encoder ให้ออกมาขนาดเท่ากับ decoder แล้วค่อยเอามาต่อกัน

ยกตัวอย่างกรณี block 1 ฝั่ง encoder มีรูปร่าง (64, 568, 568) ฝั่ง decoder หลังไปรวมกับ encoder ได้รูปร่างคือ (128, 392, 392) สามารถใช้ PyTorch เขียนออกมาได้ว่า

import torch
import torch.nn as nn

# สร้าง function สำหรับทำ center crop
def center_crop(encoder, target_size):
b, c, h, w = encoder.shape
tar_h, tar_w = target_size

start_h = (h-tar_h) // 2
start_w = (w-tar_w) // 2

cropped_encoder = encoder[:, :, start_h:start_h+tar_h, start_w:start_w+tar_w]
return cropped_encoder

# สร้าง encoder และ decoder จำลองขึ้นมา โดยสุ่มค่าที่อยู่ด้านใน
# จากภาพด้านบน encoder รูปร่าง (B, C, H, W) = (1, 64, 568, 568)
encoder = torch.randn(1, 64, 568, 568)
# จาพภาพด้านบน ขนาด decoder คือ (H, W) = (392, 392) กำหนดให้มี 64 filter เพราะเมื่อนำไปต่อกับ encoder จะได้มี 64+64=128 filters ตามภาพด้านบน
decoder = torch.randn(1, 64, 392, 392)

cropped_encoder = center_crop(encoder, decoder.shape[2:])

# ใช้ torch.cat เพื่อนำ encoder ต่อกับ decoder
out_tensor = torch.cat((cropped_encoder, decoder), dim=1)
print(out_tensor.shape)

ได้ผลลัพธ์ออกมาคือ

torch.Size([1, 128, 392, 392])

3. นำข้อมูลเข้า convolutional layer

หลังจาก upsampling และรวมกับ encoder แล้ว ก็ยังคงต้องเข้า convolutional layer อีกเพื่อดึง feature ที่สำคัญออกมาเช่นเดิม หลักการของ layer นี้เหมือน convolutional layer ที่อยู่ใน CNN

3. ตัวอย่างการสร้าง U-Net ด้วย PyTorch อย่างง่าย

เราจะสร้างทั้งหมด 2 โมเดล คือ U-Net แบบดั้งเดิมตามที่อยู่ใน paper กับ U-Net สมัยใหม่ที่ปรับปรุงโครงสร้างบางส่วนเพื่อเพิ่มประสิทธิภาพ โดยสร้างเป็น class ย่อยๆของแต่ละส่วน แล้วค่อยเอามารวมกัน

3.1 U-Net แบบดั้งเดิม

Double Convolution

import torch.nn as nn

class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0),
nn.ReLU(inplace=True),

nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.double_conv(x)

Down

import torch.nn as nn

class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.down = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
# นำ class DoubleConv ที่เคยสร้างไว้มาใช้ต่อเลย
DoubleConv(in_channels=in_channels, out_channels=out_channels)

)

def forward(self, x):
return self.down(x)

Up

import torch
import torch.nn as nn

class Up(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=2, stride=2, padding=0)
self.double_conv = DoubleConv(in_channels=in_channels, out_channels=out_channels)

def center_crop(self, encoder, target_size):
h, w = encoder.shape[2:]
tar_h, tar_w = target_size

start_h = (h - tar_h) // 2
start_w = (w - tar_w) // 2

cropped_encoder = encoder[:, :, start_h:start_h + tar_h, start_w:start_w + tar_w]
return cropped_encoder

def forward(self, x1, x2):
x1 = self.up_conv(x1)

x2 = self.center_crop(x2, x1.shape[2:])
return self.double_conv(torch.cat((x2, x1), dim=1))

Out Convolution

import torch.nn as nn

class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)

def forward(self, x):
return self.conv(x)

สุดท้ายเอาทุกส่วนมารวมกันเป็น U-Net

class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super().__init__()
self.inc = DoubleConv(in_channels=n_channels, out_channels=64)
self.down1 = Down(in_channels=64, out_channels=128)
self.down2 = Down(in_channels=128, out_channels=256)
self.down3 = Down(in_channels=256, out_channels=512)
self.down4 = Down(in_channels=512, out_channels=1024)

self.up1 = Up(in_channels=1024, out_channels=512)
self.up2 = Up(in_channels=512, out_channels=256)
self.up3 = Up(in_channels=256, out_channels=128)
self.up4 = Up(in_channels=128, out_channels=64)
self.outc = OutConv(in_channels=64, out_channels=n_classes)

def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)

x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)

return self.outc(x)

เมื่อใช้ torchsummary.summary แสดง output ของแต่ละ layer จะเห็นว่าตรงกับโครงสร้างในรูปภาพด้านบน

from torchsummary import summary

model = UNet(n_channels=1, n_classes=2)
summary(model, (1, 572, 572))

ได้ผลลัพธ์คือ

----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 570, 570] 640
ReLU-2 [-1, 64, 570, 570] 0
Conv2d-3 [-1, 64, 568, 568] 36,928
ReLU-4 [-1, 64, 568, 568] 0
DoubleConv-5 [-1, 64, 568, 568] 0
MaxPool2d-6 [-1, 64, 284, 284] 0
Conv2d-7 [-1, 128, 282, 282] 73,856
ReLU-8 [-1, 128, 282, 282] 0
Conv2d-9 [-1, 128, 280, 280] 147,584
ReLU-10 [-1, 128, 280, 280] 0
DoubleConv-11 [-1, 128, 280, 280] 0
Down-12 [-1, 128, 280, 280] 0
MaxPool2d-13 [-1, 128, 140, 140] 0
Conv2d-14 [-1, 256, 138, 138] 295,168
ReLU-15 [-1, 256, 138, 138] 0
Conv2d-16 [-1, 256, 136, 136] 590,080
ReLU-17 [-1, 256, 136, 136] 0
DoubleConv-18 [-1, 256, 136, 136] 0
Down-19 [-1, 256, 136, 136] 0
MaxPool2d-20 [-1, 256, 68, 68] 0
Conv2d-21 [-1, 512, 66, 66] 1,180,160
ReLU-22 [-1, 512, 66, 66] 0
Conv2d-23 [-1, 512, 64, 64] 2,359,808
ReLU-24 [-1, 512, 64, 64] 0
DoubleConv-25 [-1, 512, 64, 64] 0
Down-26 [-1, 512, 64, 64] 0
MaxPool2d-27 [-1, 512, 32, 32] 0
Conv2d-28 [-1, 1024, 30, 30] 4,719,616
ReLU-29 [-1, 1024, 30, 30] 0
Conv2d-30 [-1, 1024, 28, 28] 9,438,208
ReLU-31 [-1, 1024, 28, 28] 0
DoubleConv-32 [-1, 1024, 28, 28] 0
Down-33 [-1, 1024, 28, 28] 0
ConvTranspose2d-34 [-1, 512, 56, 56] 2,097,664
Conv2d-35 [-1, 512, 54, 54] 4,719,104
ReLU-36 [-1, 512, 54, 54] 0
Conv2d-37 [-1, 512, 52, 52] 2,359,808
ReLU-38 [-1, 512, 52, 52] 0
DoubleConv-39 [-1, 512, 52, 52] 0
Up-40 [-1, 512, 52, 52] 0
ConvTranspose2d-41 [-1, 256, 104, 104] 524,544
Conv2d-42 [-1, 256, 102, 102] 1,179,904
ReLU-43 [-1, 256, 102, 102] 0
Conv2d-44 [-1, 256, 100, 100] 590,080
ReLU-45 [-1, 256, 100, 100] 0
DoubleConv-46 [-1, 256, 100, 100] 0
Up-47 [-1, 256, 100, 100] 0
ConvTranspose2d-48 [-1, 128, 200, 200] 131,200
Conv2d-49 [-1, 128, 198, 198] 295,040
ReLU-50 [-1, 128, 198, 198] 0
Conv2d-51 [-1, 128, 196, 196] 147,584
ReLU-52 [-1, 128, 196, 196] 0
DoubleConv-53 [-1, 128, 196, 196] 0
Up-54 [-1, 128, 196, 196] 0
ConvTranspose2d-55 [-1, 64, 392, 392] 32,832
Conv2d-56 [-1, 64, 390, 390] 73,792
ReLU-57 [-1, 64, 390, 390] 0
Conv2d-58 [-1, 64, 388, 388] 36,928
ReLU-59 [-1, 64, 388, 388] 0
DoubleConv-60 [-1, 64, 388, 388] 0
Up-61 [-1, 64, 388, 388] 0
Conv2d-62 [-1, 2, 388, 388] 130
OutConv-63 [-1, 2, 388, 388] 0
================================================================
Total params: 31,030,658
Trainable params: 31,030,658
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 2683.55
Params size (MB): 118.37
Estimated Total Size (MB): 2803.17
----------------------------------------------------------------

3.2 U-Net สมัยใหม่

ตัวอย่างโครงสร้างของ U-Net สมัยใหม่

U-Net สมัยใหม่ได้รับการต่อยอดจาก U-Net แบบดั้งเดิมที่ถูกเผยแพร่ตั้งแต่ปี 2015 ซึ่งมีโครงสร้างแต่ต่างกันตามวัตุประสงค์การใช้งาน ไม่ว่าจะเป็นจำนวน layer, วิธีการทำ upsampling, การเชื่อมของ skip connection แต่สิ่งสำคัญที่ต่อยอดจาก U-Net แบบดั้งเดิงคือการทำ normalization ระหว่าง layer และใช้ padding ใน convolutional layer สามารถดูตัวอย่าง U-Net สมัยใหม่ใน github นี้ ซึ่งแต่ละส่วนของ U-Net โมเดลนี้ ได้แก่

Double Convolution

class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""

def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
# ทำ padding ใน convolutional layer
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
# ทำ batch normalization หลังออกจาก convolutional layer
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.double_conv(x)

Down

class Down(nn.Module):
"""Downscaling with maxpool then double conv"""

def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)

def forward(self, x):
return self.maxpool_conv(x)

Up

class Up(nn.Module):
"""Upscaling then double conv"""

def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()

# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]

x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])

x = torch.cat([x2, x1], dim=1)
return self.conv(x)

Out Convolution

class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

def forward(self, x):
return self.conv(x)

สุดท้ายเอาทุกส่วนมารวมกันเป็น U-Net

class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear

self.inc = (DoubleConv(n_channels, 64))
self.down1 = (Down(64, 128))
self.down2 = (Down(128, 256))
self.down3 = (Down(256, 512))
factor = 2 if bilinear else 1
self.down4 = (Down(512, 1024 // factor))
self.up1 = (Up(1024, 512 // factor, bilinear))
self.up2 = (Up(512, 256 // factor, bilinear))
self.up3 = (Up(256, 128 // factor, bilinear))
self.up4 = (Up(128, 64, bilinear))
self.outc = (OutConv(64, n_classes))

def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits

Nuttaset kuapanich
Nuttaset kuapanich

Written by Nuttaset kuapanich

กำลังศึกษาระดับปริญญาตรี คณะปัญญาประดิษฐ์ มหาวิยาลัยซุนยัดเซ็น Email: kuapanich@mail2.sysu.edu.cn

No responses yet

Write a response