Operasi Wavelet dengan Tensor – Pytorch

By | June 12, 2023
425 Views

Penggunaan wavelet pada image processing banyak digunakan saat ini terutama pada proses ciri fitur untuk mengurangi noise/derau. Salah satu library yang terkenal yaitu pywavelet. PyWavelet mengadopsi operasi matrix berbasis numpy. Namun bila bekerja pada deep learning berbasis tensor alih-alih menggunakan numpy.

Tapi jangan kuatir, saat ini sudah ada yang membuat package pywavelet berbasis tensor yaitu 2D Wavelet Transforms in Pytorch, Kami sudah mencobanya dan memang compliance sekali dengan pywavelet. Namun penggunanya agak sedikit berbeda yaitu mengikuti framework Pytorch dengan format NCHW – N -> ukuran batch, C -> untuk channel, H –> height, dan W –> width.

Hasil PyWavelet

Sebagai bahan uji coba, kita akan membandingkan hasil antar PyWavelet dengan pytorch_wavelets, agar lebih mudah kita buat sebuah array seperti berikut

import numpy as np
lines = np.arange(0,20)
lines2 = np.tile(lines,(20,1))

Sekarang kita akan hitung array tersebut

import pywt
from pywt import wavedec

LL, (LH, HL, HH) = pywt.dwt2(lines2, 'db32',mode='symmetric')

pytorch_wavelets

Mari kita coba dengan pytorch_wavelets namun kita akan sesuaikan dengan format NCHW

import torch
#bikin tensor dari numpy
lines3 = torch.tensor(lines2)
#bikin dummy dengan ukuran batch = 1, channel =1
lines4 = torch.rand([1,1,20,20])
#masukan ke batch ke 1
lines4[0,:,:,:] = lines3

Sekarang kita hitung nilai wavelet nya

from pytorch_wavelets import DWTForward, DWTInverse # (or import DWT, IDWT)
xfm = DWTForward(wave='db32',mode='symmetric')
Yl, Yh = xfm(lines4)

Yang perlu kalian ketahuai yaitu

  • Yl adalah LL dan
  • Yh adalah tuple dengan isi LH, HL, dan HH

Hasil Wavelet antar PyWavelet dan versi torch

Kita bisa ngecek hasil antar LL dan Yl sebagai berikut (nilai nya sama) hanya berbeda koma saja

LL
Out[42]: 
array([[29.54730837, 25.71705718, 21.47321198, ..., 36.95568489,
        33.62611836, 29.54730837],
       [29.54730837, 25.71705718, 21.47321198, ..., 36.95568489,
        33.62611836, 29.54730837],
       [29.54730837, 25.71705718, 21.47321198, ..., 36.95568489,
        33.62611836, 29.54730837],
       ...,
       [29.54730837, 25.71705718, 21.47321198, ..., 36.95568489,
        33.62611836, 29.54730837],
       [29.54730837, 25.71705718, 21.47321198, ..., 36.95568489,
        33.62611836, 29.54730837],
       [29.54730837, 25.71705718, 21.47321198, ..., 36.95568489,
        33.62611836, 29.54730837]])

Yl
Out[43]: 
tensor([[[[29.5473, 25.7171, 21.4732,  ..., 36.9557, 33.6261, 29.5473],
          [29.5473, 25.7171, 21.4732,  ..., 36.9557, 33.6261, 29.5473],
          [29.5473, 25.7171, 21.4732,  ..., 36.9557, 33.6261, 29.5473],
          ...,
          [29.5473, 25.7171, 21.4732,  ..., 36.9557, 33.6261, 29.5473],
          [29.5473, 25.7171, 21.4732,  ..., 36.9557, 33.6261, 29.5473],
          [29.5473, 25.7171, 21.4732,  ..., 36.9557, 33.6261, 29.5473]]]])

Sedangkan untuk akses

  • LH yaitu Yh[0][0][0][0]
  • HL yaitu Yh[0][0][0][1]
  • HH yaitu Yh[0][0][0][2]
See also  Decompose Citra dalam bit-planes

Kalian jangan bingung karena Yh merupakan tuple dengan ukuran sebagai berikut

Yh[0].shape
Out[51]: torch.Size([1, 1, 3, 41, 41])

Sedangkan untuk Invers nya gimana? gampang kok

ifm = DWTInverse(wave='db32',mode='symmetric')
Y = ifm((Yl, Yh))

Mari kita bandingkan antar lines4 dan Y sebagai berikut apakah hasil invers kembali seperti semula?

(lines4==Y).sum()

Makan hasilnya adalah 400 (karena ukuran matrix adalah 20 x 20) yang menandakan tidak ada perbedaan hasil

Bagaimana dengan penerapan di Gambar?

Untuk digunakan pada gambar, langkahnya cukup mudah kok

from PIL import Image
from torchvision import transforms
import PIL
file = 'gambar.jpg'

img = Image.open(file).convert("RGB")

Jangan lupa kita buat tensor dulu

#dibikin tensor
transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor()])

img_tensor = transform(img)

Apakah sudah betul? kita cek gambarnya dalam bentuk tampilan di matplotlib

from matplotlib import pyplot as plt
plt.figure()
plt.imshow(img_tensor.permute(1,2,0),cmap='gray')
plt.show()

Kita lakukan operasi wavelet yang terlebih dahulu buat format NCHW

#blank tensor
blank_tensor = torch.rand([1,
                           img_tensor.size(0),
                           img_tensor.size(1),
                           img_tensor.size(2)])

blank_tensor[0,:,:,:] = img_tensor

Yuk kita olah saja

xfm = DWTForward(wave='db32',mode='symmetric')
ifm = DWTInverse(wave='db32',mode='symmetric')
Yl, Yh = xfm(blank_tensor)

titles = ['Approximation', ' Horizontal detail',
          'Vertical detail', 'Diagonal detail']



H = [Y[0].permute(1,2,0),
     Yh[0][0][0][0],
     Yh[0][0][0][1],
     Yh[0][0][0][2]
     ]




fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(20, 15))
for i, a in enumerate(H):
    ax[i].imshow(a, cmap=plt.cm.gray)
    ax[i].set_title(titles[i], fontsize=10)
    ax[i].set_xticks([])
    ax[i].set_yticks([])

fig.tight_layout()
plt.show()

 

Untuk inversnya seperti sebelumnya yaitu gunakan kode berikut

#invers
Y = ifm((Yl, Yh))

plt.figure()
plt.imshow(Y[0].permute(1,2,0),cmap='gray')
plt.xticks([])
plt.yticks([])
fig.tight_layout()
plt.show()