Commit 213c7f8a authored by 魏博昱's avatar 魏博昱

1

parent b934062a
*.JPG
*.jpg
*.txt
datasets/
checkpoints/
output/
results/
.vscode/
log/
logs/
*.swp
*.pth
*.pyc
.idea/
*-checkpoint.py
*.ipynb_checkpoints/
masks/
resized_paris/
fakeB/
shifted/
MIT License
Copyright (c) 2018 Zhaoyi-Yan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
# ShiftNet4LibTorch # New training strategy
I release a new training strategy that helps deal with random mask training by reducing color shifting at the cost of about extra 30% training time. It is quite useful when we perform face inpainiting.
Set `--which_model_netG='face_unet_shift_triple'` and `--model='face_shiftnet'` and `--batchSize=1`to carry out the strategy.
See some examples below, many approaches suffer from such `color shifting` when training with random masks on face datasets.
<table style="float:center">
<tr>
 <th><B>Input</B></th> <th><B> Navie Shift</B></th> <th><B> Flip Shift</B> <th><B>Ground-truth</B></th>
</tr>
<tr>
<td>
<img src='./imgs/compare/13_real_A.png' >
</td>
<td>
<img src='./imgs/compare/13_fake_B.png'>
</td>
<td>
<img src='./imgs/compare/13_fake_B_flip.png'>
</td>
<td>
<img src='./imgs/compare/13_real_B.png'>
</td>
</tr>
<tr>
<td>
<img src='./imgs/compare/18_real_A.png' >
</td>
<td>
<img src='./imgs/compare/18_fake_B.png'>
</td>
<td>
<img src='./imgs/compare/18_fake_B_flip.png'>
</td>
<td>
<img src='./imgs/compare/18_real_B.png'>
</td>
</tr>
</table>
Note: When you use `face_flip training strategy`, it suffers some minor drawbacks:
1. It is not fully-parallel compared with original shift.
2. It can only be trained on the 'cpu' or on a single gpu, the batch size must be 1, or it occurs an error.
If you want to conquer these drawbacks, you can optimize it by referring to original shift. It is not difficult, however, I do not have time to do it.
# Architecutre
<img src="architecture.png" width="1000"/>
# Shift layer
<img src="shift_layer.png" width="800"/>
## Prerequisites
- Linux or Windows.
- Python 2 or Python 3.
- CPU or NVIDIA GPU + CUDA CuDNN.
- Tested on pytorch >= **1.2**
## Getting Started
### Installation
- Install PyTorch and dependencies from http://pytorch.org/
- Install python libraries [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate).
```bash
pip install visdom
pip install dominate
```
- Clone this repo:
```bash
git clone https://github.com/Zhaoyi-Yan/Shift-Net_pytorch
cd Shift-Net_pytorch
```
# Trained models
Usually, I would like to suggest you just pull the latest code and train by following the instructions.
However, for now, several models have been trained and uploaded.
| Mask | Paris | CelebaHQ_256 |
| ---- | ---- | ---- |
| center-mask | ok | ok |
| random mask(from **partial conv**)| ok | ok |
For CelebaHQ_256 dataset:
I select the first 2k images in CelebaHQ_256 for testing, the rest are for training.
```
python train.py --loadSize=256 --batchSize=1 --model='face_shiftnet' --name='celeb256' --which_model_netG='face_unet_shift_triple' --niter=30 --datarooot='./datasets/celeba-256/train'
```
Mention: **`loadSize` should be `256` for face datasets, meaning direct resize the input image to `256x256`.**
The following some results on celebaHQ-256 and Paris.
Specially, for training models of random masks, we adopt the masks of **partial conv**(only the masks of which the ratio of masked region is 20~30% are used.)
<table style="float:center">
<tr>
 <th><B>Input</B></th> <th><B>Results</B></th> <th><B>Ground-truth</B></th>
</tr>
<tr>
<td>
<img src='./imgs/face_center/106_real_A.png' >
</td>
<td>
<img src='./imgs/face_center/106_fake_B.png'>
</td>
<td>
<img src='./imgs/face_center/106_real_B.png'>
</td>
</tr>
<tr>
<td>
<img src='./imgs/face_center/111_real_A.png' >
</td>
<td>
<img src='./imgs/face_center/111_fake_B.png'>
</td>
<td>
<img src='./imgs/face_center/111_real_B.png'>
</td>
</tr>
<tr>
<td>
<img src='./imgs/face_random/0_real_A.png' >
</td>
<td>
<img src='./imgs/face_random/0_fake_B.png'>
</td>
<td>
<img src='./imgs/face_random/0_real_B.png'>
</td>
</tr>
<tr>
<td>
<img src='./imgs/face_random/1_real_A.png' >
</td>
<td>
<img src='./imgs/face_random/1_fake_B.png'>
</td>
<td>
<img src='./imgs/face_random/1_real_B.png'>
</td>
</tr>
<tr>
<td>
<img src='./imgs/paris_center/048_im_real_A.png' >
</td>
<td>
<img src='./imgs/paris_center/048_im_fake_B.png'>
</td>
<td>
<img src='./imgs/paris_center/048_im_real_B.png'>
</td>
</tr>
<tr>
<td>
<img src='./imgs/paris_center/004_im_real_A.png' >
</td>
<td>
<img src='./imgs/paris_center/004_im_fake_B.png'>
</td>
<td>
<img src='./imgs/paris_center/004_im_real_B.png'>
</td>
</tr>
<tr>
<td>
<img src='./imgs/paris_random/006_im_real_A.png' >
</td>
<td>
<img src='./imgs/paris_random/006_im_fake_B.png'>
</td>
<td>
<img src='./imgs/paris_random/006_im_real_B.png'>
</td>
</tr>
<tr>
<td>
<img src='./imgs/paris_random/073_im_real_A.png' >
</td>
<td>
<img src='./imgs/paris_random/073_im_fake_B.png'>
</td>
<td>
<img src='./imgs/paris_random/073_im_real_B.png'>
</td>
</tr>
</table>
For testing, please read the documnent carefully.
Pretrained model for face center inpainting are available:
```bash
bash download_models.sh
```
Rename `face_center_mask.pth` to `30_net_G.pth`, and put it in the folder `./log/face_center_mask_20_30`(if not existed, create it)
```bash
python test.py --which_model_netG='unet_shift_triple' --model='shiftnet' --name='face_center_mask_20_30' --which_epoch=30
```
For face random inpainting, it is trained with `--which_model_netG='face_unet_shift_triple'` and `--model='face_shiftnet'`. Rename `face_flip_random.pth` to `30_net_G.pth` and set `which_model_netG='face_unet_shift_triple'` and `--model='face_shiftnet'` when testing.
Similarity, for paris random inpainting, rename `paris_random_mask_20_30.pth` to `30_net_G.pth`, and put it in the folder `./log/paris_random_mask_20_30`(if not existed, create it)
Then test the model:
```
python test.py --which_epoch=30 --name='paris_random_mask_20_30' --offline_loading_mask=1 --testing_mask_folder='masks' --dataroot='./datasets/celeba-256/test' --norm='instance'
```
Mention, your own masks should be prepared in the folder `testing_mask_folder` in advance.
For other models, I think you know how to evaluate them.
For models trained with center mask, make sure `--mask_type='center' --offline_loading_mask=0`.
## Train models
- Download your own inpainting datasets.
- Train a model:
Please read this paragraph carefully before running the code.
Usually, we train/test `navie shift-net` with `center` mask.
```bash
python train.py --batchsize=1 --use_spectral_norm_D=1 --which_model_netD='basic' --mask_type='center' --which_model_netG='unet_shift_triple' --model='shiftnet' --shift_sz=1 --mask_thred=1
```
For some datasets, such as `CelebA`, some images are smaller than `256*256`, so you need add `--loadSize=256` when training, **it is important**.
- To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. The checkpoints will be saved in `./log` by default.
**DO NOT** set batchsize larger than 1 for `square` mask training, the performance degrades a lot(I don't know why...)
For `random mask`(`mask_sub_type` is NOT `rect` or your own random masks), the training batchsize can be larger than 1 without hurt of performance.
Random mask training(both online and offline) are also supported.
Personally, I would like to suggest you to loading the masks offline(similar as **partial conv**). Please refer to section **Masks**.
## Test the model
**Keep the same settings as those during training phase to avoid errors or bad performance**
For example, if you train `patch soft shift-net`, then the following testing command is appropriate.
```bash
python test.py --fuse=1/0 --which_model_netG='patch_soft_unet_shift_triple' --model='patch_soft_shiftnet' --shift_sz=3 --mask_thred=4
```
The test results will be saved to a html file here: `./results/`.
## Masks
Usually, **Keep the same setting of masks of between training and testing.**
It is because the performance is highly-related to the masks your applied in training.
The consistency of training and testing masks are crucial to get good performance.
| training | testing |
| ---- | ---- |
| center-mask | center-mask |
| random-square| All |
| random | All|
| your own masks| your own masks|
It means that if you training a model with `center-mask`, then test it using `center-mask`(even without one pixel offset). For more info, you may refer to https://github.com/Zhaoyi-Yan/Shift-Net_pytorch/issues/125
### Training by online-generating marks
We offer three types of online-generating masks: `center-mask, random_square and random_mask`.
If you want to train on your own masks silimar like **partial conv**, ref to **Training on your own masks**.
### Training on your own masks
It now supports both online-generating and offline-loading for training and testing.
We generate masks online by default, however, set `--offline_loading_mask=1` when you want to train/test with your own prepared masks.
**The prepared masks should be put in the folder `--training_mask_folder` and `--testing_mask_folder`.**
### Masks when training
For each batch, then:
- Generating online: masks are the same for each image in a batch.(To save computation)
- Loading offline: masks are loaded randomly for each image in a batch.
## Using Switchable Norm instead of Instance/Batch Norm
For fixed mask training, `Switchable Norm` delivers better stableness when batchSize > 1. **Please use switchable norm when you want to training with batchsize is large, much more stable than instance norm or batchnorm!**
### Extra variants
**These 3 models are just for fun**
For `res patch soft shift-net`:
```bash
python train.py --batchSize=1 --which_model_netG='res_patch_soft_unet_shift_triple' --model='res_patch_soft_shiftnet' --shift_sz=3 --mask_thred=4
```
For `res navie shift-net`:
```bash
python train.py --which_model_netG='res_unet_shift_triple' --model='res_shiftnet' --shift_sz=1 --mask_thred=1
```
For `patch soft shift-net`:
```bash
python train.py --which_model_netG='patch_soft_unet_shift_triple' --model='patch_soft_shiftnet' --shift_sz=3 --mask_thred=4
```
DO NOT change the shift_sz and mask_thred. Otherwise, it errors with a high probability.
For `patch soft shift-net` or `res patch soft shift-net`. You may set `fuse=1` to see whether it delivers better results(Mention, you need keep the same setting between training and testing).
## New things that I want to add
- [x] Make U-Net handle with inputs of any sizes.
- [x] Add more GANs, like spectural norm and relativelistic GAN.
- [x] Boost the efficiency of shift layer.
- [x] Directly resize the global_mask to get the mask in feature space.
- [x] Visualization of flow. It is still experimental now.
- [x] Extensions of Shift-Net. Still active in absorbing new features.
- [x] Fix bug in guidance loss when adopting it in multi-gpu.
- [x] Add composit L1 loss between mask loss and non-mask loss.
- [x] Finish optimizing soft-shift.
- [x] Add mask varaint in a batch.
- [x] Support Online-generating/Offline-loading prepared masks for training/testing.
- [x] Add VGG loss and TV loss
- [x] Fix performance degradance when batchsize is larger than 1.
- [x] Make it compatible for Pytorch 1.2
- [ ] Training with mixed type of masks.
- [ ] Try amp training
- [ ] Try self-attn discriminator(maybe it helps)
## Somethings extra I have tried
**Gated Conv**: I have tried gated conv(by replacing the normal convs of UNet with gated conv, expect the innermost/outermost layer).
However, I obtained no benifits. Maybe I should try replacing all layers with gated conv. I will try again when I am free.
**Non local block**: I added, but seems worse. Maybe I haven't added the blocks on the proper postion. (It makes the training time increase a lot. So I am not in favor of it.)
## Citation
If you find this work useful or gives you some insights, please cite:
```
@InProceedings{Yan_2018_Shift,
author = {Yan, Zhaoyi and Li, Xiaoming and Li, Mu and Zuo, Wangmeng and Shan, Shiguang},
title = {Shift-Net: Image Inpainting via Deep Feature Rearrangement},
booktitle = {The European Conference on Computer Vision (ECCV)},
month = {September},
year = {2018}
}
```
## Acknowledgments
We benefit a lot from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
#-*-coding:utf-8-*-
import os.path
import random
import torchvision.transforms as transforms
import torch
import random
from data.base_dataset import BaseDataset
from data.image_folder import make_dataset
from PIL import Image
class AlignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.dir_A = opt.dataroot
self.A_paths = sorted(make_dataset(self.dir_A))
if self.opt.offline_loading_mask:
self.mask_folder = self.opt.training_mask_folder if self.opt.isTrain else self.opt.testing_mask_folder
self.mask_paths = sorted(make_dataset(self.mask_folder))
assert(opt.resize_or_crop == 'resize_and_crop')
transform_list = [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
self.transform = transforms.Compose(transform_list)
def __getitem__(self, index):
A_path = self.A_paths[index]
A = Image.open(A_path).convert('RGB')
w, h = A.size
##只切割###3
A = self.transform(A)
nw = int(w / self.opt.fineSize * 2)
nh = int(h / self.opt.fineSize * 2)
nw0 = int(w % self.opt.fineSize)
nw0 = int(h % self.opt.fineSize)
step = int(self.opt.fineSize / 2)
A_temp = torch.FloatTensor(nw * nh, 3, self.opt.fineSize, self.opt.fineSize).zero_()
for iw in range(nw):
for ih in range(nh):
if iw == nw-1 and ih == nh-1:
A_temp[iw * nh + ih, :, :, :] = A[:, w - self.opt.fineSize:w, h-self.opt.fineSize:h]
continue
if iw == nw-1 and ih != nh-1:
A_temp[iw * nh + ih, :, :, :] = A[:, w- self.opt.fineSize:w,ih * step:ih * step + self.opt.fineSize]
continue
if iw != nw-1 and ih == nh-1:
A_temp[iw * nh + ih, :, :, :] = A[:, iw * step:iw * step+self.opt.fineSize, h-self.opt.fineSize:h]
continue
A_temp[iw * nh + ih, :, :, :] = A[:, iw * step:iw * step+self.opt.fineSize, ih * step:ih * step+self.opt.fineSize]
A = A_temp
###end####
"重置大小,切割图像 bg"
'''
if w < h:
ht_1 = self.opt.loadSize * h // w
wd_1 = self.opt.loadSize
A = A.resize((wd_1, ht_1), Image.BICUBIC)
else:
wd_1 = self.opt.loadSize * w // h
ht_1 = self.opt.loadSize
A = A.resize((wd_1, ht_1), Image.BICUBIC)
A = self.transform(A)
h = A.size(1)
w = A.size(2)
w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
A = A[:, h_offset:h_offset + self.opt.fineSize,
w_offset:w_offset + self.opt.fineSize]
'''
"重置大小,切割图像 end"
if (not self.opt.no_flip) and random.random() < 0.5:
A = torch.flip(A, [2])
# let B directly equals to A
B = A.clone()
A_flip = torch.flip(A, [2])
B_flip = A_flip.clone()
# Just zero the mask is fine if not offline_loading_mask.
mask = A.clone().zero_()
if self.opt.offline_loading_mask:
if self.opt.isTrain:
mask = Image.open(self.mask_paths[random.randint(0, len(self.mask_paths)-1)])
else:
mask = Image.open(self.mask_paths[index % len(self.mask_paths)])
mask = mask.resize((self.opt.fineSize, self.opt.fineSize), Image.NEAREST)
mask = transforms.ToTensor()(mask)
return {'A': A, 'B': B, 'A_F': A_flip, 'B_F': B_flip, 'M': mask,
'A_paths': A_path, 'im_size': [w, h]}
def __len__(self):
return len(self.A_paths)
def name(self):
return 'AlignedDataset'
#-*-coding:utf-8-*-
import os.path
import random
import torchvision.transforms as transforms
import torch
from data.base_dataset import BaseDataset
from data.image_folder import make_dataset
from PIL import Image
class AlignedDatasetResized(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_A = opt.dataroot # More Flexible for users
self.A_paths = sorted(make_dataset(self.dir_A))
assert(opt.resize_or_crop == 'resize_and_crop')
transform_list = [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
self.transform = transforms.Compose(transform_list)
def __getitem__(self, index):
A_path = self.A_paths[index]
A = Image.open(A_path).convert('RGB')
A = A.resize ((self.opt.fineSize, self.opt.fineSize), Image.BICUBIC)
A = self.transform(A)
#if (not self.opt.no_flip) and random.random() < 0.5:
# idx = [i for i in range(A.size(2) - 1, -1, -1)] # size(2)-1, size(2)-2, ... , 0
# idx = torch.LongTensor(idx)
# A = A.index_select(2, idx)
# let B directly equals A
B = A.clone()
return {'A': A, 'B': B,
'A_paths': A_path}
def __len__(self):
return len(self.A_paths)
def name(self):
return 'AlignedDatasetResized'
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
pass
def load_data():
return None
import torch.utils.data as data
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
#-*-coding:utf-8-*-
import torch.utils.data
from data.base_data_loader import BaseDataLoader
def CreateDataset(opt):
dataset = None
if opt.dataset_mode == 'aligned':
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
elif opt.dataset_mode == 'aligned_resized':
from data.aligned_dataset_resized import AlignedDatasetResized
dataset = AlignedDatasetResized()
elif opt.dataset_mode == 'single':
from data.single_dataset import SingleDataset
dataset = SingleDataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i*self.opt.batchSize >= self.opt.max_dataset_size:
break
yield data
\ No newline at end of file
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
import os.path
import torchvision.transforms as transforms
from data.base_dataset import BaseDataset
from data.image_folder import make_dataset
from PIL import Image
class SingleDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_A = os.path.join(opt.dataroot)
# make_dataset returns paths of all images in one folder
self.A_paths = make_dataset(self.dir_A)
self.A_paths = sorted(self.A_paths)
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
transform_list.append(transforms.Scale(opt.loadSize))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
if opt.resize_or_crop != 'no_resize':
transform_list.append(transforms.RandomCrop(opt.fineSize))
# Make it between [-1, 1], beacuse [(0-0.5)/0.5, (1-0.5)/0.5]
transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
self.transform = transforms.Compose(transform_list)
def __getitem__(self, index):
A_path = self.A_paths[index]
A_img = Image.open(A_path).convert('RGB')
A = self.transform(A_img)
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
else:
input_nc = self.opt.input_nc
return {'A': A, 'A_paths': A_path}
def __len__(self):
return len(self.A_paths)
def name(self):
return 'SingleImageDataset'
# face model (Trained on CelebaHQ-256, the first 2k images are for testing, the rest are for training.)
wget -c https://drive.google.com/open?id=1qvsWHVO9iXpEAPtwyRB25mklTmD0jgPV
# face random mask model
wget -c https://drive.google.com/open?id=1Pz9gkm2VYaEK3qMXnszJufsvRqbcXrjS
# paris random mask model
wget -c https://drive.google.com/open?id=14MzixaqYUdJNL5xGdVhSKI9jOfvGdr3M
# paris center mask model
wget -c https://drive.google.com/open?id=1nDkCdsqUdiEXfSjZ_P915gWeZELK0fo_
import torch
# import numpy as np
from options.train_options import TrainOptions
import util.util as util
import os
from PIL import Image
import glob
mask_folder = 'masks/testing_masks'
test_folder = './datasets/Paris/test'
util.mkdir(mask_folder)
opt = TrainOptions().parse()
f = glob.glob(test_folder+'/*.png')
print(f)
for fl in f:
mask = torch.zeros(opt.fineSize, opt.fineSize)
if opt.mask_sub_type == 'fractal':
assert 1==2, "It is broken now..."
mask = util.create_walking_mask() # create an initial random mask.
elif opt.mask_sub_type == 'rect':
mask, rand_t, rand_l = util.create_rand_mask(opt)
elif opt.mask_sub_type == 'island':
mask = util.wrapper_gmask(opt)
print('Generating mask for test image: '+os.path.basename(fl))
util.save_image(mask.squeeze().numpy()*255, os.path.join(mask_folder, os.path.splitext(os.path.basename(fl))[0]+'_mask.png'))
def create_model(opt):
model = None
print(opt.model)
if opt.model == 'shiftnet':
assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
from models.shift_net.shiftnet_model import ShiftNetModel
model = ShiftNetModel()
'''
elif opt.model == 'res_shiftnet':
assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
from models.res_shift_net.shiftnet_model import ResShiftNetModel
model = ResShiftNetModel()
elif opt.model == 'patch_soft_shiftnet':
assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
from models.patch_soft_shift.patch_soft_shiftnet_model import PatchSoftShiftNetModel
model = PatchSoftShiftNetModel()
elif opt.model == 'res_patch_soft_shiftnet':
assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
from models.res_patch_soft_shift.res_patch_soft_shiftnet_model import ResPatchSoftShiftNetModel
model = ResPatchSoftShiftNetModel()
else:
raise ValueError("Model [%s] not recognized." % opt.model)
'''
model.initialize(opt)
print("model [%s] was created" % (model.name()))
return model
from .discrimators import *
from .losses import *
from .modules import *
from .shift_unet import *
from .unet import *
\ No newline at end of file
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from .modules import *
import torch.utils.model_zoo as model_zoo
from collections import OrderedDict
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}
def densenet121(pretrained=False, use_spectral_norm=True, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), use_spectral_norm=use_spectral_norm,
**kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet121'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict, strict=False)
return model
def densenet169(pretrained=False, **kwargs):
r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
**kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet169'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model
def densenet201(pretrained=False, **kwargs):
r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
**kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet201'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model
def densenet161(pretrained=False, **kwargs):
r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
**kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet161'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model
class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, use_spectral_norm):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU()),
self.add_module('conv1', spectral_norm(nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1, bias=False), use_spectral_norm)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU()),
self.add_module('conv2', spectral_norm(nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False), use_spectral_norm)),
self.drop_rate = drop_rate
def forward(self, x):
new_features = super(_DenseLayer, self).forward(x)
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return torch.cat([x, new_features], 1)
class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, use_spectral_norm):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, use_spectral_norm)
self.add_module('denselayer%d' % (i + 1), layer)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features, use_spectral_norm):
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU())
self.add_module('conv', spectral_norm(nn.Conv2d(num_input_features, num_output_features,
kernel_size=1, stride=1, bias=False), use_spectral_norm))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
class DenseNet(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), use_spectral_norm=True,
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
super(DenseNet, self).__init__()
# First convolution
self.features = nn.Sequential(OrderedDict([
('conv0', spectral_norm(nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False), use_spectral_norm)),
('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU()),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]))
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, use_spectral_norm=use_spectral_norm)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, use_spectral_norm=use_spectral_norm)
self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
self.conv_last = spectral_norm(nn.Conv2d(num_features, 256, kernel_size=3), use_spectral_norm)
# Linear layer
# self.classifier = nn.Linear(num_features, num_classes)
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
features = self.features(x)
features = self.conv_last(features)
return features
import functools
import torch.nn as nn
from .denset_net import *
from .modules import *
################################### This is for D ###################################
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_spectral_norm=True):
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [
spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), use_spectral_norm),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias), use_spectral_norm),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias), use_spectral_norm),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw), use_spectral_norm)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
# Defines a densetnet inspired discriminator (Should improve its ability to create stronger representation)
class DenseNetDiscrimator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_spectral_norm=True):
super(DenseNetDiscrimator, self).__init__()
self.model = densenet121(pretrained=True, use_spectral_norm=use_spectral_norm)
self.use_sigmoid = use_sigmoid
if self.use_sigmoid:
self.sigmoid = nn.Sigmoid()
def forward(self, input):
if self.use_sigmoid:
return self.sigmoid(self.model(input))
else:
return self.model(input)
import torch
import torch.nn as nn
import numpy as np
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, gan_type='wgan_gp', target_real_label=1.0, target_fake_label=0.0):
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_type = gan_type
if gan_type == 'wgan_gp':
self.loss = nn.MSELoss()
elif gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif gan_type == 'vanilla':
self.loss = nn.BCELoss()
#######################################################################
### Relativistic GAN - https://github.com/AlexiaJM/RelativisticGAN ###
#######################################################################
# When Using `BCEWithLogitsLoss()`, remove the sigmoid layer in D.
elif gan_type == 're_s_gan':
self.loss = nn.BCEWithLogitsLoss()
elif gan_type == 're_avg_gan':
self.loss = nn.BCEWithLogitsLoss()
else:
raise ValueError("GAN type [%s] not recognized." % gan_type)
def get_target_tensor(self, prediction, target_is_real):
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
def __call__(self, prediction, target_is_real):
if self.gan_type == 'wgan_gp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
else:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
return loss
################# Discounting loss #########################
######################################################
class Discounted_L1(nn.Module):
def __init__(self, opt):
super(Discounted_L1, self).__init__()
# Register discounting template as a buffer
self.register_buffer('discounting_mask', torch.tensor(spatial_discounting_mask(opt.fineSize//2 - opt.overlap * 2, opt.fineSize//2 - opt.overlap * 2, 0.9, opt.discounting)))
self.L1 = nn.L1Loss()
def forward(self, input, target):
self._assert_no_grad(target)
input_tmp = input * self.discounting_mask
target_tmp = target * self.discounting_mask
return self.L1(input_tmp, target_tmp)
def _assert_no_grad(self, variable):
assert not variable.requires_grad, \
"nn criterions don't compute the gradient w.r.t. targets - please " \
"mark these variables as volatile or not requiring gradients"
def spatial_discounting_mask(mask_width, mask_height, discounting_gamma, discounting=1):
"""Generate spatial discounting mask constant.
Spatial discounting mask is first introduced in publication:
Generative Image Inpainting with Contextual Attention, Yu et al.
Returns:
tf.Tensor: spatial discounting mask
"""
gamma = discounting_gamma
shape = [1, 1, mask_width, mask_height]
if discounting:
print('Use spatial discounting l1 loss.')
mask_values = np.ones((mask_width, mask_height), dtype='float32')
for i in range(mask_width):
for j in range(mask_height):
mask_values[i, j] = max(
gamma**min(i, mask_width-i),
gamma**min(j, mask_height-j))
mask_values = np.expand_dims(mask_values, 0)
mask_values = np.expand_dims(mask_values, 1)
mask_values = mask_values
else:
mask_values = np.ones(shape, dtype='float32')
return mask_values
class TVLoss(nn.Module):
def __init__(self, tv_loss_weight=1):
super(TVLoss, self).__init__()
self.tv_loss_weight = tv_loss_weight
def forward(self, x):
bz, _, h, w = x.size()
count_h = self._tensor_size(x[:, :, 1:, :])
count_w = self._tensor_size(x[:, :, :, 1:])
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h - 1, :]), 2).sum()
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w - 1]), 2).sum()
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / bz
@staticmethod
def _tensor_size(t):
return t.size(1) * t.size(2) * t.size(3)
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
class Self_Attn (nn.Module):
""" Self attention Layer"""
'''
https://github.com/heykeetae/Self-Attention-GAN/blob/master/sagan_models.py
'''
def __init__(self, in_dim, activation, with_attention=False):
super (Self_Attn, self).__init__ ()
self.chanel_in = in_dim
self.activation = activation
self.with_attention = with_attention
self.query_conv = nn.Conv2d (in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.key_conv = nn.Conv2d (in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.value_conv = nn.Conv2d (in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter (torch.zeros (1))
self.softmax = nn.Softmax (dim=-1) #
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
m_batchsize, C, width, height = x.size ()
proj_query = self.query_conv (x).view (m_batchsize, -1, width * height).permute (0, 2, 1) # B X CX(N)
proj_key = self.key_conv (x).view (m_batchsize, -1, width * height) # B X C x (*W*H)
energy = torch.bmm (proj_query, proj_key) # transpose check
attention = self.softmax (energy) # BX (N) X (N)
proj_value = self.value_conv (x).view (m_batchsize, -1, width * height) # B X C X N
out = torch.bmm (proj_value, attention.permute (0, 2, 1))
out = out.view (m_batchsize, C, width, height)
out = self.gamma * out + x
if self.with_attention:
return out, attention
else:
return out
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
def spectral_norm(module, mode=True):
if mode:
return nn.utils.spectral_norm(module)
return module
class SwitchNorm2d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.9, using_moving_average=True, using_bn=True,
last_gamma=False):
super(SwitchNorm2d, self).__init__()
self.eps = eps
self.momentum = momentum
self.using_moving_average = using_moving_average
self.using_bn = using_bn
self.last_gamma = last_gamma
self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))
if self.using_bn:
self.mean_weight = nn.Parameter(torch.ones(3))
self.var_weight = nn.Parameter(torch.ones(3))
else:
self.mean_weight = nn.Parameter(torch.ones(2))
self.var_weight = nn.Parameter(torch.ones(2))
if self.using_bn:
self.register_buffer('running_mean', torch.zeros(1, num_features, 1))
self.register_buffer('running_var', torch.zeros(1, num_features, 1))
self.reset_parameters()
def reset_parameters(self):
if self.using_bn:
self.running_mean.zero_()
self.running_var.zero_()
if self.last_gamma:
self.weight.data.fill_(0)
else:
self.weight.data.fill_(1)
self.bias.data.zero_()
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
def forward(self, x):
self._check_input_dim(x)
N, C, H, W = x.size()
x = x.view(N, C, -1)
mean_in = x.mean(-1, keepdim=True)
var_in = x.var(-1, keepdim=True)
mean_ln = mean_in.mean(1, keepdim=True)
temp = var_in + mean_in ** 2
var_ln = temp.mean(1, keepdim=True) - mean_ln ** 2
if self.using_bn:
if self.training:
mean_bn = mean_in.mean(0, keepdim=True)
var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2
if self.using_moving_average:
self.running_mean.mul_(self.momentum)
self.running_mean.add_((1 - self.momentum) * mean_bn.data)
self.running_var.mul_(self.momentum)
self.running_var.add_((1 - self.momentum) * var_bn.data)
else:
self.running_mean.add_(mean_bn.data)
self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
else:
mean_bn = torch.autograd.Variable(self.running_mean)
var_bn = torch.autograd.Variable(self.running_var)
softmax = nn.Softmax(0)
mean_weight = softmax(self.mean_weight)
var_weight = softmax(self.var_weight)
if self.using_bn:
mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn
else:
mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln
var = var_weight[0] * var_in + var_weight[1] * var_ln
x = (x-mean) / (var+self.eps).sqrt()
x = x.view(N, C, H, W)
return x * self.weight + self.bias
class PartialConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(PartialConv).__init__()
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias)
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, False)
#self.input_conv.apply(weights_init('kaiming'))
torch.nn.init.constant_(self.mask_conv.weight, 1.0)
# mask is not updated
for param in self.mask_conv.parameters():
param.requires_grad = False
def forward(self, input, mask):
output = self.input_conv(input * mask)
if self.input_conv.bias is not None:
output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(
output)
else:
output_bias = torch.zeros_like(output)
with torch.no_grad():
output_mask = self.mask_conv(mask)
no_update_holes = output_mask == 0
mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)
output_pre = (output - output_bias) / mask_sum + output_bias
output = output_pre.masked_fill_(no_update_holes, 0.0)
new_mask = torch.ones_like(output)
new_mask = new_mask.masked_fill_(no_update_holes, 0.0)
return output, new_mask
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
import torch
import torch.nn as nn
import torch.nn.functional as F
# For original shift
from models.shift_net.InnerShiftTriple import InnerShiftTriple
from models.shift_net.InnerCos import InnerCos
# for face shift
#from models.face_shift_net.InnerFaceShiftTriple import InnerFaceShiftTriple
# For res shift
from models.res_shift_net.innerResShiftTriple import InnerResShiftTriple
# For patch patch shift
from models.patch_soft_shift.innerPatchSoftShiftTriple import InnerPatchSoftShiftTriple
# For res patch patch shift
from models.res_patch_soft_shift.innerResPatchSoftShiftTriple import InnerResPatchSoftShiftTriple
from .unet import UnetSkipConnectionBlock
from .modules import *
################################### *************************** #####################################
################################### Shift_net #####################################
################################### *************************** #####################################
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGeneratorShiftTriple(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, opt, innerCos_list, shift_list, mask_global, ngf=64,
norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
super(UnetGeneratorShiftTriple, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
innermost=True, use_spectral_norm=use_spectral_norm)
print(unet_block)
for i in range(num_downs - 5): # The innner layers number is 3 (sptial size:512*512), if unet_256.
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_shift_block = UnetSkipConnectionShiftBlock(ngf * 2, ngf * 4, opt, innerCos_list, shift_list,
mask_global, input_nc=None, \
submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm, layer_to_last=3) # passing in unet_shift_block
'''
unet_block = UnetSkipConnectionBlock(ngf*2, ngf * 4, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) ###此层代替上面shift测试
'''
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_shift_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
self.model = unet_block
def forward(self, input):
return self.model(input)
# Mention: the TripleBlock differs in `upconv` defination.
# 'cos' means that we add a `innerCos` layer in the block.
class UnetSkipConnectionShiftBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, opt, innerCos_list, shift_list, mask_global, input_nc, \
submodule=None, shift_layer=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d,
use_spectral_norm=False, layer_to_last=3):
super(UnetSkipConnectionShiftBlock, self).__init__()
self.outermost = outermost
if input_nc is None:
input_nc = outer_nc
downconv = spectral_norm(nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1), use_spectral_norm)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
device = 'cpu' if len(opt.gpu_ids) == 0 else 'gpu'
# As the downconv layer is outer_nc in and inner_nc out.
# So the shift define like this:
shift = InnerShiftTriple(opt.shift_sz, opt.stride, opt.mask_thred,
opt.triple_weight, layer_to_last=layer_to_last, device=device)
shift.set_mask(mask_global)
shift_list.append(shift)
# Add latent constraint
# Then add the constraint to the constrain layer list!
innerCos = InnerCos(strength=opt.strength, skip=opt.skip, layer_to_last=layer_to_last, device=device)
innerCos.set_mask(mask_global) # Here we need to set mask for innerCos layer too.
innerCos_list.append(innerCos)
# Different position only has differences in `upconv`
# for the outermost, the special is `tanh`
'''
if outermost:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
# for the innermost, the special is `inner_nc` instead of `inner_nc*2`
elif innermost:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downrelu, downconv] # for the innermost, no submodule, and delete the bn
up = [uprelu, upconv, upnorm]
model = down + up
# else, the normal
else:
'''
# shift triple differs in here. It is `*3` not `*2`.
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 3, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downrelu, downconv, downnorm]
# shift should be placed after uprelu
# NB: innerCos is placed before shift. So need to add the latent gredient to
# to former part.
up = [uprelu, innerCos, shift, upconv, upnorm]
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost: # if it is the outermost, directly pass the input in.
return self.model(x)
else:
x_latter = self.model(x)
_, _, h, w = x.size()
if h != x_latter.size(2) or w != x_latter.size(3):
x_latter = F.interpolate(x_latter, (h, w), mode='bilinear')
return torch.cat([x_latter, x], 1) # cat in the C channel
################################### *************************** #####################################
################################### Face Shift_net #####################################
################################### *************************** #####################################
class FaceUnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, innerCos_list, shift_list, mask_global, opt, ngf=64,
norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
super(FaceUnetGenerator, self).__init__()
# Encoder layers
self.e1_c = spectral_norm(nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e2_c = spectral_norm(nn.Conv2d(ngf, ngf*2, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e2_norm = norm_layer(ngf*2)
self.e3_c = spectral_norm(nn.Conv2d(ngf*2, ngf*4, kernel_size=6, stride=2, padding=2), use_spectral_norm)
self.e3_norm = norm_layer(ngf*4)
self.e4_c = spectral_norm(nn.Conv2d(ngf*4, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e4_norm = norm_layer(ngf*8)
self.e5_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e5_norm = norm_layer(ngf*8)
self.e6_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e6_norm = norm_layer(ngf*8)
self.e7_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e7_norm = norm_layer(ngf*8)
self.e8_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
# Deocder layers
self.d1_dc = spectral_norm(nn.ConvTranspose2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d1_norm = norm_layer(ngf*8)
self.d2_dc = spectral_norm(nn.ConvTranspose2d(ngf*8*2 , ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d2_norm = norm_layer(ngf*8)
self.d3_dc = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d3_norm = norm_layer(ngf*8)
self.d4_dc = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d4_norm = norm_layer(ngf*8)
self.d5_dc = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*4, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d5_norm = norm_layer(ngf*4)
# shift before this layer
self.d6_dc = spectral_norm(nn.ConvTranspose2d(ngf*4*3, ngf*2, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d6_norm = norm_layer(ngf*2)
self.d7_dc = spectral_norm(nn.ConvTranspose2d(ngf*2*2, ngf, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d7_norm = norm_layer(ngf)
self.d8_dc = spectral_norm(nn.ConvTranspose2d(ngf*2, output_nc, kernel_size=4, stride=2, padding=1), use_spectral_norm)
# construct shift and innerCos
device = 'cpu' if len(opt.gpu_ids) == 0 else 'gpu'
self.shift = InnerFaceShiftTriple(opt.shift_sz, opt.stride, opt.mask_thred,
opt.triple_weight, layer_to_last=3, device=device)
self.shift.set_mask(mask_global)
shift_list.append(self.shift)
self.innerCos = InnerCos(strength=opt.strength, skip=opt.skip, layer_to_last=3, device=device)
self.innerCos.set_mask(mask_global) # Here we need to set mask for innerCos layer too.
innerCos_list.append(self.innerCos)
# In this case, we have very flexible unet construction mode.
def forward(self, input, flip_feat=None):
# Encoder
# No norm on the first layer
e1 = self.e1_c(input)
e2 = self.e2_norm(self.e2_c(F.leaky_relu_(e1, negative_slope=0.2)))
e3 = self.e3_norm(self.e3_c(F.leaky_relu_(e2, negative_slope=0.2)))
e4 = self.e4_norm(self.e4_c(F.leaky_relu_(e3, negative_slope=0.2)))
e5 = self.e5_norm(self.e5_c(F.leaky_relu_(e4, negative_slope=0.2)))
e6 = self.e6_norm(self.e6_c(F.leaky_relu_(e5, negative_slope=0.2)))
e7 = self.e7_norm(self.e7_c(F.leaky_relu_(e6, negative_slope=0.2)))
# No norm in the inner_most layer
e8 = self.e8_c(F.leaky_relu_(e7, negative_slope=0.2))
# Decoder
d1 = self.d1_norm(self.d1_dc(F.relu_(e8)))
d2 = self.d2_norm(self.d2_dc(F.relu_(self.cat_feat(d1, e7))))
d3 = self.d3_norm(self.d3_dc(F.relu_(self.cat_feat(d2, e6))))
d4 = self.d4_norm(self.d4_dc(F.relu_(self.cat_feat(d3, e5))))
d5 = self.d5_norm(self.d5_dc(F.relu_(self.cat_feat(d4, e4))))
tmp, innerFeat = self.shift(self.innerCos(F.relu_(self.cat_feat(d5, e3))), flip_feat)
d6 = self.d6_norm(self.d6_dc(tmp))
d7 = self.d7_norm(self.d7_dc(F.relu_(self.cat_feat(d6, e2))))
# No norm on the last layer
d8 = self.d8_dc(F.relu_(self.cat_feat(d7, e1)))
d8 = torch.tanh(d8)
return d8, innerFeat
def cat_feat(self, de_feat, en_feat):
_, _, h1, w1 = de_feat.size()
_, _, h2, w2 = en_feat.size()
if h1 != h2 or w1 != w2:
de_feat = F.interpolate(de_feat, (h2, w2), mode='bilinear')
return torch.cat([de_feat, en_feat], 1)
################################### *************************** #####################################
################################### Res Shift_net #####################################
################################### *************************** #####################################
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class ResUnetGeneratorShiftTriple(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, opt, innerCos_list, shift_list, mask_global, ngf=64,
norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
super(ResUnetGeneratorShiftTriple, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
innermost=True, use_spectral_norm=use_spectral_norm)
print(unet_block)
for i in range(num_downs - 5): # The innner layers number is 3 (sptial size:512*512), if unet_256.
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_shift_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, opt, innerCos_list, shift_list,
mask_global, input_nc=None, \
submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm, layer_to_last=3) # passing in unet_shift_block
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_shift_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
self.model = unet_block
def forward(self, input):
return self.model(input)
# Mention: the TripleBlock differs in `upconv` defination.
# 'cos' means that we add a `innerCos` layer in the block.
class ResUnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, opt, innerCos_list, shift_list, mask_global, input_nc, \
submodule=None, shift_layer=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d,
use_spectral_norm=False, layer_to_last=3):
super(ResUnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if input_nc is None:
input_nc = outer_nc
downconv = spectral_norm(nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1), use_spectral_norm)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
device = 'cpu' if len(opt.gpu_ids) == 0 else 'gpu'
# As the downconv layer is outer_nc in and inner_nc out.
# So the shift define like this:
shift = InnerResShiftTriple(inner_nc, opt.shift_sz, opt.stride, opt.mask_thred,
opt.triple_weight, layer_to_last=layer_to_last, device=device)
shift.set_mask(mask_global)
shift_list.append(shift)
# Add latent constraint
# Then add the constraint to the constrain layer list!
innerCos = InnerCos(strength=opt.strength, skip=opt.skip, layer_to_last=layer_to_last, device=device)
innerCos.set_mask(mask_global) # Here we need to set mask for innerCos layer too.
innerCos_list.append(innerCos)
# Different position only has differences in `upconv`
# for the outermost, the special is `tanh`
if outermost:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
# for the innermost, the special is `inner_nc` instead of `inner_nc*2`
elif innermost:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downrelu, downconv] # for the innermost, no submodule, and delete the bn
up = [uprelu, upconv, upnorm]
model = down + up
# else, the normal
else:
# Res shift differs with other shift here. It is `*2` not `*3`.
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downrelu, downconv, downnorm]
# shift should be placed after uprelu
# NB: innerCos are placed before shift. So need to add the latent gredient to
# to former part.
up = [uprelu, innerCos, shift, upconv, upnorm]
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost: # if it is the outermost, directly pass the input in.
return self.model(x)
else:
x_latter = self.model(x)
_, _, h, w = x.size()
if h != x_latter.size(2) or w != x_latter.size(3):
x_latter = F.interpolate(x_latter, (h, w), mode='bilinear')
return torch.cat([x_latter, x], 1) # cat in the C channel
################################### *************************** #####################################
################################### patch soft shift_net #####################################
################################### *************************** #####################################
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class PatchSoftUnetGeneratorShiftTriple(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, opt, innerCos_list, shift_list, mask_global, ngf=64,
norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
super(PatchSoftUnetGeneratorShiftTriple, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
innermost=True, use_spectral_norm=use_spectral_norm)
print(unet_block)
for i in range(num_downs - 5): # The innner layers number is 3 (sptial size:512*512), if unet_256.
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_shift_block = PatchSoftUnetSkipConnectionShiftTriple(ngf * 2, ngf * 4, opt, innerCos_list, shift_list,
mask_global, input_nc=None, \
submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm, layer_to_last=3) # passing in unet_shift_block
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_shift_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
self.model = unet_block
def forward(self, input):
return self.model(input)
# Mention: the TripleBlock differs in `upconv` defination.
# 'cos' means that we add a `innerCos` layer in the block.
class PatchSoftUnetSkipConnectionShiftTriple(nn.Module):
def __init__(self, outer_nc, inner_nc, opt, innerCos_list, shift_list, mask_global, input_nc, \
submodule=None, shift_layer=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d,
use_spectral_norm=False, layer_to_last=3):
super(PatchSoftUnetSkipConnectionShiftTriple, self).__init__()
self.outermost = outermost
if input_nc is None:
input_nc = outer_nc
downconv = spectral_norm(nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1), use_spectral_norm)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
device = 'cpu' if len(opt.gpu_ids) == 0 else 'gpu'
# As the downconv layer is outer_nc in and inner_nc out.
# So the shift define like this:
shift = InnerPatchSoftShiftTriple(opt.shift_sz, opt.stride, opt.mask_thred,
opt.triple_weight, opt.fuse, layer_to_last=layer_to_last, device=device)
shift.set_mask(mask_global)
shift_list.append(shift)
# Add latent constraint
# Then add the constraint to the constrain layer list!
innerCos = InnerCos(strength=opt.strength, skip=opt.skip, layer_to_last=layer_to_last, device=device)
innerCos.set_mask(mask_global) # Here we need to set mask for innerCos layer too.
innerCos_list.append(innerCos)
# Different position only has differences in `upconv`
# for the outermost, the special is `tanh`
if outermost:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
# for the innermost, the special is `inner_nc` instead of `inner_nc*2`
elif innermost:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downrelu, downconv] # for the innermost, no submodule, and delete the bn
up = [uprelu, upconv, upnorm]
model = down + up
# else, the normal
else:
# shift triple differs in here. It is `*3` not `*2`.
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 3, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downrelu, downconv, downnorm]
# shift should be placed after uprelu
# NB: innerCos are placed before shift. So need to add the latent gredient to
# to former part.
up = [uprelu, innerCos, shift, upconv, upnorm]
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost: # if it is the outermost, directly pass the input in.
return self.model(x)
else:
x_latter = self.model(x)
_, _, h, w = x.size()
if h != x_latter.size(2) or w != x_latter.size(3):
x_latter = F.interpolate(x_latter, (h, w), mode='bilinear')
return torch.cat([x_latter, x], 1) # cat in the C channel
################################### *************************** #####################################
################################### Res patch soft shift_net #####################################
################################### *************************** #####################################
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class ResPatchSoftUnetGeneratorShiftTriple(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, opt, innerCos_list, shift_list, mask_global, ngf=64,
norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
super(ResPatchSoftUnetGeneratorShiftTriple, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
innermost=True, use_spectral_norm=use_spectral_norm)
print(unet_block)
for i in range(num_downs - 5): # The innner layers number is 3 (sptial size:512*512), if unet_256.
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_shift_block = ResPatchSoftUnetSkipConnectionShiftTriple(ngf * 2, ngf * 4, opt, innerCos_list, shift_list,
mask_global, input_nc=None, \
submodule=unet_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm, layer_to_last=3) # passing in unet_shift_block
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_shift_block,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True,
norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
self.model = unet_block
def forward(self, input):
return self.model(input)
# Mention: the TripleBlock differs in `upconv` defination.
# 'cos' means that we add a `innerCos` layer in the block.
class ResPatchSoftUnetSkipConnectionShiftTriple(nn.Module):
def __init__(self, outer_nc, inner_nc, opt, innerCos_list, shift_list, mask_global, input_nc, \
submodule=None, shift_layer=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d,
use_spectral_norm=False, layer_to_last=3):
super(ResPatchSoftUnetSkipConnectionShiftTriple, self).__init__()
self.outermost = outermost
if input_nc is None:
input_nc = outer_nc
downconv = spectral_norm(nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1), use_spectral_norm)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
device = 'cpu' if len(opt.gpu_ids) == 0 else 'gpu'
# As the downconv layer is outer_nc in and inner_nc out.
# So the shift define like this:
shift = InnerResPatchSoftShiftTriple(inner_nc, opt.shift_sz, opt.stride, opt.mask_thred,
opt.triple_weight, opt.fuse, layer_to_last=layer_to_last, device=device)
shift.set_mask(mask_global)
shift_list.append(shift)
# Add latent constraint
# Then add the constraint to the constrain layer list!
innerCos = InnerCos(strength=opt.strength, skip=opt.skip, layer_to_last=layer_to_last, device=device)
innerCos.set_mask(mask_global) # Here we need to set mask for innerCos layer too.
innerCos_list.append(innerCos)
# Different position only has differences in `upconv`
# for the outermost, the special is `tanh`
if outermost:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
# for the innermost, the special is `inner_nc` instead of `inner_nc*2`
elif innermost:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downrelu, downconv] # for the innermost, no submodule, and delete the bn
up = [uprelu, upconv, upnorm]
model = down + up
# else, the normal
else:
# Res shift differs with other shift here. It is `*2` not `*3`.
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downrelu, downconv, downnorm]
# shift should be placed after uprelu
# NB: innerCos are placed before shift. So need to add the latent gredient to
# to former part.
up = [uprelu, innerCos, shift, upconv, upnorm]
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost: # if it is the outermost, directly pass the input in.
return self.model(x)
else:
x_latter = self.model(x)
_, _, h, w = x.size()
if h != x_latter.size(2) or w != x_latter.size(3):
x_latter = F.interpolate(x_latter, (h, w), mode='bilinear')
return torch.cat([x_latter, x], 1) # cat in the C channel
import torch
import torch.nn as nn
import torch.nn.functional as F
from .modules import spectral_norm
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, use_spectral_norm=use_spectral_norm)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
self.model = unet_block
def forward(self, input):
return self.model(input)
# construct network from the inside to the outside.
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if input_nc is None:
input_nc = outer_nc
downconv = spectral_norm(nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1), use_spectral_norm)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
# Different position only has differences in `upconv`
# for the outermost, the special is `tanh`
if outermost:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
# for the innermost, the special is `inner_nc` instead of `inner_nc*2`
elif innermost:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downrelu, downconv] # for the innermost, no submodule, and delete the bn
up = [uprelu, upconv, upnorm]
model = down + up
# else, the normal
else:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost: # if it is the outermost, directly pass the input in.
return self.model(x)
else:
x_latter = self.model(x)
_, _, h, w = x.size()
if h != x_latter.size(2) or w != x_latter.size(3):
x_latter = F.interpolate(x_latter, (h, w), mode='bilinear')
return torch.cat([x_latter, x], 1) # cat in the C channel
# It is an easy type of UNet, intead of constructing UNet with UnetSkipConnectionBlocks.
# In this way, every thing is much clear and more flexible for extension.
class EasyUnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64,
norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
super(EasyUnetGenerator, self).__init__()
# Encoder layers
self.e1_c = spectral_norm(nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e2_c = spectral_norm(nn.Conv2d(ngf, ngf*2, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e2_norm = norm_layer(ngf*2)
self.e3_c = spectral_norm(nn.Conv2d(ngf*2, ngf*4, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e3_norm = norm_layer(ngf*4)
self.e4_c = spectral_norm(nn.Conv2d(ngf*4, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e4_norm = norm_layer(ngf*8)
self.e5_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e5_norm = norm_layer(ngf*8)
self.e6_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e6_norm = norm_layer(ngf*8)
self.e7_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e7_norm = norm_layer(ngf*8)
self.e8_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
# Deocder layers
self.d1_c = spectral_norm(nn.ConvTranspose2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d1_norm = norm_layer(ngf*8)
self.d2_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2 , ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d2_norm = norm_layer(ngf*8)
self.d3_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d3_norm = norm_layer(ngf*8)
self.d4_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d4_norm = norm_layer(ngf*8)
self.d5_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*4, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d5_norm = norm_layer(ngf*4)
self.d6_c = spectral_norm(nn.ConvTranspose2d(ngf*4*2, ngf*2, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d6_norm = norm_layer(ngf*2)
self.d7_c = spectral_norm(nn.ConvTranspose2d(ngf*2*2, ngf, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d7_norm = norm_layer(ngf)
self.d8_c = spectral_norm(nn.ConvTranspose2d(ngf*2, output_nc, kernel_size=4, stride=2, padding=1), use_spectral_norm)
# In this case, we have very flexible unet construction mode.
def forward(self, input):
# Encoder
# No norm on the first layer
e1 = self.e1_c(input)
e2 = self.e2_norm(self.e2_c(F.leaky_relu_(e1, negative_slope=0.2)))
e3 = self.e3_norm(self.e3_c(F.leaky_relu_(e2, negative_slope=0.2)))
e4 = self.e4_norm(self.e4_c(F.leaky_relu_(e3, negative_slope=0.2)))
e5 = self.e5_norm(self.e5_c(F.leaky_relu_(e4, negative_slope=0.2)))
e6 = self.e6_norm(self.e6_c(F.leaky_relu_(e5, negative_slope=0.2)))
e7 = self.e7_norm(self.e7_c(F.leaky_relu_(e6, negative_slope=0.2)))
# No norm on the inner_most layer
e8 = self.e8_c(F.leaky_relu_(e7, negative_slope=0.2))
# Decoder
d1 = self.d1_norm(self.d1_c(F.relu_(e8)))
d2 = self.d2_norm(self.d2_c(F.relu_(torch.cat([d1, e7], dim=1))))
d3 = self.d3_norm(self.d3_c(F.relu_(torch.cat([d2, e6], dim=1))))
d4 = self.d4_norm(self.d4_c(F.relu_(torch.cat([d3, e5], dim=1))))
d5 = self.d5_norm(self.d5_c(F.relu_(torch.cat([d4, e4], dim=1))))
d6 = self.d6_norm(self.d6_c(F.relu_(torch.cat([d5, e3], dim=1))))
d7 = self.d7_norm(self.d7_c(F.relu_(torch.cat([d6, e2], dim=1))))
# No norm on the last layer
d8 = self.d8_c(F.relu_(torch.cat([d7, e1], 1)))
d8 = torch.tanh(d8)
return d8
#-*-coding:utf-8-*-
from torch.nn import init
from torch.optim import lr_scheduler
from torchvision import models
from .modules import *
###############################################################################
# Functions
###############################################################################
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=True, track_running_stats=False)
elif norm_type == 'switchable':
norm_layer = functools.partial(SwitchNorm2d)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def get_scheduler(optimizer, opt):
if opt.lr_policy == 'lambda':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids)
init_weights(net, init_type, gain=init_gain)
return net
# Note: Adding SN to G tends to give inferior results. Need more checking.
def define_G(input_nc, output_nc, ngf, which_model_netG, opt, mask_global, norm='batch', use_spectral_norm=False, init_type='normal', gpu_ids=[], init_gain=0.02):
netG = None
norm_layer = get_norm_layer(norm_type=norm)
innerCos_list = []
shift_list = []
print('input_nc {}'.format(input_nc))
print('output_nc {}'.format(output_nc))
print('which_model_netG {}'.format(which_model_netG))
# Here we need to initlize an artificial mask_global to construct the init model.
# When training, we need to set mask for special layers(mostly for Shift layers) first.
# If mask is fixed during training, we only need to set mask for these layers once,
# else we need to set the masks each iteration, generating new random masks and mask the input
# as well as setting masks for these special layers.
print('[CREATED] MODEL')
if which_model_netG == 'unet_256':
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
elif which_model_netG == 'easy_unet_256':
netG = EasyUnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
elif which_model_netG == 'face_unet_shift_triple':
netG = FaceUnetGenerator(input_nc, output_nc, innerCos_list, shift_list, mask_global, opt, \
ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
elif which_model_netG == 'unet_shift_triple':
netG = UnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \
ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
elif which_model_netG == 'res_unet_shift_triple':
netG = ResUnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \
ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
elif which_model_netG == 'patch_soft_unet_shift_triple':
netG = PatchSoftUnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \
ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
elif which_model_netG == 'res_patch_soft_unet_shift_triple':
netG = ResPatchSoftUnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \
ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
print('[CREATED] MODEL')
print('Constraint in netG:')
print(innerCos_list)
print('Shift in netG:')
print(shift_list)
print('NetG:')
print(netG)
return init_net(netG, init_type, init_gain, gpu_ids), innerCos_list, shift_list
def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, norm='batch', use_sigmoid=False, use_spectral_norm=False, init_type='normal', gpu_ids=[], init_gain=0.02):
netD = None
norm_layer = get_norm_layer(norm_type=norm)
if which_model_netD == 'basic':
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, use_spectral_norm=use_spectral_norm)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, use_spectral_norm=use_spectral_norm)
elif which_model_netD == 'densenet':
netD = DenseNetDiscrimator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, use_spectral_norm=use_spectral_norm)
else:
print('Discriminator model name [%s] is not recognized' %
which_model_netD)
print('NetD:')
print(netD)
return init_net(netD, init_type, init_gain, gpu_ids)
import torch.nn as nn
import torch
import util.util as util
from .innerPatchSoftShiftTripleModule import InnerPatchSoftShiftTripleModule
# TODO: Make it compatible for show_flow.
#
class InnerPatchSoftShiftTriple(nn.Module):
def __init__(self, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, fuse=True, layer_to_last=3):
super(InnerPatchSoftShiftTriple, self).__init__()
self.shift_sz = shift_sz
self.stride = stride
self.mask_thred = mask_thred
self.triple_weight = triple_weight
self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
self.fuse = fuse
self.layer_to_last = layer_to_last
self.softShift = InnerPatchSoftShiftTripleModule()
def set_mask(self, mask_global):
mask = util.cal_feat_mask(mask_global, self.layer_to_last)
self.mask = mask
return self.mask
# If mask changes, then need to set cal_fix_flag true each iteration.
def forward(self, input):
_, self.c, self.h, self.w = input.size()
# Just pass self.mask in, instead of self.flag.
final_out = self.softShift(input, self.stride, self.triple_weight, self.mask, self.mask_thred, self.shift_sz, self.show_flow, self.fuse)
if self.show_flow:
self.flow_srcs = self.softShift.get_flow_src()
return final_out
def get_flow(self):
return self.flow_srcs
def set_flow_true(self):
self.show_flow = True
def set_flow_false(self):
self.show_flow = False
def __repr__(self):
return self.__class__.__name__+ '(' \
+ ' ,triple_weight ' + str(self.triple_weight) + ')'
from util.NonparametricShift import Modified_NonparametricShift
from torch.nn import functional as F
import torch.nn as nn
import torch
import util.util as util
class InnerPatchSoftShiftTripleModule(nn.Module):
def forward(self, input, stride, triple_w, mask, mask_thred, shift_sz, show_flow, fuse=True):
assert input.dim() == 4, "Input Dim has to be 4"
assert mask.dim() == 4, "Mask Dim has to be 4"
self.triple_w = triple_w
self.mask = mask
self.mask_thred = mask_thred
self.show_flow = show_flow
self.bz, self.c, self.h, self.w = input.size()
self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available else torch.FloatTensor
self.ind_lst = self.Tensor(self.bz, self.h * self.w, self.h * self.w).zero_()
# former and latter are all tensors
former_all = input.narrow(1, 0, self.c//2) ### decoder feature
latter_all = input.narrow(1, self.c//2, self.c//2) ### encoder feature
shift_masked_all = torch.Tensor(former_all.size()).type_as(former_all) # addition feature
self.mask = self.mask.to(input)
# extract patches from latter.
latter_all_pad = F.pad(latter_all, [shift_sz//2, shift_sz//2, shift_sz//2, shift_sz//2], 'constant', 0)
latter_all_windows = latter_all_pad.unfold(2, shift_sz, stride).unfold(3, shift_sz, stride)
latter_all_windows = latter_all_windows.contiguous().view(self.bz, -1, self.c//2, shift_sz, shift_sz)
# Extract patches from mask
# Mention: mask here must be 1*1*H*W
m_pad = F.pad(self.mask, (shift_sz//2, shift_sz//2, shift_sz//2, shift_sz//2), 'constant', 0)
m = m_pad.unfold(2, shift_sz, stride).unfold(3, shift_sz, stride)
m = m.contiguous().view(self.bz, 1, -1, shift_sz, shift_sz)
# It implements the similar functionality as `cal_flag_given_mask_thred`.
# However, it differs what `mm` means.
# Here mm: the masked reigon is filled with 0, nonmasked region is filled with 1.
# While mm in `cal_flag_given_mask_thred`, it is opposite.
m = torch.mean(torch.mean(m, dim=3, keepdim=True), dim=4, keepdim=True)
mm = m.le(self.mask_thred/(1.*shift_sz**2)).float() # bz*1*(32*32)*1*1
fuse_weight = torch.eye(shift_sz).view(1, 1, shift_sz, shift_sz).type_as(input)
self.shift_offsets = []
for idx in range(self.bz):
mm_cur = mm[idx]
# latter_win = latter_all_windows.narrow(0, idx, 1)[0]
latter_win = latter_all_windows.narrow(0, idx, 1)[0]
former = former_all.narrow(0, idx, 1)
# normalize latter for each patch.
latter_den = torch.sqrt(torch.einsum("bcij,bcij->b", [latter_win, latter_win]))
latter_den = torch.max(latter_den, self.Tensor([1e-4]))
latter_win_normed = latter_win/latter_den.view(-1, 1, 1, 1)
y_i = F.conv2d(former, latter_win_normed, stride=1, padding=shift_sz//2)
# conv implementation for fuse scores to encourage large patches
if fuse:
y_i = y_i.view(1, 1, self.h*self.w, self.h*self.w) # make all of depth of spatial resolution.
y_i = F.conv2d(y_i, fuse_weight, stride=1, padding=1)
y_i = y_i.contiguous().view(1, self.h, self.w, self.h, self.w)
y_i = y_i.permute(0, 2, 1, 4, 3)
y_i = y_i.contiguous().view(1, 1, self.h*self.w, self.h*self.w)
y_i = F.conv2d(y_i, fuse_weight, stride=1, padding=1)
y_i = y_i.contiguous().view(1, self.w, self.h, self.w, self.h)
y_i = y_i.permute(0, 2, 1, 4, 3)
y_i = y_i.contiguous().view(1, self.h*self.w, self.h, self.w) # 1*(32*32)*32*32
# firstly, wash away the masked reigon.
# multiply `mm` means (:, index_masked, :, :) will be 0.
y_i = y_i * mm_cur
# Then apply softmax to the nonmasked region.
cosine = F.softmax(y_i*10, dim=1)
# Finally, dummy parameters of masked reigon are filtered out.
cosine = cosine * mm_cur
# paste
shift_i = F.conv_transpose2d(cosine, latter_win, stride=1, padding=shift_sz//2)/9.
shift_masked_all[idx] = shift_i
# Addition: show shift map
# TODO: fix me.
# cosine here is a full size of 32*32, not only the masked region in `shift_net`,
# which results in non-direct reusing the code.
# torch.set_printoptions(threshold=2015)
# if self.show_flow:
# _, indexes = torch.max(cosine, dim=1)
# # calculate self.flag from self.m
# self.flag = (1 - mm).view(-1)
# torch.set_printoptions(threshold=1025)
# print(self.flag)
# non_mask_indexes = (self.flag == 0.).nonzero()
# non_mask_indexes = non_mask_indexes[indexes]
# print('ll')
# print(non_mask_indexes.size())
# print(non_mask_indexes)
# # Here non_mask_index is too large, should be 192.
# shift_offset = torch.stack([non_mask_indexes.squeeze() // self.w, non_mask_indexes.squeeze() % self.w], dim=-1)
# print(shift_offset.size())
# self.shift_offsets.append(shift_offset)
# print('cc')
# if self.show_flow:
# # Note: Here we assume that each mask is the same for the same batch image.
# self.shift_offsets = torch.cat(self.shift_offsets, dim=0).float() # make it cudaFloatTensor
# # Assume mask is the same for each image in a batch.
# mask_nums = self.shift_offsets.size(0)//self.bz
# self.flow_srcs = torch.zeros(self.bz, 3, self.h, self.w).type_as(input)
# for idx in range(self.bz):
# shift_offset = self.shift_offsets.narrow(0, idx*mask_nums, mask_nums)
# # reconstruct the original shift_map.
# shift_offsets_map = torch.zeros(1, self.h, self.w, 2).type_as(input)
# print(shift_offsets_map.size())
# print(shift_offset.unsqueeze(0).size())
# print(shift_offsets_map[:, (self.flag == 1).nonzero().squeeze() // self.w, (self.flag == 1).nonzero().squeeze() % self.w, :].size())
# shift_offsets_map[:, (self.flag == 1).nonzero().squeeze() // self.w, (self.flag == 1).nonzero().squeeze() % self.w, :] = \
# shift_offset.unsqueeze(0)
# # It is indicating the pixels(non-masked) that will shift the the masked region.
# flow_src = util.highlight_flow(shift_offsets_map, self.flag.unsqueeze(0))
# self.flow_srcs[idx] = flow_src
return torch.cat((former_all, latter_all, shift_masked_all), 1)
def get_flow_src(self):
return self.flow_srcs
from models.shift_net.shiftnet_model import ShiftNetModel
class PatchSoftShiftNetModel(ShiftNetModel):
def name(self):
return 'PatchSoftShiftNetModel'
import torch.nn as nn
import torch
import util.util as util
from models.patch_soft_shift.innerPatchSoftShiftTripleModule import InnerPatchSoftShiftTripleModule
# TODO: Make it compatible for show_flow.
#
class InnerResPatchSoftShiftTriple(nn.Module):
def __init__(self, inner_nc, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, fuse=True, layer_to_last=3):
super(InnerResPatchSoftShiftTriple, self).__init__()
self.shift_sz = shift_sz
self.stride = stride
self.mask_thred = mask_thred
self.triple_weight = triple_weight
self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
self.fuse = fuse
self.layer_to_last = layer_to_last
self.softShift = InnerPatchSoftShiftTripleModule()
# Additional for ResShift.
self.inner_nc = inner_nc
self.res_net = nn.Sequential(
nn.Conv2d(inner_nc*2, inner_nc, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(inner_nc),
nn.ReLU(True),
nn.Conv2d(inner_nc, inner_nc, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(inner_nc)
)
def set_mask(self, mask_global):
mask = util.cal_feat_mask(mask_global, self.layer_to_last)
self.mask = mask
return self.mask
# If mask changes, then need to set cal_fix_flag true each iteration.
def forward(self, input):
_, self.c, self.h, self.w = input.size()
# Just pass self.mask in, instead of self.flag.
# Try to making it faster by avoiding `cal_flag_given_mask_thread`.
shift_out = self.softShift(input, self.stride, self.triple_weight, self.mask, self.mask_thred, self.shift_sz, self.show_flow, self.fuse)
c_out = shift_out.size(1)
# get F_c, F_s, F_shift
F_c = shift_out.narrow(1, 0, c_out//3)
F_s = shift_out.narrow(1, c_out//3, c_out//3)
F_shift = shift_out.narrow(1, c_out*2//3, c_out//3)
F_fuse = F_c * F_shift
F_com = torch.cat([F_c, F_fuse], dim=1)
res_out = self.res_net(F_com)
F_c = F_c + res_out
final_out = torch.cat([F_c, F_s], dim=1)
if self.show_flow:
self.flow_srcs = self.softShift.get_flow_src()
return final_out
def get_flow(self):
return self.flow_srcs
def set_flow_true(self):
self.show_flow = True
def set_flow_false(self):
self.show_flow = False
def __repr__(self):
return self.__class__.__name__+ '(' \
+ ' ,triple_weight ' + str(self.triple_weight) + ')'
from models.shift_net.shiftnet_model import ShiftNetModel
class ResPatchSoftShiftNetModel(ShiftNetModel):
def name(self):
return 'ResPatchSoftShiftNetModel'
import torch.nn as nn
import torch
import util.util as util
from models.shift_net.InnerShiftTripleFunction import InnerShiftTripleFunction
class InnerResShiftTriple(nn.Module):
def __init__(self, inner_nc, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, layer_to_last=3):
super(InnerResShiftTriple, self).__init__()
self.shift_sz = shift_sz
self.stride = stride
self.mask_thred = mask_thred
self.triple_weight = triple_weight
self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
self.layer_to_last = layer_to_last
# Additional for ResShift.
self.inner_nc = inner_nc
self.res_net = nn.Sequential(
nn.Conv2d(inner_nc*2, inner_nc, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(inner_nc),
nn.ReLU(True),
nn.Conv2d(inner_nc, inner_nc, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(inner_nc)
)
def set_mask(self, mask_global):
mask = util.cal_feat_mask(mask_global, self.layer_to_last)
self.mask = mask.squeeze()
return self.mask
# If mask changes, then need to set cal_fix_flag true each iteration.
def forward(self, input):
#print(input.shape)
_, self.c, self.h, self.w = input.size()
self.flag = util.cal_flag_given_mask_thred(self.mask, self.shift_sz, self.stride, self.mask_thred)
shift_out = InnerShiftTripleFunction.apply(input, self.shift_sz, self.stride, self.triple_weight, self.flag, self.show_flow)
c_out = shift_out.size(1)
# get F_c, F_s, F_shift
F_c = shift_out.narrow(1, 0, c_out//3)
F_s = shift_out.narrow(1, c_out//3, c_out//3)
F_shift = shift_out.narrow(1, c_out*2//3, c_out//3)
F_fuse = F_c * F_shift
F_com = torch.cat([F_c, F_fuse], dim=1)
res_out = self.res_net(F_com)
F_c = F_c + res_out
final_out = torch.cat([F_c, F_s], dim=1)
if self.show_flow:
self.flow_srcs = InnerShiftTripleFunction.get_flow_src()
return final_out
def get_flow(self):
return self.flow_srcs
def set_flow_true(self):
self.show_flow = True
def set_flow_false(self):
self.show_flow = False
def __repr__(self):
return self.__class__.__name__+ '(' \
+ ' ,triple_weight ' + str(self.triple_weight) + ')'
from models.shift_net.shiftnet_model import ShiftNetModel
class ResShiftNetModel(ShiftNetModel):
def name(self):
return 'ResShiftNetModel'
\ No newline at end of file
import torch.nn as nn
import torch
import torch.nn.functional as F
import util.util as util
from .InnerCosFunction import InnerCosFunction
class InnerCos(nn.Module):
def __init__(self, crit='MSE', strength=1, skip=0, layer_to_last=3, device='gpu'):
super(InnerCos, self).__init__()
self.crit = crit
self.criterion = torch.nn.MSELoss() if self.crit == 'MSE' else torch.nn.L1Loss()
self.strength = strength
# To define whether this layer is skipped.
self.skip = skip
self.layer_to_last = layer_to_last
self.device = device
# Init a dummy value is fine.
self.target = torch.tensor(1.0)
self.bz = 0
self.c = 0
self.cur_mask = torch.tensor(0)
self.output = torch.tensor(0)
def set_mask(self, mask_global):
mask_all = util.cal_feat_mask(mask_global, self.layer_to_last)
self.mask_all = mask_all.float()
def _split_mask(self, cur_bsize):
# get the visible indexes of gpus and assign correct mask to set of images
cur_device = torch.cuda.current_device()
self.cur_mask = self.mask_all[cur_device*cur_bsize:(cur_device+1)*cur_bsize, :, :, :]
def forward(self, in_data):
self.bz = in_data.size(0)
self.c = in_data.size(1)
self.cur_mask = self.mask_all
self.cur_mask = self.cur_mask.to(in_data)
# if not self.skip:
# # It works like this:
# # Each iteration contains 2 forward passes, In the first forward pass, we input a GT image, just to get the target.
# # In the second forward pass, we input the corresponding corrupted image, then back-propagate the network, the guidance loss works as expected.
# self.output = InnerCosFunction.apply(in_data, self.criterion, self.strength, self.target, self.cur_mask)
# self.target = in_data.narrow(1, self.c // 2, self.c // 2).detach() # the latter part
# else:
self.output = in_data
return self.output
def __repr__(self):
skip_str = 'True' if not self.skip else 'False'
return self.__class__.__name__+ '(' \
+ 'skip: ' + skip_str \
+ 'layer ' + str(self.layer_to_last) + ' to last' \
+ ' ,strength: ' + str(self.strength) + ')'
import torch
import torch.nn as nn
class InnerCosFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, criterion, strength, target, mask):
ctx.c = input.size(1)
ctx.strength = strength
ctx.criterion = criterion
if len(target.size()) == 0: # For the first iteration.
target = target.expand_as(input.narrow(1, ctx.c // 2, ctx.c // 2)).type_as(input)
ctx.save_for_backward(input, target, mask)
return input
@staticmethod
def backward(ctx, grad_output):
with torch.enable_grad():
input, target, mask = ctx.saved_tensors
former = input.narrow(1, 0, ctx.c//2)
former_in_mask = torch.mul(former, mask)
if former_in_mask.size() != target.size(): # For the last iteration of one epoch
target = target.narrow(0, 0, 1).expand_as(former_in_mask).type_as(former_in_mask)
former_in_mask_clone = former_in_mask.clone().detach().requires_grad_(True)
ctx.loss = ctx.criterion(former_in_mask_clone, target) * ctx.strength
ctx.loss.backward()
grad_output[:,0:ctx.c//2, :,:] += former_in_mask_clone.grad
return grad_output, None, None, None, None
\ No newline at end of file
import torch.nn as nn
import torch
import util.util as util
from testFun import InnerShiftTripleFunction
class InnerShiftTriple(nn.Module):
def __init__(self, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, layer_to_last=3, device='gpu'):
super(InnerShiftTriple, self).__init__()
self.shift_sz = torch.tensor(shift_sz)
self.stride = torch.tensor(stride)
self.mask_thred = torch.tensor(mask_thred)
self.triple_weight = triple_weight
self.layer_to_last = layer_to_last
self.device = device
self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
self.bz = 0
self.c = 0
self.h = 0
self.w = 0
self.cur_mask = torch.tensor(0)
self.flag = torch.tensor(0)
def set_mask(self, mask_global):
self.mask_all = util.cal_feat_mask(mask_global, self.layer_to_last)
def _split_mask(self, cur_bsize):
# get the visible indexes of gpus and assign correct mask to set of images
cur_device = torch.cuda.current_device()
self.cur_mask = self.mask_all[cur_device*cur_bsize:(cur_device+1)*cur_bsize, :, :, :]
# If mask changes, then need to set cal_fix_flag true each iteration.
def forward(self, input):
self.bz = input.size(0)
self.c = input.size(1)
self.h = input.size(2)
self.w = input.size(3)
self.cur_mask = self.mask_all
self.flag = util.cal_flag_given_mask_thred(self.cur_mask, self.shift_sz, self.stride, self.mask_thred)
# final_out = InnerShiftTripleFunction.apply(input, self.shift_sz, self.stride, self.triple_weight, self.flag, self.show_flow)
# if self.show_flow:
# self.flow_srcs = InnerShiftTripleFunction.get_flow_src()
final_out = InnerShiftTripleFunction(input, self.shift_sz, self.stride, torch.tensor(self.triple_weight), self.flag, torch.tensor(self.show_flow))
return final_out
def get_flow(self):
return self.flow_srcs
def set_flow_true(self):
self.show_flow = True
def set_flow_false(self):
self.show_flow = False
def __repr__(self):
return self.__class__.__name__+ '(' \
+ ' ,triple_weight ' + str(self.triple_weight) + ')'
import numpy as np
from util.NonparametricShift import Modified_NonparametricShift, Batch_NonShift
import torch
import util.util as util
import time
class InnerShiftTripleFunction(torch.autograd.Function):
ctx = None
@staticmethod
def forward(ctx, input, shift_sz, stride, triple_w, flag, show_flow):
InnerShiftTripleFunction.ctx = ctx
assert input.dim() == 4, "Input Dim has to be 4"
ctx.triple_w = triple_w
ctx.flag = flag
ctx.show_flow = show_flow
ctx.bz, c_real, ctx.h, ctx.w = input.size()
c = c_real
ctx.ind_lst = torch.Tensor(ctx.bz, ctx.h * ctx.w, ctx.h * ctx.w).zero_().to(input)
# former and latter are all tensors
former_all = input.narrow(1, 0, c//2) ### decoder feature
latter_all = input.narrow(1, c//2, c//2) ### encoder feature
shift_masked_all = torch.Tensor(former_all.size()).type_as(former_all).zero_() # addition feature
ctx.flag = ctx.flag.to(input).long()
# None batch version
bNonparm = Batch_NonShift()
ctx.shift_offsets = []
# batch version
cosine, latter_windows, i_2, i_3, i_1 = bNonparm.cosine_similarity(former_all.clone(), latter_all.clone(), 1, stride, flag)
_, indexes = torch.max(cosine, dim=2)
mask_indexes = (flag==1).nonzero(as_tuple=False)[:, 1].view(ctx.bz, -1)
non_mask_indexes = (flag==0).nonzero(as_tuple=False)[:, 1].view(ctx.bz, -1).gather(1, indexes)
idx_b = torch.arange(ctx.bz).long().unsqueeze(1).expand(ctx.bz, mask_indexes.size(1))
# set the elemnets of indexed by [mask_indexes, non_mask_indexes] to 1.
# It is a batch version
ctx.ind_lst[(idx_b, mask_indexes, non_mask_indexes)] = 1
shift_masked_all = bNonparm._paste(latter_windows, ctx.ind_lst, i_2, i_3, i_1)
# --- Non-batch version ----
#for idx in range(ctx.bz):
# flag_cur = ctx.flag[idx]
# latter = latter_all.narrow(0, idx, 1) ### encoder feature
# former = former_all.narrow(0, idx, 1) ### decoder feature
# #GET COSINE, RESHAPED LATTER AND ITS INDEXES
# cosine, latter_windows, i_2, i_3, i_1 = Nonparm.cosine_similarity(former.clone().squeeze(), latter.clone().squeeze(), 1, stride, flag_cur)
# ## GET INDEXES THAT MAXIMIZE COSINE SIMILARITY
# _, indexes = torch.max(cosine, dim=1)
# # SET TRANSITION MATRIX
# mask_indexes = (flag_cur == 1).nonzero()
# non_mask_indexes = (flag_cur == 0).nonzero()[indexes]
# ctx.ind_lst[idx][mask_indexes, non_mask_indexes] = 1
# # GET FINAL SHIFT FEATURE
# shift_masked_all[idx] = Nonparm._paste(latter_windows, ctx.ind_lst[idx], i_2, i_3, i_1)
# if ctx.show_flow:
# shift_offset = torch.stack([non_mask_indexes.squeeze() // ctx.w, non_mask_indexes.squeeze() % ctx.w], dim=-1)
# ctx.shift_offsets.append(shift_offset)
if ctx.show_flow:
assert 1==2, "I do not want maintance the functionality of `show flow`... ^_^"
ctx.shift_offsets = torch.cat(ctx.shift_offsets, dim=0).float() # make it cudaFloatTensor
# Assume mask is the same for each image in a batch.
mask_nums = ctx.shift_offsets.size(0)//ctx.bz
ctx.flow_srcs = torch.zeros(ctx.bz, 3, ctx.h, ctx.w).type_as(input)
for idx in range(ctx.bz):
shift_offset = ctx.shift_offsets.narrow(0, idx*mask_nums, mask_nums)
# reconstruct the original shift_map.
shift_offsets_map = torch.zeros(1, ctx.h, ctx.w, 2).type_as(input)
shift_offsets_map[:, (flag_cur == 1).nonzero(as_tuple=False).squeeze() // ctx.w, (flag_cur == 1).nonzero(as_tuple=False).squeeze() % ctx.w, :] = \
shift_offset.unsqueeze(0)
# It is indicating the pixels(non-masked) that will shift the the masked region.
flow_src = util.highlight_flow(shift_offsets_map, flag_cur.unsqueeze(0))
ctx.flow_srcs[idx] = flow_src
return torch.cat((former_all, latter_all, shift_masked_all), 1)
@staticmethod
def get_flow_src():
return InnerShiftTripleFunction.ctx.flow_srcs
@staticmethod
def backward(ctx, grad_output):
ind_lst = ctx.ind_lst
c = grad_output.size(1)
# # the former and the latter are keep original. Only the thrid part is shifted.
# C: content, pixels in masked region of the former part.
# S: style, pixels in the non-masked region of the latter part.
# N: the shifted feature, the new feature that will be used as the third part of features maps.
# W_mat: ind_lst[idx], shift matrix.
# Note: **only the masked region in N has values**.
# The gradient of shift feature should be added back to the latter part(to be precise: S).
# `ind_lst[idx][i,j] = 1` means that the i_th pixel will **be replaced** by j_th pixel in the forward.
# When applying `S mm W_mat`, then S will be transfer to N.
# (pixels in non-masked region of the latter part will be shift to the masked region in the third part.)
# However, we need to transfer back the gradient of the third part to S.
# This means the graident in S will **`be replaced`(to be precise, enhanced)** by N.
grad_former_all = grad_output[:, 0:c//3, :, :]
grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
grad_shifted_all = grad_output[:, c*2//3:c, :, :].clone()
W_mat_t = ind_lst.permute(0, 2, 1).contiguous()
grad = grad_shifted_all.view(ctx.bz, c//3, -1).permute(0, 2, 1)
grad_shifted_weighted = torch.bmm(W_mat_t, grad)
grad_shifted_weighted = grad_shifted_weighted.permute(0, 2, 1).contiguous().view(ctx.bz, c//3, ctx.h, ctx.w)
grad_latter_all = torch.add(grad_latter_all, grad_shifted_weighted.mul(ctx.triple_w))
# ----- 'Non_batch version here' --------------------
# for idx in range(ctx.bz):
# # So we need to transpose `W_mat`
# W_mat_t = ind_lst[idx].t()
# grad = grad_shifted_all[idx].view(c//3, -1).t()
# grad_shifted_weighted = torch.mm(W_mat_t, grad)
# # Then transpose it back
# grad_shifted_weighted = grad_shifted_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
# grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_shifted_weighted.mul(ctx.triple_w))
# note the input channel and the output channel are all c, as no mask input for now.
grad_input = torch.cat([grad_former_all, grad_latter_all], 1)
return grad_input, None, None, None, None, None, None
import os
import torch
from collections import OrderedDict
from torchsummary import summary
class BaseModel():
def name(self):
return 'BaseModel'
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
if opt.resize_or_crop != 'scale_width':
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.model_names = []
self.visual_names = []
self.image_paths = []
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, wrapping `forward` in no_grad() so we don't save
# intermediate steps for backprop
def test(self):
with torch.no_grad():
self.forward()
# get image paths
def get_image_paths(self):
return self.image_paths
def optimize_parameters(self):
pass
# update learning rate (called once every epoch)
def update_learning_rate(self):
for scheduler in self.schedulers:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
# return visualization images. train.py will display these images, and save the images to a html
def get_current_visuals(self):
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
# return traning losses/errors. train.py will print out these errors as debugging information
def get_current_losses(self):
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
# float(...) works for both scalar tensor and float number
errors_ret[name] = float(getattr(self, 'loss_' + name))
return errors_ret
# save models to the disk
def save_networks(self, which_epoch):
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (which_epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
############################################
if name == 'G':
save_filename_pt_st = '%s_net_%s_st.pt' % (which_epoch, name)
###unet256 成功#############
# example = torch.rand(1, 4, 256, 256)
# example = example.cuda()
# net.eval()
# traced_script_module = torch.jit.trace(net.module.cuda().eval(), example)
# traced_script_module.save(save_filename_pt_st)
example1 = torch.ones(1, 4, 64, 64).cuda()
example2 = torch.ones(2, 4, 64, 64).cuda()
# result1 = net.module(example)
net.eval()
# traced_script_module = torch.jit.trace(net.module.cuda().eval(), example)
traced_script_module = torch.jit.script(net.module)
result1 = traced_script_module(example1)
result2 = traced_script_module(example2)
print(result1)
print(result2)
# torch.jit.export_opnames(traced_script_module)
traced_script_module.save(save_filename_pt_st)
else:
torch.save(net.cpu().state_dict(), save_path)
############################################
if name == 'G':
save_filename_pt_st = '%s_net_%s_st.pt' % (which_epoch, name)
# torch.save(net.cpu(), save_filename_pt_st)
# example = torch.rand(1, 4, 256, 256)
# example = example.cuda()
# traced_script_module = torch.jit.trace(net.module, (example, example))
# traced_script_module.save(save_filename_pt_st)
example = torch.zeros(1, 4, 128, 128)
example = example
net.eval()
traced_script_module = torch.jit.trace(net.cpu(), example)
traced_script_module.save(save_filename_pt_st)
'''
model = torch.load(save_path)
net.load_state_dict(model)
net.eval()
example = torch.rand(1, 4, 256, 256).cuda() # 生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(net, example)
traced_script_module.save(save_filename_pt_st)
model = torch.load(save_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary(model, input_size=(1, 4, 256, 256))
model = model.to(device)
traced_script_module = torch.jit.trace(model, torch.ones(1, 4, 256, 256).to(device))
traced_script_module.save(save_filename_pt_st)
'''
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
# load models from the disk
def load_networks(self, which_epoch):
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (which_epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
#state_dict = torch.load(load_path)
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
# print network information
def print_networks(self, verbose):
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
\ No newline at end of file
import torch
from torch.nn import functional as F
import util.util as util
from models import networks
from models.shift_net.base_model import BaseModel
import time
import torchvision.transforms as transforms
import os
import numpy as np
from PIL import Image
class ShiftNetModel(BaseModel):
def name(self):
return 'ShiftNetModel'
def create_random_mask(self):
if self.opt.mask_type == 'random':
if self.opt.mask_sub_type == 'fractal':
assert 1==2, "It is broken somehow, use another mask_sub_type please"
mask = util.create_walking_mask() # create an initial random mask.
elif self.opt.mask_sub_type == 'rect':
mask, rand_t, rand_l = util.create_rand_mask(self.opt)
self.rand_t = rand_t
self.rand_l = rand_l
return mask
elif self.opt.mask_sub_type == 'island':
mask = util.wrapper_gmask(self.opt)
return mask
def initialize(self, opt):
BaseModel.initialize(self, opt)
self.opt = opt
self.isTrain = opt.isTrain
# specify the training losses you want to print out. The program will call base_model.get_current_losses
self.loss_names = ['G_GAN', 'G_L1', 'D', 'style', 'content', 'tv']
# specify the images you want to save/display. The program will call base_model.get_current_visuals
if self.opt.show_flow:
self.visual_names = ['real_A', 'fake_B', 'real_B', 'flow_srcs']
else:
self.visual_names = ['real_A', 'fake_B', 'real_B']
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
if self.isTrain:
self.model_names = ['G', 'D']
else: # during test time, only load Gs
self.model_names = ['G']
# batchsize should be 1 for mask_global
self.mask_global = torch.zeros((self.opt.batchSize, 1, \
opt.fineSize, opt.fineSize), dtype=torch.bool)
# Here we need to set an artificial mask_global(center hole is ok.)
self.mask_global.zero_()
self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1
if len(opt.gpu_ids) > 0:
self.mask_global = self.mask_global.to(self.device)
# load/define networks
# self.ng_innerCos_list is the guidance loss list in netG inner layers.
# self.ng_shift_list is the mask list constructing shift operation.
if opt.add_mask2input:
input_nc = opt.input_nc + 1
else:
input_nc = opt.input_nc
self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G(input_nc, opt.output_nc, opt.ngf,
opt.which_model_netG, opt, self.mask_global, opt.norm, opt.use_spectral_norm_G, opt.init_type, self.gpu_ids, opt.init_gain)
if self.isTrain:
use_sigmoid = False
if opt.gan_type == 'vanilla':
use_sigmoid = True # only vanilla GAN using BCECriterion
# don't use cGAN
self.netD = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.use_spectral_norm_D, opt.init_type, self.gpu_ids, opt.init_gain)
# add style extractor
self.vgg16_extractor = util.VGG16FeatureExtractor()
if len(opt.gpu_ids) > 0:
self.vgg16_extractor = self.vgg16_extractor.to(self.gpu_ids[0])
self.vgg16_extractor = torch.nn.DataParallel(self.vgg16_extractor, self.gpu_ids)
if self.isTrain:
self.old_lr = opt.lr
# define loss functions
self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to(self.device)
self.criterionL1 = torch.nn.L1Loss()
self.criterionL1_mask = networks.Discounted_L1(opt).to(self.device) # make weights/buffers transfer to the correct device
# VGG loss
self.criterionL2_style_loss = torch.nn.MSELoss()
self.criterionL2_content_loss = torch.nn.MSELoss()
# TV loss
self.tv_criterion = networks.TVLoss(self.opt.tv_weight)
# initialize optimizers
self.schedulers = []
self.optimizers = []
if self.opt.gan_type == 'wgan_gp':
opt.beta1 = 0
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.9))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.9))
else:
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
for optimizer in self.optimizers:
self.schedulers.append(networks.get_scheduler(optimizer, opt))
if not self.isTrain or opt.continue_train:
self.load_networks(opt.which_epoch)
self.print_networks(opt.verbose)
def set_input(self, input):
self.image_paths = input['A_paths']
real_A = input['A'].to(self.device)
real_B = input['B'].to(self.device)
# directly load mask offline
self.mask_global = input['M'].to(self.device).byte()
self.mask_global = self.mask_global.narrow(1,0,1).bool()
# create mask online
if not self.opt.offline_loading_mask:
if self.opt.mask_type == 'center':
self.mask_global.zero_()
self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1
self.rand_t, self.rand_l = int(self.opt.fineSize/4) + self.opt.overlap, int(self.opt.fineSize/4) + self.opt.overlap
elif self.opt.mask_type == 'random':
self.mask_global = self.create_random_mask().type_as(self.mask_global).view(1, *self.mask_global.size()[-3:])
# As generating random masks online are computation-heavy
# So just generate one ranodm mask for a batch images.
self.mask_global = self.mask_global.expand(self.opt.batchSize, *self.mask_global.size()[-3:])
else:
raise ValueError("Mask_type [%s] not recognized." % self.opt.mask_type)
# For loading mask offline, we also need to change 'opt.mask_type' and 'opt.mask_sub_type'
# to avoid forgetting such settings.
else:
self.opt.mask_type = 'random'
self.opt.mask_sub_type = 'island'
self.set_latent_mask(self.mask_global)
real_A.narrow(1,0,1).masked_fill_(self.mask_global, 0.)#2*123.0/255.0 - 1.0
real_A.narrow(1,1,1).masked_fill_(self.mask_global, 0.)#2*104.0/255.0 - 1.0
real_A.narrow(1,2,1).masked_fill_(self.mask_global, 0.)#2*117.0/255.0 - 1.0
if self.opt.add_mask2input:
# make it 4 dimensions.
# Mention: the extra dim, the masked part is filled with 0, non-mask part is filled with 1.
real_A = torch.cat((real_A, (~self.mask_global).expand(real_A.size(0), 1, real_A.size(2), real_A.size(3)).type_as(real_A)), dim=1)
self.real_A = real_A
self.real_B = real_B
def set_latent_mask(self, mask_global):
for ng_shift in self.ng_shift_list: # ITERATE OVER THE LIST OF ng_shift_list
ng_shift.set_mask(mask_global)
for ng_innerCos in self.ng_innerCos_list: # ITERATE OVER THE LIST OF ng_innerCos_list:
ng_innerCos.set_mask(mask_global)
def set_gt_latent(self):
if not self.opt.skip:
if self.opt.add_mask2input:
# make it 4 dimensions.
# Mention: the extra dim, the masked part is filled with 0, non-mask part is filled with 1.
real_B = torch.cat([self.real_B, (~self.mask_global).expand(self.real_B.size(0), 1, self.real_B.size(2), self.real_B.size(3)).type_as(self.real_B)], dim=1)
else:
real_B = self.real_B
self.netG(real_B) # input ground truth
def forward(self):
self.set_gt_latent()
self.fake_B = self.netG(self.real_A)
# Just assume one shift layer.
def set_flow_src(self):
self.flow_srcs = self.ng_shift_list[0].get_flow()
self.flow_srcs = F.interpolate(self.flow_srcs, scale_factor=8, mode='nearest')
# Just to avoid forgetting setting show_map_false
self.set_show_map_false()
# Just assume one shift layer.
def set_show_map_true(self):
self.ng_shift_list[0].set_flow_true()
def set_show_map_false(self):
self.ng_shift_list[0].set_flow_false()
def get_image_paths(self):
return self.image_paths
def backward_D(self):
fake_B = self.fake_B
# Real
real_B = self.real_B # GroundTruth
# Has been verfied, for square mask, let D discrinate masked patch, improves the results.
if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
# Using the cropped fake_B as the input of D.
fake_B = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
real_B = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
self.pred_fake = self.netD(fake_B.detach())
self.pred_real = self.netD(real_B)
if self.opt.gan_type == 'wgan_gp':
gradient_penalty, _ = util.cal_gradient_penalty(self.netD, real_B, fake_B.detach(), self.device, constant=1, lambda_gp=self.opt.gp_lambda)
self.loss_D_fake = torch.mean(self.pred_fake)
self.loss_D_real = -torch.mean(self.pred_real)
self.loss_D = self.loss_D_fake + self.loss_D_real + gradient_penalty
else:
if self.opt.gan_type in ['vanilla', 'lsgan']:
self.loss_D_fake = self.criterionGAN(self.pred_fake, False)
self.loss_D_real = self.criterionGAN (self.pred_real, True)
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
elif self.opt.gan_type == 're_s_gan':
self.loss_D = self.criterionGAN(self.pred_real - self.pred_fake, True)
self.loss_D.backward()
def backward_G(self):
# First, G(A) should fake the discriminator
fake_B = self.fake_B
# Has been verfied, for square mask, let D discrinate masked patch, improves the results.
if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
# Using the cropped fake_B as the input of D.
fake_B = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
real_B = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
else:
real_B = self.real_B
pred_fake = self.netD(fake_B)
if self.opt.gan_type == 'wgan_gp':
self.loss_G_GAN = -torch.mean(pred_fake)
else:
if self.opt.gan_type in ['vanilla', 'lsgan']:
self.loss_G_GAN = self.criterionGAN(pred_fake, True) * self.opt.gan_weight
elif self.opt.gan_type == 're_s_gan':
pred_real = self.netD (real_B)
self.loss_G_GAN = self.criterionGAN (pred_fake - pred_real, True) * self.opt.gan_weight
elif self.opt.gan_type == 're_avg_gan':
self.pred_real = self.netD(real_B)
self.loss_G_GAN = (self.criterionGAN (self.pred_real - torch.mean(self.pred_fake), False) \
+ self.criterionGAN (self.pred_fake - torch.mean(self.pred_real), True)) / 2.
self.loss_G_GAN *= self.opt.gan_weight
# If we change the mask as 'center with random position', then we can replacing loss_G_L1_m with 'Discounted L1'.
self.loss_G_L1, self.loss_G_L1_m = 0, 0
self.loss_G_L1 += self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A
# calcuate mask construction loss
# When mask_type is 'center' or 'random_with_rect', we can add additonal mask region construction loss (traditional L1).
# Only when 'discounting_loss' is 1, then the mask region construction loss changes to 'discounting L1' instead of normal L1.
if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
mask_patch_fake = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
mask_patch_real = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
# Using Discounting L1 loss
self.loss_G_L1_m += self.criterionL1_mask(mask_patch_fake, mask_patch_real)*self.opt.mask_weight_G
self.loss_G = self.loss_G_L1 + self.loss_G_L1_m + self.loss_G_GAN
# Then, add TV loss
self.loss_tv = self.tv_criterion(self.fake_B*self.mask_global.float())
# Finally, add style loss
vgg_ft_fakeB = self.vgg16_extractor(fake_B)
vgg_ft_realB = self.vgg16_extractor(real_B)
self.loss_style = 0
self.loss_content = 0
for i in range(3):
self.loss_style += self.criterionL2_style_loss(util.gram_matrix(vgg_ft_fakeB[i]), util.gram_matrix(vgg_ft_realB[i]))
self.loss_content += self.criterionL2_content_loss(vgg_ft_fakeB[i], vgg_ft_realB[i])
self.loss_style *= self.opt.style_weight
self.loss_content *= self.opt.content_weight
self.loss_G += (self.loss_style + self.loss_content + self.loss_tv)
self.loss_G.backward()
def optimize_parameters(self):
self.forward()
# update D
self.set_requires_grad(self.netD, True)
self.optimizer_D.zero_grad()
self.backward_D()
self.optimizer_D.step()
# update G
self.set_requires_grad(self.netD, False)
self.optimizer_G.zero_grad()
self.backward_G()
self.optimizer_G.step()
This source diff could not be displayed because it is too large. You can view the blob instead.
{
"cells": [
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import os\n",
"import torch\n",
"import argparse\n",
"import matplotlib.pyplot as plt\n",
"import sys\n",
"sys.path.append('../')\n",
"from models.soft_shift_net.innerSoftShiftTriple import InnerSoftShiftTriple\n",
"#from models.accelerated_shift_net.accelerated_InnerShiftTriple import AcceleratedInnerShiftTriple\n",
"from options.train_options import TrainOptions \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CREATE DEFAULT OPTIONS TO INITIALIZE THE SHIFTMODEL"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"dataroot = '/mnt/hdd2/AIM/DAGM/Class4_def/' # ENTER HERE THE PATH YOU WANT TO USE AS DATAROOT\n",
"options = '--dataroot {}'.format(dataroot).split(' ')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def get_parser(options=None):\n",
" parser = TrainOptions()\n",
" parser.parse(options=options)\n",
" return parser"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------- Options ---------------\n",
" add_mask2input: False \n",
" batchSize: 1 \n",
" beta1: 0.5 \n",
" bottleneck: 512 \n",
" checkpoints_dir: ./log \n",
" constrain: MSE \n",
" continue_train: False \n",
" dataroot: /mnt/hdd2/AIM/DAGM/Class4_def/\t[default: ./datasets/Paris/train]\n",
" dataset_mode: aligned \n",
" display_freq: 10 \n",
" display_id: 1 \n",
" display_ncols: 4 \n",
" display_port: 8097 \n",
" display_server: http://localhost \n",
"display_single_pane_ncols: 0 \n",
" display_winsize: 256 \n",
" epoch_count: 1 \n",
" fineSize: 256 \n",
" fixed_mask: 1 \n",
" gan_type: vanilla \n",
" gan_weight: 0.2 \n",
" gp_lambda: 10.0 \n",
" gpu_ids: 0 \n",
" init_gain: 0.02 \n",
" init_type: normal \n",
" input_nc: 3 \n",
" isTrain: True \t[default: None]\n",
" lambda_A: 100 \n",
" loadSize: 350 \n",
" lr: 0.0002 \n",
" lr_decay_iters: 50 \n",
" lr_policy: lambda \n",
" mask_sub_type: island \n",
" mask_thred: 1 \n",
" mask_type: random \n",
" max_dataset_size: inf \n",
" model: accelerated_shiftnet \n",
" nThreads: 2 \n",
" n_layers_D: 3 \n",
" name: \n",
" ncritic: 5 \n",
" ndf: 64 \n",
" ngf: 64 \n",
" niter: 10000000 \n",
" niter_decay: 0 \n",
" no_flip: False \n",
" no_html: False \n",
" norm: instance \n",
" only_lastest: True \n",
" output_nc: 3 \n",
" overlap: 4 \n",
" phase: train \n",
" print_freq: 50 \n",
" resize_or_crop: resize_and_crop \n",
" save_epoch_freq: 2 \n",
" save_latest_freq: 5000 \n",
" serial_batches: False \n",
" shift_sz: 1 \n",
" skip: 0 \n",
" strength: 1 \n",
" stride: 1 \n",
" suffix: \n",
" threshold: 0.3125 \n",
" triple_weight: 1 \n",
" update_html_freq: 1000 \n",
" use_dropout: False \n",
" verbose: False \n",
" which_epoch: latest \n",
" which_model_netD: densenet \n",
" which_model_netG: acc_unet_shift_triple \n",
"----------------- End -------------------\n"
]
}
],
"source": [
"parser = get_parser(options=options)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CREATE INNER_SHIFT_TRIPLE LAYER"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"#from models.InnerShiftTriple import InnerShiftTriple"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"opt = parser.opt\n",
"#opt"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"inner_shift_triple = InnerSoftShiftTriple(opt.threshold, opt.fixed_mask)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"InnerSoftShiftTriple(threshold: 0.3125 ,triple_weight 1)"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inner_shift_triple.cuda()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# EVALUE SPEED FORWARD"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### THE SIZE OF THE INPUT TENSOR IS (BATCH_SIZE, 256 * 2 (former | latter), 32, 32). LET CREATE A RANDOM TENSORS AND EVALUTE ITS FORWARD FIRST"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"torch.cuda.is_available()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### NOW WE NEED TO SET UP THE MASK"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f1979918588>"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADJBJREFUeJzt3H+s3XV9x/Hna7TUDF2E4ZpSmoGm+wOXrJIbJJEYFzKBZknxHwJ/SGdI6h+YaOL+qPqH/GPilqmZyUZSI7EuTiRTQ/9gU2xMjH+oFFKBwtCKJbQrdA6CZCYV8L0/7rd47Pve3tt7z7nn3Pl8JDfnez/ne+5595vmme/5mapCkkb9wbQHkDR7DIOkxjBIagyDpMYwSGoMg6RmYmFIcmOSp5IcTbJ3UvcjafwyifcxJLkA+AnwV8Bx4CHgtqp6Yux3JmnsJnXGcA1wtKqerqpfA/cCuyZ0X5LGbMOE/u5W4NmR348D71xs5wuzqd7ARRMaRRLAy7z4i6p6y3L2nVQYlpRkD7AH4A38Ie/M9dMaRfq98J36t2eWu++kHkqcALaN/H75sPa6qtpXVXNVNbeRTRMaQ9JKTCoMDwHbk1yZ5ELgVuDAhO5L0phN5KFEVb2a5EPAt4ALgHuq6sgk7kvS+E3sOYaqegB4YFJ/X9Lk+M5HSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUbFjNjZMcA14GXgNeraq5JJcAXwOuAI4Bt1TVi6sbU9JaGscZw19W1Y6qmht+3wscrKrtwMHhd0nryCQeSuwC9g/b+4GbJ3AfkiZotWEo4NtJHk6yZ1jbXFUnh+3ngM0L3TDJniSHkhx6hdOrHEPSOK3qOQbguqo6keRPgAeT/OfolVVVSWqhG1bVPmAfwB/lkgX3kTQdqzpjqKoTw+Up4JvANcDzSbYADJenVjukpLW14jAkuSjJm85sA+8FHgcOALuH3XYD9692SElrazUPJTYD30xy5u/8a1X9R5KHgPuS3AE8A9yy+jElraUVh6Gqngb+YoH1/wGuX81QkqbLdz5KagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGo2LLVDknuAvwZOVdWfD2uXAF8DrgCOAbdU1YtJAvwjsBP4FfA3VfXIZEbX2b71X4enPcJE3HDZjmmP8HtnOWcMXwJuPGttL3CwqrYDB4ffAW4Ctg8/e4C7xzOmpLW0ZBiq6nvAC2ct7wL2D9v7gZtH1r9c834AvDnJlnENK2ltrPQ5hs1VdXLYfg7YPGxvBZ4d2e/4sCZpHVn1k49VVUCd7+2S7ElyKMmhVzi92jEkjdFKw/D8mYcIw+WpYf0EsG1kv8uHtaaq9lXVXFXNbWTTCseQNAkrDcMBYPewvRu4f2T99sy7Fnhp5CGHpHViOS9XfhV4D3BpkuPAJ4FPA/cluQN4Brhl2P0B5l+qPMr8y5UfmMDMkiZsyTBU1W2LXHX9AvsWcOdqh5I0Xb7zUVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNUuGIck9SU4leXxk7a4kJ5IcHn52jlz3sSRHkzyV5IZJDS5pcpZzxvAl4MYF1j9XVTuGnwcAklwF3Aq8fbjNPye5YFzDSlobS4ahqr4HvLDMv7cLuLeqTlfVz4GjwDWrmE/SFKzmOYYPJXl0eKhx8bC2FXh2ZJ/jw1qTZE+SQ0kOvcLpVYwhadxWGoa7gbcBO4CTwGfO9w9U1b6qmququY1sWuEYkiZhRWGoquer6rWq+g3wBX77cOEEsG1k18uHNUnryIrCkGTLyK/vA868YnEAuDXJpiRXAtuBH61uRElrbcNSOyT5KvAe4NIkx4FPAu9JsgMo4BjwQYCqOpLkPuAJ4FXgzqp6bTKjS5qUJcNQVbctsPzFc+z/KeBTqxlK0nT5zkdJjWGQ1Cz5UELrxw2X7Zj2CPp/wjMGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1CwZhiTbknw3yRNJjiT58LB+SZIHk/x0uLx4WE+Szyc5muTRJFdP+h8habyWc8bwKvDRqroKuBa4M8lVwF7gYFVtBw4OvwPcBGwffvYAd499akkTtWQYqupkVT0ybL8MPAlsBXYB+4fd9gM3D9u7gC/XvB8Ab06yZeyTS5qY83qOIckVwDuAHwKbq+rkcNVzwOZheyvw7MjNjg9rktaJZYchyRuBrwMfqapfjl5XVQXU+dxxkj1JDiU59Aqnz+emkiZsWWFIspH5KHylqr4xLD9/5iHCcHlqWD8BbBu5+eXD2u+oqn1VNVdVcxvZtNL5JU3Acl6VCPBF4Mmq+uzIVQeA3cP2buD+kfXbh1cnrgVeGnnIIWkd2LCMfd4FvB94LMnhYe3jwKeB+5LcATwD3DJc9wCwEzgK/Ar4wFgnljRxS4ahqr4PZJGrr19g/wLuXOVckqbIdz5KagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6RmyTAk2Zbku0meSHIkyYeH9buSnEhyePjZOXKbjyU5muSpJDdM8h8gafw2LGOfV4GPVtUjSd4EPJzkweG6z1XVP4zunOQq4Fbg7cBlwHeS/FlVvTbOwSVNzpJnDFV1sqoeGbZfBp4Etp7jJruAe6vqdFX9HDgKXDOOYSWtjfN6jiHJFcA7gB8OSx9K8miSe5JcPKxtBZ4dudlxFghJkj1JDiU59Aqnz3twSZOz7DAkeSPwdeAjVfVL4G7gbcAO4CTwmfO546raV1VzVTW3kU3nc1NJE7asMCTZyHwUvlJV3wCoquer6rWq+g3wBX77cOEEsG3k5pcPa5LWieW8KhHgi8CTVfXZkfUtI7u9D3h82D4A3JpkU5Irge3Aj8Y3sqRJW86rEu8C3g88luTwsPZx4LYkO4ACjgEfBKiqI0nuA55g/hWNO31FQlpfUlXTnoEk/w38L/CLac+yDJeyPuaE9TOrc47fQrP+aVW9ZTk3nokwACQ5VFVz055jKetlTlg/szrn+K12Vt8SLakxDJKaWQrDvmkPsEzrZU5YP7M65/itataZeY5B0uyYpTMGSTNi6mFIcuPw8eyjSfZOe56zJTmW5LHho+WHhrVLkjyY5KfD5cVL/Z0JzHVPklNJHh9ZW3CuzPv8cIwfTXL1DMw6cx/bP8dXDMzUcV2Tr0Koqqn9ABcAPwPeClwI/Bi4apozLTDjMeDSs9b+Htg7bO8F/m4Kc70buBp4fKm5gJ3AvwMBrgV+OAOz3gX87QL7XjX8P9gEXDn8/7hgjebcAlw9bL8J+Mkwz0wd13PMObZjOu0zhmuAo1X1dFX9GriX+Y9tz7pdwP5hez9w81oPUFXfA144a3mxuXYBX655PwDefNZb2idqkVkXM7WP7dfiXzEwU8f1HHMu5ryP6bTDsKyPaE9ZAd9O8nCSPcPa5qo6OWw/B2yezmjNYnPN6nFe8cf2J+2srxiY2eM6zq9CGDXtMKwH11XV1cBNwJ1J3j16Zc2fq83cSzuzOteIVX1sf5IW+IqB183ScR33VyGMmnYYZv4j2lV1Yrg8BXyT+VOw58+cMg6Xp6Y34e9YbK6ZO841ox/bX+grBpjB4zrpr0KYdhgeArYnuTLJhcx/V+SBKc/0uiQXDd9zSZKLgPcy//HyA8DuYbfdwP3TmbBZbK4DwO3Ds+jXAi+NnBpPxSx+bH+xrxhgxo7rYnOO9ZiuxbOoSzzDupP5Z1V/Bnxi2vOcNdtbmX8298fAkTPzAX8MHAR+CnwHuGQKs32V+dPFV5h/zHjHYnMx/6z5Pw3H+DFgbgZm/ZdhlkeH/7hbRvb/xDDrU8BNazjndcw/THgUODz87Jy143qOOcd2TH3no6Rm2g8lJM0gwyCpMQySGsMgqTEMkhrDIKkxDJIawyCp+T/ngnElFjH4JwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"c, h, w = (1, 256, 256)\n",
"hh = h//2\n",
"wh = w//2\n",
"hm_size = 32\n",
"mask = np.zeros((1, c, h, w))\n",
"mask[..., hh - hm_size:hh + hm_size, wh - hm_size:wh + hm_size] = 1\n",
"#mask[..., h - hh:, :] = 1\n",
"mask_global=torch.ByteTensor(mask).cuda()#.cpu()\n",
"plt.imshow(np.squeeze(mask))"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" ...,\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0]], device='cuda:0', dtype=torch.uint8)"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inner_shift_triple.set_mask(mask_global=mask_global, threshold=opt.threshold, layer_to_last=3)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"x_np = np.random.normal(0, 1, (1, 512, 32, 32))\n",
"x_tr = torch.FloatTensor(x_np)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.05 ms ± 846 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit output = inner_shift_triple(x_tr.cuda())\n",
"#output = inner_shift_triple(x_tr.cuda())\n",
"#flag, indexes, ind_lst = inner_shift_triple(x_tr.cuda())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = tuple((np.where(flag_n == 1), f1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"transition_matrx[idx] = 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.imshow(transition_matrx)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tmp = tmp[:, cp]\n",
"tmp.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cp = np.where(flag == 0)[0][indexes][0] == flag"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"indexes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tmp[:, 0, indexes] = 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.sum(tmp)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.sum(transition_matrx)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.sum(flag)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"indexes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from util import util\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def create_random_mask(opt):\n",
" gMask_opts = {}\n",
" mask_global = torch.ByteTensor(1, 1, \\\n",
" opt.fineSize, opt.fineSize)\n",
"\n",
" # Here we need to set an artificial mask_global(not to make it broken, so center hole is ok.)\n",
" mask_global.zero_()\n",
" mask_global[:, :, int(opt.fineSize/4) + opt.overlap : int(opt.fineSize/2) + int(opt.fineSize/4) - opt.overlap,\\\n",
" int(opt.fineSize/4) + opt.overlap: int(opt.fineSize/2) + int(opt.fineSize/4) - opt.overlap] = 1 \n",
" \n",
" res = 0.06 # the lower it is, the more continuous the output will be. 0.01 is too small and 0.1 is too large\n",
" density = 0.25\n",
" MAX_SIZE = 300\n",
" maxPartition = 30\n",
" low_pattern = torch.rand(1, 1, int(res*MAX_SIZE), int(res*MAX_SIZE)).mul(255)\n",
" pattern = F.upsample(low_pattern, (MAX_SIZE, MAX_SIZE), mode='bilinear').data\n",
" low_pattern = None\n",
" pattern.div_(255)\n",
" pattern = torch.lt(pattern,density).byte() # 25% 1s and 75% 0s\n",
" pattern = torch.squeeze(pattern).byte()\n",
" gMask_opts['pattern'] = pattern\n",
" gMask_opts['MAX_SIZE'] = MAX_SIZE\n",
" gMask_opts['fineSize'] = opt.fineSize\n",
" gMask_opts['maxPartition'] = maxPartition\n",
" gMask_opts['mask_global'] = mask_global\n",
" mask_global = util.create_gMask(gMask_opts) # create an initial random mask. \n",
" return mask_global"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%time mask_global = create_random_mask(opt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mask_global.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.imshow(np.squeeze(mask_global))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mask_global"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"inner_shift_triple.set_mask(mask_global=mask_global, threshold=opt.threshold, layer_to_last=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%timeit output = inner_shift_triple.forward(x_tr.cuda())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# THE ENTIRE PROCESS IS PRETTY FAST, THE ISSUE WAS COMING FROM THE MASK GENERATOR"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# IMPLEMENT AN ACCELERATE MODULE"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"opt = parser.opt\n",
"opt.shift_sz = 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from models.accelerated_InnerShiftTriple import AcceleratedInnerShiftTriple\n",
"acce_inner_shift_triple = AcceleratedInnerShiftTriple(opt.threshold, opt.fixed_mask)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"acce_inner_shift_triple.cuda()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"acce_inner_shift_triple.set_mask(mask_global=mask_global, threshold=opt.threshold, layer_to_last=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%timeit output = acce_inner_shift_triple(x_tr.cuda())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"acce_inner_shift_triple.__dict__"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('THE SPEED UP IS {} FOLD'.format(582/115))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import argparse
import os
from util import util
import torch
class BaseOptions():
def __init__(self):
self.initialized = False
def initialize(self, parser):
parser.add_argument('--dataroot', default='F:\\santan\\super-resolution\\CTDataset\\DatesetCt\\real\\', help='path to training/testing images')
parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
parser.add_argument('--loadSize', type=int, default=350, help='scale images to this size')
parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD, [basic|densenet]')
parser.add_argument('--which_model_netG', type=str, default='unet_shift_triple', help='selects model to use for netG [unet_256| unet_shift_triple| \
res_unet_shift_triple|patch_soft_unet_shift_triple| \
res_patch_soft_unet_shift_triple| face_unet_shift_triple]')
parser.add_argument('--model', type=str, default='shiftnet', \
help='chooses which model to use. [shiftnet|res_shiftnet|patch_soft_shiftnet|res_patch_soft_shiftnet|test]')
parser.add_argument('--triple_weight', type=float, default=1, help='The weight on the gradient of skip connections from the gradient of shifted')
parser.add_argument('--name', type=str, default='exp', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2, use \'-1 \' for cpu training/testing')
parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [aligned | aligned_resized | single]')
parser.add_argument('--nThreads', default=1, type=int, help='# threads for loading data')
parser.add_argument('--checkpoints_dir', type=str, default='./log', help='models are saved here')
parser.add_argument('--norm', type=str, default='instance', help='[instance|batch|switchable] normalization')
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}')
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width]')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
parser.add_argument('--show_flow', type=int, default=0, help='show the flow information. WARNING: set display_freq a large number as it is quite slow when showing flow')
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
## model specific
parser.add_argument('--mask_type', type=str, default='random',
help='the type of mask you want to apply, \'center\' or \'random\'')
parser.add_argument('--mask_sub_type', type=str, default='island',
help='the type of mask you want to apply, \'rect \' or \'fractal \' or \'island \'')
parser.add_argument('--lambda_A', type=int, default=100, help='weight on L1 term in objective')
parser.add_argument('--stride', type=int, default=1, help='should be dense, 1 is a good option.')
parser.add_argument('--shift_sz', type=int, default=1, help='shift_sz>1 only for \'soft_shift_patch\'.')
parser.add_argument('--mask_thred', type=int, default=1, help='number to decide whether a patch is masked')
parser.add_argument('--overlap', type=int, default=4, help='the overlap for center mask')
parser.add_argument('--bottleneck', type=int, default=512, help='neurals of fc')
parser.add_argument('--gp_lambda', type=float, default=10.0, help='gradient penalty coefficient')
parser.add_argument('--constrain', type=str, default='MSE', help='guidance loss type')
parser.add_argument('--strength', type=float, default=1, help='the weight of guidance loss')
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
parser.add_argument('--skip', type=int, default=1, help='Whether skip guidance loss, if skipped performance degrades with dozens of percents faster')
parser.add_argument('--fuse', type=int, default=0, help='Fuse may encourage large patches shifting when using \'patch_soft_shift\'')
parser.add_argument('--gan_type', type=str, default='vanilla', help='wgan_gp, '
'lsgan, '
'vanilla, '
're_s_gan (Relativistic Standard GAN), ')
parser.add_argument('--gan_weight', type=float, default=0.2, help='the weight of gan loss')
# New added
parser.add_argument('--style_weight', type=float, default=10.0, help='the weight of style loss')
parser.add_argument('--content_weight', type=float, default=1.0, help='the weight of content loss')
parser.add_argument('--tv_weight', type=float, default=0.0, help='the weight of tv loss, you can set a small value, such as 0.1/0.01')
parser.add_argument('--offline_loading_mask', type=int, default=0, help='whether to load mask offline randomly')
parser.add_argument('--mask_weight_G', type=float, default=400.0, help='the weight of mask part in ouput of G, you can try different mask_weight')
parser.add_argument('--discounting', type=int, default=1, help='the loss type of mask part, whether using discounting l1 loss or normal l1')
parser.add_argument('--use_spectral_norm_D', type=int, default=1, help='whether to add spectral norm to D, it helps improve results')
parser.add_argument('--use_spectral_norm_G', type=int, default=0, help='whether to add spectral norm in G. Seems very bad when adding SN to G')
parser.add_argument('--only_lastest', type=int, default=0,
help='If True, it will save only the lastest weights')
parser.add_argument('--add_mask2input', type=int, default=1,
help='If True, It will add the mask as a fourth dimension over input space')
self.initialized = True
return parser
def gather_options(self, options=None):
# initialize parser with basic options
if not self.initialized:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
self.parser = parser
if options == None:
return parser.parse_args()
else:
return parser.parse_args(options)
def print_options(self, opt):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
# save to the disk
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write(message)
opt_file.write('\n')
def parse(self, options=None):
opt = self.gather_options(options=options)
opt.isTrain = self.isTrain # train or test
# process opt.suffix
if opt.suffix:
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
opt.name = opt.name + suffix
self.print_options(opt)
# set gpu ids
os.environ["CUDA_VISIBLE_DEVICES"]=opt.gpu_ids
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
# re-order gpu ids
opt.gpu_ids = [i.item() for i in torch.arange(len(opt.gpu_ids))]
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
self.opt = opt
return self.opt
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser)
parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
parser.add_argument('--which_epoch', type=str, default='9', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--how_many', type=int, default=1000, help='how many test images to run')
parser.add_argument('--testing_mask_folder', type=str, default='F:\\santan\\repaint\\dataset\\test3\\', help='perpared masks for testing')
self.isTrain = False
return parser
from .base_options import BaseOptions
# Here is the options especially for training
class TrainOptions(BaseOptions):
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser)
parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
parser.add_argument('--display_ncols', type=int, default=5, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
parser.add_argument('--print_freq', type=int, default=50, help='frequency of showing training results on console')
parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
parser.add_argument('--save_latest_freq', type=int, default=100, help='frequency of saving the latest results')
parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
parser.add_argument('--continue_train', type=int, default=0, help='continue training: load the latest model')
parser.add_argument('--epoch_count', type=int, default=9, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--niter', type=int, default=30, help='# of iter at starting learning rate')
parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine')
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
parser.add_argument('--training_mask_folder', type=str, default='F:\\santan\\repaint\\1\\train\\', help='prepared masks for training')
self.isTrain = True
return parser
from options.train_options import TrainOptions
from models import create_model
opt = TrainOptions().parse()
model = create_model(opt)
model.save_networks(0)
\ No newline at end of file
# -*- coding: UTF-8 -*-
import torch
import util.util as util
from util.NonparametricShift import Modified_NonparametricShift
from torch.nn import functional as F
import numpy as numpy
import matplotlib.pyplot as plt
bz = 1
c = 2 # at least 2
w = 4
h = 4
feature_size = [bz, c, w, h]
former = torch.rand(c*h*w).mul_(50).reshape(c, h, w).int().float()
latter = torch.rand(c*h*w).mul_(50).reshape(c, h, w).int().float()
flag = torch.zeros(h,w).byte()
flag[h//4:h//2+1, h//4:h//2+1] = 1
flag = flag.view(h*w)
ind_lst = torch.FloatTensor(h*w, h*w).zero_()
shift_offsets = []
Nonparm = Modified_NonparametricShift()
cosine, latter_windows, i_2, i_3, i_1, i_4 = Nonparm.cosine_similarity(former, latter, 1, 1, flag)
## GET INDEXES THAT MAXIMIZE COSINE SIMILARITY
_, indexes = torch.max(cosine, dim=1)
# SET TRANSITION MATRIX
mask_indexes = (flag == 1).nonzero()
non_mask_indexes = (flag == 0).nonzero()[indexes]
ind_lst[mask_indexes, non_mask_indexes] = 1
# GET FINAL SHIFT FEATURE
shift_masked_all = Nonparm._paste(latter_windows, ind_lst, i_2, i_3, i_1, i_4)
print('flag')
print(flag.reshape(h,w))
print('ind_lst')
print(ind_lst)
print('out')
print(shift_masked_all)
# get shift offset ()
shift_offset = torch.stack([non_mask_indexes.squeeze() // w, torch.fmod(non_mask_indexes.squeeze(), w)], dim=-1)
shift_offsets.append(shift_offset)
shift_offsets = torch.cat(shift_offsets, dim=0).float()
print('shift_offset')
print(shift_offset)
print(shift_offset.size()) # (5*5)*2 (masked points)
shift_offsets_cl = shift_offsets.clone()
#visualize which pixels are attended
print(flag.size()) # 256, (16*16)
# global and N*C*H*W
# put shift_offsets_cl back to the global map.
shift_offsets_map = torch.zeros(bz, h, w, 2).float()
print(shift_offsets_map.size()) # 1*16*16
# mask_indexes 是对应的mask区域的点的位置。
# shift_offsets是对应的要shift到mask区域的外部点的位置。
shift_offsets_map[:, mask_indexes.squeeze() // w, mask_indexes.squeeze() % w, :] = shift_offsets_cl.unsqueeze(0)
# 至此,shift_offsets_map是完整的,而且只有mask内部有值,代表着该点将被外面的某点替换。“某点”的坐标就是该点的值(2个通道)
print('global shift_offsets_map')
print(shift_offsets_map)
print(shift_offsets_map.size())
print(shift_offsets_map.type())
flow2 = til.highlight_flow(shift_offsets_map, flag.unsqueeze(0))
print('flow2 size')
print(flow2.size())
# upflow = F.interpolate(flow, scale_factor=4, mode='nearest')
upflow2 = F.interpolate(flow2, scale_factor=1, mode='nearest')
print('**After upsample flow2 size**')
print(upflow2.size())
# upflow = upflow.squeeze().permute(1,2,0)
upflow2 = upflow2.squeeze().permute(1,2,0)
print(upflow2.size())
# print('flow 1')
# print(upflow)
# print(upflow.size())
# print('flow 2')
# print(upflow2)
# print(upflow2.size())
plt.imshow(upflow2/255.)
# # axs[0].imshow(upflow)
# axs[1].imshow(upflow2)
plt.show()
import time
import os
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models import create_model
from util.visualizer import save_images
from util import html
import torch
from options.train_options import TrainOptions
from collections import OrderedDict
def concat(visuals_list, w, h):
opt = TrainOptions().parse()
w = w.numpy()[0]
h = h.numpy()[0]
print("1")
real_A = torch.FloatTensor(1, 4, w, h).zero_()
fake_B = torch.FloatTensor(1, 3, w, h).zero_()
real_B = torch.FloatTensor(1, 3, w, h).zero_()
nw = int(w / opt.fineSize * 2)
nh = int(h / opt.fineSize * 2)
nw0 = int(w % opt.fineSize)
nh0 = int(h % opt.fineSize)
step = int(opt.fineSize / 2)
step_mid = int(step/2)
#for i, visual in enumerate(visuals_list):
for y in range(nh):
for x in range(nw):
visual = visuals_list[y * nw + x]
a = visual['real_A']
b = visual['fake_B']
real_b = visual['real_B']
if x == 0 and y == 0:
real_A[:, :, 0:opt.fineSize,0:opt.fineSize] = a[:, :, :, :]
fake_B[:, :, 0:opt.fineSize, 0:opt.fineSize] = b[:, :, :, :]
continue
if x == 0 and y != nh-1:
real_A[:, :, y * step + step_mid:y * step + opt.fineSize, 0:opt.fineSize] = a[:, :, step_mid:opt.fineSize, :]
fake_B[:, :, y * step + step_mid:y * step + opt.fineSize, 0:opt.fineSize] = b[:, :, step_mid:opt.fineSize, :]
continue
if x == 0 and y == nh - 1:
real_A[:, :, h - opt.fineSize + step_mid:h, 0:opt.fineSize] = a[:, :, step_mid:opt.fineSize, :]
fake_B[:, :, h - opt.fineSize + step_mid:h, 0:opt.fineSize] = b[:, :, step_mid:opt.fineSize, :]
continue
if y ==0 and x != nw-1:
real_A[:, :, 0:opt.fineSize, x * step + step_mid:x * step + opt.fineSize] = a[:, :, :, step_mid:opt.fineSize]
fake_B[:, :, 0:opt.fineSize, x * step + step_mid:x * step + opt.fineSize] = b[:, :, :, step_mid:opt.fineSize]
continue
if y ==0 and x == nw-1:
real_A[:, :, 0:opt.fineSize, w-opt.fineSize + step_mid:w] = a[:, :, :, step_mid:opt.fineSize]
fake_B[:, :, 0:opt.fineSize, w-opt.fineSize + step_mid:w] = b[:, :, :, step_mid:opt.fineSize]
continue
if y == nh-1 and x == nw-1:
real_A[:, :, h-opt.fineSize + step_mid:h, w-opt.fineSize + step_mid:w] = a[:, :, step_mid:opt.fineSize, step_mid:opt.fineSize]
fake_B[:, :, h-opt.fineSize + step_mid:h, w-opt.fineSize + step_mid:w] = b[:, :, step_mid:opt.fineSize, step_mid:opt.fineSize]
continue
if y == nh-1:
real_A[:, :, h-opt.fineSize + step_mid:h, x * step + step_mid:x * step + opt.fineSize] = a[:, :, step_mid:opt.fineSize, step_mid:opt.fineSize]
fake_B[:, :, h-opt.fineSize + step_mid:h, x * step + step_mid:x * step + opt.fineSize] = b[:, :, step_mid:opt.fineSize, step_mid:opt.fineSize]
continue
if x == nw-1:
real_A[:, :, y * step + step_mid:y * step + opt.fineSize, w-opt.fineSize + step_mid:w] = a[:, :, step_mid:opt.fineSize, step_mid:opt.fineSize]
fake_B[:, :, y * step + step_mid:y * step + opt.fineSize, w-opt.fineSize + step_mid:w] = b[:, :, step_mid:opt.fineSize, step_mid:opt.fineSize]
continue
real_A[:, :, y * step + step_mid:y * step + opt.fineSize, x * step + step_mid:x * step + opt.fineSize] = a[:, :,step_mid:opt.fineSize, step_mid:opt.fineSize]
fake_B[:, :, y * step + step_mid:y * step + opt.fineSize, x * step + step_mid:x * step + opt.fineSize] = b[:, :,step_mid:opt.fineSize, step_mid:opt.fineSize]
visual_ret = OrderedDict()
visual_ret['real_A'] = real_A
visual_ret['real_B'] = real_B
visual_ret['fake_B'] = fake_B
return visual_ret
'''
nw = int(w / self.opt.fineSize * 2)
nh = int(h / self.opt.fineSize * 2)
nw0 = int(w % self.opt.fineSize)
nw0 = int(h % self.opt.fineSize)
step = int(self.opt.fineSize / 2)
A_temp = torch.FloatTensor(nw * nh, 3, self.opt.fineSize, self.opt.fineSize).zero_()
for iw in range(nw):
for ih in range(nh):
if iw == nw - 1 and ih == nh - 1:
A_temp[iw * nh + ih, :, :, :] = A[:, w - self.opt.fineSize:w, h - self.opt.fineSize:h]
continue
if iw == nw - 1 and ih != nh - 1:
A_temp[iw * nh + ih, :, :, :] = A[:, w - self.opt.fineSize:w, ih * step:ih * step + self.opt.fineSize]
continue
if iw != nw - 1 and ih == nh - 1:
A_temp[iw * nh + ih, :, :, :] = A[:, iw * step:iw * step + self.opt.fineSize, h - self.opt.fineSize:h]
continue
A_temp[iw * nh + ih, :, :, :] = A[:, iw * step:iw * step + self.opt.fineSize,
ih * step:ih * step + self.opt.fineSize]
A = A_temp
'''
import torchvision.transforms as transforms
if __name__ == "__main__":
opt = TestOptions().parse()
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
opt.display_id = -1 # no visdom display
#opt.loadSize = opt.fineSize # Do not scale!
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
'''
for i, data in enumerate(dataset):
if i >= opt.how_many:
break
t1 = time.time()
model.set_input(data)
model.test()
t2 = time.time()
print(t2-t1)
visuals = model.get_current_visuals()
img_path = model.get_image_paths()
print('process image... %s' % img_path)
save_images(webpage, visuals, img_path, 0, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
webpage.save()
'''
for i, data in enumerate(dataset):
if i >= opt.how_many:
break
A_path = data['A_paths']
A = data['A'][0]
B = data['B'][0]
mask = data['M'][0]
A_flip = data['A_F'][0]
B_flip = data['B_F'][0]
w = data['im_size'][0]
h = data['im_size'][1]
visuals_list = []
for j in range(A.shape[0]):
data_temp = {'A': A[j].unsqueeze(0), 'B': B[j].unsqueeze(0), 'A_F': A_flip[j].unsqueeze(0), 'B_F': B_flip[j].unsqueeze(0), 'M': mask[j].unsqueeze(0),
'A_paths': A_path}
t1 = time.time()
model.set_input(data_temp)
model.test()
t2 = time.time()
print(t2-t1)
visuals = model.get_current_visuals()
visuals_list.append(visuals)
#img_path = model.get_image_paths()
#print('process image... %s' % img_path)
#save_images(webpage, visuals, img_path, j, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
img_path = model.get_image_paths()
print('process image... %s' % img_path)
visual_ret = concat(visuals_list, w, h)
save_images(webpage, visual_ret, img_path, 0, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
webpage.save()
import numpy as np
import torch
import torch.nn as nn
from time import time
def unfold(img, patch_size, stride, with_indexes=torch.tensor(False, dtype=torch.bool)):
n_dim = 4
assert img.dim() == n_dim, 'image must be of dimension 4.'
kH, kW = patch_size, patch_size
dH, dW = stride, stride
input_windows = img.unfold(2, kH, dH).unfold(3, kW, dW)
i_0, i_1, i_2, i_3, i_4, i_5 = input_windows.size()
if with_indexes:
input_windows = input_windows.permute(0, 2, 3, 1, 4, 5).contiguous().view(i_0, i_2 * i_3, i_1)
return input_windows, i_2, i_3, i_1
else:
input_windows = input_windows.permute(0, 2, 3, 1, 4, 5).contiguous().view(i_0, i_2 * i_3, i_1, i_4, i_5)
return input_windows, 0, 0, 0
def filter(input_windows, flag, value):
## EXTRACT MASK OR NOT DEPENDING ON VALUE
assert flag.dim() == 2, "flag should be batch version"
input_window = input_windows[flag == value]
bz = flag.size(0)
return input_window.view(bz, input_window.size(0) // bz, -1)
def cosine_similarity(former, latter, patch_size, stride, flag, with_former=torch.tensor(False, dtype=torch.bool)):
former_windows, _, _, _ = unfold(former, patch_size, stride)
former = filter(former_windows, flag, torch.tensor(1))
latter_windows, i_2, i_3, i_1 = unfold(latter, patch_size, stride,
with_indexes=torch.tensor(True, dtype=torch.bool))
latter = filter(latter_windows, flag, torch.tensor(0))
num = torch.einsum('bik,bjk->bij', [former, latter])
norm_latter = torch.einsum("bij,bij->bi", [latter, latter])
norm_former = torch.einsum("bij,bij->bi", [former, former])
den = torch.sqrt(torch.einsum('bi,bj->bij', [norm_former, norm_latter]))
if not with_former:
return num / den, latter_windows, torch.tensor(0), i_2, i_3, i_1
else:
return num / den, latter_windows, former_windows, i_2, i_3, i_1
# delete i_4, as i_4 is 1
def paste(input_windows, transition_matrix, i_2, i_3, i_1):
## TRANSPOSE FEATURES NEW FEATURES
bz = input_windows.size(0)
input_windows = torch.bmm(transition_matrix, input_windows)
## RESIZE TO CORRET CONV FEATURES FORMAT
input_windows = input_windows.view(bz, i_2, i_3, i_1)
input_windows = input_windows.permute(0, 3, 1, 2)
return input_windows
def InnerShiftTripleFunction(input, shift_sz, stride, triple_w, flag, show_flow):
# InnerShiftTripleFunction.ctx = ctx
assert input.dim() == 4, "Input Dim has to be 4"
# ctx.triple_w = triple_w
# ctx.flag = flag
# ctx.show_flow = show_flow
bz = input.size(0)
c_real = input.size(1)
h = input.size(2)
w = input.size(3)
c = c_real
# ind_lst = torch.Tensor(bz, h * w, h * w).zero_().to(input)
ind_lst = torch.zeros(bz, h * w, h * w).to(input)
# former and latter are all tensors
former_all = input.narrow(1, 0, c // 2) ### decoder feature
latter_all = input.narrow(1, c // 2, c // 2) ### encoder feature
# shift_masked_all = torch.tensor(former_all.size()).type_as(former_all).zero_() # addition feature
flag = flag.to(input).long()
# None batch version
# bNonparm = Batch_NonShift()
shift_offsets = []
# batch version
cosine, latter_windows, form, i_2, i_3, i_1 = cosine_similarity(former_all.clone(), latter_all.clone(),
torch.tensor(1),
stride, flag)
_, indexes = torch.max(cosine, dim=2)
mask_indexes = (flag == 1).nonzero()[:, 1].view(bz, -1)
non_mask_indexes = (flag == 0).nonzero()[:, 1].view(bz, -1).gather(1, indexes)
idx_b = torch.arange(bz).long().unsqueeze(1).expand(bz, mask_indexes.size(1))
# set the elemnets of indexed by [mask_indexes, non_mask_indexes] to 1.
# It is a batch version
ind_lst[(idx_b, mask_indexes, non_mask_indexes)] = torch.ones(
ind_lst[(idx_b, mask_indexes, non_mask_indexes)].shape).to(input)
shift_masked_all = paste(latter_windows, ind_lst, torch.tensor(i_2), torch.tensor(i_3), torch.tensor(i_1))
return torch.cat((former_all, latter_all, shift_masked_all), 1)
import torch
import util.util as util
from util.NonparametricShift import Modified_NonparametricShift, Batch_NonShift
from torch.nn import functional as F
import numpy as numpy
import matplotlib.pyplot as plt
bz = 2
c = 3 # at least 2
w = 16
h = 16
feature_size = [bz, c, w, h]
former = torch.rand(bz*c*h*w).mul_(50).reshape(bz, c, h, w).int().float()
latter = torch.rand(bz*c*h*w).mul_(50).reshape(bz, c, h, w).int().float()
flag = torch.zeros(bz, h, w).byte()
flag[:, h//4:h//2+1, h//4:h//2+1] = 1
flag = flag.view(bz, h*w)
ind_lst = torch.FloatTensor(bz, h*w, h*w).zero_()
shift_offsets = []
#Nonparm = Modified_NonparametricShift()
bNonparm = Batch_NonShift()
cosine, latter_windows, i_2, i_3, i_1 = bNonparm.cosine_similarity(former.clone(), latter.clone(), 1, 1, flag)
print(cosine.size())
print(latter_windows.size())
## GET INDEXES THAT MAXIMIZE COSINE SIMILARITY
_, indexes = torch.max(cosine, dim=2)
print('indexes dim')
print(indexes.size())
# SET TRANSITION MATRIX
mask_indexes = (flag == 1).nonzero()
mask_indexes = mask_indexes[:,1] # remove indexes that indicates the batch dim
mask_indexes = mask_indexes.view(bz, -1)
# Also remove indexes of batch
tmp = (flag==0).nonzero()[:,1]
tmp = tmp.view(bz, -1)
print('tmp size')
print(tmp.size())
idx_tmp = indexes + torch.arange(indexes.size(0)).view(-1,1) * tmp.size(1)
non_mask_indexes = tmp.view(-1)[idx_tmp]
# Original method
non_mask_indexes_2 = []
for i in range(bz):
non_mask_indexes_tmp = tmp[i][indexes[i]]
non_mask_indexes_2.append(non_mask_indexes_tmp)
non_mask_indexes_2 = torch.stack(non_mask_indexes_2, dim=0)
print('These two methods should be the same, as the error is 0!')
print(torch.sum(non_mask_indexes-non_mask_indexes_2))
ind_lst2 = ind_lst.clone()
for i in range(bz):
ind_lst[i][mask_indexes[i], non_mask_indexes[i]] = 1
print(ind_lst.sum())
print(ind_lst)
for i in range(bz):
for mi, nmi in zip(mask_indexes[i], non_mask_indexes[i]):
print('The %d\t-th pixel in the %d-th tensor will shift to %d\t-th coordinate' %(nmi, i, mi))
print('~~~')
# GET FINAL SHIFT FEATURE
shift_masked_all = bNonparm._paste(latter_windows, ind_lst, i_2, i_3, i_1)
print(shift_masked_all.size())
assert 1==2
# print('flag')
# print(flag.reshape(h,w))
# print('ind_lst')
# print(ind_lst)
# print('out')
# print(shift_masked_all)
# get shift offset ()
shift_offset = torch.stack([non_mask_indexes.squeeze() // w, torch.fmod(non_mask_indexes.squeeze(), w)], dim=-1)
print('shift_offset')
print(shift_offset)
print(shift_offset.size())
shift_offsets.append(shift_offset)
shift_offsets = torch.cat(shift_offsets, dim=0).float()
print(shift_offsets.size())
print(shift_offsets)
shift_offsets_cl = shift_offsets.clone()
lt = (flag==1).nonzero()[0]
rb = (flag==1).nonzero()[-1]
mask_h = rb//w+1 - lt//w
mask_w = rb%w+1 - lt%w
shift_offsets = shift_offsets.view([bz] + [2] + [mask_h, mask_w]) # So only appropriate for square mask.
print(shift_offsets.size())
print(shift_offsets)
h_add = torch.arange(0, float(h)).view([1, 1, h, 1]).float()
h_add = h_add.expand(bz, 1, h, w)
w_add = torch.arange(0, float(w)).view([1, 1, 1, w]).float()
w_add = w_add.expand(bz, 1, h, w)
com_map = torch.cat([h_add, w_add], dim=1)
print('com_map')
print(com_map)
com_map_crop = com_map[:, :, lt//w:rb//w+1, lt%w:rb%w+1]
print('com_map crop')
print(com_map_crop)
shift_offsets = shift_offsets - com_map_crop
print('final shift_offsets')
print(shift_offsets)
# to flow image
flow = torch.from_numpy(util.flow_to_image(shift_offsets.permute(0,2,3,1).cpu().data.numpy()))
flow = flow.permute(0,3,1,2)
#visualize which pixels are attended
print(flag.size())
print(shift_offsets.size())
# global and N*C*H*W
# put shift_offsets_cl back to the global map.
shift_offsets_map = flag.clone().view(-1)
shift_offsets_map[indexes] = shift_offsets_cl.view(-1)
print(shift_offsets_map)
assert 1==2
flow2 = torch.from_numpy(util.highlight_flow((shift_offsets_cl).numpy()))
upflow = F.interpolate(flow, scale_factor=4, mode='nearest')
upflow2 = F.interpolate(flow2, scale_factor=4, mode='nearest')
upflow = upflow.squeeze().permute(1,2,0)
upflow2 = upflow2.squeeze().permute(1,2,0)
print('flow 1')
print(upflow)
print(upflow.size())
print('flow 2')
print(upflow2)
print(upflow2.size())
fig, axs = plot.subplots(ncols=2)
axs[0].imshow(upflow)
axs[1].imshow(upflow2)
plt.show()
#python -m visdom.server
import time
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models import create_model
import torch
import torchvision
if __name__ == "__main__":
############################
'''
model = torchvision.models.resnet101(pretrained=True)
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
output = traced_script_module(torch.ones(1, 3, 224, 224))
print(type(output), output[0, :10], output.shape)
traced_script_module.save("traced_resnet_model1.5.0.pt")
'''
##############################
opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)
model = create_model(opt)
total_steps = 0
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
for i, data in enumerate(dataset):
A_path = data['A_paths']
A = data['A'][0]
B = data['B'][0]
mask = data['M'][0]
A_flip = data['A_F'][0]
B_flip = data['B_F'][0]
for j in range(A.shape[0]):
data_temp = {'A': A[j].unsqueeze(0), 'B': B[j].unsqueeze(0), 'A_F': A_flip[j].unsqueeze(0),
'B_F': B_flip[j].unsqueeze(0), 'M': mask[j].unsqueeze(0),
'A_paths': A_path}
iter_start_time = time.time()
if total_steps % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
total_steps += opt.batchSize
epoch_iter += opt.batchSize
model.set_input(data_temp) # it not only sets the input data with mask, but also sets the latent mask.
# Additonal, should set it before 'optimize_parameters()'.
if total_steps % opt.display_freq == 0:
if opt.show_flow:
model.set_show_map_true()
model.optimize_parameters()
if total_steps % opt.display_freq == 0:
save_result = total_steps % opt.update_html_freq == 0
if opt.show_flow:
model.set_flow_src()
model.set_show_map_false()
if total_steps % opt.print_freq == 0:
losses = model.get_current_losses()
t = (time.time() - iter_start_time) / opt.batchSize
if total_steps % opt.save_latest_freq == 0:
print('saving the latest model (epoch %d, total_steps %d)' %
(epoch, total_steps))
model.save_networks('latest')
iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' %
(epoch, total_steps))
model.save_networks('latest')
if not opt.only_lastest:
model.save_networks(epoch)
print('End of epoch %d / %d \t Time Taken: %d sec' %
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
model.update_learning_rate()
import random
import math
import numpy as np
import torch
import torch.nn as nn
from time import time
# These three functions only work when patch_size = 1x1
class Modified_NonparametricShift(object):
def _extract_patches_from_flag(self, img, patch_size, stride, flag, value):
input_windows = self._unfold(img, patch_size, stride)
input_windows = self._filter(input_windows, flag, value)
return self._norm(input_windows)
# former: content, to be replaced.
# latter: style, source pixels.
def cosine_similarity(self, former, latter, patch_size, stride, flag, with_former=False):
former_windows = self._unfold(former, patch_size, stride)
former = self._filter(former_windows, flag, 1)
latter_windows, i_2, i_3, i_1 = self._unfold(latter, patch_size, stride, with_indexes=True)
latter = self._filter(latter_windows, flag, 0)
num = torch.einsum('ik,jk->ij', [former, latter])
norm_latter = torch.einsum("ij,ij->i", [latter, latter])
norm_former = torch.einsum("ij,ij->i", [former, former])
den = torch.sqrt(torch.einsum('i,j->ij', [norm_former, norm_latter]))
if not with_former:
return num / den, latter_windows, i_2, i_3, i_1
else:
return num / den, latter_windows, former_windows, i_2, i_3, i_1
def _paste(self, input_windows, transition_matrix, i_2, i_3, i_1):
## TRANSPOSE FEATURES NEW FEATURES
input_windows = torch.mm(transition_matrix, input_windows)
## RESIZE TO CORRET CONV FEATURES FORMAT
input_windows = input_windows.view(i_2, i_3, i_1)
input_windows = input_windows.permute(2, 0, 1).unsqueeze(0)
return input_windows
def _unfold(self, img, patch_size, stride, with_indexes=False):
n_dim = 3
assert img.dim() == n_dim, 'image must be of dimension 3.'
kH, kW = patch_size, patch_size
dH, dW = stride, stride
input_windows = img.unfold(1, kH, dH).unfold(2, kW, dW)
i_1, i_2, i_3, i_4, i_5 = input_windows.size()
if with_indexes:
input_windows = input_windows.permute(1, 2, 0, 3, 4).contiguous().view(i_2 * i_3, i_1)
return input_windows, i_2, i_3, i_1
else:
input_windows = input_windows.permute(1, 2, 0, 3, 4).contiguous().view(i_2 * i_3, i_1, i_4, i_5)
return input_windows
def _filter(self, input_windows, flag, value):
## EXTRACT MASK OR NOT DEPENDING ON VALUE
input_window = input_windows[flag == value]
return input_window.view(input_window.size(0), -1)
def _norm(self, input_window):
# This norm is incorrect.
#return torch.norm(input_window, dim=1, keepdim=True)
for i in range(input_window.size(0)):
input_window[i] = input_window[i]*(1/(input_window[i].norm(2)+1e-8))
return input_window
class Batch_NonShift(object):
def _extract_patches_from_flag(self, img, patch_size, stride, flag, value):
input_windows = self._unfold(img, patch_size, stride)
input_windows = self._filter(input_windows, flag, value)
return self._norm(input_windows)
# former: content, to be replaced.
# latter: style, source pixels.
def cosine_similarity(self, former, latter, patch_size, stride, flag, with_former=False):
former_windows = self._unfold(former, patch_size, stride)
former = self._filter(former_windows, flag, 1)
latter_windows, i_2, i_3, i_1 = self._unfold(latter, patch_size, stride, with_indexes=True)
latter = self._filter(latter_windows, flag, 0)
num = torch.einsum('bik,bjk->bij', [former, latter])
norm_latter = torch.einsum("bij,bij->bi", [latter, latter])
norm_former = torch.einsum("bij,bij->bi", [former, former])
den = torch.sqrt(torch.einsum('bi,bj->bij', [norm_former, norm_latter]))
if not with_former:
return num / den, latter_windows, i_2, i_3, i_1
else:
return num / den, latter_windows, former_windows, i_2, i_3, i_1
# delete i_4, as i_4 is 1
def _paste(self, input_windows, transition_matrix, i_2, i_3, i_1):
## TRANSPOSE FEATURES NEW FEATURES
bz = input_windows.size(0)
input_windows = torch.bmm(transition_matrix, input_windows)
## RESIZE TO CORRET CONV FEATURES FORMAT
input_windows = input_windows.view(bz, i_2, i_3, i_1)
input_windows = input_windows.permute(0, 3, 1, 2)
return input_windows
def _unfold(self, img, patch_size, stride, with_indexes=False):
n_dim = 4
assert img.dim() == n_dim, 'image must be of dimension 4.'
kH, kW = patch_size, patch_size
dH, dW = stride, stride
input_windows = img.unfold(2, kH, dH).unfold(3, kW, dW)
i_0, i_1, i_2, i_3, i_4, i_5 = input_windows.size()
if with_indexes:
input_windows = input_windows.permute(0, 2, 3, 1, 4, 5).contiguous().view(i_0, i_2 * i_3, i_1)
return input_windows, i_2, i_3, i_1
else:
input_windows = input_windows.permute(0, 2, 3, 1, 4, 5).contiguous().view(i_0, i_2 * i_3, i_1, i_4, i_5)
return input_windows
def _filter(self, input_windows, flag, value):
## EXTRACT MASK OR NOT DEPENDING ON VALUE
assert flag.dim() == 2, "flag should be batch version"
input_window = input_windows[flag == value]
bz = flag.size(0)
return input_window.view(bz, input_window.size(0)//bz, -1)
# Deprecated code
class NonparametricShift(object):
def buildAutoencoder(self, target_img, normalize, interpolate, nonmask_point_idx, patch_size=1, stride=1):
nDim = 3
assert target_img.dim() == nDim, 'target image must be of dimension 3.'
C = target_img.size(0)
self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available else torch.Tensor
patches_all, patches_part = self._extract_patches(target_img, patch_size, stride, nonmask_point_idx)
npatches_part = patches_part.size(0)
npatches_all = patches_all.size(0)
conv_enc_non_mask, conv_dec_non_mask = self._build(patch_size, stride, C, patches_part, npatches_part, normalize, interpolate)
conv_enc_all, conv_dec_all = self._build(patch_size, stride, C, patches_all, npatches_all, normalize, interpolate)
return conv_enc_all, conv_enc_non_mask, conv_dec_all, conv_dec_non_mask
def _build(self, patch_size, stride, C, target_patches, npatches, normalize, interpolate):
# for each patch, divide by its L2 norm.
enc_patches = target_patches.clone()
for i in range(npatches):
enc_patches[i] = enc_patches[i]*(1/(enc_patches[i].norm(2)+1e-8))
conv_enc = nn.Conv2d(C, npatches, kernel_size=patch_size, stride=stride, bias=False)
conv_enc.weight.data = enc_patches
# normalize is not needed, it doesn't change the result!
if normalize:
raise NotImplementedError
if interpolate:
raise NotImplementedError
conv_dec = nn.ConvTranspose2d(npatches, C, kernel_size=patch_size, stride=stride, bias=False)
conv_dec.weight.data = target_patches
return conv_enc, conv_dec
def _extract_patches(self, img, patch_size, stride, nonmask_point_idx):
n_dim = 3
assert img.dim() == n_dim, 'image must be of dimension 3.'
kH, kW = patch_size, patch_size
dH, dW = stride, stride
input_windows = img.unfold(1, kH, dH).unfold(2, kW, dW)
i_1, i_2, i_3, i_4, i_5 = input_windows.size(0), input_windows.size(1), input_windows.size(2), input_windows.size(3), input_windows.size(4)
input_windows = input_windows.permute(1,2,0,3,4).contiguous().view(i_2*i_3, i_1, i_4, i_5)
patches_all = input_windows
patches = input_windows.index_select(0, nonmask_point_idx) #It returns a new tensor, representing patches extracted from non-masked region!
return patches_all, patches
import dominate
from dominate.tags import *
import os
class HTML:
def __init__(self, web_dir, title, refresh=0):
self.title = title
self.web_dir = web_dir
self.img_dir = os.path.join(self.web_dir, 'images')
if not os.path.exists(self.web_dir):
os.makedirs(self.web_dir)
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
# print(self.img_dir)
self.doc = dominate.document(title=title)
if refresh > 0:
with self.doc.head:
meta(http_equiv="refresh", content=str(refresh))
def get_image_dir(self):
return self.img_dir
def add_header(self, str):
with self.doc:
h3(str)
def add_table(self, border=1):
self.t = table(border=border, style="table-layout: fixed;")
self.doc.add(self.t)
def add_images(self, ims, txts, links, width=400):
self.add_table()
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('images', link)):
img(style="width:%dpx" % width, src=os.path.join('images', im))
br()
p(txt)
def save(self):
html_file = '%s/index.html' % self.web_dir
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__':
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims = []
txts = []
links = []
for n in range(4):
ims.append('image_%d.png' % n)
txts.append('text_%d' % n)
links.append('image_%d.png' % n)
html.add_images(ims, txts, links)
html.save()
import struct
import zlib
def encode(buf, width, height):
""" buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """
assert (width * height * 3 == len(buf))
bpp = 3
def raw_data():
# reverse the vertical line order and add null bytes at the start
row_bytes = width * bpp
for row_start in range((height - 1) * width * bpp, -1, -row_bytes):
yield b'\x00'
yield buf[row_start:row_start + row_bytes]
def chunk(tag, data):
return [
struct.pack("!I", len(data)),
tag,
data,
struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag)))
]
SIGNATURE = b'\x89PNG\r\n\x1a\n'
COLOR_TYPE_RGB = 2
COLOR_TYPE_RGBA = 6
bit_depth = 8
return b''.join(
[ SIGNATURE ] +
chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) +
chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) +
chunk(b'IEND', b'')
)
import numpy as np
import scipy.sparse
import cv2
import pyamg
# pre-process the mask array so that uint64 types from opencv.imread can be adapted
def prepare_mask(mask):
if type(mask[0][0]) is np.ndarray:
result = np.ndarray((mask.shape[0], mask.shape[1]), dtype=np.uint8)
for i in range(mask.shape[0]):
for j in range(mask.shape[1]):
if sum(mask[i][j]) > 0:
result[i][j] = 1
else:
result[i][j] = 0
mask = result
return mask
def blend(img_target, img_source, img_mask, offset=(0, 0)):
# compute regions to be blended
region_source = (
max(-offset[0], 0),
max(-offset[1], 0),
min(img_target.shape[0]-offset[0], img_source.shape[0]),
min(img_target.shape[1]-offset[1], img_source.shape[1]))
region_target = (
max(offset[0], 0),
max(offset[1], 0),
min(img_target.shape[0], img_source.shape[0]+offset[0]),
min(img_target.shape[1], img_source.shape[1]+offset[1]))
region_size = (region_source[2]-region_source[0], region_source[3]-region_source[1])
# clip and normalize mask image
img_mask = img_mask[region_source[0]:region_source[2], region_source[1]:region_source[3]]
img_mask = prepare_mask(img_mask)
img_mask[img_mask==0] = False
img_mask[img_mask!=False] = True
# create coefficient matrix
A = scipy.sparse.identity(np.prod(region_size), format='lil')
for y in range(region_size[0]):
for x in range(region_size[1]):
if img_mask[y,x]:
index = x+y*region_size[1]
A[index, index] = 4
if index+1 < np.prod(region_size):
A[index, index+1] = -1
if index-1 >= 0:
A[index, index-1] = -1
if index+region_size[1] < np.prod(region_size):
A[index, index+region_size[1]] = -1
if index-region_size[1] >= 0:
A[index, index-region_size[1]] = -1
A = A.tocsr()
# create poisson matrix for b
P = pyamg.gallery.poisson(img_mask.shape)
# for each layer (ex. RGB)
for num_layer in range(img_target.shape[2]):
# get subimages
t = img_target[region_target[0]:region_target[2], region_target[1]:region_target[3],num_layer]
s = img_source[region_source[0]:region_source[2], region_source[1]:region_source[3],num_layer]
t = t.flatten()
s = s.flatten()
# create b
b = P * s
for y in range(region_size[0]):
for x in range(region_size[1]):
if not img_mask[y,x]:
index = x+y*region_size[1]
b[index] = t[index]
# solve Ax = b
x = pyamg.solve(A,b,verb=False,tol=1e-10)
# assign x to target image
x = np.reshape(x, region_size)
x[x>255] = 255
x[x<0] = 0
x = np.array(x, img_target.dtype)
img_target[region_target[0]:region_target[2],region_target[1]:region_target[3],num_layer] = x
return img_target
\ No newline at end of file
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import random
import inspect, re
import numpy as np
import os
import collections
import math
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from skimage.transform import resize
def create_masks(opt, N=10):
masks = []
masks_resized = []
for _ in range(N):
mask = wrapper_gmask (opt).cpu().numpy()
masks.append(mask)
mask_resized = resize(np.squeeze(mask), (64, 64))
masks_resized.append(mask_resized)
return np.array(masks_resized), np.array(masks)
''''''
class OptimizerMask:
'''
This class is designed to speed up inference time to cover the over all image with the minimun number of generated mask during training.
It is used in the notebook to create masks covering the entire image.
'''
def __init__(self, masks, stop_criteria=0.85):
self.masks = masks
self.indexes = []
self.stop_criteria = stop_criteria
def get_iou(self):
intersection = np.matmul(self.masks, self.masks.T)
diag = np.diag(intersection)
outer_add = np.add.outer(diag, diag)
self.iou = intersection / outer_add
self.shape = self.iou.shape
def _is_finished(self):
masks = self.masks[self.indexes]
masks = np.sum(masks, axis=0)
masks[masks > 0] = 1
area_coverage = np.sum(masks) / np.product(masks.shape)
print(area_coverage)
if area_coverage < self.stop_criteria:
return False
else:
return True
def mean(self):
_mean = np.mean(np.sum(self.masks[self.indexes], axis=-1)) / (64 * 64)
print(_mean)
def _get_next_indexes(self):
ious = self.iou[self.indexes]
_mean_iou = np.mean(ious, axis=0)
idx = np.argmin(_mean_iou)
self.indexes = np.append(self.indexes, np.argmin(_mean_iou))
def _solve(self):
self.indexes = list(np.unravel_index(np.argmin(self.iou), self.shape))
# print(self.indexes)
while not self._is_finished():
self._get_next_indexes()
def get_masks(self):
masks = self.masks[self.indexes]
full = np.ones_like(masks[0])
left = full - (np.mean(masks, axis=0) > 0)
return left.reshape((64, 64))
def solve(self):
self._solve()
# Converts a Tensor into an image array (numpy)
# |imtype|: the desired type of the converted numpy array
def tensor2im(input_image, imtype=np.uint8):
if isinstance(input_image, torch.Tensor):
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].cpu().float().numpy()
if image_numpy.shape[0] == 1:
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
return image_numpy.astype(imtype)
# Remove dummy dim from a tensor.
# Useful when input is 4 dims.
def rm_extra_dim(image):
if image.dim() == 3:
return image[:3, :, :]
elif image.dim() == 4:
return image[:, :3, :, :]
else:
raise NotImplementedError
def diagnose_network(net, name='network'):
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def wrapper_gmask(opt):
# batchsize should be 1 for mask_global
mask_global = torch.ByteTensor(1, 1, \
opt.fineSize, opt.fineSize)
res = 0.06 # the lower it is, the more continuous the output will be. 0.01 is too small and 0.1 is too large
density = 0.15
MAX_SIZE = 350
maxPartition = 30
low_pattern = torch.rand(1, 1, int(res * MAX_SIZE), int(res * MAX_SIZE)).mul(255)
pattern = F.interpolate(low_pattern, (MAX_SIZE, MAX_SIZE), mode='bilinear').detach()
low_pattern = None
pattern.div_(255)
pattern = torch.lt(pattern, density).byte() # 25% 1s and 75% 0s
pattern = torch.squeeze(pattern).byte()
'''
import matplotlib.pyplot as plt
img = pattern
img = img.cpu().numpy()
plt.imshow(img)
plt.show()
'''
gMask_opts = {}
gMask_opts['pattern'] = pattern
gMask_opts['MAX_SIZE'] = MAX_SIZE
gMask_opts['fineSize'] = opt.fineSize
gMask_opts['maxPartition'] = maxPartition
gMask_opts['mask_global'] = mask_global
return create_gMask(gMask_opts) # create an initial random mask.
from torchvision import transforms
import matplotlib.pyplot as plt
def create_gMask(gMask_opts, limit_cnt=1):
pattern = gMask_opts['pattern']
mask_global = gMask_opts['mask_global']
MAX_SIZE = gMask_opts['MAX_SIZE']
fineSize = gMask_opts['fineSize']
maxPartition=gMask_opts['maxPartition']
if pattern is None:
raise ValueError
wastedIter = 0
while wastedIter <= limit_cnt:
x = random.randint(1, MAX_SIZE-fineSize)
y = random.randint(1, MAX_SIZE-fineSize)
mask = pattern[y:y+fineSize, x:x+fineSize]
'''
import matplotlib.pyplot as plt
img = mask
img = img.cpu().numpy()
plt.imshow(img)
plt.show()
'''
area = mask.sum()*100./(fineSize*fineSize)
if area>20 and area<maxPartition:
break
wastedIter += 1
'''
import matplotlib.pyplot as plt
img = mask
img = img.cpu().numpy()
plt.imshow(img)
plt.show()
'''
if mask_global.dim() == 3:
mask_global = mask.expand(1, mask.size(0), mask.size(1))
else:
mask_global = mask.expand(1, 1, mask.size(0), mask.size(1))
return mask_global
# Create a square mask with random position.
def create_rand_mask(opt):
h, w = opt.fineSize, opt.fineSize
mask = np.zeros((h, w))
maxt = h - opt.overlap - h // 2
maxl = w - opt.overlap - w // 2
rand_t = np.random.randint(opt.overlap, maxt)
rand_l = np.random.randint(opt.overlap, maxl)
mask[rand_t:rand_t+opt.fineSize//2-2*opt.overlap, rand_l:rand_l+opt.fineSize//2-2*opt.overlap] = 1
return torch.ByteTensor(mask), rand_t, rand_l
action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
def random_walk(canvas, ini_x, ini_y, length):
x = ini_x
y = ini_y
img_size = canvas.shape[-1]
x_list = []
y_list = []
for i in range(length):
r = random.choice(range(len(action_list)))
x = np.clip(x + action_list[r][0], a_min=0, a_max=img_size - 1)
y = np.clip(y + action_list[r][1], a_min=0, a_max=img_size - 1)
x_list.append(x)
y_list.append(y)
canvas[np.array(x_list), np.array(y_list)] = 0
return canvas
def create_mask():
canvas = np.ones((256, 256)).astype("i")
ini_x = random.randint(0, 255)
ini_y = random.randint(0, 255)
print(ini_x, ini_y)
return random_walk(canvas, ini_x, ini_y, 128 ** 2)
# inMask is tensor should be bz*1*256*256 float
# Return: ByteTensor
def cal_feat_mask(inMask, nlayers):
assert inMask.dim() == 4, "mask must be 4 dimensions"
inMask = inMask.float()
ntimes = 2**nlayers
inMask = F.interpolate(inMask, (inMask.size(2)//ntimes, inMask.size(3)//ntimes), mode='nearest')
inMask = inMask.detach().byte()
return inMask
# It is only for patch_size=1 for now, although it also works correctly for patchsize > 1.
# For patch_size > 1, we adopt another implementation in `patch_soft_shift/innerPatchSoftShiftTripleModule.py` to get masked region.
# return: flag indicating where the mask is using 1s.
# flag size: bz*(h*w)
def cal_flag_given_mask_thred(mask, patch_size, stride, mask_thred):
assert mask.dim() == 4, "mask must be 4 dimensions"
assert mask.size(1) == 1, "the size of the dim=1 must be 1"
mask = mask.float()
b = mask.size(0)
p_tensor = int(patch_size // 2)
mask = F.pad(mask, [p_tensor, p_tensor, p_tensor, p_tensor], 'constant', 0.0)
m = mask.unfold(2, patch_size, stride).unfold(3, patch_size, stride)
m = m.contiguous().view(b, 1, -1, patch_size, patch_size)
m = torch.mean(torch.mean(m, dim=3, keepdim=True), dim=4, keepdim=True)
mm = m.ge(mask_thred/(1.*patch_size**2)).long()
flag = mm.view(b, -1)
# Obsolete Method
# It is Only for mask: H*W
# dim = img.dim()
# _, H, W = img.size(dim - 3), img.size(dim - 2), img.size(dim - 1)
# nH = int(math.floor((H - patch_size) / stride + 1))
# nW = int(math.floor((W - patch_size) / stride + 1))
# N = nH * nW
# flag = torch.zeros(N).long()
# for i in range(N):
# h = int(math.floor(i / nW))
# w = int(math.floor(i % nW))
# mask_tmp = mask[h * stride:h * stride + patch_size,
# w * stride:w * stride + patch_size]
# if torch.sum(mask_tmp) < mask_thred:
# pass
# else:
# flag[i] = 1
return flag
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def info(object, spacing=10, collapse=1):
"""Print methods and doc strings.
Takes module, class, list, dictionary, or string."""
methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
print( "\n".join(["%s %s" %
(method.ljust(spacing),
processFunc(str(getattr(object, method).__doc__)))
for method in methodList]) )
def varname(p):
for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
if m:
return m.group(1)
def print_numpy(x, val=True, shp=False):
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def hist_match(source, template):
"""
Adjust the pixel values of a grayscale image such that its histogram
matches that of a target image
Arguments:
-----------
source: np.ndarray
Image to transform; the histogram is computed over the flattened
array
template: np.ndarray
Template image; can have different dimensions to source
Returns:
-----------
matched: np.ndarray
The transformed output image
"""
oldshape = source.shape
source = source.ravel()
template = template.ravel()
# get the set of unique pixel values and their corresponding indices and
# counts
s_values, bin_idx, s_counts = np.unique(source, return_inverse=True,
return_counts=True)
t_values, t_counts = np.unique(template, return_counts=True)
# take the cumsum of the counts and normalize by the number of pixels to
# get the empirical cumulative distribution functions for the source and
# template images (maps pixel value --> quantile)
s_quantiles = np.cumsum(s_counts).astype(np.float64)
s_quantiles /= s_quantiles[-1]
t_quantiles = np.cumsum(t_counts).astype(np.float64)
t_quantiles /= t_quantiles[-1]
# interpolate linearly to find the pixel values in the template image
# that correspond most closely to the quantiles in the source image
interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)
return interp_t_values[bin_idx].reshape(oldshape)
'''
https://github.com/WonwoongCho/Generative-Inpainting-pytorch/blob/master/util.py#L229-L333
'''
def flow_to_image(flow):
"""Transfer flow map to image.
Part of code forked from flownet.
"""
out = []
maxu = -999.
maxv = -999.
minu = 999.
minv = 999.
maxrad = -1
for i in range(flow.shape[0]):
u = flow[i, :, :, 0]
v = flow[i, :, :, 1]
idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
u[idxunknow] = 0
v[idxunknow] = 0
maxu = max(maxu, np.max(u))
minu = min(minu, np.min(u))
maxv = max(maxv, np.max(v))
minv = min(minv, np.min(v))
rad = np.sqrt(u ** 2 + v ** 2)
maxrad = max(maxrad, np.max(rad))
u = u/(maxrad + np.finfo(float).eps)
v = v/(maxrad + np.finfo(float).eps)
img = compute_color(u, v)
out.append(img)
return np.float32(np.uint8(out))
"""
flow: N*h*w*2
Indicating which pixel will shift to the location.
mask: N*(h*w)
"""
def highlight_flow(flow, mask):
"""Convert flow into middlebury color code image.
"""
assert flow.dim() == 4 and mask.dim() == 2
assert flow.size(0) == mask.size(0)
assert flow.size(3) == 2
bz, h, w, _ = flow.shape
out = torch.zeros(bz, 3, h, w).type_as(flow)
for idx in range(bz):
mask_index = (mask[idx] == 1).nonzero()
img = torch.ones(3, h, w).type_as(flow) * 144.
u = flow[idx, :, :, 0]
v = flow[idx, :, :, 1]
# It is quite slow here.
for h_i in range(h):
for w_j in range(w):
p = h_i*w + w_j
#If it is a masked pixel, we get which pixel that will replace it.
# DO NOT USE `if p in mask_index:`, it is slow.
if torch.sum(mask_index == p).item() != 0:
ui = u[h_i,w_j]
vi = v[h_i,w_j]
img[:, int(ui), int(vi)] = 255.
img[:, h_i, w_j] = 200. # Also indicating where the mask is.
out[idx] = img
return out
def compute_color(u,v):
h, w = u.shape
img = np.zeros([h, w, 3])
nanIdx = np.isnan(u) | np.isnan(v)
u[nanIdx] = 0
v[nanIdx] = 0
colorwheel = make_color_wheel()
ncols = np.size(colorwheel, 0)
rad = np.sqrt(u**2+v**2)
a = np.arctan2(-v, -u) / np.pi
fk = (a+1) / 2 * (ncols - 1) + 1
k0 = np.floor(fk).astype(int)
k1 = k0 + 1
k1[k1 == ncols+1] = 1
f = fk - k0
for i in range(np.size(colorwheel,1)):
tmp = colorwheel[:, i]
col0 = tmp[k0-1] / 255
col1 = tmp[k1-1] / 255
col = (1-f) * col0 + f * col1
idx = rad <= 1
col[idx] = 1-rad[idx]*(1-col[idx])
notidx = np.logical_not(idx)
col[notidx] *= 0.75
img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
return img
def make_color_wheel():
RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros([ncols, 3])
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
col += RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
colorwheel[col:col+YG, 1] = 255
col += YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
col += GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
colorwheel[col:col+CB, 2] = 255
col += CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
col += + BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
colorwheel[col:col+MR, 0] = 255
return colorwheel
################# Style loss #########################
######################################################
class VGG16FeatureExtractor(nn.Module):
def __init__(self):
super(VGG16FeatureExtractor, self).__init__()
vgg16 = models.vgg16(pretrained=True)
self.enc_1 = nn.Sequential(*vgg16.features[:5])
self.enc_2 = nn.Sequential(*vgg16.features[5:10])
self.enc_3 = nn.Sequential(*vgg16.features[10:17])
# print(self.enc_1)
# print(self.enc_2)
# print(self.enc_3)
# fix the encoder
for i in range(3):
for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
param.requires_grad = False
def forward(self, image):
results = [image]
for i in range(3):
func = getattr(self, 'enc_{:d}'.format(i + 1))
results.append(func(results[-1]))
return results[1:]
def total_variation_loss(image):
# shift one pixel and get difference (for both x and y direction)
loss = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \
torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))
return loss
def gram_matrix(feat):
(batch, ch, h, w) = feat.size()
feat = feat.view(batch, ch, h*w)
feat_t = feat.transpose(1, 2)
gram = torch.bmm(feat, feat_t) / (ch * h * w)
return gram
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
Arguments:
netD (network) -- discriminator network
real_data (tensor array) -- real images
fake_data (tensor array) -- generated images from the generator
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
type (str) -- if we mix real and fake data or not [real | fake | mixed].
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
lambda_gp (float) -- weight for this loss
Returns the gradient penalty loss
"""
if lambda_gp > 0.0:
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
interpolatesv = real_data
elif type == 'fake':
interpolatesv = fake_data
elif type == 'mixed':
alpha = torch.rand(real_data.shape[0], 1)
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
alpha = alpha.to(device)
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
else:
raise NotImplementedError('{} not implemented'.format(type))
interpolatesv.requires_grad_(True)
disc_interpolates = netD(interpolatesv)
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
create_graph=True, retain_graph=True, only_inputs=True)
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
return gradient_penalty, gradients
else:
return 0.0, None
import numpy as np
import os
import ntpath
import time
import sys
from subprocess import Popen, PIPE
from . import util, html
from scipy.misc import imresize
if sys.version_info[0] == 2:
VisdomExceptionBase = Exception
else:
VisdomExceptionBase = ConnectionError
def save_images(webpage, visuals, image_path,j, aspect_ratio=1.0, width=256):
"""Save images to the disk.
Parameters:
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
image_path (str) -- the string is used to create image paths
aspect_ratio (float) -- the aspect ratio of saved images
width (int) -- the images will be resized to width x width
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
"""
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
name = str(j) + "-" + name
webpage.add_header(name)
ims, txts, links = [], [], []
for label, im_data in visuals.items():
im = util.tensor2im(im_data)
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
h, w, _ = im.shape
if aspect_ratio > 1.0:
im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
if aspect_ratio < 1.0:
im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
util.save_image(im, save_path)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=width)
class Visualizer():
def __init__(self, opt):
self.display_id = opt.display_id
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
self.port = opt.display_port
self.opt = opt
self.saved = False
if self.display_id > 0:
import visdom
self.ncols = opt.display_ncols
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
if not self.vis.check_connection():
self.create_visdom_connections()
if self.use_html:
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
def reset(self):
self.saved = False
def create_visdom_connections(self):
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
print('Command: %s' % cmd)
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch, save_result):
if self.display_id > 0: # show images in the browser
ncols = self.ncols
if ncols > 0:
ncols = min(ncols, len(visuals))
h, w = next(iter(visuals.values())).shape[:2]
table_css = """<style>
table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
</style>""" % (w, h)
title = self.name
label_html = ''
label_html_row = ''
images = []
idx = 0
for label, image in visuals.items():
image = util.rm_extra_dim(image) # remove the dummy dim
image_numpy = util.tensor2im(image)
label_html_row += '<td>%s</td>' % label
images.append(image_numpy.transpose([2, 0, 1]))
idx += 1
if idx % ncols == 0:
label_html += '<tr>%s</tr>' % label_html_row
label_html_row = ''
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
while idx % ncols != 0:
images.append(white_image)
label_html_row += '<td></td>'
idx += 1
if label_html_row != '':
label_html += '<tr>%s</tr>' % label_html_row
try:
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
padding=2, opts=dict(title=title + ' images'))
label_html = '<table>%s</table>' % label_html
self.vis.text(table_css + label_html, win=self.display_id + 2,
opts=dict(title=title + ' labels'))
except VisdomExceptionBase:
self.create_visdom_connections()
else:
idx = 1
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
win=self.display_id + idx)
idx += 1
if self.use_html and (save_result or not self.saved): # save images to a html file
self.saved = True
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
# update website
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
for n in range(epoch, 0, -1):
webpage.add_header('epoch [%d]' % n)
ims, txts, links = [], [], []
for label, image_numpy in visuals.items():
image_numpy = util.tensor2im(image)
img_path = 'epoch%.3d_%s.png' % (n, label)
ims.append(img_path)
txts.append(label)
links.append(img_path)
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
# losses: dictionary of error labels and values
def plot_current_losses(self, epoch, counter_ratio, opt, losses):
if not hasattr(self, 'plot_data'):
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'epoch',
'ylabel': 'loss'},
win=self.display_id)
# losses: same format as |losses| of plot_current_losses
def print_current_losses(self, epoch, i, losses, t, t_data):
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
print(message)
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message)
<!DOCTYPE html>
<html>
<head>
<title>test_html</title>
</head>
<body>
<h3>hello world</h3>
<table border="1" style="table-layout: fixed;">
<tr>
<td halign="center" style="word-wrap: break-word;" valign="top">
<p>
<a href="images\image_0.png">
<img src="images\image_0.png" style="width:400px">
</a><br>
<p>text_0</p>
</p>
</td>
<td halign="center" style="word-wrap: break-word;" valign="top">
<p>
<a href="images\image_1.png">
<img src="images\image_1.png" style="width:400px">
</a><br>
<p>text_1</p>
</p>
</td>
<td halign="center" style="word-wrap: break-word;" valign="top">
<p>
<a href="images\image_2.png">
<img src="images\image_2.png" style="width:400px">
</a><br>
<p>text_2</p>
</p>
</td>
<td halign="center" style="word-wrap: break-word;" valign="top">
<p>
<a href="images\image_3.png">
<img src="images\image_3.png" style="width:400px">
</a><br>
<p>text_3</p>
</p>
</td>
</tr>
</table>
</body>
</html>
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment