深度学习(计算数据集均值标准差)

   

深度学习中有些数据集可能不符合imagenet计算出的均值和标准差,需要根据自己的数据集单独计算。

下面这个脚本能够计算当前数据集均值和标准差。 

import torch
import os
from PIL import Image
from torchvision import transforms

# trans = transforms.Compose([
#     transforms.Resize((256, 256)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
# ])

toTensor = transforms.ToTensor()
root = './imgs/'
torch.set_printoptions(precision=10)

def get_file_names(directory):
    file_names = []
    for file_name in os.listdir(directory):
        if os.path.isfile(os.path.join(directory, file_name)):
            file_names.append(file_name)
    return file_names

filenames = get_file_names(root)

mean = torch.zeros(3)
std = torch.zeros(3)
#tensor([0.4526, 0.4316, 0.3995]) tensor([0.2419, 0.2364, 0.2406])

count = 0
for file in filenames:
    imgname = root + file
    image = Image.open(imgname)
    tensor = toTensor(image)
    for c in range(3):
        mean[c] += tensor[c,:,:].mean()
        std[c] += tensor[c,:,:].std()
    count+=1
    print(mean/count,std/count)

         

请登录后发表评论

    没有回复内容