1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
| import torch import torch.nn as nn
class BasicBlock(nn.Module): expansion = 1 def __init__(self, input_channel, channel, stride): super(BasicBlock, self).__init__() self.downsample = lambda x: x if(input_channel != channel): self.downsample = nn.Sequential( nn.Conv2d(in_channels = input_channel, out_channels = channel, kernel_size = 1, stride = stride, bias = False), nn.BatchNorm2d(channel) ) self.relu = nn.ReLU(inplace = True) self.convlayers = nn.Sequential( nn.Conv2d(in_channels = input_channel, out_channels = channel, kernel_size = 3, stride = stride, padding = 1, bias = False), nn.BatchNorm2d(channel), nn.ReLU(inplace = True), nn.Conv2d(in_channels = channel, out_channels = channel, kernel_size = 3, stride = 1, padding = 1, bias = False), nn.BatchNorm2d(channel) ) def forward(self, x): out = self.downsample(x) + self.convlayers(x) out = self.relu(out) return out
class Bottleneck(nn.Module): expansion = 4 def __init__(self, input_channel, channel, stride, expansion = 4): super(Bottleneck, self).__init__() self.expansion = expansion output_channel = channel * expansion self.downsample = lambda x: x if(input_channel != output_channel): self.downsample = nn.Sequential( nn.Conv2d(in_channels = input_channel, out_channels = output_channel, kernel_size = 1, stride = stride, bias = False), nn.BatchNorm2d(output_channel) ) self.relu = nn.ReLU(inplace = True) self.convlayers = nn.Sequential( nn.Conv2d(in_channels = input_channel, out_channels = channel, kernel_size = 1, stride = 1, bias = False), nn.BatchNorm2d(channel), nn.ReLU(inplace = True), nn.Conv2d(in_channels = channel, out_channels = channel, kernel_size = 3, stride = stride, padding = 1, bias = False), nn.BatchNorm2d(channel), nn.ReLU(inplace = True), nn.Conv2d(in_channels = channel, out_channels = output_channel, kernel_size = 1, stride = 1, bias = False), nn.BatchNorm2d(output_channel) ) def forward(self, x): out = self.downsample(x) + self.convlayers(x) out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, block_nums, input_channel, class_num): super(ResNet, self).__init__() self.stacklayers = nn.Sequential( nn.Conv2d(in_channels = input_channel, out_channels = 64, kernel_size = 7, stride = 2, padding = 3, bias = False), nn.BatchNorm2d(64), nn.ReLU(inplace = True), nn.MaxPool2d(kernel_size = 3, stride = 2), self.make_layers(block = block, input_channel = 64, channel = 64, stride = 1, block_num = block_nums[0]), self.make_layers(block = block, input_channel = 64 * block.expansion, channel = 128, stride = 2, block_num = block_nums[1]), self.make_layers(block = block, input_channel = 128 * block.expansion, channel = 256, stride = 2, block_num = block_nums[2]), self.make_layers(block = block, input_channel = 256 * block.expansion, channel = 512, stride = 2, block_num = block_nums[3]), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(512*block.expansion, class_num) ) def make_layers(self, block, input_channel, channel, stride, block_num, expansion = 4, reduction = 16): layers = [] layers.append(block(input_channel, channel, stride)) input_channel = channel * block.expansion for _ in range(1, block_num): layers.append(block(input_channel, channel, 1)) return nn.Sequential(*layers) def forward(self, x): out = self.stacklayers(x) return out def ResNet_18(input_channel, class_num = 1000): return ResNet(BasicBlock, [2,2,2,2], input_channel, class_num)
def ResNet_34(input_channel, class_num = 1000): return ResNet(BasicBlock, [3,4,6,3], input_channel, class_num)
def ResNet_50(input_channel, class_num = 1000): return ResNet(Bottleneck, [3,4,6,3], input_channel, class_num)
def ResNet_101(input_channel, class_num = 1000): return ResNet(Bottleneck, [3,4,23,3], input_channel, class_num)
def ResNet_152(input_channel, class_num = 1000): return ResNet(Bottleneck, [3,8,36,3], input_channel, class_num)
|