from fastapi import APIRouter, File, UploadFile
from fastapi.responses import Response
from PIL import Image
import io
import torch
from torchvision import models, transforms
import numpy as np
router = APIRouter()
print("✂️ 모델(DeepLabV3)을 로딩 중입니다...")
# 모델의 최적의 가중치를 가져옵니다
weights = models.DeepLabV3_ResNet50_Weights.DEFAULT
# 최적의 가중치를 가진 모델을 탑제합니다
model = models.deeplabv3_resnet50(weights=weights)
# 모델을 실전모드로 전환해라 문제를 틀리지말고 정석대로 진행해라
model.eval()
# 입력할 이미지를 어떻게 변환해야할지 규칙을 정한다
preprocess = weights.transforms()
# 서로 다른 색깔을 랜덤하게 배정하는 함수
def get_palette(num_classes=21):
palette = torch.tensor([2 ** 25 -1, 2 ** 15 -1, 2 ** 21 - 1])
colors = torch.arange(num_classes)[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")
return colors
@router.post("/segment")
async def segment_image(file: UploadFile = File(...)):
# 파일을 읽고 형태를 갖추는작업
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# 이미지를 규칙대로 자른뒤 이미지 1장을 못받으니 배치차원을 하나 추가하여 1장이라고 표시함
input_tensor = preprocess(image).unsqueeze(0)
# 모델학습도중에 저장하지 않을것
with torch.no_grad():
output = model(input_tensor)['out'][0]
# argmax(중요) 21개의 기준으로 각각 점수를 매깁니다 (ex. 배경 10점 사람 90점 강아지 5점) 가장 높은 점수를 가진번호를 뽑습니다 사람 당선
output_predictions = output.argmax(0)
# 결과가 숫자가 큰값으로 나와서 byte로 작은 용량으로 줄이고 cpu로 전환해 gpu사용을 하지말고 (쉬운작업이니) 파이토치 작업에서 numpy포멧으로 변경해라
pred_np = output_predictions.byte().cpu().numpy()
# 숫자로된 정보를 이제 색칠해라
result_image = Image.fromarray(pred_np).convert("P")
result_image.putpalette(get_palette(21))
# 결과 이미지를 사이즈 조정하는것? 학습때 이미지 변형을 했으니 하는거같음
result_image = result_image.resize(image.size)
# 메모리위에 가상의 파일 폴더를 만든다
img_byte_arr = io.BytesIO()
# png로 이미지를 저장하는코드
result_image.save(img_byte_arr, format="PNG")
return Response(content=img_byte_arr.getvalue(), media_type="image/png")