برنامه نویسی

AutoAugment در Pytorch – Community Dev

from torchvision.datasets import OxfordIIITPet
from torchvision.transforms.v2 import AutoAugment
from torchvision.transforms.v2 import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode

aa = AutoAugment()
aa = AutoAugment(policy=AutoAugmentPolicy.IMAGENET,
                 interpolation = InterpolationMode.NEAREST,
                 fill=None)
aa
# AutoAugment(interpolation=InterpolationMode.NEAREST,
#             policy=AutoAugmentPolicy.IMAGENET)

aa.policy
# 

aa.interpolation
# 

print(aa.fill)
# None

origin_data = OxfordIIITPet(
    root="data",
    transform=None
)

pIMAGENET_data = OxfordIIITPet( # `p` is policy.
    root="data",
    transform=AutoAugment(policy=AutoAugmentPolicy.IMAGENET)
)

pCIFAR10_data = OxfordIIITPet(
    root="data",
    transform=AutoAugment(policy=AutoAugmentPolicy.CIFAR10)
)

pSVHN_data = OxfordIIITPet(
    root="data",
    transform=AutoAugment(policy=AutoAugmentPolicy.SVHN)
)

pIMAGENETf150_data = OxfordIIITPet(
    root="data",
    transform=AutoAugment(policy=AutoAugmentPolicy.IMAGENET, fill=150)
)

pIMAGENETf160_32_240_data = OxfordIIITPet(
    root="data",
    transform=AutoAugment(policy=AutoAugmentPolicy.IMAGENET,
                          fill=[160, 32, 240])
)

pCIFAR10f150_data = OxfordIIITPet(
    root="data",
    transform=AutoAugment(policy=AutoAugmentPolicy.CIFAR10, fill=150)
)

pCIFAR10f160_32_240_data = OxfordIIITPet(
    root="data",
    transform=AutoAugment(policy=AutoAugmentPolicy.CIFAR10,
                          fill=[160, 32, 240])
)

pSVHNf150_data = OxfordIIITPet(
    root="data",
    transform=AutoAugment(policy=AutoAugmentPolicy.SVHN, fill=150)
)

pSVHNf160_32_240_data = OxfordIIITPet(
    root="data",
    transform=AutoAugment(policy=AutoAugmentPolicy.SVHN,
                          fill=[160, 32, 240])
)

import matplotlib.pyplot as plt

def show_images1(data, main_title=None):
    plt.figure(figsize=[10, 5])
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    for i, (im, _) in zip(range(1, 6), data):
        plt.subplot(1, 5, i)
        plt.imshow(X=im)
        plt.xticks(ticks=[])
        plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images1(data=origin_data, main_title="origin_data")
print()
show_images1(data=pIMAGENET_data, main_title="pIMAGENET_data")
show_images1(data=pIMAGENET_data, main_title="pIMAGENET_data")
show_images1(data=pIMAGENET_data, main_title="pIMAGENET_data")
print()
show_images1(data=pCIFAR10_data, main_title="pCIFAR10_data")
show_images1(data=pCIFAR10_data, main_title="pCIFAR10_data")
show_images1(data=pCIFAR10_data, main_title="pCIFAR10_data")
print()
show_images1(data=pSVHN_data, main_title="pSVHN_data")
show_images1(data=pSVHN_data, main_title="pSVHN_data")
show_images1(data=pSVHN_data, main_title="pSVHN_data")
print()
show_images1(data=pIMAGENETf150_data, main_title="pIMAGENETf150_data")
show_images1(data=pIMAGENETf160_32_240_data,
             main_title="pIMAGENETf160_32_240_data")
print()
show_images1(data=pCIFAR10f150_data, main_title="pCIFAR10f150_data")
show_images1(data=pCIFAR10f160_32_240_data,
             main_title="pCIFAR10f160_32_240_data")
print()
show_images1(data=pSVHNf150_data, main_title="pSVHNf150_data")
show_images1(data=pSVHNf160_32_240_data,
             main_title="pSVHNf160_32_240_data")

# ↓ ↓ ↓ ↓ ↓ ↓ The code below is identical to the code above. ↓ ↓ ↓ ↓ ↓ ↓
def show_images2(data, main_title=None, p=None,
                 ip=InterpolationMode.NEAREST,
                 f=None):
    plt.figure(figsize=[10, 5])
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    if p != None:
        for i, (im, _) in zip(range(1, 6), data):
            plt.subplot(1, 5, i)
            aa = AutoAugment(policy=p, interpolation=ip, fill=f)
            plt.imshow(X=aa(im))
            plt.xticks(ticks=[])
            plt.yticks(ticks=[])
    else:
        for i, (im, _) in zip(range(1, 6), data):
            plt.subplot(1, 5, i)
            plt.imshow(X=im)
            plt.xticks(ticks=[])
            plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images2(data=origin_data, main_title="origin_data")
print()
show_images2(data=origin_data, main_title="pIMAGENET_data", 
             p=AutoAugmentPolicy.IMAGENET)
show_images2(data=origin_data, main_title="pIMAGENET_data", 
             p=AutoAugmentPolicy.IMAGENET)
show_images2(data=origin_data, main_title="pIMAGENET_data", 
             p=AutoAugmentPolicy.IMAGENET)
print()
show_images2(data=origin_data, main_title="pCIFAR10_data", 
             p=AutoAugmentPolicy.CIFAR10)
show_images2(data=origin_data, main_title="pCIFAR10_data", 
             p=AutoAugmentPolicy.CIFAR10)
show_images2(data=origin_data, main_title="pCIFAR10_data", 
             p=AutoAugmentPolicy.CIFAR10)
print()
show_images2(data=origin_data, main_title="pSVHN_data", 
             p=AutoAugmentPolicy.SVHN)
show_images2(data=origin_data, main_title="pSVHN_data", 
             p=AutoAugmentPolicy.SVHN)
show_images2(data=origin_data, main_title="pSVHN_data", 
             p=AutoAugmentPolicy.SVHN)
print()
show_images2(data=origin_data, main_title="pIMAGENETf150_data", 
             p=AutoAugmentPolicy.IMAGENET, f=150)
show_images2(data=origin_data, main_title="pIMAGENETf160_32_240_data", 
             p=AutoAugmentPolicy.IMAGENET, f=[160, 32, 240])
print()
show_images2(data=origin_data, main_title="pCIFAR10f150_data", 
             p=AutoAugmentPolicy.CIFAR10, f=150)
show_images2(data=origin_data, main_title="pCIFAR10f160_32_240_data", 
             p=AutoAugmentPolicy.CIFAR10, f=[160, 32, 240])
print()
show_images2(data=origin_data, main_title="pSVHNf150_data", 
             p=AutoAugmentPolicy.SVHN, f=150)
show_images2(data=origin_data, main_title="pSVHNf160_32_240_data", 
             p=AutoAugmentPolicy.SVHN, f=[160, 32, 240])
حالت تمام صفحه را وارد کنید

از حالت تمام صفحه خارج شوید

نوشته های مشابه

دیدگاهتان را بنویسید

نشانی ایمیل شما منتشر نخواهد شد. بخش‌های موردنیاز علامت‌گذاری شده‌اند *

دکمه بازگشت به بالا