当前位置:网站首页>[code Notes] - u-net
[code Notes] - u-net
2022-07-22 15:59:00 【chaikeya】
Catalog
【 Paper notes 】—U-Net—2015-MICCAI
U-Net Model
Using bilinear interpolation instead of transpose convolution in the original paper .
unet.py Network construction
1、 First , Define a DoubleConv modular , It contains two ( Convolution +BN+Relu) The combination of .
Parameters :in_channels,out_channels, There is also an intermediate parameter mid_channels;
Corresponding to the operation of two blue arrows in the network .
contracting path:
expanding path:
2、 Define a Down modular , It contains ( Down sampling +DoubleConv).
Parameters :in_channels,out_channels.
3、 Define a Up modular , It contains ( On the sampling + context Splicing + Two layers of convolution )
Parameters :in_channels,out_channels, Default bilinear=True( When upsampling, bilinear interpolation is used instead of transpose convolution )
In the original paper, the sampling method uses transpose convolution .
Bilinear interpolation : Get the characteristic matrix channel( In the figure 1) and context Stitched characteristic matrix channel( In the figure 2) Agreement .
4、 Define a OutConv modular , Contains a 1x1 The convolution of layer , There is no activation function .
Parameters :in_channels,,num_classes( The number of classification categories in the segmentation task )
5、 Definition UNet Model
initialization , assignment ,DoubleConv,Down,Down,Down,Down,Up,Up,Up,Up,OutConv.
unet.py Code
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Sequential):
def __init__(self, in_channels, out_channels, mid_channels=None):
if mid_channels is None: # If the middle tier is not defined channel And the output layer channel equally
mid_channels = out_channels
# Call initial Method .
# Join in BN Layer so bias=0.
super(DoubleConv, self).__init__(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
class Down(nn.Sequential):
def __init__(self, in_channels, out_channels):
# Down sampling maxpool,kernel_size=2,stride=2.
super(Down, self).__init__(
nn.MaxPool2d(2, stride=2),
DoubleConv(in_channels, out_channels)
)
# On the sampling + context Splicing + Two layers of convolution
class Up(nn.Module):
# in_channels Corresponding context The number of feature layers after splicing ,bilinear Whether to use bilinear interpolation instead of transpose convolution
def __init__(self, in_channels, out_channels, bilinear=True):
super(Up, self).__init__()
if bilinear: # Bilinear interpolation , Get the characteristic matrix channel and context Stitched characteristic matrix channel Agreement
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else: # Transposition convolution
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
# x2: Stitched feature layer
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
# [N, C, H, W]
diff_y = x2.size()[2] - x1.size()[2] # Difference in height direction
diff_x = x2.size()[3] - x1.size()[3] # Difference in width
# pad Ensure that the pictures after sampling and context The height and width of the stitched picture are equal , And is 16 Integer multiple
# padding_left, padding_right, padding_top, padding_bottom
x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
diff_y // 2, diff_y - diff_y // 2])
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class OutConv(nn.Sequential):
def __init__(self, in_channels, num_classes):
super(OutConv, self).__init__(
nn.Conv2d(in_channels, num_classes, kernel_size=1)
)
class UNet(nn.Module):
def __init__(self,
in_channels: int = 1, # Instantiation time , Incoming color pictures =3
num_classes: int = 2, # The number of classification categories in the segmentation task
bilinear: bool = True, # bilinear Whether to use bilinear interpolation instead of transpose convolution
base_c: int = 64): # The number of convolution kernels used in the first convolution layer
super(UNet, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.bilinear = bilinear
self.in_conv = DoubleConv(in_channels, base_c)
self.down1 = Down(base_c, base_c * 2)
self.down2 = Down(base_c * 2, base_c * 4)
self.down3 = Down(base_c * 4, base_c * 8)
# bilinear=true,bilinear=2, Bilinear interpolation : The fourth one down Module inputs and outputs channel unchanged
# bilinear=false,bilinear=1, Transpose convolution : The fourth one down Output module channel Twice the input
factor = 2 if bilinear else 1
self.down4 = Down(base_c * 8, base_c * 16 // factor)
self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
self.up4 = Up(base_c * 2, base_c, bilinear)
self.out_conv = OutConv(base_c, num_classes)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
x1 = self.in_conv(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.out_conv(x) # 1x1 Convolution
return {"out": logits} # The form of a dictionary returns
nn.Sequential
In short ,nn.Sequential() You can package a series of operations , These operations can include Conv2d()、ReLU()、Maxpool2d() etc. , It's convenient to call after packaging , It's like a black box ,forward() Just call the black box when .
边栏推荐
- 记忆化搜索
- Monte Carlo tree search (MCTS) explanation
- 18. What is the persistence mechanism of redis? Respective advantages and disadvantages?
- 单片机是如何工作的?
- 牛血清白蛋白-铂复合纳米材料/HSA-Pc NPs人血清白蛋白(HSA)包裹酞菁分子纳米粒
- 【精讲】Es6 导入 import, 导出 export等多种操作
- C语言力扣第38题之外观数列。三种方法(遍历法、递归法与狼灭法)
- PPy-HSA导电聚合物聚吡咯PPy物质BSA白米纳米粒/白蛋白包覆纳米脂质载体BSA NLC的研究制备
- GeneralizedRCNN:features = OrderedDict([(“0“, features)])
- QT笔记——自定义的QListWidget
猜你喜欢
SSTI简单总结和CISCN 2019华东南]Double Secret
社交电商:链动2+1-入口快速裂变的模式
【代码笔记】—U-Net
Comparison between deep convolution and ordinary convolution
How to deal with tough interview questions
LeetCode高频题:二叉树的锯齿形(Z字形,之字形)层序遍历
Monopoly of Web3 social protocol and soul binding token
分布式事务,原理简单,写起来全是坑
面试刁钻问题应对思路
为什么我们开发的系统会有并发Bug,并发Bug根源到底是什么?
随机推荐
Shallow copy, deep copy (implementation mode)
命令行程序测试自动化
常见的probe set和gallery set究竟是什么
Check for degenerate boxes
OPENCN学习DAY3
MySQL advanced addition, deletion, modification and query operations, including detailed explanation and operation of various queries and table design explanation
C语言动态分配内存
EACCES: permission denied, unlink ‘/usr/local/bin/code‘
[medical image segmentation] using deep learning: a survey
Albumin nanoparticles / gossypol albumin nanoparticles / lapatinib albumin nanoparticles coated with DNA and photosensitizer CE6
QT笔记——自定义的QListWidget
How to improve the efficiency of test case review?
HA自动故障转换(active)namenode的大问题
How does SCM work?
3年测试在职经验,面试测试岗连20k都拿不到了吗?有这么坑?
JVM经典面试20问
DOM之12种节点
拉动日活,使用云函数群发微信小程序订阅消息
Preparation of tiniposide multilayer coated albumin nanoparticles / human serum albumin polycaprolactone nanoparticles
C language outputs the number of all daffodils