매직코드
article thumbnail

선행논문 : FCN > (R-CNN) > (Fast R-CNN) > Faster R-CNN > FPN > Mask R-CNN

Mask R-CNN을 이해하기 위해서는 이전에 발표되었는 FCN, Faster R-CNN (ROI), FPN의 개념들과

Object Detection / Semantic Segmentation / Instance Segmentation의 차이를 알아야 한다.

논문을 선정한 이유

  • instance segmentation 에서 가장 많이 활용하는 모델
  • object detection과 sementic segmentation의 장점을 합친 업그레이드 된 모델

논문 읽기

Abstract

 

Mask R-CNN은 Faster R-CNN에 mask를 생성하는 브랜치를 병렬적으로 추가하는 방법이다.

기존 task에 일반적으로 잘 적용할 수 있는 장점이 있다.

 

 

Introduction

<요약>

  • instance segmentation을 빠르게 추론하는 프레임워크 개발이 목표
  •  Faster R-CNN에 mask branch를 추가한 구조
  • RoI Pooling 대신 RoI Align 기법 사용

 

object detection, semantic segmentation, instance segmentation을 나타낸 이미지 - 약초의숲으로놀러오세요

기존에 있는 object detection의 bounding box를 이용한 개별 객체 분류 요소와 semantic segmentation의 객체 인스턴스와 상관없이 픽셀단위로 카테고리를 분류하는 요소를 결합한다.

 

 

Mask R-CNN은 Faster R-CNN에 segmentation mask를 예측하는 mask branch를 추가한 구조다.

그래서 총 3가지의 branch를 가지게 된다.

  • Classification Branch : Faster R-CNN에서 얻은 RoI(Region of Interest)에 대해 객체의 class 예측
  • Bbox Regression Branch : bounding box 예측
  • Mask Branch : segmentation mask 예측

또한 Faster R-CNN과 다른점은 RoI Pooling 대신 RoI Align 기법을 사용했다는 점이다.

Faster R-CNN의 목적은 object detection으로 대상물체의 테두리만 얼추 맞게 찾아내는 모델이라 각 셀을 계산할 때 int형으로 반올림하여 대상물체의 위치를 잡아주는 RoI Pooling을 사용했다면, Mask R-CNN에서는 segmentation으로 대상물체의 위치에 정확하게 색칠해줘야 하기 때문에 각 셀의 모든 위치를 대상으로 계산할 수 있도록 RoI Align을 사용했다.

RoI Align은 아래 추가적으로 설명이 나온다.

 

논문에서는 실험 결과 Mask 예측과 Class 예측을 분리하는것이 필수적이라는 것을 발견했다.

기존의 방법대로 한번에 class를 예측하고 거기에 mask를 씌우는게 성능이 크게 하락했다고 하면서 class를 먼저 예측하고 그 class 안에 서 mask를 생성하는 방법을 추천한다.

 

Related Work

  • R-CNN: Faster R-CNN을 기반으로 RoI Align 사용
  • Instance Segmentation: Class와 Mask 예측을 병렬로 수행, segmentation 우선이 아니라 instance를 우선으로 설정

 

Mask R-CNN

Faster R-CNN

Faster R-CNN에서 첫번째 단계는 RPN(Region Proposal Network)라고 불리며 후보 객체에 대해서 bounding boxs를 제안한다.

첫번째 단계에서 출력된 bounding box 후보들은 두번째 단계에서 RoI Pooling을 사용하여 특징을 추출하고 분류 및 회귀를 수행한다.

Mask R-CNN

Mask R-CNN은 Faster R-CNN의 첫번째단계, 두번째단계를 동일하게 채택한다.

두번째단계에서 Mask R-CNN은 각각의 RoI에 대해서 binary mask도 outputs으로 출력한다.

각 샘플의 RoI는 다음 수식과 같이 정의한다.

 

RoI 수식

Lcls는 클래스 분류 loss, Lbox는 bounding box loss를 의미하고 Lmask는 mask branch의 mask loss를 의미한다.

이 때 Lmask는 각각의 RoI에 대해 Km^2 차원의 output을 출력하는데, m*m 해상도를 가진 이미지에 대해 binary mask class가 K개 있다는 의미이다. K개의 binary mask는 픽셀당 sigmoid를 적용하고, Lmask는 average binary cross entropy loss를 적용한다.

 

Lmask는 모든 클래스에 대해서 경쟁없이 mask를 생성할 수 있어야 한다.

일반적으로 FCN(Fully Convolution Networks)를 적용할 때 픽셀당 softmax를 통해 loss를 계산하는데 이렇게되면 어떤 class가 몇%의 정답을 가져가는지 경쟁하게 된다. Mask R-CNN에서는 각 Class를 먼저 분류하고 그 안에서 mask를 찾기 때문에 위에서 언급한 바와 같이 sigmoid를 이용하여 mask인지 아닌지 판단한다.

 

mask rcnn branch 순서

 

Mask Representation

mask는 공간을 인코딩하기 때문에 짧은 출력 벡터로 축소되는 fully connected layer(fc layer)가 아니라 컨볼루션을 사용하여 pixel to pixel로 대응할 수 있도록 FCN를 이용한다.

FCN은 m*m 사이즈의 mask의 RoI에 대해서 적용한다. 그런데 FCN을 잘 적용하기 위해서는 RoI가 정확해야하고 이를 위해서 RoI Align 레이어를 만들었다.

RoI Align

기존의 RoI Pool은 RoI에서 작은 피쳐맵(주로 7*7 사이즈를 사용하는데 아래 그림에서는 2*2로 예시를 들었다)을 추출하기 위한 작업이다.

float로 표현된 좌표에 대해서 int로 그 값을 수정해주고(반올림이 수행됨) 작은 피쳐맵의 사이즈에 맞춰 RoI를 구분해준다.

이 때 당연히 숫자가 딱 맞아 떨어지게 구분할 수 없기 때문에 구분된 구역은 각각 크기가 다르고 구역별로 max pooling을 통해 작은 피쳐맵을 완성시킨다.

 

출처 유튜브 Taeoh Kim 채널 PR-057: Mask R-CNN

 

float로 이루어진 실제 위치에서 int로 변환한 위치로 이동하다 보니 위치가 딱 맞아떨어지지 않는다는 단점이 있었지만 이 개념은 object detection을 위해서 사용되었기 때문에 위치가 어느정도 빗나가도 큰 문제는 없지만 segmentation의 경우 찾고자 하는 대상의 위치에 딱 맞게 경계를 그어야 하기 때문에 정확한 위치값을 받는게 중요하다.

그래서 이 논문에서는 RoI Pool 대신 RoI Align을 제시한다.

 

 

여기서 quantization 이라는 단어가 나오는데 한국말로 하면 양자화다. 크흠...

양자화는 연속성이 있는 대상을 '양자'라는 단위를 사용하여 '정수'처리한다고 생각하면 된다.

RoI Align에서는 양자화를 수행하지 않는다고 했으니 RoI Pool처럼 정수처리를 하지 않고 픽셀값을 float로 사용한다고 생각하면 된다.

Figure 3에서 점선은 형상지도이고 실선은 2*2 RoI 를 표현한다. RoI의 1개 셀 안에는 4개의 점은 샘플링 포인트를 나타낸다.

점선으로 이루어진 형상지도 1셀의 꼭지점에서 실전으로 이루어진 RoI 내 샘플링 포인트까지 양방향 보간을 통해 샘플링 점 값을 계산한다.

 

출처 유튜브 Taeoh Kim 채널 PR-057: Mask R-CNN

RoI 사이즈가 2*2인데 1개의 셀을 2*2로 또 나눠서 서브셀을 만들어준다.

여기서 말하는 서브셀이 논문에서 말하는 4개의 샘플링 포인트와 동일한 역할을 한다.

각 서브셀 내부에서 bilinear interpolation(양방향 보간법)을 하고 그걸 더한 값이 모여서 하나의 서브셀 값을 만든다.

양방향 보간법을 통해 계산이 완료된 서브셀이 16개가 나오는데 이 때 각 셀에 대하여 max pooling을 통해 RoI를 구한다.

 

Network Architecture

여러가지 아키텍쳐로 Mask R-CNN을 인스턴스화 한다.

먼저 아래 두가지를 명확하게 구별한다.

1) 전체 이미지에 대한 feature extraction를 위해 사용된 컨볼루션 백본 아키텍쳐

2) 각 RoI에 개별적으로 적용되는 bounding box 인식 및 mask 예측을 위한 네트워크 헤드

 

Mask R-CNN을 구현하기 위해서 다음과 같은 백본을 사용한다.

1) Faster R-CNN + ResNet 백본 + fully convolutional mask prediction branch

2) Faster R-CNN + FPN 백본 + fully convolutional mask prediction branch

 

 

왼쪽에 있는 그림이 Faster R-CNN과 ResNet을 이용한 백본에 Mask Branch를 추가한 것이고, 오른쪽에 있는 그림이 Faster R-CNN과 FPN을 이용한 백본에 Mask Branch를 추가한 것이다. FPN은 ResNet을 이용해서 구현할 수 있기 때문에 논문에서 성능 평가하는 부분을 보면 ResNet-50-FPN 이라는 내용이 나오기도 하는데 이는 FPN을 ResNet으로 구현한 것으로 Faster R-CNN + ResNet과는 다르다.

 

Implementation Details

Training

최소 0.5 이상의 ground-truth IOU를 가진 RoI는 positive, 그렇지 않으면 negative로 간주한다.

mask loss인 Lmask는 positive RoI에만 적용한다.

 

image centric training으로 800pixes로 resize한다.

각각의 미니배치는 GPU당 2개의 이미지가 있고 각 이미지에는 N개의 sample RoI가 있다.

이 때 RoI 비율은 positive 1 : negative 3이다.

sample RoI 개수인 N은 백본이 ResNet C4인 경우 64개, FPN 백본인 경우 512개다.

160k iteration을 위해 8GPU를 사용한다. (따라서 minibatch는 16)

learning rate = 0.02, weight decay = 0.0001, momentum = 0.9

 

ResNext를 사용한 경우에 대한 디테일은 논문을 참고하면 된다.

 

Inference

Test를 수행할 때 ResNet-C4의 경우 제안되는 박스는 300개 정도이고, FPN의 경우에는 1000개정도 제안된다.

이렇게 제안된 것들에 대해서 box 예측을 수행한다.

그 다음 score가 가장 높은 100개의 detection box만이 mask branch에 적용된다.

 

mask branch는 RoI당 K개의 mask를 예측할 수 있지만, 논문에서는 k번째의 mask만 사용했다.

여기서 k는 classification branch에서 예측 분류 된 클래스다.

분류된 클래스에 대한 box 안에서 floating mask의 출력은 RoI 사이즈로 resized 되고, 0.5를 기준으로 이진화 된다.

 

 

Experiments: Instance Segmentation

COCO 데이터셋에 Mask R-CNN을 적용한 결과 매우 좋은 성능을 보였다.

또한 여러가지 경우에 대한 학습 결과를 비교하기도 했는데 자세한 내용은 논문을 참고하면 된다.

  • Multinomial vs Independent Masks
  • Class-Specific vs Class-Agnostic Masks
  • RoI Pool vs RoI Align
  • Mask Branch
  • Bounding Box Detection Results
  • Timing

 

Mask R-CNN for Human Pose Estimation

Mask R-CNN을 단순히 instance segmentation으로 끝내는 것이 아니라 사람의 자세 추정으로 확장할 수 있다.

키포인트 위치를 One-Hot mask로 모델링하고 각 키포인트에 대해 mask를 예측하기 위해 Mask R-CNN을 이용하는 형식이다.

Mask R-CNN의 확장성에 대해 이야기하는 부분으로 자세한 내용은 논문을 참고하길 바란다.

 

 

추가적으로 COCO 데이터 세트와 Cityscapes 데이터 세트에 대한 내용은 Appendix로 설명이 되어있다.

 

논문리뷰

저자가 뭘 해내고 싶어했는가?

segmentation을 수행할 때 object detection 처럼 각 객체별로 mask가 생성되게 하는 instance segmentation을 하는것

 

이 연구의 접근에서 중요한 요소는 무엇인가?

object detection의 장점과 (객체별로 class 탐지) semantic segmentation의 장점 (픽셀별로 라벨링)을 잘 조합하는것

따라서 object detection 처럼 객체를 먼저 탐지하고 그 안에서 segmentation을 수행하게 만들었다.

 

논문을 보고 느낀점?

object detection과 semantic segmentation만 나온 상황이라면 instance segmentation을 필요로 하는 경우가 분명히 존재했을 것이고, 그걸 새롭게 만들기 보다 기존에 나와있는 코드의 장점만 가져와서 재조합하여 만들었다는 점이 좋은 아이디어라고 생각한다.

 

어떤 프로젝트에 적용할 수 있는가?

각 객체가 독립적으로 파악되어야 하는 프로젝트(교통사고, 제조과정 등)에 적용하는 경우 원인파악 또는 책임소재를 명확하게 하는 근거가 될 수 있을 것 같다.

 

추가적으로 공부해야할 것 또는 참고하고 싶은 다른 레퍼런스에는 어떤 것이 있는가?

추가적이라기 보다 처음에 언급한 것 처럼 Mask R-CNN을 공부하기 전에 선행학습 되어야하는 논문들을 먼저 공부하고 시작하면 좋을 것 같다. 이후 논문리뷰는 Multi-Modal쪽으로 해보고 싶다.

 

 

코드 구현

코드를 한번에 구현해보려고 했으나 많은 부분이 resnet, faster rcnn, nms 등 기존 패키지에서 이것저것 가져와서 input 또는 output을 수정하여 사용하는 구조라 너무 복잡하여 코드 구현 부분은 건너뛰고 바로 실급으로 들어간다.

 

Mask R-CNN 실습

Mask R-CNN은 torchvision 패키지에 내장되어있어서 힘들게 하나씩 만들 필요는 없다.

데이터세트는 캐글 세포 데이터세트 다운로드 에서 받아왔다.

 

import warnings
warnings.filterwarnings('ignore')

import os
import time
import random
import collections

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

 

변수 설정

def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
fix_all_seeds(2021)
# 경로
TRAIN_CSV = '../data/paper_review/maskrcnn_cell/train.csv'
TRAIN_PATH = '../data/paper_review/maskrcnn_cell/train'
TEST_PATH = '../data/paper_review/maskrcnn_cell/test'

# 원본이미지 크기
WIDTH = 704
HEIGHT = 520

# True 인 경우 다른 조건을 주기 위함
TEST = False    
NORMALIZE = False
USE_SCHEDULER = False

# 데이터 전처리
resnet_mean = (0.485, 0.456, 0.406)
resnet_std = (0.229, 0.224, 0.225)

# model parameters
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
n_epochs = 10
batch_size = 16
momentum = 0.9
lr = 0.001
weight_decay = 0.0005

# mask rcnn 에서는 0.5를 기준으로 mask가 있는지 없는지 판단
mask_threshold = 0.5

# 이미지 당 detection 최대 갯수
# 세포 이미지의 경우 500대 단위를 많이 사용하는 편으로, 이미지에 따라 수정
box_detections_per_img = 539

# 겹치는 부분 score 기준
# 0 또는 0.5부터 가장 좋은 결과를 가질 때까지 수정
min_score = 0.5

 

Data 전처리 함수

class Compose:
    def __init__(self, transforms):
        self.transforms = transforms
        
    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target
# 이미지 플립
class VerticalFlip:
    def __init__(self, prob):
        self.prob = prob
        
    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-2)
            bbox = target['boxes']
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target['boxes'] = bbox
            target['masks'] = target['masks'].flip(-2)
        return image, target
    
class HorizontalFlip:
    def __init__(self, prob):
        self.prob = prob
        
    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target['boxes']
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target['boxes'] = bbox
            target['masks'] = target['masks'].flip(-1)
        return image, target
# 데이터 처리
class Normalize:
    def __call__(self, image, target):
        image = F.normalize(image, RESNET_MEAN, RESNET_STD)
        return image, target
    
class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target
    
def get_transform(train):
    transforms = [ToTensor()]
    if NORMALIZE:
        transforms.append(Normalize())
    
    # Data augmentation for train
    if train: 
        transforms.append(HorizontalFlip(0.5))
        transforms.append(VerticalFlip(0.5))

    return Compose(transforms)
# 마스크 표기하는 함수
# shape: (height, width)
# return 0:background, 1:mask

def rle_decode(annotation, shape, color=1):
    s = annotation.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

 

Train Dataset & DataLoader

class CellDataset(Dataset):
    def __init__(self, image_dir, df, transforms=None, resize=False):
        self.transforms = transforms
        self.image_dir = image_dir
        self.df = df
        self.should_resize = resize is not False
        
        if self.should_resize:
            self.height = int(HEIGHT * resize)
            self.width = int(WIDTH * resize)
        else:
            self.height = HEIGHT
            self.width = WIDTH
        
        self.image_info = collections.defaultdict(dict)
        temp_df = self.df.groupby('id')['annotation'].agg(lambda x: list(x)).reset_index()
        
        for index, row in temp_df.iterrows():
            self.image_info[index] = {'image_id': row['id'], 
                                      'image_path': os.path.join(self.image_dir, row['id']+'.png'),
                                      'annotations': row['annotation']}
            
    def get_box(self, a_mask):
        # 주어진 mask로부터 bbox 확보
        pos = np.where(a_mask)
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        return [xmin, ymin, xmax, ymax]

    def __getitem__(self, idx):
        img_path = self.image_info[idx]['image_path']
        img = Image.open(img_path).convert('RGB')

        if self.should_resize:
            img = img.resize((self.width, self.height), resample=Image.BILINEAR)

        info = self.image_info[idx]

        n_objects = len(info['annotations'])
        masks = np.zeros((len(info['annotations']), self.height, self.width), dtype=np.uint8)

        # bbox 좌표 얻기
        boxes = []
        for i, annotation in enumerate(info['annotations']):
            a_mask = rle_decode(annotation, (HEIGHT, WIDTH))
            a_mask = Image.fromarray(a_mask)

            if self.should_resize:
                a_mask = a_mask.resize((self.width, self.height), resample=Image.BILINEAR)

            a_mask = np.array(a_mask) > 0
            masks[i, :, :] = a_mask
            boxes.append(self.get_box(a_mask))

        # dummy lables
        labels = [1 for _ in range(n_objects)]

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((n_objects,), dtype=torch.int64)

        target = {'boxes': boxes,
                  'labels': labels,
                  'masks': masks,
                  'image_id': image_id,
                  'area': area,
                  'iscrowd': iscrowd}

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.image_info)
df_train = pd.read_csv(TRAIN_CSV, nrows=5000 if TEST else None)
ds_train = CellDataset(TRAIN_PATH, df_train, resize=False, transforms=get_transform(train=True))
dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=False, num_workers=2, collate_fn=lambda x: tuple(zip(*x)))

 

Model

def get_model():
    NUM_CLASSES = 2
    
    if NORMALIZE:
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True,
                                                                   box_detections_per_img=box_detections_per_img,
                                                                   image_mean=resnet_mean,
                                                                   image_std=resnet_std)
    else:
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True,
                                                                   box_detections_per_img=box_detections_per_img)
    
    # pretrained 모델의 필요한 부분을 가져오고 새로운 학습을 위해서 해서 설정해야 하는 부분 설정해주기
    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES)
    
    # get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # replace mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, NUM_CLASSES)
    
    return model
model = get_model()
model.to(device)

for param in model.parameters():
    param.requires_grad = True
    
model.train()
# pretrained 모델의 구조에서 roi head에 있는 box predictor와 mask predictor의 output이 바뀐걸 확인할 수 있음

mask rcnn pre-trained 모델 구조에서 input data에 맞춰 output 형태를 바꿔준 경우 모델 구조

 

Train

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=lr, momentum=momentum, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
n_batches = len(dl_train)
for epoch in range(n_epochs):
    
    time_start = time.time()
    loss_accum = 0.0
    loss_mask_accum = 0.0
    
    for batch_idx, (images, targets) in enumerate(dl_train, 1):
        
        # predic
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        loss_dict = model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        
        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # logging
        loss_mask = loss_dict['loss_mask'].item()
        loss_accum += loss.item()
        loss_mask_accum += loss_mask
        
        # if batch_idx % 10 == 0:
        #     print(f'Batch {batch_idx:3d}/{n_batches:3d}  Batch Train Loss {loss.item():7.3f}  Mask Only Loss {loss_mask:7.3f}')
            
    if USE_SCHEDULER:
        lr_scheduler.step()
        
    train_loss = loss_accum/n_batches
    train_loss_mask = loss_mask_accum/n_batches
    
    take_time = time.time()-time_start
    
    # torch.save(model.state_dict(), f'model_{epoch}.bin')
    prefix = f'Epoch {epoch:2d}/{n_epochs:2d}'
    print(f'{prefix} -- Train Loss {train_loss:.3f} -- Train Mask Only Loss {train_loss_mask:.3f} -- Take time {take_time:.0f}sec')
# train 시각화 함수
def analyze_train_sample(model, ds_train, sample_index=10):
    
    # sample image
    img, targets = ds_train[sample_index]
    plt.imshow(img.numpy().transpose((1, 2, 0)))
    plt.title('Sample Image')
    plt.show()
    
    # ground truth
    masks = np.zeros((HEIGHT, WIDTH))
    for mask in targets['masks']:
        masks = np.logical_or(masks, mask)
        
    plt.imshow(img.numpy().transpose((1, 2, 0)))
    plt.imshow(masks, alpha=0.3)
    plt.title('Ground Truth')
    plt.show()
    
    # pred
    model.eval()
    with torch.no_grad():
        preds = model([img.to(device)])[0]
        
    plt.imshow(img.cpu().numpy().transpose((1, 2, 0)))
    all_preds_masks = np.zeros((HEIGHT, WIDTH))
    for mask in preds['masks'].cpu().detach().numpy():
        all_preds_masks = np.logical_or(all_preds_masks, mask[0] > mask_threshold)
    plt.imshow(all_preds_masks, alpha=0.4)
    plt.title('Predictions')
    plt.show
    
    return preds
    
sample_vi = analyze_train_sample(model, ds_train, 20)

모델이 잘 학습되었는지 샘플 시각화

 

Test

처음 다운로드 받은 데이터세트에서 주어진 test 이미지는 총 3장이다.

class CellTestDataset(Dataset):
    def __init__(self, image_dir, transforms=None):
        self.transforms = transforms
        self.image_dir = image_dir
        self.image_ids = [f[:-4] for f in os.listdir(self.image_dir)]
        
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.image_dir, image_id + '.png')
        image = Image.open(image_path).convert('RGB')
        
        if self.transforms is not None:
            image, _ = self.transforms(image=image, target=None)
        return {'image': image, 'image_id': image_id}
    
    def __len__(self):
        return len(self.image_ids)
ds_test = CellTestDataset(TEST_PATH, transforms=get_transform(train=False))
ds_test[0]

test dataset에는 image 픽셀값과 image_id만 있음을 확인

 

모델학습 진행 후 많이 생성된 예측 mask들 중에 겹치는 부분은 제거하는 함수를 통해 가장 성능이 좋은 mask 하나만 남긴다.

아래 코드는 캐글 대회 제출용 코드로 결과를 dataframe 형식으로 담아낸다.

# get mask 
def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    
    for b in dots:
        if (b > prev+1):
            run_lengths.extend((b+1, 0))
            run_lengths[-1] += 1
            prev = b
            
    return ' '.join(map(str, run_lengths))

# 겹치는 부분 지우기
def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
            
    return mask
model.eval()
submission = []

for sample in ds_test:
    img = sample['image']
    image_id = sample['image_id']
    with torch.no_grad():
        result = model([img.to(device)])[0]
        
    previous_masks = []
    for i, mask in enumerate(result['masks']):
        score = result['scores'][i].cpu().item()
        if score < min_score:
            continue
            
        # 가장 연관성 높은 mask만 남기기
        mask = mask.cpu().numpy()
        binary_mask = mask > mask_threshold
        binary_mask = remove_overlapping_pixels(binary_mask, previous_masks)
        previous_masks.append(binary_mask)
        rle = rle_encoding(binary_mask)
        submission.append((image_id, rle))
        
    # 이미지에 대해서 rle가 생성되지 않으면 빈 prediction 추가
    all_images_ids = [image_id for image_id, rle in submission]
    if image_id not in all_images_ids:
        submission.append((image_id, ''))
# 제출
df_sub = pd.DataFrame(submission, columns=['id', 'pred'])
df_sub.to_csv('../data/paper_review/maskrcnn_cell/submission.csv', index=False)
df_sub.head()

 

Dataframe 형식으로 만들 필요 없이 예측 결과를 바로 시각화 하고 싶다면 아래 코드를 작성하면 된다.

이 때 mask의 범위는 처음에 변수로 설정해준 min_score 값을 줄이거나 더하면서 최적의 값을 찾아 조절하면 된다.

# min_score 값 조절
min_score = 0.43

def visualized_test(ds_test, sample_index):
    
    # test sample
    img = ds_test[sample_index]['image']
    image_id = ds_test[sample_index]['image_id']
    plt.imshow(img.numpy().transpose((1, 2, 0)))
    plt.title('Test Sample Image')
    plt.show()
    
    # test predict
    model.eval()
    with torch.no_grad():
        result = model([img.to(device)])[0]
        
    plt.imshow(img.numpy().transpose((1, 2, 0)))
    all_preds_masks = np.zeros((HEIGHT, WIDTH))
    
    previous_masks = []
    for i, mask in enumerate(result['masks']):
        score = result['scores'][i].cpu().item()
        if score < min_score:
            continue
            
        # 가장 연관성 높은 mask만 남기기
        mask = mask.cpu().numpy()
        binary_mask = mask > mask_threshold
        binary_mask = remove_overlapping_pixels(binary_mask, previous_masks)
        previous_masks.append(binary_mask)
        
        for mask in previous_masks:
            all_preds_masks = np.logical_or(all_preds_masks, mask[0] > mask_threshold)
        
    plt.imshow(all_preds_masks, alpha=0.4)
    plt.title('Test Pred')
    plt.show()
visualized_test(ds_test, 1)
# visualized_test(ds_test, 2)
# visualized_test(ds_test, 3)

 

3장의 세포 이미지에 대해서 Test 결과는 위와 같이 나왔다.

정확도가 매우 높아보이지는 않지만 캐글 본대회에서 1등의 점수가 0.3인걸 보면

어려운 부분이 있는 것 같다.

어느정도 mask rcnn을 수행했다는 것에 의의를 두었다.

 

 

참고

profile

매직코드

@개발법사

포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!